/*
 * Decompiled with CFR 0.152.
 */
package org.wso2.ballerinalang.compiler.bir.codegen.methodgen;

import io.ballerina.identifier.Utils;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.ballerinalang.compiler.BLangCompilerException;
import org.ballerinalang.model.elements.PackageID;
import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.Label;
import org.objectweb.asm.MethodVisitor;
import org.wso2.ballerinalang.compiler.bir.codegen.BallerinaClassWriter;
import org.wso2.ballerinalang.compiler.bir.codegen.JarEntries;
import org.wso2.ballerinalang.compiler.bir.codegen.JvmCastGen;
import org.wso2.ballerinalang.compiler.bir.codegen.JvmCodeGenUtil;
import org.wso2.ballerinalang.compiler.bir.codegen.JvmPackageGen;
import org.wso2.ballerinalang.compiler.bir.codegen.internal.AsyncDataCollector;
import org.wso2.ballerinalang.compiler.bir.codegen.internal.LambdaClass;
import org.wso2.ballerinalang.compiler.bir.codegen.internal.LambdaFunction;
import org.wso2.ballerinalang.compiler.bir.codegen.methodgen.MethodGenUtils;
import org.wso2.ballerinalang.compiler.bir.codegen.model.BIRFunctionWrapper;
import org.wso2.ballerinalang.compiler.bir.codegen.split.constants.JvmConstantGenCommons;
import org.wso2.ballerinalang.compiler.bir.model.BIRAbstractInstruction;
import org.wso2.ballerinalang.compiler.bir.model.BIRInstruction;
import org.wso2.ballerinalang.compiler.bir.model.BIRNode;
import org.wso2.ballerinalang.compiler.bir.model.BIRNonTerminator;
import org.wso2.ballerinalang.compiler.bir.model.BIROperand;
import org.wso2.ballerinalang.compiler.bir.model.BIRTerminator;
import org.wso2.ballerinalang.compiler.bir.model.InstructionKind;
import org.wso2.ballerinalang.compiler.semantics.model.symbols.BInvokableSymbol;
import org.wso2.ballerinalang.compiler.semantics.model.symbols.BPackageSymbol;
import org.wso2.ballerinalang.compiler.semantics.model.types.BFutureType;
import org.wso2.ballerinalang.compiler.semantics.model.types.BInvokableType;
import org.wso2.ballerinalang.compiler.semantics.model.types.BType;
import org.wso2.ballerinalang.compiler.util.Name;

public class LambdaGen {
    private final JvmPackageGen jvmPackageGen;
    private final JvmCastGen jvmCastGen;
    private final BIRNode.BIRPackage module;

    public LambdaGen(JvmPackageGen jvmPackageGen, JvmCastGen jvmCastGen, BIRNode.BIRPackage module) {
        this.jvmPackageGen = jvmPackageGen;
        this.jvmCastGen = jvmCastGen;
        this.module = module;
    }

    public void generateLambdaClasses(AsyncDataCollector asyncDataCollector, JarEntries jarEntries) {
        Map<String, LambdaClass> lambdaClasses = asyncDataCollector.getLambdaClasses();
        if (lambdaClasses.isEmpty()) {
            return;
        }
        for (Map.Entry<String, LambdaClass> entry : lambdaClasses.entrySet()) {
            String lambdaClassName = entry.getKey();
            LambdaClass lambdaClass = entry.getValue();
            BallerinaClassWriter cw = new BallerinaClassWriter(2);
            cw.visitSource(lambdaClass.sourceFileName, null);
            this.generateConstantsClassInit(cw, lambdaClassName);
            List<LambdaFunction> lambdaList = lambdaClass.lambdaFunctionList;
            for (LambdaFunction recordDefaultValueLambda : lambdaList) {
                this.generateLambdaMethod(recordDefaultValueLambda.callInstruction, cw, recordDefaultValueLambda.lambdaName, lambdaClassName);
            }
            cw.visitEnd();
            jarEntries.put(lambdaClassName + ".class", cw.toByteArray());
        }
    }

