/*
 * Decompiled with CFR 0.152.
 */
package io.ballerina.stdlib.ai.plugin;

import io.ballerina.compiler.syntax.tree.AnnotationNode;
import io.ballerina.compiler.syntax.tree.ClassDefinitionNode;
import io.ballerina.compiler.syntax.tree.ExpressionNode;
import io.ballerina.compiler.syntax.tree.FunctionDefinitionNode;
import io.ballerina.compiler.syntax.tree.IdentifierToken;
import io.ballerina.compiler.syntax.tree.MappingConstructorExpressionNode;
import io.ballerina.compiler.syntax.tree.MappingFieldNode;
import io.ballerina.compiler.syntax.tree.MetadataNode;
import io.ballerina.compiler.syntax.tree.Minutiae;
import io.ballerina.compiler.syntax.tree.MinutiaeList;
import io.ballerina.compiler.syntax.tree.ModuleMemberDeclarationNode;
import io.ballerina.compiler.syntax.tree.ModulePartNode;
import io.ballerina.compiler.syntax.tree.ModuleVariableDeclarationNode;
import io.ballerina.compiler.syntax.tree.Node;
import io.ballerina.compiler.syntax.tree.NodeFactory;
import io.ballerina.compiler.syntax.tree.NodeList;
import io.ballerina.compiler.syntax.tree.NodeParser;
import io.ballerina.compiler.syntax.tree.QualifiedNameReferenceNode;
import io.ballerina.compiler.syntax.tree.SeparatedNodeList;
import io.ballerina.compiler.syntax.tree.SpecificFieldNode;
import io.ballerina.compiler.syntax.tree.SyntaxKind;
import io.ballerina.compiler.syntax.tree.SyntaxTree;
import io.ballerina.compiler.syntax.tree.Token;
import io.ballerina.projects.DocumentId;
import io.ballerina.projects.Module;
import io.ballerina.projects.ModuleId;
import io.ballerina.projects.plugins.ModifierTask;
import io.ballerina.projects.plugins.SourceModifierContext;
import io.ballerina.stdlib.ai.plugin.ModifierContext;
import io.ballerina.stdlib.ai.plugin.ToolAnnotationConfig;
import io.ballerina.tools.text.TextDocument;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

