/*
 * Decompiled with CFR 0.152.
 */
package org.wso2.ballerinalang.compiler.semantics.analyzer.cyclefind;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.ballerinalang.model.symbols.SymbolKind;
import org.ballerinalang.model.tree.Node;
import org.ballerinalang.model.tree.NodeKind;
import org.ballerinalang.model.tree.TopLevelNode;
import org.ballerinalang.util.diagnostic.DiagnosticErrorCode;
import org.wso2.ballerinalang.compiler.diagnostic.BLangDiagnosticLog;
import org.wso2.ballerinalang.compiler.semantics.analyzer.Types;
import org.wso2.ballerinalang.compiler.semantics.model.symbols.BInvokableSymbol;
import org.wso2.ballerinalang.compiler.semantics.model.symbols.BSymbol;
import org.wso2.ballerinalang.compiler.semantics.model.symbols.BVarSymbol;
import org.wso2.ballerinalang.compiler.semantics.model.types.BType;
import org.wso2.ballerinalang.compiler.tree.BLangClassDefinition;
import org.wso2.ballerinalang.compiler.tree.BLangFunction;
import org.wso2.ballerinalang.compiler.tree.BLangIdentifier;
import org.wso2.ballerinalang.compiler.tree.BLangPackage;
import org.wso2.ballerinalang.compiler.tree.BLangSimpleVariable;
import org.wso2.ballerinalang.compiler.tree.BLangTypeDefinition;
import org.wso2.ballerinalang.compiler.tree.BLangVariable;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangConstant;
import org.wso2.ballerinalang.compiler.tree.types.BLangStructureTypeNode;
import org.wso2.ballerinalang.compiler.tree.types.BLangType;
import org.wso2.ballerinalang.compiler.util.CompilerContext;

public class GlobalVariableRefAnalyzer {
    private static final CompilerContext.Key<GlobalVariableRefAnalyzer> REF_ANALYZER_KEY = new CompilerContext.Key();
    private final BLangDiagnosticLog dlog;
    private BLangPackage pkgNode;
    private Map<BSymbol, Set<BSymbol>> globalNodeDependsOn;
    Map<BSymbol, BSymbol> symbolOwner;
    private Map<BSymbol, Set<BVarSymbol>> globalVariablesDependsOn;
    private final Map<BSymbol, NodeInfo> dependencyNodes;
    private final Deque<NodeInfo> nodeInfoStack;
    private final List<List<NodeInfo>> cycles;
    private final List<NodeInfo> dependencyOrder;
    private int curNodeId;
    private boolean cyclicErrorFound;

    public static GlobalVariableRefAnalyzer getInstance(CompilerContext context) {
        GlobalVariableRefAnalyzer refAnalyzer = context.get(REF_ANALYZER_KEY);
        if (refAnalyzer == null) {
            refAnalyzer = new GlobalVariableRefAnalyzer(context);
        }
        return refAnalyzer;
    }

    private GlobalVariableRefAnalyzer(CompilerContext context) {
        context.put(REF_ANALYZER_KEY, this);
        this.dlog = BLangDiagnosticLog.getInstance(context);
        this.dependencyNodes = new HashMap<BSymbol, NodeInfo>();
        this.cycles = new ArrayList<List<NodeInfo>>();
        this.nodeInfoStack = new ArrayDeque<NodeInfo>();
        this.dependencyOrder = new ArrayList<NodeInfo>();
        this.globalVariablesDependsOn = new HashMap<BSymbol, Set<BVarSymbol>>();
    }

    private void resetAnalyzer() {
        this.dependencyNodes.clear();
        this.cycles.clear();
        this.nodeInfoStack.clear();
        this.dependencyOrder.clear();
        this.curNodeId = 0;
        this.globalVariablesDependsOn = new HashMap<BSymbol, Set<BVarSymbol>>();
        this.cyclicErrorFound = false;
    }

    public void populateFunctionDependencies(Map<BSymbol, Set<BSymbol>> globalNodeDependsOn, List<BLangVariable> globalVars) {
        this.resetAnalyzer();
        this.globalNodeDependsOn = globalNodeDependsOn;
        Set<BSymbol> dependentSet = this.globalNodeDependsOn.keySet();
        for (BSymbol dependent : dependentSet) {
            if (dependent.kind != SymbolKind.FUNCTION) continue;
            this.analyzeDependenciesRecursively(dependent, (Set<BSymbol>)globalVars.stream().map(v -> v.symbol).collect(Collectors.toCollection(HashSet::new)));
        }
    }