    private void generateConstantsClassInit(ClassWriter cw, String lambdaClassName) {
        cw.visit(65, 33, lambdaClassName, null, "java/lang/Object", null);
        MethodVisitor methodVisitor = cw.visitMethod(2, "<init>", "()V", null, null);
        methodVisitor.visitCode();
        Label methodStartLabel = new Label();
        methodVisitor.visitLabel(methodStartLabel);
        methodVisitor.visitVarInsn(25, 0);
        methodVisitor.visitMethodInsn(183, "java/lang/Object", "<init>", "()V", false);
        Label methodEndLabel = new Label();
        methodVisitor.visitLabel(methodEndLabel);
        methodVisitor.visitLocalVariable("self", "Ljava/lang/Object;", null, methodStartLabel, methodEndLabel, 0);
        JvmConstantGenCommons.genMethodReturn(methodVisitor);
    }

    private void generateLambdaMethod(BIRInstruction ins, ClassWriter cw, String lambdaName, String className) {
        LambdaDetails lambdaDetails = this.getLambdaDetails(ins);
        boolean isSamePkg = JvmCodeGenUtil.isSameModule(this.module.packageID, lambdaDetails.packageID);
        MethodVisitor mv = this.getMethodVisitorAndLoadFirst(cw, lambdaName, lambdaDetails, ins, isSamePkg);
        ArrayList<BType> paramBTypes = new ArrayList<BType>();
        if (ins.getKind() == InstructionKind.ASYNC_CALL) {
            this.handleAsyncCallLambda((BIRTerminator.AsyncCall)ins, lambdaDetails, mv, paramBTypes, isSamePkg);
        } else {
            this.handleFpLambda((BIRNonTerminator.FPLoad)ins, lambdaDetails, mv, paramBTypes, isSamePkg);
        }
        MethodGenUtils.visitReturn(mv, lambdaName, className);
    }

    private void genNonVirtual(LambdaDetails lambdaDetails, MethodVisitor mv, List<BType> paramBTypes, boolean isSamePkg) {
        if (!isSamePkg) {
            String jvmClass = JvmCodeGenUtil.getModuleLevelClassName(lambdaDetails.packageID, "creators/$_function_calls");
            String funcName = "call";
            String methodDesc = "(Lio/ballerina/runtime/internal/scheduling/Strand;Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/Object;";
            mv.visitMethodInsn(184, jvmClass, funcName, methodDesc, false);
            return;
        }
        String jvmClass = lambdaDetails.functionWrapper != null ? lambdaDetails.functionWrapper.fullQualifiedClassName() : JvmCodeGenUtil.getModuleLevelClassName(lambdaDetails.packageID, "creators/$_function_calls");
        String methodDesc = this.getLambdaMethodDesc(paramBTypes, lambdaDetails.returnType, lambdaDetails.closureMapsCount);
        mv.visitMethodInsn(184, jvmClass, lambdaDetails.encodedFuncName, methodDesc, false);
        this.jvmCastGen.addBoxInsn(mv, lambdaDetails.returnType);
    }

    private void handleAsyncCallLambda(BIRTerminator.AsyncCall ins, LambdaDetails lambdaDetails, MethodVisitor mv, List<BType> paramBTypes, boolean isSamePkg) {
        if (ins.isVirtual) {
            this.handleLambdaVirtual(ins, lambdaDetails, mv);
        } else {
            this.handleAsyncNonVirtual(lambdaDetails, mv, paramBTypes, isSamePkg);
        }
    }

    private void handleLambdaVirtual(BIRTerminator.AsyncCall ins, LambdaDetails lambdaDetails, MethodVisitor mv) {
        List paramTypes = ins.args;
        this.genLoadDataForObjectAttachedLambdas(ins, mv, lambdaDetails.closureMapsCount, paramTypes);
        int paramIndex = 1;
        for (int paramTypeIndex = 1; paramTypeIndex < paramTypes.size(); ++paramTypeIndex) {
            this.generateObjectArgs(mv, paramIndex);
            ++paramIndex;
        }
        mv.visitMethodInsn(185, "io/ballerina/runtime/api/values/BObject", "call", "(Lio/ballerina/runtime/internal/scheduling/Strand;Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/Object;", true);
    }

