// Copyright (c) 2023 WSO2 LLC (http://www.wso2.com).
//
// WSO2 LLC. licenses this file to you under the Apache License,
// Version 2.0 (the "License"); you may not use this file except
// in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

import ai.observe;

import ballerina/io;
import ballerina/log;

# Execution progress record
type ExecutionProgress record {|
    # Unique identifier for this execution
    string executionId;
    # Question to the agent
    string query;
    # Instruction used by the agent during the execution
    string instruction;
    # Execution history up to the current action
    ExecutionStep[] history = [];
    # Contextual information to be used by the tools during the execution
    Context context;
|};

# Execution step information
public type ExecutionStep record {|
    # Response generated by the LLM
    json llmResponse;
    # Observations produced by the tool during the execution
    anydata|error observation;
|};

# Execution step information
public type ExecutionResult record {|
    # Tool decided by the LLM during the reasoning
    LlmToolResponse tool;
    # Observations produced by the tool during the execution
    anydata|error observation;
|};

public type ExecutionError record {|
    # Response generated by the LLM
    json llmResponse;
    # Error caused during the execution
    LlmInvalidGenerationError|ToolExecutionError|MemoryError 'error;
    # Observation on the caused error as additional instruction to the LLM
    string observation;
|};

# An chat response by the LLM
type LlmChatResponse record {|
    # A text response to the question
    string content;
|};

# Tool selected by LLM to be performed by the agent
public type LlmToolResponse record {|
    # Name of the tool to selected
    string name;
    # Input to the tool
    map<json>? arguments = {};
    # Identifier for the tool call
    string id?;
|};

# Output from executing an action
public type ToolOutput record {|
    # Output value the tool
    anydata|error value;
|};

type BaseAgent distinct isolated object {
    ModelProvider model;
    ToolStore toolStore;
    Memory memory;
    boolean stateless;

    # Parse the llm response and extract the tool to be executed.
    #
    # + llmResponse - Raw LLM response
    # + return - A record containing the tool decided by the LLM, chat response or an error if the response is invalid
    isolated function parseLlmResponse(json llmResponse) returns LlmToolResponse|LlmChatResponse|LlmInvalidGenerationError;

    # Use LLM to decide the next tool/step.
    #
    # + progress - Execution progress with the current query and execution history
    # + sessionId - The ID associated with the agent memory
    # + return - LLM response containing the tool or chat response (or an error if the call fails)
    isolated function selectNextTool(ExecutionProgress progress, string sessionId = DEFAULT_SESSION_ID) returns json|Error;

    isolated function run(string query, string instruction, int maxIter = 5, boolean verbose = true,
            string sessionId = DEFAULT_SESSION_ID, Context context = new, string executionId = DEFAULT_EXECUTION_ID)
            returns ExecutionTrace;
};

# An iterator to iterate over agent's execution
class Iterator {
    *object:Iterable;
    private final Executor executor;

    # Initialize the iterator with the agent and the query.
    #
    # + agent - Agent instance to be executed
    # + sessionId - The ID associated with the agent memory
    # + query - Natural language query to be executed by the agent
    # + context - Contextual information to be used by the tools during the execution
    isolated function init(BaseAgent agent, string sessionId, *ExecutionProgress progress) {
        self.executor = new (agent, sessionId, progress);
    }

    # Iterate over the agent's execution steps.
    # + return - a record with the execution step or an error if the agent failed
    public function iterator() returns object {
        public function next() returns record {|ExecutionResult|LlmChatResponse|ExecutionError|Error value;|}?;
    } {
        return self.executor;
    }
}

# An executor to perform step-by-step execution of the agent.
class Executor {
    private boolean isCompleted = false;
    private final string sessionId;
    private final BaseAgent agent;
    # Contains the current execution progress for the agent and the query
    public ExecutionProgress progress;

    # Initialize the executor with the agent and the query.
    #
    # + agent - Agent instance to be executed
    # + query - Natural language query to be executed by the agent
    # + history - Execution history of the agent (This is used to continue an execution paused without completing)
    # + context - Contextual information to be used by the tools during the execution
    isolated function init(BaseAgent agent, string sessionId, *ExecutionProgress progress) {
        self.sessionId = sessionId;
        self.agent = agent;
        self.progress = progress;
    }

    # Checks whether agent has more steps to execute.
    #
    # + return - True if agent has more steps to execute, false otherwise
    public isolated function hasNext() returns boolean {
        return !self.isCompleted;
    }

    # Reason the next step of the agent.
    #
    # + return - generated LLM response during the reasoning or an error if the reasoning fails
    public isolated function reason() returns json|Error {
        if self.isCompleted {
            return error TaskCompletedError("Task is already completed. No more reasoning is needed.");
        }
        log:printDebug("LLM reasoning started",
            executionId = self.progress.executionId,
            sessionId = self.sessionId,
            history = self.progress.history.toString()
        );
        return check self.agent.selectNextTool(self.progress, self.sessionId);
    }