    public Map<BSymbol, Set<BVarSymbol>> getGlobalVariablesDependsOn() {
        return this.globalVariablesDependsOn;
    }

    private void analyzeDependenciesRecursively(BSymbol dependent, Set<BSymbol> globalVars) {
        if (!this.dependencyNodes.containsKey(dependent)) {
            NodeInfo node = new NodeInfo(this.curNodeId++, dependent);
            this.dependencyNodes.put(dependent, node);
            this.analyzeDependenciesRecursively(node, globalVars);
        }
    }

    private Set<BVarSymbol> analyzeDependenciesRecursively(NodeInfo node, Set<BSymbol> globalVars) {
        if (node.onStack) {
            return this.getGlobalVarFromCurrentNode(node, globalVars);
        }
        if (node.visited) {
            return this.getDependentsFromSymbol(node.symbol, globalVars);
        }
        node.visited = true;
        node.onStack = true;
        Set providers = this.globalNodeDependsOn.getOrDefault(node.symbol, new LinkedHashSet());
        if (providers.isEmpty()) {
            return new HashSet<BVarSymbol>(0);
        }
        HashSet<BVarSymbol> currentDependencies = new HashSet<BVarSymbol>();
        for (BSymbol providerSym : providers) {
            NodeInfo providerNode = this.dependencyNodes.computeIfAbsent(providerSym, s -> new NodeInfo(this.curNodeId++, providerSym));
            if (this.isGlobalVarSymbol(providerSym, globalVars)) {
                currentDependencies.add((BVarSymbol)providerSym);
            }
            currentDependencies.addAll(this.analyzeDependenciesRecursively(providerNode, globalVars));
        }
        node.onStack = false;
        Set dependentGlobalVars = node.symbol.kind == SymbolKind.FUNCTION ? ((BInvokableSymbol)node.symbol).dependentGlobalVars : this.globalVariablesDependsOn.computeIfAbsent(node.symbol, s -> new HashSet());
        dependentGlobalVars.addAll(currentDependencies);
        return dependentGlobalVars;
    }

    private Set<BVarSymbol> getGlobalVarFromCurrentNode(NodeInfo node, Set<BSymbol> globalVars) {
        HashSet<BVarSymbol> globalVarsForCurrentNode = new HashSet<BVarSymbol>();
        Set providers = this.globalNodeDependsOn.getOrDefault(node.symbol, new LinkedHashSet());
        for (BSymbol provider : providers) {
            if (!this.isGlobalVarSymbol(provider, globalVars)) continue;
            globalVarsForCurrentNode.add((BVarSymbol)provider);
        }
        return globalVarsForCurrentNode;
    }

    private Set<BVarSymbol> getDependentsFromSymbol(BSymbol symbol, Set<BSymbol> globalVars) {
        if (this.isFunction(symbol)) {
            return ((BInvokableSymbol)symbol).dependentGlobalVars;
        }
        if (this.isGlobalVarSymbol(symbol, globalVars)) {
            return this.globalVariablesDependsOn.getOrDefault(symbol, new HashSet());
        }
        return new HashSet<BVarSymbol>(0);
    }

    private boolean isFunction(BSymbol symbol) {
        return (symbol.tag & 0x334L) == 820L;
    }

    private boolean isGlobalVarSymbol(BSymbol symbol, Set<BSymbol> globalVars) {
        if (symbol == null) {
            return false;
        }
        if (symbol.owner == null) {
            return false;
        }
        if (symbol.owner.tag != 4097L) {
            return false;
        }
        if ((symbol.tag & 0x334L) == 820L) {
            return false;
        }
        return (symbol.tag & 0x34L) == 52L && globalVars.contains(symbol);
    }

    public void analyzeAndReOrder(BLangPackage pkgNode, Map<BSymbol, Set<BSymbol>> globalNodeDependsOn, Map<BSymbol, BSymbol> symbolOwner) {
        this.dlog.setCurrentPackageId(pkgNode.packageID);
        this.pkgNode = pkgNode;
        this.globalNodeDependsOn = globalNodeDependsOn;
        this.symbolOwner = symbolOwner;
        this.resetAnalyzer();
        this.reOrderTopLevelNodeList();
    }

