/*
 * Decompiled with CFR 0.152.
 */
package org.wso2.ballerinalang.compiler.semantics.analyzer;

import io.ballerina.tools.diagnostics.Location;
import io.ballerina.types.SemType;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.ballerinalang.model.elements.Flag;
import org.ballerinalang.model.symbols.SymbolKind;
import org.ballerinalang.model.symbols.SymbolOrigin;
import org.ballerinalang.model.tree.NodeKind;
import org.ballerinalang.model.tree.OperatorKind;
import org.wso2.ballerinalang.compiler.semantics.analyzer.SemTypeHelper;
import org.wso2.ballerinalang.compiler.semantics.analyzer.SymbolEnter;
import org.wso2.ballerinalang.compiler.semantics.analyzer.TypeChecker;
import org.wso2.ballerinalang.compiler.semantics.analyzer.Types;
import org.wso2.ballerinalang.compiler.semantics.model.SymbolEnv;
import org.wso2.ballerinalang.compiler.semantics.model.SymbolTable;
import org.wso2.ballerinalang.compiler.semantics.model.symbols.BSymbol;
import org.wso2.ballerinalang.compiler.semantics.model.symbols.BTypeSymbol;
import org.wso2.ballerinalang.compiler.semantics.model.symbols.BVarSymbol;
import org.wso2.ballerinalang.compiler.semantics.model.symbols.Symbols;
import org.wso2.ballerinalang.compiler.semantics.model.types.BFiniteType;
import org.wso2.ballerinalang.compiler.semantics.model.types.BType;
import org.wso2.ballerinalang.compiler.semantics.model.types.BUnionType;
import org.wso2.ballerinalang.compiler.tree.BLangBlockFunctionBody;
import org.wso2.ballerinalang.compiler.tree.BLangNode;
import org.wso2.ballerinalang.compiler.tree.BLangNodeVisitor;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangBinaryExpr;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangExpression;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangGroupExpr;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangLiteral;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangSimpleVarRef;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangTypeTestExpr;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangUnaryExpr;
import org.wso2.ballerinalang.compiler.tree.statements.BLangBlockStmt;
import org.wso2.ballerinalang.compiler.util.CompilerContext;
import org.wso2.ballerinalang.compiler.util.Name;
import org.wso2.ballerinalang.compiler.util.Names;
import org.wso2.ballerinalang.util.Flags;

