/*
 * Decompiled with CFR 0.152.
 */
package io.ballerina.lib.ai.np.compilerplugin;

import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import io.ballerina.compiler.api.SemanticModel;
import io.ballerina.compiler.api.symbols.AnnotationAttachmentSymbol;
import io.ballerina.compiler.api.symbols.AnnotationSymbol;
import io.ballerina.compiler.api.symbols.ExternalFunctionSymbol;
import io.ballerina.compiler.api.symbols.ModuleSymbol;
import io.ballerina.compiler.api.symbols.TypeSymbol;
import io.ballerina.compiler.api.values.ConstantValue;
import io.ballerina.compiler.syntax.tree.AbstractNodeFactory;
import io.ballerina.compiler.syntax.tree.BaseNodeModifier;
import io.ballerina.compiler.syntax.tree.DefaultableParameterNode;
import io.ballerina.compiler.syntax.tree.ExpressionFunctionBodyNode;
import io.ballerina.compiler.syntax.tree.ExpressionNode;
import io.ballerina.compiler.syntax.tree.ExternalFunctionBodyNode;
import io.ballerina.compiler.syntax.tree.FunctionBodyNode;
import io.ballerina.compiler.syntax.tree.FunctionCallExpressionNode;
import io.ballerina.compiler.syntax.tree.FunctionDefinitionNode;
import io.ballerina.compiler.syntax.tree.ImportDeclarationNode;
import io.ballerina.compiler.syntax.tree.ImportOrgNameNode;
import io.ballerina.compiler.syntax.tree.IncludedRecordParameterNode;
import io.ballerina.compiler.syntax.tree.ModuleMemberDeclarationNode;
import io.ballerina.compiler.syntax.tree.ModulePartNode;
import io.ballerina.compiler.syntax.tree.NaturalExpressionNode;
import io.ballerina.compiler.syntax.tree.Node;
import io.ballerina.compiler.syntax.tree.NodeFactory;
import io.ballerina.compiler.syntax.tree.NodeList;
import io.ballerina.compiler.syntax.tree.NodeParser;
import io.ballerina.compiler.syntax.tree.NodeTransformer;
import io.ballerina.compiler.syntax.tree.ParameterNode;
import io.ballerina.compiler.syntax.tree.RequiredParameterNode;
import io.ballerina.compiler.syntax.tree.RestParameterNode;
import io.ballerina.compiler.syntax.tree.SeparatedNodeList;
import io.ballerina.compiler.syntax.tree.SyntaxKind;
import io.ballerina.compiler.syntax.tree.Token;
import io.ballerina.lib.ai.np.compilerplugin.CodeGenerationUtils;
import io.ballerina.lib.ai.np.compilerplugin.Commons;
import io.ballerina.projects.Document;
import io.ballerina.projects.DocumentId;
import io.ballerina.projects.Module;
import io.ballerina.projects.ModuleId;
import io.ballerina.projects.Package;
import io.ballerina.projects.ProjectKind;
import io.ballerina.projects.plugins.ModifierTask;
import io.ballerina.projects.plugins.SourceModifierContext;
import io.ballerina.tools.text.TextDocument;
import io.ballerina.tools.text.TextDocuments;
import java.io.IOException;
import java.io.PrintWriter;
import java.net.http.HttpClient;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.ballerinalang.formatter.core.Formatter;
import org.ballerinalang.formatter.core.FormatterException;