    private void genLoadDataForObjectAttachedLambdas(BIRTerminator.AsyncCall ins, MethodVisitor mv, int closureMapsCount, List<BIROperand> paramTypes) {
        mv.visitInsn(87);
        mv.visitVarInsn(25, closureMapsCount);
        mv.visitInsn(4);
        BIROperand ref = (BIROperand)ins.args.getFirst();
        mv.visitInsn(50);
        this.jvmCastGen.addUnboxInsn(mv, ref.variableDcl.type);
        mv.visitVarInsn(25, closureMapsCount);
        mv.visitInsn(3);
        mv.visitInsn(50);
        mv.visitTypeInsn(192, "io/ballerina/runtime/internal/scheduling/Strand");
        mv.visitLdcInsn((Object)JvmCodeGenUtil.rewriteVirtualCallTypeName(ins.name.value, ref.variableDcl.type));
        int objectArrayLength = paramTypes.size() - 1;
        mv.visitIntInsn(16, objectArrayLength);
        mv.visitTypeInsn(189, "java/lang/Object");
    }

    private void generateObjectArgs(MethodVisitor mv, int paramIndex) {
        mv.visitInsn(89);
        mv.visitIntInsn(16, paramIndex - 1);
        mv.visitVarInsn(25, 0);
        mv.visitIntInsn(16, paramIndex + 1);
        mv.visitInsn(50);
        mv.visitInsn(83);
    }

    private void generateFpCallArgs(MethodVisitor mv, int paramIndex) {
        mv.visitInsn(89);
        mv.visitIntInsn(16, paramIndex);
        mv.visitVarInsn(25, 0);
        mv.visitIntInsn(16, paramIndex + 1);
        mv.visitInsn(50);
        mv.visitInsn(83);
    }

    private void handleAsyncNonVirtual(LambdaDetails lambdaDetails, MethodVisitor mv, List<BType> paramBTypes, boolean isSamePkg) {
        List<BType> paramTypes = this.getFpParamTypes(lambdaDetails);
        if (isSamePkg) {
            int argIndex = 1;
            for (BType paramType : paramTypes) {
                mv.visitVarInsn(25, 0);
                mv.visitIntInsn(16, argIndex);
                mv.visitInsn(50);
                this.jvmCastGen.addUnboxInsn(mv, paramType);
                paramBTypes.add(argIndex - 1, paramType);
                ++argIndex;
            }
        } else {
            mv.visitIntInsn(16, paramTypes.size());
            mv.visitTypeInsn(189, "java/lang/Object");
            for (int paramIndex = 0; paramIndex < paramTypes.size(); ++paramIndex) {
                this.generateFpCallArgs(mv, paramIndex);
            }
        }
        this.genNonVirtual(lambdaDetails, mv, paramBTypes, isSamePkg);
    }

    private List<BType> getFpParamTypes(LambdaDetails lambdaDetails) {
        List<BType> paramTypes;
        if (lambdaDetails.functionWrapper != null) {
            paramTypes = this.getInitialParamTypes(lambdaDetails.functionWrapper.func().type.paramTypes, lambdaDetails.functionWrapper.func().argsCount);
        } else {
            BInvokableType type = (BInvokableType)lambdaDetails.funcSymbol.type;
            if (type.restType == null) {
                return type.paramTypes;
            }
            paramTypes = new ArrayList<BType>(type.paramTypes);
            paramTypes.add(type.restType);
        }
        return paramTypes;
    }

    private void handleFpLambda(BIRNonTerminator.FPLoad ins, LambdaDetails lambdaDetails, MethodVisitor mv, List<BType> paramBTypes, boolean isSamePkg) {
        this.loadClosureMaps(lambdaDetails, mv);
        this.loadAndCastParamValues(ins, lambdaDetails, mv, paramBTypes, isSamePkg);
        this.genNonVirtual(lambdaDetails, mv, paramBTypes, isSamePkg);
    }

    private void loadAndCastParamValues(BIRNonTerminator.FPLoad ins, LambdaDetails lambdaDetails, MethodVisitor mv, List<BType> paramBTypes, boolean isSamePkg) {
        if (isSamePkg) {
            int argIndex = 1;
            for (BIRNode.BIRVariableDcl dcl : ins.params) {
                mv.visitVarInsn(25, lambdaDetails.closureMapsCount);
                mv.visitIntInsn(16, argIndex);
                mv.visitInsn(50);
                this.jvmCastGen.addUnboxInsn(mv, dcl.type);
                paramBTypes.add(argIndex - 1, dcl.type);
                ++argIndex;
            }
        } else {
            mv.visitIntInsn(16, ins.params.size());
            mv.visitTypeInsn(189, "java/lang/Object");
            for (int paramIndex = 0; paramIndex < ins.params.size(); ++paramIndex) {
                this.generateFpCallArgs(mv, paramIndex);
            }
        }
    }

