/*
 * Decompiled with CFR 0.152.
 */
package io.ballerina.stdlib.mcp.plugin;

import io.ballerina.compiler.api.SemanticModel;
import io.ballerina.compiler.api.symbols.AnnotationSymbol;
import io.ballerina.compiler.api.symbols.ConstantSymbol;
import io.ballerina.compiler.api.symbols.Documentable;
import io.ballerina.compiler.api.symbols.Documentation;
import io.ballerina.compiler.api.symbols.FunctionSymbol;
import io.ballerina.compiler.api.symbols.FunctionTypeSymbol;
import io.ballerina.compiler.api.symbols.ModuleSymbol;
import io.ballerina.compiler.api.symbols.ParameterSymbol;
import io.ballerina.compiler.api.symbols.ServiceDeclarationSymbol;
import io.ballerina.compiler.api.symbols.Symbol;
import io.ballerina.compiler.api.symbols.SymbolKind;
import io.ballerina.compiler.api.symbols.TypeSymbol;
import io.ballerina.compiler.api.values.ConstantValue;
import io.ballerina.compiler.syntax.tree.AnnotationNode;
import io.ballerina.compiler.syntax.tree.BasicLiteralNode;
import io.ballerina.compiler.syntax.tree.ExpressionNode;
import io.ballerina.compiler.syntax.tree.FunctionDefinitionNode;
import io.ballerina.compiler.syntax.tree.IdentifierToken;
import io.ballerina.compiler.syntax.tree.MappingConstructorExpressionNode;
import io.ballerina.compiler.syntax.tree.MappingFieldNode;
import io.ballerina.compiler.syntax.tree.MetadataNode;
import io.ballerina.compiler.syntax.tree.Node;
import io.ballerina.compiler.syntax.tree.NodeList;
import io.ballerina.compiler.syntax.tree.NodeLocation;
import io.ballerina.compiler.syntax.tree.QualifiedNameReferenceNode;
import io.ballerina.compiler.syntax.tree.SeparatedNodeList;
import io.ballerina.compiler.syntax.tree.ServiceDeclarationNode;
import io.ballerina.compiler.syntax.tree.SpecificFieldNode;
import io.ballerina.compiler.syntax.tree.SyntaxKind;
import io.ballerina.projects.plugins.SyntaxNodeAnalysisContext;
import io.ballerina.stdlib.mcp.plugin.diagnostics.CompilationDiagnostic;
import io.ballerina.tools.diagnostics.Diagnostic;
import io.ballerina.tools.diagnostics.Location;
import java.util.List;
import java.util.Optional;

public class Utils {
    public static final String BALLERINA_ORG = "ballerina";
    public static final String TOOL_ANNOTATION_NAME = "Tool";
    public static final String MCP_PACKAGE_NAME = "mcp";
    public static final String MCP_BASIC_SERVICE_NAME = "Service";
    public static final String SESSION_TYPE_NAME = "Session";
    public static final String UNKNOWN_SYMBOL = "unknown";
    public static final String SERVICE_CONFIG_ANNOTATION_NAME = "ServiceConfig";
    public static final String SESSION_MODE_FIELD = "sessionMode";

    private Utils() {
    }

    public static boolean isMcpToolAnnotation(AnnotationSymbol annotationSymbol) {
        return annotationSymbol.getModule().isPresent() && Utils.isMcpModuleSymbol((Symbol)annotationSymbol.getModule().get()) && annotationSymbol.getName().isPresent() && TOOL_ANNOTATION_NAME.equals(annotationSymbol.getName().get());
    }

    public static boolean isMcpModuleSymbol(Symbol symbol) {
        return symbol.getModule().isPresent() && MCP_PACKAGE_NAME.equals(((ModuleSymbol)symbol.getModule().get()).id().moduleName()) && BALLERINA_ORG.equals(((ModuleSymbol)symbol.getModule().get()).id().orgName());
    }

    public static String getParameterDescription(FunctionSymbol functionSymbol, String parameterName) {
        if (functionSymbol.documentation().isEmpty() || ((Documentation)functionSymbol.documentation().get()).description().isEmpty()) {
            return null;
        }
        return ((Documentation)functionSymbol.documentation().get()).parameterMap().getOrDefault(parameterName, null);
    }

    public static String getDescription(Documentable documentable) {
        if (documentable.documentation().isEmpty() || ((Documentation)documentable.documentation().get()).description().isEmpty()) {
            return null;
        }
        return (String)((Documentation)documentable.documentation().get()).description().get();
    }

