/*
 * Decompiled with CFR 0.152.
 */
package org.wso2.ballerinalang.compiler.bir.optimizer;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.wso2.ballerinalang.compiler.bir.model.BIRNode;
import org.wso2.ballerinalang.compiler.bir.model.BIRNonTerminator;
import org.wso2.ballerinalang.compiler.bir.model.BIROperand;
import org.wso2.ballerinalang.compiler.bir.model.BIRTerminator;
import org.wso2.ballerinalang.compiler.bir.model.BIRVisitor;
import org.wso2.ballerinalang.compiler.bir.model.InstructionKind;
import org.wso2.ballerinalang.compiler.bir.model.VarKind;
import org.wso2.ballerinalang.compiler.bir.model.VarScope;
import org.wso2.ballerinalang.compiler.semantics.analyzer.Types;
import org.wso2.ballerinalang.compiler.semantics.model.types.BRecordType;
import org.wso2.ballerinalang.compiler.semantics.model.types.BType;
import org.wso2.ballerinalang.compiler.util.Name;

public class BIRRecordValueOptimizer
extends BIRVisitor {
    private final List<BIROperand> recordOperandList = new ArrayList<BIROperand>();
    private final Map<BIROperand, BRecordType> recordOperandTypeMap = new HashMap<BIROperand, BRecordType>();
    private BIRNode.BIRBasicBlock lastBB = null;
    private BIRNode.BIRFunction currentFunction = null;
    private List<BIRNode.BIRBasicBlock> newBBs = new ArrayList<BIRNode.BIRBasicBlock>();
    private List<BIRNode.BIRFunction> moduleFunctions = new ArrayList<BIRNode.BIRFunction>();
    private final Map<String, BIRNode.BIRVariableDcl> typecastVars = new HashMap<String, BIRNode.BIRVariableDcl>();
    private boolean fpRemoved = false;
    private boolean valueCreated = false;

    public void optimizeNode(BIRNode node) {
        node.accept(this);
    }

    @Override
    public void visit(BIRNode.BIRPackage birPackage) {
        this.moduleFunctions = birPackage.functions;
        birPackage.typeDefs.forEach(tDef -> tDef.accept(this));
        birPackage.functions.forEach(func -> func.accept(this));
    }

    @Override
    public void visit(BIRNode.BIRTypeDefinition birTypeDefinition) {
        birTypeDefinition.attachedFuncs.forEach(func -> func.accept(this));
    }

    @Override
    public void visit(BIRNode.BIRFunction birFunction) {
        this.currentFunction = birFunction;
        birFunction.basicBlocks.forEach(bb -> bb.accept(this));
        birFunction.basicBlocks = this.newBBs;
        this.newBBs = new ArrayList<BIRNode.BIRBasicBlock>();
        this.typecastVars.clear();
    }

    @Override
    public void visit(BIRNode.BIRBasicBlock basicBlock) {
        List<BIRNonTerminator> instructions = basicBlock.instructions;
        for (BIRNonTerminator inst : instructions) {
            if (Objects.requireNonNull(inst.kind) == InstructionKind.NEW_TYPEDESC) {
                this.handleNewTypeDesc(inst);
                continue;
            }
            if (inst.kind != InstructionKind.NEW_STRUCTURE) continue;
            this.handleNewStructure((BIRNonTerminator.NewStructure)inst);
        }
        if (!this.fpRemoved) {
            this.newBBs.add(basicBlock);
        } else {
            this.lastBB.instructions.addAll(basicBlock.instructions);
        }
        if (basicBlock.terminator.kind == InstructionKind.FP_CALL) {
            this.handleFPCall(basicBlock);
        } else if (this.fpRemoved && this.valueCreated) {
            this.resetBasicBlock(basicBlock);
        }
    }

    private void handleFPCall(BIRNode.BIRBasicBlock basicBlock) {
        BIRTerminator.FPCall fpCall = (BIRTerminator.FPCall)basicBlock.terminator;
        BIROperand recOperand = this.recordOperandList.isEmpty() ? null : this.recordOperandList.get(this.recordOperandList.size() - 1);
        BRecordType recordType = this.recordOperandTypeMap.get(recOperand);
        if (recordType == null || recordType.tsymbol == null) {
            this.resetBasicBlock(basicBlock);
            return;
        }
        if (!fpCall.fp.variableDcl.name.value.contains(recordType.tsymbol.name.value)) {
            this.resetBasicBlock(basicBlock);
            return;
        }
        BIRNode.BIRFunction defaultFunction = this.getDefaultBIRFunction(fpCall.fp.variableDcl.name.value);
        if (defaultFunction == null) {
            this.resetBasicBlock(basicBlock);
            return;
        }
        BIRNode.BIRBasicBlock firstBB = defaultFunction.basicBlocks.get(0);
        BIRNode.BIRBasicBlock bIRBasicBlock = this.lastBB = this.lastBB != null ? this.lastBB : basicBlock;
        if (this.containsOnlyConstantLoad(defaultFunction)) {
            this.moveConstLoadInstruction(fpCall, firstBB);
            this.lastBB.terminator = null;
            this.fpRemoved = true;
        } else {
            this.resetBasicBlock(basicBlock);
        }
    }

    private void handleNewStructure(BIRNonTerminator.NewStructure inst) {
        this.recordOperandList.remove(inst.rhsOp);
        this.valueCreated = true;
    }

    private void handleNewTypeDesc(BIRNonTerminator inst) {
        BType referredType = Types.getReferredType(((BIRNonTerminator.NewTypeDesc)inst).type);
        if (referredType.tag == 12) {
            this.recordOperandList.add(inst.lhsOp);
            this.recordOperandTypeMap.put(inst.lhsOp, (BRecordType)referredType);
        }
    }

    private void moveConstLoadInstruction(BIRTerminator.FPCall fpCall, BIRNode.BIRBasicBlock firstBB) {
        BIRNonTerminator.ConstantLoad constantLoad = (BIRNonTerminator.ConstantLoad)firstBB.instructions.get(0);
        if (firstBB.instructions.size() == 2) {
            BIRNode.BIRVariableDcl tempVar;
            BIRNonTerminator.TypeCast typeCast = (BIRNonTerminator.TypeCast)firstBB.instructions.get(1);
            String tempVarName = "%temp_" + typeCast.rhsOp.variableDcl.name.value;
            if (this.typecastVars.containsKey(tempVarName)) {
                tempVar = this.typecastVars.get(tempVarName);
            } else {
                tempVar = new BIRNode.BIRVariableDcl(null, constantLoad.type, new Name(tempVarName), VarScope.FUNCTION, VarKind.TEMP, null);
                this.typecastVars.put(tempVarName, tempVar);
                this.currentFunction.localVars.add(tempVar);
            }
            BIROperand tempVarOperand = new BIROperand(tempVar);
            BIRNonTerminator.ConstantLoad newConstLoad = new BIRNonTerminator.ConstantLoad(constantLoad.pos, constantLoad.value, constantLoad.type, tempVarOperand);
            newConstLoad.scope = fpCall.scope;
            this.lastBB.instructions.add(newConstLoad);
            BIRNonTerminator.TypeCast newTypeCast = new BIRNonTerminator.TypeCast(typeCast.pos, fpCall.lhsOp, tempVarOperand, typeCast.type, typeCast.checkTypes);
            this.lastBB.instructions.add(newTypeCast);
        } else {
            BIRNonTerminator.ConstantLoad newConstLoad = new BIRNonTerminator.ConstantLoad(constantLoad.pos, constantLoad.value, constantLoad.type, fpCall.lhsOp);
            newConstLoad.scope = fpCall.scope;
            this.lastBB.instructions.add(newConstLoad);
        }
    }

    private boolean containsOnlyConstantLoad(BIRNode.BIRFunction defaultFunction) {
        if (defaultFunction.basicBlocks.size() != 2) {
            return false;
        }
        BIRNode.BIRBasicBlock firstBB = defaultFunction.basicBlocks.get(0);
        BIRNode.BIRBasicBlock secondBB = defaultFunction.basicBlocks.get(1);
        if (!secondBB.instructions.isEmpty() || secondBB.terminator.kind != InstructionKind.RETURN) {
            return false;
        }
        return switch (firstBB.instructions.size()) {
            case 1 -> {
                if (firstBB.instructions.get((int)0).kind == InstructionKind.CONST_LOAD && firstBB.instructions.get((int)0).lhsOp.variableDcl.kind == VarKind.RETURN) {
                    yield true;
                }
                yield false;
            }
            case 2 -> {
                if (firstBB.instructions.get((int)0).kind == InstructionKind.CONST_LOAD && firstBB.instructions.get((int)1).kind == InstructionKind.TYPE_CAST && firstBB.instructions.get((int)1).lhsOp.variableDcl.kind == VarKind.RETURN) {
                    yield true;
                }
                yield false;
            }
            default -> false;
        };
    }

    private void resetBasicBlock(BIRNode.BIRBasicBlock basicBlock) {
        if (this.lastBB != null) {
            this.lastBB.terminator = basicBlock.terminator;
            this.lastBB = null;
        }
        this.fpRemoved = false;
    }

    private BIRNode.BIRFunction getDefaultBIRFunction(String funcName) {
        for (BIRNode.BIRFunction func : this.moduleFunctions) {
            if (!func.name.value.equals(funcName)) continue;
            return func;
        }
        return null;
    }
}

