/*
 * Decompiled with CFR 0.152.
 */
package org.ballerinalang.langserver.codeaction.providers.changetype;

import io.ballerina.compiler.api.SemanticModel;
import io.ballerina.compiler.api.symbols.FunctionTypeSymbol;
import io.ballerina.compiler.api.symbols.TypeDescKind;
import io.ballerina.compiler.api.symbols.TypeSymbol;
import io.ballerina.compiler.syntax.tree.CheckExpressionNode;
import io.ballerina.compiler.syntax.tree.ExplicitAnonymousFunctionExpressionNode;
import io.ballerina.compiler.syntax.tree.ExpressionNode;
import io.ballerina.compiler.syntax.tree.FunctionDefinitionNode;
import io.ballerina.compiler.syntax.tree.FunctionSignatureNode;
import io.ballerina.compiler.syntax.tree.Node;
import io.ballerina.compiler.syntax.tree.NodeVisitor;
import io.ballerina.compiler.syntax.tree.NonTerminalNode;
import io.ballerina.compiler.syntax.tree.ReturnStatementNode;
import io.ballerina.compiler.syntax.tree.ReturnTypeDescriptorNode;
import io.ballerina.compiler.syntax.tree.SyntaxKind;
import io.ballerina.projects.Document;
import io.ballerina.tools.diagnostics.Diagnostic;
import io.ballerina.tools.text.LinePosition;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import org.ballerinalang.langserver.codeaction.CodeActionNodeValidator;
import org.ballerinalang.langserver.codeaction.CodeActionUtil;
import org.ballerinalang.langserver.codeaction.ReturnStatementFinder;
import org.ballerinalang.langserver.common.utils.PositionUtil;
import org.ballerinalang.langserver.commons.CodeActionContext;
import org.ballerinalang.langserver.commons.codeaction.spi.DiagBasedPositionDetails;
import org.ballerinalang.langserver.commons.codeaction.spi.DiagnosticBasedCodeActionProvider;
import org.eclipse.lsp4j.CodeAction;
import org.eclipse.lsp4j.Position;
import org.eclipse.lsp4j.Range;
import org.eclipse.lsp4j.TextEdit;