    private void loadClosureMaps(LambdaDetails lambdaDetails, MethodVisitor mv) {
        for (int i = 0; i < lambdaDetails.closureMapsCount; ++i) {
            mv.visitVarInsn(25, i);
        }
    }

    private MethodVisitor getMethodVisitorAndLoadFirst(ClassWriter cw, String lambdaName, LambdaDetails lambdaDetails, BIRInstruction ins, boolean isSamePkg) {
        String closureMapsDesc = this.getMapValueDesc(lambdaDetails.closureMapsCount);
        MethodVisitor mv = cw.visitMethod(9, lambdaName, "(" + closureMapsDesc + "[Ljava/lang/Object;)Ljava/lang/Object;", null, null);
        mv.visitCode();
        JvmCodeGenUtil.generateDiagnosticPos(((BIRAbstractInstruction)ins).pos, mv);
        mv.visitVarInsn(25, lambdaDetails.closureMapsCount);
        mv.visitInsn(3);
        mv.visitInsn(50);
        mv.visitTypeInsn(192, "io/ballerina/runtime/internal/scheduling/Strand");
        if (!isSamePkg) {
            mv.visitLdcInsn((Object)lambdaDetails.encodedFuncName);
        }
        return mv;
    }

    private String getMapValueDesc(int count) {
        StringBuilder desc = new StringBuilder();
        for (int i = 0; i < count; ++i) {
            desc.append("L").append("io/ballerina/runtime/internal/values/MapValue").append(";");
        }
        return desc.toString();
    }

    private LambdaDetails getLambdaDetails(BIRInstruction ins) {
        LambdaDetails lambdaDetails;
        InstructionKind kind = ins.getKind();
        if (kind == InstructionKind.ASYNC_CALL) {
            lambdaDetails = this.populateAsyncLambdaDetails((BIRTerminator.AsyncCall)ins);
        } else if (kind == InstructionKind.FP_LOAD) {
            lambdaDetails = this.populateFpLambdaDetails((BIRNonTerminator.FPLoad)ins);
        } else {
            throw new BLangCompilerException("JVM lambda method generation is not supported for instruction " + String.valueOf(ins));
        }
        lambdaDetails.isExternFunction = this.isExternStaticFunctionCall(ins);
        this.populateLambdaReturnType(ins, lambdaDetails);
        return lambdaDetails;
    }

    private LambdaDetails populateAsyncLambdaDetails(BIRTerminator.AsyncCall asyncIns) {
        LambdaDetails lambdaDetails = new LambdaDetails();
        lambdaDetails.lhsType = asyncIns.lhsOp != null ? asyncIns.lhsOp.variableDcl.type : null;
        lambdaDetails.packageID = asyncIns.calleePkg;
        lambdaDetails.funcName = asyncIns.name.getValue();
        lambdaDetails.encodedFuncName = Utils.encodeFunctionIdentifier((String)lambdaDetails.funcName);
        if (!asyncIns.isVirtual) {
            this.populateLambdaFunctionDetails(lambdaDetails);
        }
        return lambdaDetails;
    }

    private LambdaDetails populateFpLambdaDetails(BIRNonTerminator.FPLoad fpIns) {
        LambdaDetails lambdaDetails = new LambdaDetails();
        lambdaDetails.lhsType = fpIns.lhsOp.variableDcl.type;
        lambdaDetails.packageID = fpIns.pkgId;
        lambdaDetails.funcName = fpIns.funcName.getValue();
        lambdaDetails.closureMapsCount = fpIns.closureMaps.size();
        this.populateLambdaFunctionDetails(lambdaDetails);
        return lambdaDetails;
    }