    # Execute the next step of the agent.
    #
    # + llmResponse - LLM response containing the tool to be executed and the raw LLM output
    # + return - Observations from the tool can be any|error|null
    public isolated function act(json llmResponse) returns ExecutionResult|LlmChatResponse|ExecutionError {
        LlmToolResponse|LlmChatResponse|LlmInvalidGenerationError parsedOutput = self.agent.parseLlmResponse(llmResponse);
        if parsedOutput is LlmChatResponse {
            log:printDebug("Parsed LLM response as chat response",
                executionId = self.progress.executionId,
                sessionId = self.sessionId,
                response = parsedOutput.content
            );
            self.isCompleted = true;
            return parsedOutput;
        }

        anydata observation;
        ExecutionResult|ExecutionError executionResult;
        if parsedOutput is LlmToolResponse {
            string toolName = parsedOutput.name;
            log:printDebug("Parsed LLM response as tool call",
                executionId = self.progress.executionId,
                sessionId = self.sessionId,
                toolName = toolName,
                arguments = parsedOutput.arguments
            );
            observe:ExecuteToolSpan span = observe:createExecuteToolSpan(toolName);
            string? toolCallId = parsedOutput.id;
            if toolCallId is string {
                span.addId(toolCallId);
            }
            string? toolDescription = self.agent.toolStore.getToolDescription(toolName);
            if toolDescription is string {
                span.addDescription(toolDescription);

            }
            span.addType(self.agent.toolStore.isMcpTool(toolName) ? observe:EXTENTION : observe:FUNCTION);
            span.addArguments(parsedOutput.arguments);

            ToolOutput|ToolExecutionError|LlmInvalidGenerationError output = self.agent.toolStore.execute(parsedOutput,
                self.progress.context);
            if output is Error {
                if output is ToolNotFoundError {
                    observation = "Tool is not found. Please check the tool name and retry.";
                } else if output is ToolInvalidInputError {
                    observation = "Tool execution failed due to invalid inputs. Retry with correct inputs.";
                } else {
                    observation = "Tool execution failed. Retry with correct inputs.";
                }
                observation = string `${observation.toString()} <detail>${output.toString()}</detail>`;
                executionResult = {
                    llmResponse,
                    'error: output,
                    observation: observation.toString()
                };

                log:printDebug("Tool execution resulted in error",
                    executionId = self.progress.executionId,
                    observation = observation.toString(),
                    sessionId = self.sessionId,
                    toolName = toolName
                );

                Error toolExecutionError = error Error(observation.toString(), details = {parsedOutput});
                span.close(toolExecutionError);
            } else {
                anydata|error value = output.value;
                observation = value is error ? value.toString() : value;
                log:printDebug("Tool execution successful",
                    executionId = self.progress.executionId,
                    sessionId = self.sessionId,
                    toolName = toolName,
                    output = observation
                );
                executionResult = {
                    tool: parsedOutput,
                    observation: value
                };

                span.addOutput(observation);
                span.close();
            }
        } else {
            log:printDebug("Failed to parse LLM response as valid tool or chat",
                executionId = self.progress.executionId,
                sessionId = self.sessionId,
                errorMessage = parsedOutput.message()
            );
            observation = "Tool extraction failed due to invalid JSON_BLOB. Retry with correct JSON_BLOB.";
            executionResult = {
                llmResponse,
                'error: parsedOutput,
                observation: observation.toString()
            };
        }
        self.update({
            llmResponse,
            observation
        });
        return executionResult;
    }

    # Update the agent with an execution step.
    #
    # + step - Latest step to be added to the history
    public isolated function update(ExecutionStep step) {
        self.progress.history.push(step);
    }

    # Reason and execute the next step of the agent.
    #
    # + return - A record with ExecutionResult, chat response or an error 
    public isolated function next() returns record {|ExecutionResult|LlmChatResponse|ExecutionError|Error value;|}? {
        if self.isCompleted {
            return ();
        }
        json|Error llmResponse = self.reason();
        if llmResponse is Error {
            return {value: llmResponse};
        }
        return {value: self.act(llmResponse)};
    }
}