public class TypeNarrower
extends BLangNodeVisitor {
    private SymbolEnv env;
    private final SymbolTable symTable;
    private final Types types;
    private final SymbolEnter symbolEnter;
    private final TypeChecker typeChecker;
    private static final CompilerContext.Key<TypeNarrower> TYPE_NARROWER_KEY = new CompilerContext.Key();

    private TypeNarrower(CompilerContext context) {
        context.put(TYPE_NARROWER_KEY, this);
        this.symTable = SymbolTable.getInstance(context);
        this.typeChecker = TypeChecker.getInstance(context);
        this.types = Types.getInstance(context);
        this.symbolEnter = SymbolEnter.getInstance(context);
    }

    public static TypeNarrower getInstance(CompilerContext context) {
        TypeNarrower typeNarrower = context.get(TYPE_NARROWER_KEY);
        if (typeNarrower == null) {
            typeNarrower = new TypeNarrower(context);
        }
        return typeNarrower;
    }

    public SymbolEnv evaluateTruth(BLangExpression expr, BLangNode targetNode, SymbolEnv env, boolean isBinaryExpr) {
        Map<BVarSymbol, BType.NarrowedTypes> narrowedTypes = this.getNarrowedTypes(expr, env);
        if (narrowedTypes.isEmpty()) {
            return env;
        }
        SymbolEnv targetEnv = this.getTargetEnv(targetNode, env);
        Set<Map.Entry<BVarSymbol, BType.NarrowedTypes>> entrySet = narrowedTypes.entrySet();
        for (Map.Entry<BVarSymbol, BType.NarrowedTypes> entry : entrySet) {
            BVarSymbol symbol = entry.getKey();
            BType.NarrowedTypes typeInfo = entry.getValue();
            BType narrowedType = isBinaryExpr && typeInfo.trueType == this.symTable.semanticError ? typeInfo.falseType : typeInfo.trueType;
            BVarSymbol originalSym = this.getOriginalVarSymbol(symbol);
            this.symbolEnter.defineTypeNarrowedSymbol(expr.pos, targetEnv, originalSym, narrowedType, originalSym.origin == SymbolOrigin.VIRTUAL);
        }
        return targetEnv;
    }

    public SymbolEnv evaluateTruth(BLangExpression expr, BType typeToRemove, BLangNode targetNode, SymbolEnv env) {
        if (expr.getKind() != NodeKind.SIMPLE_VARIABLE_REF || typeToRemove == null) {
            return env;
        }
        BLangSimpleVarRef varRef = (BLangSimpleVarRef)expr;
        Name varName = new Name(varRef.variableName.value);
        BType originalType = env.scope.entries.containsKey(varName) ? env.scope.entries.get((Object)varName).symbol.type : varRef.getBType();
        if (originalType == this.symTable.semanticError) {
            return env;
        }
        BType remainingType = this.types.getRemainingMatchExprType(originalType, typeToRemove, env);
        if (remainingType == this.symTable.nullSet || remainingType == this.symTable.semanticError) {
            return env;
        }
        SymbolEnv targetEnv = this.getTargetEnv(targetNode, env);
        BVarSymbol originalVarSym = this.getOriginalVarSymbol((BVarSymbol)varRef.symbol);
        this.symbolEnter.defineTypeNarrowedSymbol(varRef.pos, targetEnv, originalVarSym, remainingType, originalVarSym.origin == SymbolOrigin.VIRTUAL);
        return targetEnv;
    }

    public SymbolEnv evaluateTruth(BLangExpression expr, BLangNode targetNode, SymbolEnv env) {
        return this.evaluateTruth(expr, targetNode, env, false);
    }

    public SymbolEnv evaluateFalsity(BLangExpression expr, BLangNode targetNode, SymbolEnv env, boolean isLogicalOrContext) {
        Map<BVarSymbol, BType.NarrowedTypes> narrowedTypes = this.getNarrowedTypes(expr, env);
        if (narrowedTypes.isEmpty()) {
            return env;
        }
        SymbolEnv targetEnv = this.getTargetEnv(targetNode, env);
        for (Map.Entry<BVarSymbol, BType.NarrowedTypes> narrowedType : narrowedTypes.entrySet()) {
            BType falseType = narrowedType.getValue().falseType;
            BType trueType = narrowedType.getValue().trueType;
            BVarSymbol originalSym = this.getOriginalVarSymbol(narrowedType.getKey());
            falseType = isLogicalOrContext ? (falseType == this.symTable.semanticError ? this.types.getRemainingType(originalSym.type, trueType, env) : falseType) : (falseType == this.symTable.nullSet ? this.symTable.neverType : falseType);
            this.symbolEnter.defineTypeNarrowedSymbol(expr.pos, targetEnv, originalSym, falseType == this.symTable.semanticError ? this.symTable.neverType : falseType, originalSym.origin == SymbolOrigin.VIRTUAL);
        }
        return targetEnv;
    }

    @Override
    public void visit(BLangUnaryExpr unaryExpr) {
        if (unaryExpr.operator != OperatorKind.NOT) {
            return;
        }
        Map<BVarSymbol, BType.NarrowedTypes> narrowedTypes = this.getNarrowedTypes(unaryExpr.expr, this.env);
        HashMap<BVarSymbol, BType.NarrowedTypes> newMap = new HashMap<BVarSymbol, BType.NarrowedTypes>(narrowedTypes.size());
        for (Map.Entry<BVarSymbol, BType.NarrowedTypes> entry : narrowedTypes.entrySet()) {
            newMap.put(entry.getKey(), new BType.NarrowedTypes(entry.getValue().falseType, entry.getValue().trueType));
        }
        unaryExpr.narrowedTypeInfo = newMap;
    }

    @Override
    public void visit(BLangBinaryExpr binaryExpr) {
        BLangExpression lhsExpr = binaryExpr.lhsExpr;
        BLangExpression rhsExpr = binaryExpr.rhsExpr;
        OperatorKind opKind = binaryExpr.opKind;
        if (opKind == OperatorKind.EQUAL || opKind == OperatorKind.NOT_EQUAL) {
            this.narrowTypeForEqualOrNotEqual(binaryExpr, lhsExpr, rhsExpr);
            this.narrowTypeForEqualOrNotEqual(binaryExpr, rhsExpr, lhsExpr);
            return;
        }
        Map<BVarSymbol, BType.NarrowedTypes> t1 = this.getNarrowedTypes(lhsExpr, this.env);
        Map<BVarSymbol, BType.NarrowedTypes> t2 = this.getNarrowedTypes(rhsExpr, this.env);
        LinkedHashSet<BVarSymbol> updatedSymbols = new LinkedHashSet<BVarSymbol>(t1.keySet());
        updatedSymbols.addAll(t2.keySet());
        if (opKind == OperatorKind.AND || opKind == OperatorKind.OR) {
            for (BVarSymbol symbol : updatedSymbols) {
                binaryExpr.narrowedTypeInfo.put(this.getOriginalVarSymbol(symbol), this.getNarrowedTypesForBinaryOp(t1, t2, this.getOriginalVarSymbol(symbol), binaryExpr.opKind));
            }
        }
    }

    @Override
    public void visit(BLangGroupExpr groupExpr) {
        this.analyzeExpr(groupExpr.expression, this.env);
        groupExpr.narrowedTypeInfo.putAll(groupExpr.expression.narrowedTypeInfo);
    }

    @Override
    public void visit(BLangTypeTestExpr typeTestExpr) {
        this.analyzeExpr(typeTestExpr.expr, this.env);
        BLangExpression lhsExpression = typeTestExpr.expr;
        if (lhsExpression.getKind() != NodeKind.SIMPLE_VARIABLE_REF) {
            return;
        }
        BSymbol symbol = ((BLangSimpleVarRef)lhsExpression).symbol;
        if (symbol == this.symTable.notFoundSymbol) {
            return;
        }
        TypeChecker.AnalyzerData data = new TypeChecker.AnalyzerData();
        data.env = this.env;
        this.typeChecker.markAndRegisterClosureVariable(symbol, lhsExpression.pos, this.env, data);
        if (symbol.closure || (symbol.owner.tag & 0x1001L) == 4097L) {
            return;
        }
        BVarSymbol varSymbol = (BVarSymbol)symbol;
        this.setNarrowedTypeInfo(typeTestExpr, varSymbol, typeTestExpr.typeNode.getBType(), typeTestExpr.pos);
    }

    private Map<BVarSymbol, BType.NarrowedTypes> getNarrowedTypes(BLangExpression expr, SymbolEnv env) {
        this.analyzeExpr(expr, env);
        return expr.narrowedTypeInfo;
    }

    private void analyzeExpr(BLangExpression expr, SymbolEnv env) {
        switch (expr.getKind()) {
            case BINARY_EXPR: 
            case TYPE_TEST_EXPR: 
            case GROUP_EXPR: 
            case UNARY_EXPR: {
                break;
            }
            default: {
                if (expr.narrowedTypeInfo == null) {
                    expr.narrowedTypeInfo = new HashMap<BVarSymbol, BType.NarrowedTypes>();
                }
                return;
            }
        }
        SymbolEnv prevEnv = this.env;
        this.env = env;
        if (expr.narrowedTypeInfo == null) {
            expr.narrowedTypeInfo = new HashMap<BVarSymbol, BType.NarrowedTypes>();
            expr.accept(this);
        }
        this.env = prevEnv;
    }

    private BType.NarrowedTypes getNarrowedTypesForBinaryOp(Map<BVarSymbol, BType.NarrowedTypes> lhsTypes, Map<BVarSymbol, BType.NarrowedTypes> rhsTypes, BVarSymbol symbol, OperatorKind operator) {
        BType falseType;
        BType trueType;
        BType rhsFalseType;
        BType rhsTrueType;
        BType lhsFalseType;
        BType lhsTrueType;
        BType.NarrowedTypes narrowedTypes;
        if (lhsTypes.containsKey(symbol)) {
            narrowedTypes = lhsTypes.get(symbol);
            lhsTrueType = narrowedTypes.trueType;
            lhsFalseType = narrowedTypes.falseType;
        } else {
            lhsTrueType = lhsFalseType = this.getValidTypeInScope(symbol);
        }
        if (rhsTypes.containsKey(symbol)) {
            narrowedTypes = rhsTypes.get(symbol);
            rhsTrueType = narrowedTypes.trueType;
            rhsFalseType = narrowedTypes.falseType;
            if (rhsTrueType.tag == 28 && operator == OperatorKind.AND) {
                rhsTrueType = rhsFalseType;
                rhsFalseType = this.types.getRemainingType(symbol.type, rhsTrueType, this.env);
            }
        } else {
            rhsTrueType = rhsFalseType = this.getValidTypeInScope(symbol);
        }
        Types.IntersectionContext nonLoggingContext = Types.IntersectionContext.typeTestIntersectionCalculationContext();
        if (operator == OperatorKind.AND) {
            trueType = this.types.getTypeIntersection(nonLoggingContext, lhsTrueType, rhsTrueType, this.env);
            BType tmpType1 = this.types.getTypeIntersection(nonLoggingContext, lhsTrueType, rhsFalseType, this.env);
            BType tmpType2 = this.types.getTypeIntersection(nonLoggingContext, lhsFalseType, rhsTrueType, this.env);
            if (tmpType1.tag == 28) {
                tmpType1 = tmpType2;
            }
            falseType = this.getTypeUnion(lhsFalseType, tmpType1);
        } else {
            BType tmpType = this.types.getTypeIntersection(nonLoggingContext, lhsFalseType, rhsTrueType, this.env);
            trueType = lhsTypes.containsKey(symbol) ? this.getTypeUnion(lhsTrueType, tmpType) : this.getTypeUnion(tmpType, lhsTrueType);
            falseType = this.types.getTypeIntersection(nonLoggingContext, lhsFalseType, rhsFalseType, this.env);
        }
        return new BType.NarrowedTypes(trueType, falseType);
    }

    private BType getValidTypeInScope(BVarSymbol symbol) {
        if (this.env.scope.entries.containsKey(symbol.name)) {
            BVarSymbol symbolInScope = (BVarSymbol)this.env.scope.entries.get((Object)symbol.name).symbol;
            BType typeInScope = symbolInScope.type;
            if (!this.types.isAssignable(symbol.type, typeInScope)) {
                return Types.getImpliedType(typeInScope);
            }
        }
        return Types.getImpliedType(symbol.type);
    }

    private BType getTypeUnion(BType currentType, BType targetType) {
        LinkedHashSet<BType> union = new LinkedHashSet<BType>(this.types.getAllTypes(currentType, true));
        List<BType> targetComponentTypes = this.types.getAllTypes(targetType, true);
        block0: for (BType newType : targetComponentTypes) {
            if (newType.tag == 51) continue;
            for (BType existingType : union) {
                if (this.types.isAssignable(newType, existingType)) continue;
                union.add(newType);
                continue block0;
            }
        }
        if (union.contains(this.symTable.semanticError)) {
            return this.symTable.semanticError;
        }
        if (union.size() == 1) {
            return union.toArray(new BType[1])[0];
        }
        return BUnionType.create(this.symTable.typeEnv(), null, union);
    }

    BVarSymbol getOriginalVarSymbol(BVarSymbol varSymbol) {
        if (varSymbol.originalSymbol == null) {
            return varSymbol;
        }
        return this.getOriginalVarSymbol(varSymbol.originalSymbol);
    }

    private SymbolEnv getTargetEnv(BLangNode targetNode, SymbolEnv env) {
        SymbolEnv targetEnv = SymbolEnv.createTypeNarrowedEnv(targetNode, env);
        if (targetNode.getKind() == NodeKind.BLOCK) {
            ((BLangBlockStmt)targetNode).scope = targetEnv.scope;
        }
        if (targetNode.getKind() == NodeKind.BLOCK_FUNCTION_BODY) {
            ((BLangBlockFunctionBody)targetNode).scope = targetEnv.scope;
        }
        return targetEnv;
    }

    private BFiniteType createFiniteType(BLangExpression expr) {
        SemType semType;
        BTypeSymbol finiteTypeSymbol = Symbols.createTypeSymbol(557084L, Flags.asMask(EnumSet.noneOf(Flag.class)), Names.EMPTY, this.env.enclPkg.symbol.pkgID, null, this.env.scope.owner, expr.pos, SymbolOrigin.SOURCE);
        if (expr.getKind() == NodeKind.UNARY_EXPR) {
            semType = SemTypeHelper.resolveSingletonType(Types.constructNumericLiteralFromUnaryExpr((BLangUnaryExpr)expr));
        } else {
            expr.setBType(this.symTable.getTypeFromTag(expr.getBType().tag));
            semType = SemTypeHelper.resolveSingletonType((BLangLiteral)expr);
        }
        BFiniteType finiteType = BFiniteType.newSingletonBFiniteType(finiteTypeSymbol, semType);
        finiteTypeSymbol.type = finiteType;
        return finiteType;
    }

    private void narrowTypeForEqualOrNotEqual(BLangBinaryExpr binaryExpr, BLangExpression lhsExpr, BLangExpression rhsExpr) {
        BSymbol rhsVarSymbol;
        if (lhsExpr.getKind() != NodeKind.SIMPLE_VARIABLE_REF) {
            return;
        }
        BSymbol lhsVarSymbol = ((BLangSimpleVarRef)lhsExpr).symbol;
        if ((lhsVarSymbol.tag & 0x34L) != 52L) {
            return;
        }
        TypeChecker.AnalyzerData data = new TypeChecker.AnalyzerData();
        data.env = this.env;
        this.typeChecker.markAndRegisterClosureVariable(lhsVarSymbol, lhsExpr.pos, this.env, data);
        if (lhsVarSymbol.closure || (lhsVarSymbol.owner.tag & 0x1001L) == 4097L) {
            return;
        }
        NodeKind rhsExprKind = rhsExpr.getKind();
        if (rhsExprKind == NodeKind.LITERAL || rhsExprKind == NodeKind.NUMERIC_LITERAL || this.types.isExpressionAnAllowedUnaryType(rhsExpr, rhsExprKind)) {
            this.setNarrowedTypeInfo(binaryExpr, (BVarSymbol)lhsVarSymbol, this.createFiniteType(rhsExpr), binaryExpr.pos);
        } else if (rhsExprKind == NodeKind.SIMPLE_VARIABLE_REF && (rhsVarSymbol = ((BLangSimpleVarRef)rhsExpr).symbol) != this.symTable.notFoundSymbol && rhsVarSymbol.kind == SymbolKind.CONSTANT) {
            this.setNarrowedTypeInfo(binaryExpr, (BVarSymbol)lhsVarSymbol, rhsVarSymbol.type, binaryExpr.pos);
        }
    }

    private void setNarrowedTypeInfo(BLangExpression expr, BVarSymbol varSymbol, BType narrowWithType, Location intersectionPos) {
        BType falseType;
        BType trueType;
        Types.IntersectionContext nonLoggingContext = Types.IntersectionContext.typeTestIntersectionCalculationContext(intersectionPos);
        if (expr.getKind() == NodeKind.BINARY_EXPR && ((BLangBinaryExpr)expr).opKind == OperatorKind.NOT_EQUAL) {
            trueType = this.types.getRemainingType(varSymbol.type, narrowWithType, this.env);
            falseType = this.types.getTypeIntersection(nonLoggingContext, varSymbol.type, narrowWithType, this.env);
        } else if (expr.getKind() == NodeKind.TYPE_TEST_EXPR) {
            if (((BLangTypeTestExpr)expr).isNegation) {
                trueType = this.types.getRemainingType(varSymbol.type, narrowWithType, this.env);
                falseType = this.types.getTypeIntersection(nonLoggingContext, varSymbol.type, narrowWithType, this.env);
            } else {
                trueType = this.types.getTypeIntersection(nonLoggingContext, varSymbol.type, narrowWithType, this.env);
                falseType = this.types.getRemainingType(varSymbol.type, narrowWithType, this.env);
            }
            if (falseType == trueType) {
                falseType = this.symTable.nullSet;
            }
        } else {
            trueType = this.types.getTypeIntersection(nonLoggingContext, varSymbol.type, narrowWithType, this.env);
            falseType = this.types.getRemainingType(varSymbol.type, narrowWithType, this.env);
        }
        expr.narrowedTypeInfo.put(this.getOriginalVarSymbol(varSymbol), new BType.NarrowedTypes(trueType, falseType));
    }
}