    private void populateLambdaFunctionDetails(LambdaDetails lambdaDetails) {
        lambdaDetails.encodedFuncName = Utils.encodeFunctionIdentifier((String)lambdaDetails.funcName);
        lambdaDetails.lookupKey = JvmCodeGenUtil.getPackageName(lambdaDetails.packageID) + lambdaDetails.encodedFuncName;
        lambdaDetails.functionWrapper = this.jvmPackageGen.lookupBIRFunctionWrapper(lambdaDetails.lookupKey);
        if (lambdaDetails.functionWrapper == null) {
            BPackageSymbol symbol = this.jvmPackageGen.packageCache.getSymbol(String.valueOf(lambdaDetails.packageID.orgName) + "/" + String.valueOf(lambdaDetails.packageID.name));
            lambdaDetails.funcSymbol = (BInvokableSymbol)symbol.scope.lookup((Name)new Name((String)lambdaDetails.funcName)).symbol;
        }
    }

    private boolean isExternStaticFunctionCall(BIRInstruction callIns) {
        PackageID packageID;
        String methodName;
        InstructionKind kind = callIns.getKind();
        switch (kind) {
            case CALL: {
                BIRTerminator.Call call = (BIRTerminator.Call)callIns;
                if (call.isVirtual) {
                    return false;
                }
                methodName = call.name.value;
                packageID = call.calleePkg;
                break;
            }
            case ASYNC_CALL: {
                BIRTerminator.AsyncCall asyncCall = (BIRTerminator.AsyncCall)callIns;
                methodName = asyncCall.name.value;
                packageID = asyncCall.calleePkg;
                break;
            }
            case FP_LOAD: {
                BIRNonTerminator.FPLoad fpLoad = (BIRNonTerminator.FPLoad)callIns;
                methodName = fpLoad.funcName.value;
                packageID = fpLoad.pkgId;
                break;
            }
            default: {
                throw new BLangCompilerException("JVM static function call generation is not supported for instruction " + String.valueOf(callIns));
            }
        }
        String key = JvmCodeGenUtil.getPackageName(packageID) + methodName;
        BIRFunctionWrapper functionWrapper = this.jvmPackageGen.lookupBIRFunctionWrapper(key);
        return functionWrapper != null && JvmCodeGenUtil.isExternFunc(functionWrapper.func());
    }

    private void populateLambdaReturnType(BIRInstruction ins, LambdaDetails lambdaDetails) {
        BType lhsType = JvmCodeGenUtil.getImpliedType(lambdaDetails.lhsType);
        if (lhsType.tag == 32) {
            lambdaDetails.returnType = ((BFutureType)lhsType).constraint;
        } else if (ins instanceof BIRNonTerminator.FPLoad) {
            lambdaDetails.returnType = ((BInvokableType)((BIRNonTerminator.FPLoad)ins).type).retType;
        } else {
            throw new BLangCompilerException("JVM generation is not supported for async return type " + String.valueOf(lambdaDetails.lhsType));
        }
    }

    private String getLambdaMethodDesc(List<BType> paramTypes, BType retType, int closureMapsCount) {
        StringBuilder desc = new StringBuilder("(Lio/ballerina/runtime/internal/scheduling/Strand;");
        this.appendClosureMaps(closureMapsCount, desc);
        this.appendParamTypes(paramTypes, desc);
        desc.append(JvmCodeGenUtil.generateReturnType(retType, this.jvmCastGen.typeEnv()));
        return desc.toString();
    }

    private void appendParamTypes(List<BType> paramTypes, StringBuilder desc) {
        for (BType paramType : paramTypes) {
            desc.append(JvmCodeGenUtil.getArgTypeSignature(paramType));
        }
    }

    private void appendClosureMaps(int closureMapsCount, StringBuilder desc) {
        for (int j = 0; j < closureMapsCount; ++j) {
            desc.append("L").append("io/ballerina/runtime/internal/values/MapValue").append(";");
        }
    }

    private List<BType> getInitialParamTypes(List<BType> paramTypes, int argsCount) {
        ArrayList<BType> initialParamTypes = new ArrayList<BType>();
        for (int index = 0; index < argsCount; ++index) {
            initialParamTypes.add(paramTypes.get(index));
        }
        return initialParamTypes;
    }

    private static class LambdaDetails {
        BType lhsType;
        PackageID packageID;
        String funcName;
        boolean isExternFunction;
        String encodedFuncName = null;
        String lookupKey;
        BIRFunctionWrapper functionWrapper = null;
        BInvokableSymbol funcSymbol = null;
        BType returnType;
        int closureMapsCount = 0;

        private LambdaDetails() {
        }
    }
}