    public static String escapeDoubleQuotes(String input) {
        return input.replace("\"", "\\\"");
    }

    public static String addDoubleQuotes(String input) {
        return "\"" + input + "\"";
    }

    public static Optional<AnnotationNode> getToolAnnotationNode(SemanticModel semanticModel, FunctionDefinitionNode functionDefinitionNode) {
        Optional metadataNode = functionDefinitionNode.metadata();
        if (metadataNode.isEmpty()) {
            return Optional.empty();
        }
        NodeList annotationNodes = ((MetadataNode)metadataNode.get()).annotations();
        return annotationNodes.stream().filter(annotationNode -> semanticModel.symbol((Node)annotationNode).filter(symbol -> symbol.kind() == SymbolKind.ANNOTATION).filter(symbol -> Utils.isMcpToolAnnotation((AnnotationSymbol)symbol)).isPresent()).findFirst();
    }

    public static boolean isMcpServiceFunction(SemanticModel semanticModel, FunctionDefinitionNode functionDefinitionNode) {
        Optional parentSymbol = semanticModel.symbol((Node)functionDefinitionNode.parent());
        if (parentSymbol.isEmpty() || ((Symbol)parentSymbol.get()).kind() != SymbolKind.SERVICE_DECLARATION) {
            return false;
        }
        ServiceDeclarationSymbol serviceSymbol = (ServiceDeclarationSymbol)parentSymbol.get();
        Optional firstListenerType = serviceSymbol.listenerTypes().stream().findFirst();
        boolean isFromMcpModule = firstListenerType.flatMap(Symbol::getModule).flatMap(module -> module.getName().map(MCP_PACKAGE_NAME::equals)).orElse(false);
        boolean isServiceType = serviceSymbol.typeDescriptor().flatMap(type -> type.getName().map(MCP_BASIC_SERVICE_NAME::equals)).orElse(false);
        return isFromMcpModule && isServiceType;
    }

    public static boolean isAnydataType(TypeSymbol typeSymbol, SyntaxNodeAnalysisContext context) {
        return typeSymbol.subtypeOf(context.semanticModel().types().ANYDATA);
    }

    public static boolean validateParameterTypes(FunctionSymbol functionSymbol, FunctionDefinitionNode functionDefinitionNode, SyntaxNodeAnalysisContext context) {
        FunctionTypeSymbol functionTypeSymbol = functionSymbol.typeDescriptor();
        if (functionTypeSymbol.params().isEmpty()) {
            return true;
        }
        String functionName = functionSymbol.getName().orElse(UNKNOWN_SYMBOL);
        NodeLocation alternativeLocation = functionDefinitionNode.location();
        SessionMode sessionMode = Utils.getSessionMode(functionDefinitionNode, context.semanticModel());
        List parameterSymbolList = (List)functionTypeSymbol.params().get();
        boolean hasSessionParam = false;
        for (int i = 0; i < parameterSymbolList.size(); ++i) {
            ParameterSymbol parameterSymbol = (ParameterSymbol)parameterSymbolList.get(i);
            TypeSymbol parameterType = parameterSymbol.typeDescriptor();
            String parameterName = parameterSymbol.getName().orElse(UNKNOWN_SYMBOL);
            boolean isSessionType = Utils.isSessionType(parameterType);
            if (isSessionType) {
                if (hasSessionParam) {
                    Diagnostic diagnostic = CompilationDiagnostic.getDiagnostic(CompilationDiagnostic.SESSION_PARAM_MUST_BE_FIRST, (Location)parameterSymbol.getLocation().orElse(alternativeLocation), functionName, parameterName);
                    context.reportDiagnostic(diagnostic);
                    return false;
                }
                if (i != 0) {
                    Diagnostic diagnostic = CompilationDiagnostic.getDiagnostic(CompilationDiagnostic.SESSION_PARAM_MUST_BE_FIRST, (Location)parameterSymbol.getLocation().orElse(alternativeLocation), functionName, parameterName);
                    context.reportDiagnostic(diagnostic);
                    return false;
                }
                if (sessionMode == SessionMode.STATELESS) {
                    Diagnostic diagnostic = CompilationDiagnostic.getDiagnostic(CompilationDiagnostic.SESSION_PARAM_NOT_ALLOWED_IN_STATELESS_MODE, (Location)parameterSymbol.getLocation().orElse(alternativeLocation), functionName, parameterName);
                    context.reportDiagnostic(diagnostic);
                    return false;
                }
                hasSessionParam = true;
                continue;
            }
            if (Utils.isAnydataType(parameterType, context)) continue;
            Diagnostic diagnostic = CompilationDiagnostic.getDiagnostic(CompilationDiagnostic.INVALID_PARAMETER_TYPE, (Location)parameterSymbol.getLocation().orElse(alternativeLocation), functionName, parameterName);
            context.reportDiagnostic(diagnostic);
            return false;
        }
        return true;
    }