# Execute the agent for a given user's query.
#
# + agent - Agent to be executed
# + instruction - Instruction that the agent uses to execute the task
# + query - Natural langauge commands to the agent  
# + maxIter - No. of max iterations that agent will run to execute the task (default: 5)
# + context - Context values to be used by the agent to execute the task
# + verbose - If true, then print the reasoning steps (default: true)
# + sessionId - The ID associated with the memory
# + executionId - Unique identifier for this execution
# + return - Returns the execution steps tracing the agent's reasoning and outputs from the tools
isolated function run(BaseAgent agent, string instruction, string query, int maxIter, boolean verbose,
        string sessionId = DEFAULT_SESSION_ID, Context context = new, string executionId = DEFAULT_EXECUTION_ID)
        returns ExecutionTrace {
    lock {
        log:printDebug("Agent execution loop started",
            executionId = executionId,
            sessionId = sessionId,
            maxIterations = maxIter,
            tools = agent.toolStore.tools.toString(),
            isStateless = agent.stateless
        );

        (ExecutionResult|ExecutionError|Error)[] steps = [];

        string? content = ();
        Iterator iterator = new (agent, sessionId, instruction = instruction, query = query, context = context, executionId = executionId);
        int iter = 0;
        ChatSystemMessage systemMessage = {role: SYSTEM, content: instruction};
        updateMemory(agent.memory, sessionId, systemMessage);

        ChatUserMessage userMessage = {role: USER, content: query};
        updateMemory(agent.memory, sessionId, userMessage);

        ChatMessage[] temporaryMemory = [];
        foreach ExecutionResult|LlmChatResponse|ExecutionError|Error step in iterator {
            if iter == maxIter {
                log:printDebug("Maximum iterations reached without final answer",
                    executionId = executionId,
                    iterations = iter,
                    stepsCompleted = steps.length(),
                    sessionId = sessionId
                );
                break;
            }
            if step is Error {
                error? cause = step.cause();
                log:printDebug("Error occurred during agent iteration",
                    step,
                    executionId = executionId,
                    iteration = iter,
                    sessionId = sessionId,
                    cause = cause !is () ? cause.toString() : "none"
                );
                steps.push(step);
                break;
            }
            if step is LlmChatResponse {
                content = step.content;
                log:printDebug("Final answer generated by agent",
                    executionId = executionId,
                    iteration = iter,
                    answer = step.content,
                    sessionId = sessionId
                );
                if verbose {
                    io:println(string `${"\n\n"}Final Answer: ${step.content}${"\n\n"}`);
                }
                ChatAssistantMessage assistantMessage = {role: "assistant", content: step.content};
                temporaryMemory.push(assistantMessage);
                break;
            }
            iter += 1;
            log:printDebug("Agent iteration started",
                executionId = executionId,
                iteration = iter,
                maxIterations = maxIter,
                stepsCompleted = steps.length(),
                sessionId = sessionId
            );
            if verbose {
                io:println(string `${"\n\n"}Agent Iteration ${iter.toString()}`);
                if step is ExecutionResult {
                    LlmToolResponse tool = step.tool;
                    io:println(string `Action:
    ${BACKTICKS}
    {
        ${ACTION_NAME_KEY}: ${tool.name},
        ${ACTION_ARGUEMENTS_KEY}: ${(tool.arguments ?: "None").toString()}
    }
    ${BACKTICKS}`);
                    anydata|error observation = step?.observation;
                    if observation is error {
                        io:println(string `${OBSERVATION_KEY} (Error): ${observation.toString()}`);
                    } else if observation !is () {
                        io:println(string `${OBSERVATION_KEY}: ${observation.toString()}`);
                    }
                } else {
                    error? cause = step.'error.cause();
                    io:println(string `LLM Generation Error: 
    ${BACKTICKS}
    {
        message: ${step.'error.message()},
        cause: ${(cause is error ? cause.message() : "Unspecified")},
        llmResponse: ${step.llmResponse.toString()}
    }
    ${BACKTICKS}`);
                }
            }
            updateExecutionResultInMemory(step, temporaryMemory);
            steps.push(step);
        }

        foreach ChatMessage message in temporaryMemory {
            updateMemory(agent.memory, sessionId, message);
        }

        if agent.stateless {
            MemoryError? err = agent.memory.delete(sessionId);
            // Ignore this error since the stateless agent always relies on DefaultMessageWindowChatMemoryManager,  
            // which never return an error.
        }
        return {steps, answer: content};
    }
}

isolated function getObservationString(anydata|error observation) returns string {
    if observation is () {
        return "Tool didn't return anything. Probably it is successful. Should we verify using another tool?";
    } else if observation is error {
        record {|string message; string cause?;|} errorInfo = {
            message: observation.message().trim()
        };
        error? cause = observation.cause();
        if cause is error {
            errorInfo.cause = cause.message().trim();
        }
        return "Error occured while trying to execute the tool: " + errorInfo.toString();
    } else {
        return observation.toString().trim();
    }
}

# Get the tools registered with the agent.
#
# + agent - Agent instance
# + return - Array of tools registered with the agent
public isolated function getTools(Agent agent) returns Tool[] => agent.functionCallAgent.toolStore.tools.toArray();

isolated function updateMemory(Memory memory, string sessionId, ChatMessage message) {
    error? updationStation = memory.update(sessionId, message);
    if updationStation is error {
        log:printError("Error occured while updating the memory", updationStation);
    }
}

isolated function updateExecutionResultInMemory(ExecutionResult|LlmChatResponse|ExecutionError|Error step, ChatMessage[] temporaryMemory) {
    if step is ExecutionResult {
        LlmToolResponse tool = step.tool;
        anydata|error observation = step?.observation;

        ChatAssistantMessage assistantMessage = {
            role: ASSISTANT,
            toolCalls: [{name: tool.name, id: tool.id, arguments: tool.arguments}]
        };
        temporaryMemory.push(assistantMessage);

        ChatFunctionMessage functionMessage = {
            role: FUNCTION,
            name: tool.name,
            content: observation is error ?
                observation.toString() : observation is () ? "" : observation.toString(),
            id: tool.id
        };
        temporaryMemory.push(functionMessage);
    }
}