    private List<BSymbol> analyzeDependenciesStartingFrom(BSymbol symbol) {
        if (!this.dependencyNodes.containsKey(symbol)) {
            NodeInfo node = new NodeInfo(this.curNodeId++, symbol);
            this.dependencyNodes.put(symbol, node);
            this.analyzeProvidersRecursively(node);
        }
        if (!this.dependencyOrder.isEmpty()) {
            List<BSymbol> symbolsProvidersOrdered = this.dependencyOrder.stream().map(nodeInfo -> nodeInfo.symbol).toList();
            this.dependencyOrder.clear();
            return symbolsProvidersOrdered;
        }
        return new ArrayList<BSymbol>();
    }

    private void reOrderTopLevelNodeList() {
        Map<BSymbol, TopLevelNode> varMap = this.collectAssociateSymbolsWithTopLevelNodes();
        LinkedHashSet<BSymbol> sorted = new LinkedHashSet<BSymbol>();
        for (BSymbol symbol : varMap.keySet()) {
            sorted.addAll(this.analyzeDependenciesStartingFrom(symbol));
        }
        if (this.cyclicErrorFound) {
            return;
        }
        LinkedHashSet<TopLevelNode> sortedTopLevelNodes = new LinkedHashSet<TopLevelNode>();
        for (BSymbol symbol : sorted) {
            if (!varMap.containsKey(symbol)) continue;
            sortedTopLevelNodes.add(varMap.get(symbol));
        }
        sortedTopLevelNodes.addAll(this.pkgNode.topLevelNodes);
        this.pkgNode.topLevelNodes.clear();
        this.pkgNode.topLevelNodes.addAll(sortedTopLevelNodes);
    }

    private Map<BSymbol, TopLevelNode> collectAssociateSymbolsWithTopLevelNodes() {
        LinkedHashMap<BSymbol, TopLevelNode> resultMap = new LinkedHashMap<BSymbol, TopLevelNode>();
        LinkedHashMap<BSymbol, TopLevelNode> tempVarMap = new LinkedHashMap<BSymbol, TopLevelNode>();
        for (TopLevelNode topLevelNode : this.pkgNode.topLevelNodes) {
            BSymbol symbol = this.getSymbolFromTopLevelNode(topLevelNode);
            if (symbol == null) continue;
            if ((symbol.tag & 0x334L) == 820L) {
                resultMap.put(symbol, topLevelNode);
                continue;
            }
            tempVarMap.put(symbol, topLevelNode);
        }
        resultMap.putAll(tempVarMap);
        return resultMap;
    }

    private BSymbol getSymbolFromTopLevelNode(TopLevelNode topLevelNode) {
        return switch (topLevelNode.getKind()) {
            case NodeKind.VARIABLE, NodeKind.RECORD_VARIABLE, NodeKind.TUPLE_VARIABLE, NodeKind.ERROR_VARIABLE -> ((BLangVariable)topLevelNode).symbol;
            case NodeKind.TYPE_DEFINITION -> Types.getImpliedType((BType)((BLangTypeDefinition)topLevelNode).symbol.type).tsymbol;
            case NodeKind.CONSTANT -> ((BLangConstant)topLevelNode).symbol;
            case NodeKind.FUNCTION -> ((BLangFunction)topLevelNode).symbol;
            default -> null;
        };
    }

    private int analyzeProvidersRecursively(NodeInfo node) {
        if (node.visited) {
            return node.lowLink;
        }
        node.visited = true;
        node.lowLink = node.id;
        node.onStack = true;
        this.nodeInfoStack.push(node);
        Set providers = this.globalNodeDependsOn.getOrDefault(node.symbol, new LinkedHashSet());
        for (BSymbol providerSym : providers) {
            BSymbol symbol = this.symbolOwner.getOrDefault(providerSym, providerSym);
            NodeInfo providerNode = this.dependencyNodes.computeIfAbsent(providerSym, s -> new NodeInfo(this.curNodeId++, symbol));
            int lastLowLink = this.analyzeProvidersRecursively(providerNode);
            if (!providerNode.onStack) continue;
            node.lowLink = Math.min(node.lowLink, lastLowLink);
        }
        if (node.id == node.lowLink) {
            this.handleCyclicReferenceError(node);
        }
        this.dependencyOrder.add(node);
        return node.lowLink;
    }