    static boolean isSessionType(TypeSymbol typeSymbol) {
        return SESSION_TYPE_NAME.equals(typeSymbol.getName().orElse("")) && Utils.isMcpModuleSymbol((Symbol)typeSymbol);
    }

    private static SessionMode getSessionMode(FunctionDefinitionNode functionDefinitionNode, SemanticModel semanticModel) {
        ServiceDeclarationNode serviceNode = (ServiceDeclarationNode)functionDefinitionNode.parent();
        if (serviceNode.metadata().isEmpty() || ((MetadataNode)serviceNode.metadata().get()).annotations().isEmpty()) {
            return SessionMode.AUTO;
        }
        AnnotationNode serviceConfigAnnotation = null;
        for (AnnotationNode annotation : ((MetadataNode)serviceNode.metadata().get()).annotations()) {
            if (!Utils.isMcpServiceConfigAnnotation(annotation)) continue;
            serviceConfigAnnotation = annotation;
            break;
        }
        if (serviceConfigAnnotation == null || serviceConfigAnnotation.annotValue().isEmpty()) {
            return SessionMode.AUTO;
        }
        SeparatedNodeList fields = ((MappingConstructorExpressionNode)serviceConfigAnnotation.annotValue().get()).fields();
        for (MappingFieldNode field : fields) {
            SpecificFieldNode specificField;
            String fieldName;
            if (field.kind() != SyntaxKind.SPECIFIC_FIELD || !SESSION_MODE_FIELD.equals(fieldName = ((IdentifierToken)(specificField = (SpecificFieldNode)field).fieldName()).text()) || !specificField.valueExpr().isPresent()) continue;
            return Utils.resolveSessionModeValue((ExpressionNode)specificField.valueExpr().get(), semanticModel);
        }
        return SessionMode.AUTO;
    }

    private static SessionMode resolveSessionModeValue(ExpressionNode valueExpr, SemanticModel semanticModel) {
        Object constValue;
        ConstantSymbol enumMemberSymbol;
        Symbol resolvedSymbol;
        Optional symbol = semanticModel.symbol((Node)valueExpr);
        if (symbol.isPresent() && (resolvedSymbol = (Symbol)symbol.get()).kind() == SymbolKind.ENUM_MEMBER && Utils.isMcpModuleSymbol((Symbol)(enumMemberSymbol = (ConstantSymbol)resolvedSymbol)) && (constValue = enumMemberSymbol.constValue()) instanceof ConstantValue) {
            String enumValue = ((ConstantValue)constValue).value().toString();
            return SessionMode.fromString(enumValue);
        }
        if (valueExpr.kind() == SyntaxKind.STRING_LITERAL) {
            BasicLiteralNode stringLiteral = (BasicLiteralNode)valueExpr;
            String literalValue = stringLiteral.literalToken().text();
            if (literalValue.startsWith("\"") && literalValue.endsWith("\"")) {
                literalValue = literalValue.substring(1, literalValue.length() - 1);
            }
            return SessionMode.fromString(literalValue);
        }
        return SessionMode.AUTO;
    }

    private static boolean isMcpServiceConfigAnnotation(AnnotationNode annotation) {
        if (annotation.annotReference().kind() != SyntaxKind.QUALIFIED_NAME_REFERENCE) {
            return false;
        }
        QualifiedNameReferenceNode qualifiedRef = (QualifiedNameReferenceNode)annotation.annotReference();
        String modulePrefix = qualifiedRef.modulePrefix().text();
        String identifier = qualifiedRef.identifier().text();
        return MCP_PACKAGE_NAME.equals(modulePrefix) && SERVICE_CONFIG_ANNOTATION_NAME.equals(identifier);
    }

    public static enum SessionMode {
        STATEFUL("stateful"),
        STATELESS("stateless"),
        AUTO("auto");

        private final String value;

        private SessionMode(String value) {
            this.value = value;
        }

        public String getValue() {
            return this.value;
        }

        public static SessionMode fromString(String value) {
            if (value == null) {
                return AUTO;
            }
            for (SessionMode mode : SessionMode.values()) {
                if (!mode.value.equalsIgnoreCase(value)) continue;
                return mode;
            }
            return AUTO;
        }
    }
}