public class FixReturnTypeCodeAction
implements DiagnosticBasedCodeActionProvider {
    public static final String NAME = "Fix Return Type";
    public static final Set<String> DIAGNOSTIC_CODES = Set.of("BCE2066", "BCE2068", "BCE3032");

    public boolean validate(Diagnostic diagnostic, DiagBasedPositionDetails positionDetails, CodeActionContext context) {
        NonTerminalNode ancestorNode;
        if (!DIAGNOSTIC_CODES.contains(diagnostic.diagnosticInfo().code())) {
            return false;
        }
        NonTerminalNode matchedNode = positionDetails.matchedNode();
        if (matchedNode.kind() == SyntaxKind.CHECK_ACTION || matchedNode.kind() == SyntaxKind.CHECK_EXPRESSION) {
            return CodeActionNodeValidator.validate(context.nodeAtRange());
        }
        NonTerminalNode parentNode = matchedNode.parent();
        if (parentNode == null) {
            return false;
        }
        if (parentNode.kind() == SyntaxKind.RETURN_KEYWORD) {
            return CodeActionNodeValidator.validate(context.nodeAtRange());
        }
        NonTerminalNode grandParentNode = parentNode.parent();
        if (grandParentNode == null) {
            return false;
        }
        if (grandParentNode.kind() == SyntaxKind.RETURN_STATEMENT) {
            return CodeActionNodeValidator.validate(context.nodeAtRange());
        }
        if (matchedNode.kind() == SyntaxKind.COLLECT_CLAUSE && ((ancestorNode = grandParentNode.parent()) == null || ancestorNode.kind() != SyntaxKind.RETURN_STATEMENT)) {
            return false;
        }
        return CodeActionNodeValidator.validate(context.nodeAtRange());
    }

    public List<CodeAction> getCodeActions(Diagnostic diagnostic, DiagBasedPositionDetails positionDetails, CodeActionContext context) {
        FunctionSignatureNode functionSignatureNode;
        Optional foundType = Optional.empty();
        if ("BCE2068".equals(diagnostic.diagnosticInfo().code())) {
            foundType = positionDetails.diagnosticProperty(CodeActionUtil.getDiagPropertyFilterFunction(1));
        } else if ("BCE2066".equals(diagnostic.diagnosticInfo().code())) {
            foundType = positionDetails.diagnosticProperty(1);
        }
        boolean checkExprDiagnostic = "BCE3032".equals(diagnostic.diagnosticInfo().code());
        if (foundType.isEmpty() && !checkExprDiagnostic) {
            return Collections.emptyList();
        }
        if (context.currentSemanticModel().isEmpty() || context.currentDocument().isEmpty()) {
            return Collections.emptyList();
        }
        SemanticModel semanticModel = (SemanticModel)context.currentSemanticModel().get();
        Optional<Object> funcDef = Optional.empty();
        Optional<ExplicitAnonymousFunctionExpressionNode> anonFunc = CodeActionUtil.getEnclosingAnonFuncExpr((Node)positionDetails.matchedNode());
        boolean isMainFunction = false;
        Optional expectedReturnType = Optional.empty();
        if (anonFunc.isEmpty()) {
            funcDef = CodeActionUtil.getEnclosedFunction((Node)positionDetails.matchedNode());
            if (funcDef.isEmpty()) {
                return Collections.emptyList();
            }
            isMainFunction = "main".equals(((FunctionDefinitionNode)funcDef.get()).functionName().text());
            functionSignatureNode = ((FunctionDefinitionNode)funcDef.get()).functionSignature();
        } else {
            functionSignatureNode = anonFunc.get().functionSignature();
            Optional annoExpectedTypeSymbol = semanticModel.expectedType((Document)context.currentDocument().get(), anonFunc.get().lineRange().startLine());
            if (annoExpectedTypeSymbol.isEmpty() || ((TypeSymbol)annoExpectedTypeSymbol.get()).typeKind() != TypeDescKind.FUNCTION) {
                return Collections.emptyList();
            }
            expectedReturnType = ((FunctionTypeSymbol)annoExpectedTypeSymbol.get()).returnTypeDescriptor();
        }
        StartEndPositionDetails startEndPositionDetails = FixReturnTypeCodeAction.extractStartAndEndPosOfReturnNode(functionSignatureNode);
        if (checkExprDiagnostic) {
            if (expectedReturnType.isPresent() && !semanticModel.types().ERROR.subtypeOf((TypeSymbol)expectedReturnType.get())) {
                return Collections.emptyList();
            }
            String returnTypeDesc = functionSignatureNode.returnTypeDesc().isEmpty() ? "error?" : ((ReturnTypeDescriptorNode)functionSignatureNode.returnTypeDesc().get()).type().toString().trim().concat("|").concat("error");
            return List.of(FixReturnTypeCodeAction.getReturnTypeChangeCodeAction(context, Set.of(returnTypeDesc), startEndPositionDetails, functionSignatureNode, new ArrayList<TextEdit>()));
        }
        if (isMainFunction) {
            return Collections.emptyList();
        }
        ArrayList<TextEdit> importEdits = new ArrayList<TextEdit>();
        List<Set<String>> types = new ArrayList<Set<String>>();
        ArrayList<CodeAction> codeActions = new ArrayList<CodeAction>();
        ArrayList<List<String>> combinedTypes = new ArrayList<List<String>>();
        ReturnStatementFinder returnStatementFinder = new ReturnStatementFinder();
        if (funcDef.isPresent()) {
            returnStatementFinder.visit((FunctionDefinitionNode)funcDef.get());
        } else {
            returnStatementFinder.visit(anonFunc.get());
        }
        List<ReturnStatementNode> nodeList = returnStatementFinder.getNodeList();
        for (ReturnStatementNode returnStatementNode : nodeList) {
            if (returnStatementNode.expression().isEmpty() || context.currentSemanticModel().isEmpty()) {
                return Collections.emptyList();
            }
            ExpressionNode expression = (ExpressionNode)returnStatementNode.expression().get();
            Optional typeSymbol = semanticModel.typeOf((Node)expression);
            if (typeSymbol.isEmpty() || ((TypeSymbol)typeSymbol.get()).typeKind() == TypeDescKind.COMPILATION_ERROR) {
                return Collections.emptyList();
            }
            if (expectedReturnType.isPresent() && !((TypeSymbol)typeSymbol.get()).subtypeOf((TypeSymbol)expectedReturnType.get())) continue;
            if (((TypeSymbol)typeSymbol.get()).typeKind() == TypeDescKind.FUNCTION) {
                combinedTypes.add(Collections.singletonList("(" + CodeActionUtil.getPossibleTypes((TypeSymbol)typeSymbol.get(), importEdits, context).get(0) + ")"));
                continue;
            }
            combinedTypes.add(CodeActionUtil.getPossibleTypes((TypeSymbol)typeSymbol.get(), importEdits, context));
        }
        CheckExprNodeFinder checkExprNodeFinder = new CheckExprNodeFinder();
        if (funcDef.isPresent()) {
            ((FunctionDefinitionNode)funcDef.get()).accept((NodeVisitor)checkExprNodeFinder);
        } else {
            anonFunc.get().accept((NodeVisitor)checkExprNodeFinder);
        }
        if (checkExprNodeFinder.containCheckExprNode() && (expectedReturnType.isEmpty() || semanticModel.types().ERROR.subtypeOf((TypeSymbol)expectedReturnType.get()))) {
            combinedTypes.add(Collections.singletonList("error"));
        }
        types = FixReturnTypeCodeAction.getPossibleCombinations(combinedTypes, types);
        types.forEach(type -> codeActions.add(FixReturnTypeCodeAction.getReturnTypeChangeCodeAction(context, type, startEndPositionDetails, functionSignatureNode, importEdits)));
        return codeActions;
    }

    public String getName() {
        return NAME;
    }

    private static CodeAction getReturnTypeChangeCodeAction(CodeActionContext context, Set<String> typeSet, StartEndPositionDetails positionDetails, FunctionSignatureNode functionSignatureNode, List<TextEdit> importEdits) {
        ArrayList<TextEdit> edits = new ArrayList<TextEdit>();
        String newType = String.join((CharSequence)"|", typeSet);
        Object editText = functionSignatureNode.returnTypeDesc().isEmpty() ? " returns " + newType : newType;
        edits.add(new TextEdit(new Range(positionDetails.start, positionDetails.end), (String)editText));
        edits.addAll(importEdits);
        String commandTitle = String.format("Change return type to '%s'", newType);
        return CodeActionUtil.createCodeAction(commandTitle, edits, context.fileUri(), "quickfix");
    }

    private static StartEndPositionDetails extractStartAndEndPosOfReturnNode(FunctionSignatureNode signatureNode) {
        if (signatureNode.returnTypeDesc().isPresent()) {
            ReturnTypeDescriptorNode returnTypeDesc = (ReturnTypeDescriptorNode)signatureNode.returnTypeDesc().get();
            LinePosition retStart = returnTypeDesc.type().lineRange().startLine();
            LinePosition retEnd = returnTypeDesc.type().lineRange().endLine();
            return new StartEndPositionDetails(new Position(retStart.line(), retStart.offset()), new Position(retEnd.line(), retEnd.offset()));
        }
        Position funcBodyStart = PositionUtil.toPosition(signatureNode.lineRange().endLine());
        return new StartEndPositionDetails(funcBodyStart, funcBodyStart);
    }

    private static List<Set<String>> getPossibleCombinations(List<List<String>> combinedTypes, List<Set<String>> typeList) {
        for (List<String> possibleTypes : combinedTypes) {
            if (typeList.isEmpty()) {
                for (String type : possibleTypes) {
                    typeList.add(Set.of(type));
                }
                continue;
            }
            ArrayList<Set<String>> updatedTypes = new ArrayList<Set<String>>();
            for (String type : possibleTypes) {
                for (Set<String> strings : typeList) {
                    HashSet<String> combination = new HashSet<String>(strings);
                    combination.add(type);
                    updatedTypes.add(combination);
                }
            }
            if (updatedTypes.isEmpty()) continue;
            typeList = updatedTypes;
        }
        return typeList;
    }

    private record StartEndPositionDetails(Position start, Position end) {
    }

    static class CheckExprNodeFinder
    extends NodeVisitor {
        private CheckExpressionNode checkExpressionNode = null;

        CheckExprNodeFinder() {
        }

        public void visit(FunctionDefinitionNode functionDefinitionNode) {
            functionDefinitionNode.functionBody().accept((NodeVisitor)this);
        }

        public void visit(CheckExpressionNode checkExpressionNode) {
            this.checkExpressionNode = checkExpressionNode;
        }

        boolean containCheckExprNode() {
            return this.checkExpressionNode != null;
        }
    }
}