    private void handleCyclicReferenceError(NodeInfo node) {
        ArrayList<NodeInfo> cycle = new ArrayList<NodeInfo>();
        while (!this.nodeInfoStack.isEmpty()) {
            NodeInfo cNode = this.nodeInfoStack.pop();
            cNode.onStack = false;
            cNode.lowLink = node.id;
            cycle.add(cNode);
            if (cNode.id != node.id) continue;
            break;
        }
        this.cycles.add(cycle);
        if (cycle.size() > 1) {
            cycle = new ArrayList(cycle);
            Collections.reverse(cycle);
            List<BSymbol> symbolsOfCycle = cycle.stream().map(n -> n.symbol).toList();
            if (this.doesContainAGlobalVar(symbolsOfCycle)) {
                this.emitErrorMessage(symbolsOfCycle);
                this.cyclicErrorFound = true;
            }
        }
    }

    private void emitErrorMessage(List<BSymbol> symbolsOfCycle) {
        ArrayList<TopLevelNode> nodesInCycle = new ArrayList<TopLevelNode>();
        for (TopLevelNode topLevelNode : this.pkgNode.topLevelNodes) {
            BSymbol topLevelSymbol = this.getSymbol(topLevelNode);
            for (BSymbol symbol : symbolsOfCycle) {
                if (topLevelSymbol != symbol) continue;
                nodesInCycle.add(topLevelNode);
            }
        }
        Optional<TopLevelNode> firstNode = nodesInCycle.stream().filter(node -> node.getKind() == NodeKind.VARIABLE).min(Comparator.comparingInt(o -> o.getPosition().lineRange().startLine().line()));
        BSymbol firstNodeSymbol = this.getSymbol(firstNode.get());
        int splitFrom = symbolsOfCycle.indexOf(firstNodeSymbol);
        int len = symbolsOfCycle.size();
        ArrayList<BSymbol> firstSubList = new ArrayList<BSymbol>(symbolsOfCycle.subList(0, splitFrom));
        ArrayList<BSymbol> secondSubList = new ArrayList<BSymbol>(symbolsOfCycle.subList(splitFrom, len));
        secondSubList.addAll(firstSubList);
        List<BLangIdentifier> names = secondSubList.stream().map(this::getNodeName).filter(Objects::nonNull).toList();
        this.dlog.error(firstNode.get().getPosition(), DiagnosticErrorCode.GLOBAL_VARIABLE_CYCLIC_DEFINITION, names);
    }

    private boolean doesContainAGlobalVar(List<BSymbol> symbolsOfCycle) {
        return this.pkgNode.globalVars.stream().map(v -> v.symbol).anyMatch(symbolsOfCycle::contains);
    }

    private BLangIdentifier getNodeName(BSymbol symbol) {
        for (TopLevelNode node : this.pkgNode.topLevelNodes) {
            BLangType typeNode;
            if (this.getSymbol(node) != symbol) continue;
            if (node.getKind() == NodeKind.VARIABLE) {
                return ((BLangSimpleVariable)node).name;
            }
            if (node.getKind() == NodeKind.FUNCTION) {
                return ((BLangFunction)node).name;
            }
            if (node.getKind() == NodeKind.CLASS_DEFN) {
                return ((BLangClassDefinition)node).name;
            }
            if (node.getKind() != NodeKind.TYPE_DEFINITION || (typeNode = ((BLangTypeDefinition)node).typeNode).getKind() != NodeKind.OBJECT_TYPE && typeNode.getKind() != NodeKind.RECORD_TYPE) continue;
            return ((BLangTypeDefinition)node).name;
        }
        return null;
    }

    private BSymbol getSymbol(Node node) {
        BLangType typeNode;
        if (node.getKind() == NodeKind.VARIABLE) {
            return ((BLangVariable)node).symbol;
        }
        if (node.getKind() == NodeKind.FUNCTION) {
            return ((BLangFunction)node).symbol;
        }
        if (node.getKind() == NodeKind.CLASS_DEFN) {
            return ((BLangClassDefinition)node).symbol;
        }
        if (node.getKind() == NodeKind.TYPE_DEFINITION && ((typeNode = ((BLangTypeDefinition)node).typeNode).getKind() == NodeKind.OBJECT_TYPE || typeNode.getKind() == NodeKind.RECORD_TYPE)) {
            return ((BLangStructureTypeNode)typeNode).symbol;
        }
        return null;
    }

    private static class NodeInfo {
        final int id;
        final BSymbol symbol;
        int lowLink;
        boolean visited;
        boolean onStack;

        NodeInfo(int id, BSymbol symbol) {
            this.id = id;
            this.symbol = symbol;
        }

        public String toString() {
            return "NodeInfo{id=" + this.id + ", lowLink=" + this.lowLink + ", visited=" + this.visited + ", onStack=" + this.onStack + ", symbol=" + String.valueOf(this.symbol) + "}";
        }
    }
}