public class CompileTimePromptAsCodeCodeModificationTask
implements ModifierTask<SourceModifierContext> {
    private static final Token SEMICOLON = AbstractNodeFactory.createToken((SyntaxKind)SyntaxKind.SEMICOLON_TOKEN);
    private static final Token RIGHT_DOUBLE_ARROW = AbstractNodeFactory.createToken((SyntaxKind)SyntaxKind.RIGHT_DOUBLE_ARROW_TOKEN);
    private static final String PROMPT = "prompt";
    private static final String GENERATED_FUNCTION_SUFFIX = "NPGenerated";
    private static final String GENERATED_DIRECTORY = "generated";
    private static final String GENERATED_FUNC_FILE_NAME_SUFFIX = "_np_generated.bal";
    private static final String BAL_CODEGEN_URL = "BAL_CODEGEN_URL";
    private static final String BAL_CODEGEN_TOKEN = "BAL_CODEGEN_TOKEN";
    private static final String copilotUrl = System.getenv("BAL_CODEGEN_URL");
    private static final String copilotAccessToken = System.getenv("BAL_CODEGEN_TOKEN");

    public void modify(SourceModifierContext modifierContext) {
        Package currentPackage = modifierContext.currentPackage();
        boolean isSingleBalFileMode = currentPackage.project().kind() == ProjectKind.SINGLE_FILE_PROJECT;
        Path sourceRoot = currentPackage.project().sourceRoot();
        if (modifierContext.compilation().diagnosticResult().errorCount() > 0) {
            return;
        }
        for (ModuleId moduleId : currentPackage.moduleIds()) {
            Document document;
            Module module = currentPackage.module(moduleId);
            SemanticModel semanticModel = currentPackage.getCompilation().getSemanticModel(moduleId);
            for (DocumentId documentId : module.documentIds()) {
                document = module.document(documentId);
                if (CompileTimePromptAsCodeCodeModificationTask.npGeneratedFile(document)) {
                    modifierContext.modifySourceFile(TextDocuments.from((String)""), documentId);
                    continue;
                }
                modifierContext.modifySourceFile(CompileTimePromptAsCodeCodeModificationTask.modifyDocument(document, semanticModel, module, isSingleBalFileMode, sourceRoot), documentId);
            }
            for (DocumentId documentId : module.testDocumentIds()) {
                document = module.document(documentId);
                modifierContext.modifyTestSourceFile(CompileTimePromptAsCodeCodeModificationTask.modifyDocument(document, semanticModel, module, isSingleBalFileMode, sourceRoot), documentId);
            }
        }
    }

    private static TextDocument modifyDocument(Document document, SemanticModel semanticModel, Module module, boolean isSingleBalFileMode, Path sourceRoot) {
        ModulePartNode modulePartNode = (ModulePartNode)document.syntaxTree().rootNode();
        ArrayList<ImportDeclarationNode> newImports = new ArrayList<ImportDeclarationNode>();
        ArrayList<ModuleMemberDeclarationNode> newMembers = new ArrayList<ModuleMemberDeclarationNode>();
        CodeGenerator codeGenerator = new CodeGenerator(semanticModel, module, newImports, newMembers, isSingleBalFileMode, sourceRoot, document);
        ModulePartNode newRoot = (ModulePartNode)modulePartNode.apply((NodeTransformer)codeGenerator);
        newRoot = newRoot.modify(newRoot.imports().addAll(CompileTimePromptAsCodeCodeModificationTask.getNewImports((NodeList<ImportDeclarationNode>)newRoot.imports(), newImports)), newRoot.members().addAll(newMembers), newRoot.eofToken());
        return document.syntaxTree().modifyWith((Node)newRoot).textDocument();
    }

    private static String getGeneratedBalFileName(String originalFuncName) {
        return originalFuncName + GENERATED_FUNC_FILE_NAME_SUFFIX;
    }

    private static String getPrompt(FunctionDefinitionNode functionDefinition, SemanticModel semanticModel) {
        for (AnnotationAttachmentSymbol annotationAttachmentSymbol : ((ExternalFunctionSymbol)semanticModel.symbol((Node)functionDefinition).get()).annotAttachmentsOnExternal()) {
            AnnotationSymbol annotationSymbol = annotationAttachmentSymbol.typeDescriptor();
            Optional module = annotationSymbol.getModule();
            if (module.isEmpty() || !Commons.isLangNaturalModule((ModuleSymbol)module.get()) || !"code".equals(annotationSymbol.getName().get())) continue;
            return (String)((ConstantValue)((LinkedHashMap)((ConstantValue)annotationAttachmentSymbol.attachmentValue().get()).value()).get(PROMPT)).value();
        }
        throw new RuntimeException("cannot find the annotation");
    }

    private static FunctionCallExpressionNode createGeneratedFunctionCallExpression(FunctionDefinitionNode functionDefinition, String generatedFunctionName) {
        SeparatedNodeList parameters = functionDefinition.functionSignature().parameters();
        int size = parameters.size();
        CharSequence[] arguments = new String[size];
        for (int index = 0; index < size; ++index) {
            ParameterNode parameter = (ParameterNode)parameters.get(index);
            arguments[index] = switch (parameter.kind()) {
                case SyntaxKind.REQUIRED_PARAM -> ((Token)((RequiredParameterNode)parameter).paramName().get()).text();
                case SyntaxKind.DEFAULTABLE_PARAM -> ((Token)((DefaultableParameterNode)parameter).paramName().get()).text();
                case SyntaxKind.INCLUDED_RECORD_PARAM -> ((Token)((IncludedRecordParameterNode)parameter).paramName().get()).text();
                default -> "..." + ((Token)((RestParameterNode)parameter).paramName().get()).text();
            };
        }
        return (FunctionCallExpressionNode)NodeParser.parseExpression((String)String.format("%s(%s)", generatedFunctionName, String.join((CharSequence)", ", arguments)));
    }

    private static boolean npGeneratedFile(Document document) {
        return document.name().endsWith(GENERATED_FUNC_FILE_NAME_SUFFIX);
    }

    private static boolean hasCodeAnnotation(ExternalFunctionBodyNode externalFunctionBody, SemanticModel semanticModel) {
        return externalFunctionBody.annotations().stream().anyMatch(annotationNode -> Commons.isCodeAnnotation(annotationNode, semanticModel));
    }

    private static List<ImportDeclarationNode> getNewImports(NodeList<ImportDeclarationNode> currentImports, List<ImportDeclarationNode> importsFromGeneratedCode) {
        if (importsFromGeneratedCode.isEmpty()) {
            return importsFromGeneratedCode;
        }
        Set currentImportsModules = currentImports.stream().map(CompileTimePromptAsCodeCodeModificationTask::getModuleFQN).collect(Collectors.toSet());
        List<String> importsFromGeneratedCodeModules = importsFromGeneratedCode.stream().map(CompileTimePromptAsCodeCodeModificationTask::getModuleFQN).toList();
        ArrayList<ImportDeclarationNode> importsToAdd = new ArrayList<ImportDeclarationNode>();
        for (int i = 0; i < importsFromGeneratedCodeModules.size(); ++i) {
            String importFromGeneratedCodeModule = importsFromGeneratedCodeModules.get(i);
            if (currentImportsModules.contains(importFromGeneratedCodeModule)) continue;
            importsToAdd.add(importsFromGeneratedCode.get(i));
        }
        return importsToAdd;
    }

    private static String getModuleFQN(ImportDeclarationNode currentImport) {
        Object moduleFQN = "";
        Optional importOrgNameNode = currentImport.orgName();
        if (importOrgNameNode.isPresent()) {
            moduleFQN = ((ImportOrgNameNode)importOrgNameNode.get()).orgName().text() + "/";
        }
        return (String)moduleFQN + String.join((CharSequence)".", currentImport.moduleName().stream().map(Token::text).toList());
    }

    private static class CodeGenerator
    extends BaseNodeModifier {
        private final SemanticModel semanticModel;
        private final Module module;
        private final List<ImportDeclarationNode> newImports;
        private final List<ModuleMemberDeclarationNode> newMembers;
        private final boolean isSingleBalFileMode;
        private final Path sourceRoot;
        private final Document document;
        private HttpClient client = null;
        private JsonArray sourceFiles = null;

        public CodeGenerator(SemanticModel semanticModel, Module module, List<ImportDeclarationNode> newImports, List<ModuleMemberDeclarationNode> newMembers, boolean isSingleBalFileMode, Path sourceRoot, Document document) {
            this.semanticModel = semanticModel;
            this.module = module;
            this.newImports = newImports;
            this.newMembers = newMembers;
            this.isSingleBalFileMode = isSingleBalFileMode;
            this.sourceRoot = sourceRoot;
            this.document = document;
        }

        public FunctionDefinitionNode transform(FunctionDefinitionNode functionDefinition) {
            FunctionBodyNode functionBodyNode = functionDefinition.functionBody();
            if (!(functionBodyNode instanceof ExternalFunctionBodyNode)) {
                return (FunctionDefinitionNode)super.transform(functionDefinition);
            }
            ExternalFunctionBodyNode functionBody = (ExternalFunctionBodyNode)functionBodyNode;
            if (!CompileTimePromptAsCodeCodeModificationTask.hasCodeAnnotation(functionBody, this.semanticModel)) {
                return (FunctionDefinitionNode)super.transform(functionDefinition);
            }
            if (this.isSingleBalFileMode) {
                return (FunctionDefinitionNode)super.transform(functionDefinition);
            }
            String funcName = functionDefinition.functionName().text();
            String generatedFuncName = funcName.concat(CompileTimePromptAsCodeCodeModificationTask.GENERATED_FUNCTION_SUFFIX);
            String prompt = CompileTimePromptAsCodeCodeModificationTask.getPrompt(functionDefinition, this.semanticModel);
            String generatedCode = CodeGenerationUtils.generateCodeForFunction(copilotUrl, copilotAccessToken, funcName, generatedFuncName, prompt, this.getHttpClient(), this.getSourceFilesWithoutFileGeneratedForCurrentFunc(funcName), this.module.descriptor(), this.document.module().project().currentPackage().packageOrg().value());
            this.handleGeneratedCode(funcName, generatedCode);
            ExpressionFunctionBodyNode expressionFunctionBody = NodeFactory.createExpressionFunctionBodyNode((Token)RIGHT_DOUBLE_ARROW, (ExpressionNode)CompileTimePromptAsCodeCodeModificationTask.createGeneratedFunctionCallExpression(functionDefinition, generatedFuncName), (Token)SEMICOLON);
            return functionDefinition.modify().withFunctionBody((FunctionBodyNode)expressionFunctionBody).apply();
        }

        public ExpressionNode transform(NaturalExpressionNode naturalExpressionNode) {
            if (naturalExpressionNode.constKeyword().isEmpty()) {
                return naturalExpressionNode;
            }
            return CodeGenerationUtils.generateCodeForNaturalExpression(naturalExpressionNode, copilotUrl, copilotAccessToken, this.getHttpClient(), this.getSourceFiles(), this.semanticModel, (TypeSymbol)this.semanticModel.expectedType(this.document, naturalExpressionNode.lineRange().startLine()).get(), this.document);
        }

        private void handleGeneratedCode(String originalFuncName, String generatedCode) {
            ModulePartNode modulePartNode = NodeParser.parseModulePart((String)generatedCode);
            this.persistInGeneratedDirectory(originalFuncName, generatedCode);
            this.newImports.addAll(modulePartNode.imports().stream().toList());
            this.newMembers.addAll(modulePartNode.members().stream().toList());
        }

        private HttpClient getHttpClient() {
            if (this.client != null) {
                return this.client;
            }
            this.client = HttpClient.newHttpClient();
            return this.client;
        }

        private JsonArray getSourceFilesWithoutFileGeneratedForCurrentFunc(String originalFunctionName) {
            JsonArray sourceFiles = this.getSourceFiles();
            JsonArray filteredSourceFiles = new JsonArray(sourceFiles.size());
            for (JsonElement sourceFile : sourceFiles) {
                if (CompileTimePromptAsCodeCodeModificationTask.getGeneratedBalFileName(originalFunctionName).equals(sourceFile.getAsJsonObject().get("filePath").getAsString())) continue;
                filteredSourceFiles.add(sourceFile);
            }
            return filteredSourceFiles;
        }

        private JsonArray getSourceFiles() {
            if (this.sourceFiles != null) {
                return this.sourceFiles;
            }
            this.sourceFiles = CodeGenerator.getSourceFiles(this.module);
            return this.sourceFiles;
        }

        private static JsonArray getSourceFiles(Module module) {
            JsonArray sourceFiles = new JsonArray();
            for (DocumentId documentId : module.documentIds()) {
                Document document = module.document(documentId);
                JsonObject sourceFile = new JsonObject();
                sourceFile.addProperty("filePath", document.name());
                sourceFile.addProperty("content", String.join((CharSequence)"\n", document.textDocument().textLines()));
                sourceFiles.add((JsonElement)sourceFile);
            }
            return sourceFiles;
        }

        private void persistInGeneratedDirectory(String originalFuncName, String generatedCode) {
            Path generatedDirPath = Paths.get(this.sourceRoot.toString(), CompileTimePromptAsCodeCodeModificationTask.GENERATED_DIRECTORY);
            if (!Files.exists(generatedDirPath, new LinkOption[0])) {
                try {
                    Files.createDirectories(generatedDirPath, new FileAttribute[0]);
                }
                catch (IOException e) {
                    return;
                }
            }
            try (PrintWriter writer = new PrintWriter(Paths.get(generatedDirPath.toString(), CompileTimePromptAsCodeCodeModificationTask.getGeneratedBalFileName(originalFuncName)).toString(), StandardCharsets.UTF_8);){
                writer.println(Formatter.format((String)generatedCode));
            }
            catch (IOException | FormatterException throwable) {
                // empty catch block
            }
        }
    }
}