class AiSourceModifier
implements ModifierTask<SourceModifierContext> {
    private static final String EMPTY_STRING = "";
    private final Map<DocumentId, ModifierContext> modifierContextMap;
    private final Set<ModuleId> modulesWithPredefinedInitMethods;
    private final Set<ModuleId> modulesWithDesugaredAgentsWithInitMethod = new HashSet<ModuleId>();

    AiSourceModifier(Map<DocumentId, ModifierContext> modifierContextMap, Set<ModuleId> modulesWithPredefinedInitMethods) {
        this.modifierContextMap = modifierContextMap;
        this.modulesWithPredefinedInitMethods = modulesWithPredefinedInitMethods;
    }

    public void modify(SourceModifierContext context) {
        for (Map.Entry<DocumentId, ModifierContext> entry : this.modifierContextMap.entrySet()) {
            this.modifyDocumentWithTools(context, entry.getKey(), entry.getValue());
        }
    }

    private void modifyDocumentWithTools(SourceModifierContext context, DocumentId documentId, ModifierContext modifierContext) {
        Module module = context.currentPackage().module(documentId.moduleId());
        ModulePartNode rootNode = (ModulePartNode)module.document(documentId).syntaxTree().rootNode();
        ModulePartNode updatedRoot = this.modifyModulePartRoot(rootNode, modifierContext, documentId);
        this.updateDocument(context, module, documentId, updatedRoot);
    }

    private ModulePartNode modifyModulePartRoot(ModulePartNode modulePartNode, ModifierContext modifierContext, DocumentId documentId) {
        List<ModuleMemberDeclarationNode> modifiedMembers = this.getModifiedModuleMembers((NodeList<ModuleMemberDeclarationNode>)modulePartNode.members(), modifierContext, documentId);
        return modulePartNode.modify().withMembers(NodeFactory.createNodeList(modifiedMembers)).apply();
    }

    private List<ModuleMemberDeclarationNode> getModifiedModuleMembers(NodeList<ModuleMemberDeclarationNode> members, ModifierContext modifierContext, DocumentId documentId) {
        Map<AnnotationNode, AnnotationNode> modifiedAnnotations = this.getModifiedAnnotations(modifierContext);
        Set<ModuleVariableDeclarationNode> agentDeclarations = modifierContext.getModuleLevelAgentDeclarations();
        ArrayList<ModuleMemberDeclarationNode> modifiedMembers = new ArrayList<ModuleMemberDeclarationNode>();
        for (ModuleMemberDeclarationNode member : members) {
            modifiedMembers.add(this.getModifiedModuleMember(member, modifiedAnnotations, agentDeclarations));
        }
        ModuleId moduleId = documentId.moduleId();
        if (!this.modulesWithPredefinedInitMethods.contains(moduleId) && !this.modulesWithDesugaredAgentsWithInitMethod.contains(moduleId)) {
            ModuleMemberDeclarationNode initFunctionDeclaration = this.desugarAgentsWithinInitFunction(moduleId);
            this.modulesWithDesugaredAgentsWithInitMethod.add(moduleId);
            modifiedMembers.add(initFunctionDeclaration);
        }
        return modifiedMembers;
    }

    private ModuleMemberDeclarationNode desugarAgentsWithinInitFunction(ModuleId moduleId) {
        Set documentIds = this.modifierContextMap.keySet().stream().filter(doc -> doc.moduleId().equals((Object)moduleId)).collect(Collectors.toSet());
        Stream<ModifierContext> modifierContextStream = this.modifierContextMap.entrySet().stream().filter(entry -> documentIds.contains(entry.getKey())).map(Map.Entry::getValue);
        List agentDeclarations = modifierContextStream.map(ModifierContext::getModuleLevelAgentDeclarations).flatMap(Collection::stream).toList();
        String agentInitializationSourceCode = agentDeclarations.stream().map(Node::toSourceCode).map(code -> code.replaceFirst(".*:Agent", EMPTY_STRING)).collect(Collectors.joining(EMPTY_STRING));
        return NodeParser.parseModuleMemberDeclaration((String)("function init() returns error? {" + agentInitializationSourceCode + "}"));
    }

    private Map<AnnotationNode, AnnotationNode> getModifiedAnnotations(ModifierContext modifierContext) {
        HashMap<AnnotationNode, AnnotationNode> updatedAnnotationMap = new HashMap<AnnotationNode, AnnotationNode>();
        for (Map.Entry<AnnotationNode, ToolAnnotationConfig> entry : modifierContext.getAnnotationConfigMap().entrySet()) {
            updatedAnnotationMap.put(entry.getKey(), this.getModifiedAnnotation(entry.getKey(), entry.getValue()));
        }
        return updatedAnnotationMap;
    }

    private AnnotationNode getModifiedAnnotation(AnnotationNode targetNode, ToolAnnotationConfig config) {
        if (targetNode.annotValue().isEmpty()) {
            return this.handleAnnotationWithoutMappingConstructor(targetNode, config);
        }
        return this.handleAnnotationWithMappingConstructor(targetNode, config);
    }

    private AnnotationNode handleAnnotationWithoutMappingConstructor(AnnotationNode targetNode, ToolAnnotationConfig config) {
        String mappingConstructorExpression = this.generateConfigMappingConstructor(config);
        MappingConstructorExpressionNode mappingConstructorNode = (MappingConstructorExpressionNode)NodeParser.parseExpression((String)mappingConstructorExpression);
        Node annotationReference = targetNode.annotReference();
        if (annotationReference.kind() == SyntaxKind.QUALIFIED_NAME_REFERENCE) {
            QualifiedNameReferenceNode qualifiedNameReferenceNode = (QualifiedNameReferenceNode)annotationReference;
            String identifier = qualifiedNameReferenceNode.identifier().text().replaceAll("\\R", EMPTY_STRING);
            String modulePrefix = qualifiedNameReferenceNode.modulePrefix().text();
            annotationReference = NodeFactory.createQualifiedNameReferenceNode((Token)NodeFactory.createIdentifierToken((String)modulePrefix), (Node)NodeFactory.createToken((SyntaxKind)SyntaxKind.COLON_TOKEN), (IdentifierToken)NodeFactory.createIdentifierToken((String)identifier));
            Token closeBraceTokenWithNewLine = NodeFactory.createToken((SyntaxKind)SyntaxKind.CLOSE_BRACE_TOKEN, (MinutiaeList)NodeFactory.createEmptyMinutiaeList(), (MinutiaeList)NodeFactory.createMinutiaeList((Minutiae[])new Minutiae[]{NodeFactory.createEndOfLineMinutiae((String)System.lineSeparator())}));
            mappingConstructorNode = mappingConstructorNode.modify().withCloseBrace(closeBraceTokenWithNewLine).apply();
        }
        return NodeFactory.createAnnotationNode((Token)targetNode.atToken(), (Node)annotationReference, (MappingConstructorExpressionNode)mappingConstructorNode);
    }

    private String generateConfigMappingConstructor(ToolAnnotationConfig config) {
        return this.generateConfigMappingConstructor(config, SyntaxKind.OPEN_BRACE_TOKEN.stringValue(), SyntaxKind.CLOSE_BRACE_TOKEN.stringValue());
    }

    private String generateConfigMappingConstructor(ToolAnnotationConfig config, String openBraceSource, String closeBraceSource) {
        String name = config.name().replaceAll("\\R", " ");
        return openBraceSource + String.format("name:%s,description:%s,parameters:%s", name, config.description() != null ? config.description().replaceAll("\\R", " ") : name, config.parameterSchema()) + closeBraceSource;
    }

    private AnnotationNode handleAnnotationWithMappingConstructor(AnnotationNode targetNode, ToolAnnotationConfig config) {
        MappingConstructorExpressionNode mappingConstructorNode = this.getMappingConstructorExpressionNode(targetNode);
        Set<String> existingFieldNames = this.extractFieldNames((SeparatedNodeList<MappingFieldNode>)mappingConstructorNode.fields());
        List<MappingFieldNode> missingFields = this.getMissingFields(existingFieldNames, config);
        String missingFieldSourceCode = this.generateMissingFieldSourceCode(missingFields);
        if (missingFieldSourceCode == null) {
            return targetNode;
        }
        String annotationSourceCode = targetNode.toSourceCode();
        String modifiedAnnotationSourceCode = AiSourceModifier.getModifiedAnnotationSourceCode(annotationSourceCode, missingFieldSourceCode);
        return NodeParser.parseAnnotation((String)modifiedAnnotationSourceCode);
    }

    private Set<String> extractFieldNames(SeparatedNodeList<MappingFieldNode> fields) {
        return fields.stream().filter(field -> field.kind() == SyntaxKind.SPECIFIC_FIELD).map(field -> (SpecificFieldNode)field).map(specificFieldNode -> specificFieldNode.fieldName().toSourceCode().trim()).collect(Collectors.toSet());
    }

    private String generateMissingFieldSourceCode(List<MappingFieldNode> missingFields) {
        return missingFields.isEmpty() ? null : missingFields.stream().map(Node::toSourceCode).collect(Collectors.joining(SyntaxKind.COMMA_TOKEN.stringValue()));
    }

    private static String getModifiedAnnotationSourceCode(String annotationSourceCode, String missingFieldSourceCode) {
        int closeBraceTokenIndex = annotationSourceCode.lastIndexOf(SyntaxKind.CLOSE_BRACE_TOKEN.stringValue());
        String sourceBeforeCloseBrace = annotationSourceCode.substring(0, closeBraceTokenIndex);
        String sourceAfterCloseBrace = annotationSourceCode.substring(closeBraceTokenIndex);
        String endsWithBracesRegex = ".*\\{\\s*$";
        if (sourceBeforeCloseBrace.matches(endsWithBracesRegex)) {
            return sourceBeforeCloseBrace + missingFieldSourceCode + sourceAfterCloseBrace;
        }
        return sourceBeforeCloseBrace + SyntaxKind.COMMA_TOKEN.stringValue() + missingFieldSourceCode + sourceAfterCloseBrace;
    }

    private MappingConstructorExpressionNode getMappingConstructorExpressionNode(AnnotationNode targetNode) {
        MappingConstructorExpressionNode mappingConstructorExpressionNode = (MappingConstructorExpressionNode)targetNode.annotValue().get();
        return mappingConstructorExpressionNode;
    }

    private List<MappingFieldNode> getMissingFields(Set<String> existingFieldNames, ToolAnnotationConfig config) {
        List<String> requiredFields = List.of("name", "description", "parameters");
        ArrayList<MappingFieldNode> missingFields = new ArrayList<MappingFieldNode>();
        for (String fieldName : requiredFields) {
            if (existingFieldNames.contains(fieldName)) continue;
            missingFields.add((MappingFieldNode)NodeFactory.createSpecificFieldNode(null, (Node)NodeFactory.createIdentifierToken((String)fieldName), (Token)NodeFactory.createToken((SyntaxKind)SyntaxKind.COLON_TOKEN), (ExpressionNode)NodeParser.parseExpression((String)config.get(fieldName))));
        }
        return missingFields;
    }

    private ModuleMemberDeclarationNode getModifiedModuleMember(ModuleMemberDeclarationNode member, Map<AnnotationNode, AnnotationNode> modifiedAnnotations, Set<ModuleVariableDeclarationNode> agentDeclarations) {
        return switch (member.kind()) {
            case SyntaxKind.FUNCTION_DEFINITION -> this.modifyFunction((FunctionDefinitionNode)member, modifiedAnnotations);
            case SyntaxKind.MODULE_VAR_DECL -> this.modifyVariableDeclaration((ModuleVariableDeclarationNode)member, agentDeclarations);
            case SyntaxKind.CLASS_DEFINITION -> this.modifyClassDefinition((ClassDefinitionNode)member, modifiedAnnotations);
            default -> member;
        };
    }

    private ModuleMemberDeclarationNode modifyClassDefinition(ClassDefinitionNode classDefinitionNode, Map<AnnotationNode, AnnotationNode> modifiedAnnotations) {
        NodeList members = classDefinitionNode.members();
        ArrayList<Object> modifiedMembers = new ArrayList<Object>();
        for (Node member : members) {
            if (member.kind() == SyntaxKind.OBJECT_METHOD_DEFINITION) {
                FunctionDefinitionNode methodDeclarationNode = (FunctionDefinitionNode)member;
                if (methodDeclarationNode.metadata().isPresent()) {
                    MetadataNode modifiedMetadata = this.modifyMetadata((MetadataNode)methodDeclarationNode.metadata().get(), modifiedAnnotations);
                    methodDeclarationNode = methodDeclarationNode.modify().withMetadata(modifiedMetadata).apply();
                }
                modifiedMembers.add(methodDeclarationNode);
                continue;
            }
            modifiedMembers.add(member);
        }
        return classDefinitionNode.modify().withMembers(NodeFactory.createNodeList(modifiedMembers)).apply();
    }

    private ModuleMemberDeclarationNode modifyVariableDeclaration(ModuleVariableDeclarationNode member, Set<ModuleVariableDeclarationNode> agentDeclarations) {
        if (!agentDeclarations.contains(member)) {
            return member;
        }
        String sourceCode = member.toSourceCode();
        int numberOfNewLines = sourceCode.length() - sourceCode.replaceAll("\\R", EMPTY_STRING).length();
        sourceCode = sourceCode.replace(member.leadingMinutiae().toString(), EMPTY_STRING);
        String modifiedSource = sourceCode.split(SyntaxKind.EQUAL_TOKEN.stringValue())[0].trim() + SyntaxKind.SEMICOLON_TOKEN.stringValue() + System.lineSeparator().repeat(numberOfNewLines);
        return NodeParser.parseModuleMemberDeclaration((String)modifiedSource);
    }

    private FunctionDefinitionNode modifyFunction(FunctionDefinitionNode functionNode, Map<AnnotationNode, AnnotationNode> modifiedAnnotations) {
        if (functionNode.metadata().isEmpty()) {
            return functionNode;
        }
        MetadataNode modifiedMetadata = this.modifyMetadata((MetadataNode)functionNode.metadata().get(), modifiedAnnotations);
        return functionNode.modify().withMetadata(modifiedMetadata).apply();
    }

    private MetadataNode modifyMetadata(MetadataNode metadata, Map<AnnotationNode, AnnotationNode> modifiedAnnotations) {
        ArrayList<AnnotationNode> updatedAnnotations = new ArrayList<AnnotationNode>();
        for (AnnotationNode annotation : metadata.annotations()) {
            updatedAnnotations.add(modifiedAnnotations.getOrDefault(annotation, annotation));
        }
        return metadata.modify().withAnnotations(NodeFactory.createNodeList(updatedAnnotations)).apply();
    }

    private void updateDocument(SourceModifierContext context, Module module, DocumentId documentId, ModulePartNode updatedRoot) {
        SyntaxTree syntaxTree = module.document(documentId).syntaxTree().modifyWith((Node)updatedRoot);
        TextDocument textDocument = syntaxTree.textDocument();
        if (module.documentIds().contains(documentId)) {
            context.modifySourceFile(textDocument, documentId);
        } else {
            context.modifyTestSourceFile(textDocument, documentId);
        }
    }
}

