/*
 * Decompiled with CFR 0.152.
 */
package io.ballerina.persist.nodegenerator.syntax.utils;

import io.ballerina.compiler.syntax.tree.AnnotationNode;
import io.ballerina.compiler.syntax.tree.ExpressionNode;
import io.ballerina.compiler.syntax.tree.MappingConstructorExpressionNode;
import io.ballerina.compiler.syntax.tree.MappingFieldNode;
import io.ballerina.compiler.syntax.tree.SpecificFieldNode;
import io.ballerina.persist.BalException;
import io.ballerina.persist.PersistToolsConstants;
import io.ballerina.persist.models.Entity;
import io.ballerina.persist.models.EntityField;
import io.ballerina.persist.models.Enum;
import io.ballerina.persist.models.EnumMember;
import io.ballerina.persist.models.Index;
import io.ballerina.persist.models.Relation;
import io.ballerina.persist.models.SqlType;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

public class SqlScriptUtils {
    private static final String NEW_LINE = System.lineSeparator();
    private static final String TAB = "\t";
    private static final String COMMA_WITH_SPACE = ", ";
    private static final String PRIMARY_KEY_START_SCRIPT = NEW_LINE + "\tPRIMARY KEY(";
    private static final String ENUM_START_SCRIPT = "ENUM(";
    private static final String ENUM_END_SCRIPT = ")";
    private static final String SINGLE_QUOTE = "'";

    private SqlScriptUtils() {
    }

    public static String[] generateSqlScript(Collection<Entity> entities, String datasource) throws BalException {
        HashMap<String, List<String>> referenceTables = new HashMap<String, List<String>>();
        HashMap<String, List<String>> tableScripts = new HashMap<String, List<String>>();
        for (Entity entity : entities) {
            if (entity.containsUnsupportedTypes()) continue;
            ArrayList<String> tableScript = new ArrayList<String>();
            String tableName = SqlScriptUtils.getTableNameWithSchema(entity, datasource);
            tableScript.add(SqlScriptUtils.generateDropTableQuery(tableName));
            tableScript.add(SqlScriptUtils.generateCreateTableQuery(entity, referenceTables, tableName, datasource));
            tableScripts.put(SqlScriptUtils.removeSingleQuote(entity.getTableName()), tableScript);
        }
        ArrayList indexScripts = new ArrayList();
        for (Entity entity : entities) {
            entity.getIndexes().forEach(index -> indexScripts.add(SqlScriptUtils.generateCreateIndexQuery(index, entity, datasource, index.isUnique())));
            entity.getUniqueIndexes().forEach(index -> indexScripts.add(SqlScriptUtils.generateCreateIndexQuery(index, entity, datasource, index.isUnique())));
        }
        ArrayList<String> arrayList = new ArrayList<String>(Arrays.asList(SqlScriptUtils.rearrangeScriptsWithReference(tableScripts.keySet(), referenceTables, tableScripts)));
        arrayList.add(NEW_LINE);
        arrayList.addAll(indexScripts);
        return arrayList.toArray(new String[0]);
    }

    private static String generateDropTableQuery(String tableName) {
        return MessageFormat.format("DROP TABLE IF EXISTS {0};", tableName);
    }

    public static String generateCreateTableQuery(Entity entity, HashMap<String, List<String>> referenceTables, String tableName, String datasource) throws BalException {
        String fieldDefinitions = SqlScriptUtils.generateFieldsDefinitionSegments(entity, referenceTables, datasource);
        return MessageFormat.format("{0}CREATE TABLE {1} ({2}{3});", NEW_LINE, tableName, fieldDefinitions, NEW_LINE);
    }

    private static String generateCreateIndexQuery(Index index, Entity entity, String datasource, boolean unique) {
        String tableName = SqlScriptUtils.getTableNameWithSchema(entity, datasource);
        return MessageFormat.format("CREATE{0} INDEX {1} ON {2} ({3});", unique ? " UNIQUE" : "", SqlScriptUtils.escape(index.getIndexName(), datasource), tableName, index.getFields().stream().map(field -> SqlScriptUtils.escape(SqlScriptUtils.removeSingleQuote(field.getFieldColumnName()), datasource)).reduce((s1, s2) -> s1 + COMMA_WITH_SPACE + s2).orElse(""));
    }

    public static String getTableNameWithSchema(Entity entity, String datasource) {
        Object tableName = SqlScriptUtils.escape(SqlScriptUtils.removeSingleQuote(entity.getTableName()), datasource);
        String schemaName = entity.getSchemaName();
        if (PersistToolsConstants.CUSTOM_SCHEMA_SUPPORTED_DB_PROVIDERS.contains(datasource) && schemaName != null && !schemaName.isEmpty()) {
            tableName = schemaName + "." + (String)tableName;
        }
        return tableName;
    }

    private static String generateFieldsDefinitionSegments(Entity entity, HashMap<String, List<String>> referenceTables, String datasource) throws BalException {
        StringBuilder sqlScript = new StringBuilder();
        sqlScript.append(SqlScriptUtils.getColumnsScript(entity, datasource));
        HashMap<String, List<EntityField>> relationFields = SqlScriptUtils.getMapOfRelationFields(entity, true);
        List<String> associations = entity.getFields().stream().filter(entityField -> entityField.getRelation() != null && entityField.getRelation().isOwner()).map(EntityField::getFieldType).toList();
        for (int i = 0; i < associations.size(); ++i) {
            int occurrence = SqlScriptUtils.findOccurrence(associations, i);
            sqlScript.append(SqlScriptUtils.getRelationScripts(entity, relationFields.get(associations.get(i)).get(occurrence), occurrence, referenceTables, datasource));
        }
        sqlScript.append(SqlScriptUtils.addPrimaryKey(entity.getKeys(), datasource));
        return sqlScript.substring(0, sqlScript.length() - 1);
    }

    private static int findOccurrence(List<String> associations, int index) {
        int occured = 0;
        for (int i = 0; i < index; ++i) {
            if (!Objects.equals(associations.get(i), associations.get(index))) continue;
            ++occured;
        }
        return occured;
    }

    private static HashMap<String, List<EntityField>> getMapOfRelationFields(Entity entity, boolean isOwner) {
        HashMap<String, List<EntityField>> relationFields = new HashMap<String, List<EntityField>>();
        for (EntityField entityField : entity.getFields()) {
            if (entityField.getRelation() == null || entityField.getRelation().isOwner() != isOwner) continue;
            if (relationFields.containsKey(entityField.getFieldType())) {
                relationFields.get(entityField.getFieldType()).add(entityField);
                continue;
            }
            ArrayList<EntityField> fields = new ArrayList<EntityField>();
            fields.add(entityField);
            relationFields.put(entityField.getFieldType(), fields);
        }
        return relationFields;
    }

    private static String getColumnsScript(Entity entity, String datasource) throws BalException {
        StringBuilder columnScript = new StringBuilder();
        for (EntityField entityField : entity.getFields()) {
            if (entityField.getRelation() != null) continue;
            String fieldName = SqlScriptUtils.escape(SqlScriptUtils.removeSingleQuote(entityField.getFieldColumnName()), datasource);
            Enum enumValue = entityField.getEnum();
            String sqlType = enumValue == null ? SqlScriptUtils.getSqlType(entityField, datasource) : SqlScriptUtils.getEnumType(enumValue, fieldName, datasource);
            assert (sqlType != null);
            if (entityField.isOptionalType()) {
                columnScript.append(MessageFormat.format("{0}{1}{2} {3},", NEW_LINE, TAB, fieldName, sqlType));
                continue;
            }
            switch (datasource) {
                case "mssql": {
                    columnScript.append(MessageFormat.format("{0}{1}{2} {3}{4},", NEW_LINE, TAB, fieldName, sqlType, entityField.isDbGenerated() ? " IDENTITY(1,1)" : " NOT NULL"));
                    break;
                }
                case "postgresql": {
                    columnScript.append(MessageFormat.format("{0}{1}{2} {3}{4},", NEW_LINE, TAB, fieldName, "", entityField.isDbGenerated() ? " SERIAL" : sqlType + " NOT NULL"));
                    break;
                }
                case "mysql": 
                case "h2": {
                    columnScript.append(MessageFormat.format("{0}{1}{2} {3}{4},", NEW_LINE, TAB, fieldName, sqlType, entityField.isDbGenerated() ? " AUTO_INCREMENT" : " NOT NULL"));
                    break;
                }
            }
        }
        return columnScript.toString();
    }

    private static String getRelationScripts(Entity entity, EntityField entityField, int index, HashMap<String, List<String>> referenceTables, String datasource) throws BalException {
        StringBuilder relationScripts = new StringBuilder();
        Relation relation = entityField.getRelation();
        List<Relation.Key> keyColumns = relation.getKeyColumns();
        List<String> references = relation.getKeyColumns().stream().map(Relation.Key::getReferenceColumnName).toList();
        Entity assocEntity = relation.getAssocEntity();
        EntityField assocEntityField = SqlScriptUtils.getMapOfRelationFields(assocEntity, false).get(entity.getEntityName()).get(index);
        Relation.RelationType associatedEntityRelationType = assocEntityField.getRelation().getRelationType();
        StringBuilder foreignKey = new StringBuilder();
        StringBuilder referenceFieldName = new StringBuilder();
        int noOfReferencesKey = references.size();
        boolean uniqueIndexExists = entity.getUniqueIndexes().stream().anyMatch(idx -> idx.getFields().stream().map(EntityField::getFieldColumnName).toList().equals(keyColumns.stream().map(Relation.Key::getColumnName).toList()));
        for (int i = 0; i < noOfReferencesKey; ++i) {
            Object referenceSqlType = null;
            for (EntityField assocField : assocEntity.getFields()) {
                if (assocField.getRelation() != null || !assocField.getFieldColumnName().equals(references.get(i))) continue;
                referenceSqlType = SqlScriptUtils.getSqlType(assocField, datasource);
                break;
            }
            if (relation.getRelationType().equals((Object)Relation.RelationType.ONE) && associatedEntityRelationType.equals((Object)Relation.RelationType.ONE) && noOfReferencesKey == 1 && !uniqueIndexExists) {
                referenceSqlType = referenceSqlType + " UNIQUE";
            }
            foreignKey.append(SqlScriptUtils.escape(SqlScriptUtils.removeSingleQuote(keyColumns.get(i).getColumnName()), datasource));
            referenceFieldName.append(SqlScriptUtils.escape(SqlScriptUtils.removeSingleQuote(references.get(i)), datasource));
            if (i < noOfReferencesKey - 1) {
                foreignKey.append(COMMA_WITH_SPACE);
                referenceFieldName.append(COMMA_WITH_SPACE);
            }
            relationScripts.append(MessageFormat.format("{0}{1}{2} {3}{4},", NEW_LINE, TAB, SqlScriptUtils.escape(SqlScriptUtils.removeSingleQuote(keyColumns.get(i).getColumnName()), datasource), referenceSqlType, " NOT NULL"));
        }
        if (noOfReferencesKey > 1 && relation.getRelationType().equals((Object)Relation.RelationType.ONE) && associatedEntityRelationType.equals((Object)Relation.RelationType.ONE) && !uniqueIndexExists) {
            relationScripts.append(MessageFormat.format("{0}{1}UNIQUE ({2}),", NEW_LINE, TAB, foreignKey));
        }
        relationScripts.append(MessageFormat.format("{0}{1}FOREIGN KEY({2}) REFERENCES {3}({4}),", NEW_LINE, TAB, foreignKey.toString(), SqlScriptUtils.getTableNameWithSchema(assocEntity, datasource), referenceFieldName));
        SqlScriptUtils.updateReferenceTable(SqlScriptUtils.removeSingleQuote(entity.getTableName()), assocEntity.getTableName(), referenceTables);
        return relationScripts.toString();
    }

    private static String removeSingleQuote(String fieldName) {
        if (fieldName.startsWith(SINGLE_QUOTE)) {
            return fieldName.substring(1);
        }
        return fieldName;
    }

    public static void updateReferenceTable(String tableName, String referenceTableName, HashMap<String, List<String>> referenceTables) {
        List<Object> setOfReferenceTables = referenceTables.containsKey(tableName) ? referenceTables.get(tableName) : new ArrayList();
        setOfReferenceTables.add(referenceTableName);
        referenceTables.put(tableName, setOfReferenceTables);
    }

    private static String addPrimaryKey(List<EntityField> primaryKeys, String datasource) {
        return SqlScriptUtils.createKeysScript(primaryKeys, datasource);
    }

    private static String createKeysScript(List<EntityField> keys, String datasource) {
        StringBuilder keyScripts = new StringBuilder();
        if (keys.size() > 0) {
            keyScripts.append(MessageFormat.format("{0}", PRIMARY_KEY_START_SCRIPT));
            for (EntityField key : keys) {
                keyScripts.append(MessageFormat.format("{0},", SqlScriptUtils.escape(SqlScriptUtils.removeSingleQuote(key.getFieldColumnName()), datasource)));
            }
            keyScripts.deleteCharAt(keyScripts.length() - 1).append("),");
        }
        return keyScripts.toString();
    }

    public static String getSqlType(EntityField entityField, String datasource) throws BalException {
        String sqlType = !entityField.isArrayType() ? SqlScriptUtils.getTypeNonArray(entityField.getFieldType(), entityField.getSqlType(), datasource) : SqlScriptUtils.getTypeArray(entityField.getFieldType(), datasource);
        if (!sqlType.equals("VARCHAR")) {
            return sqlType;
        }
        String length = "191";
        for (AnnotationNode annotationNode : entityField.getAnnotation()) {
            String annotationName = annotationNode.annotReference().toSourceCode().trim();
            if (!annotationName.equals("constraint:String") || annotationNode.annotValue().isEmpty()) continue;
            for (MappingFieldNode mappingFieldNode : ((MappingConstructorExpressionNode)annotationNode.annotValue().get()).fields()) {
                SpecificFieldNode specificFieldNode = (SpecificFieldNode)mappingFieldNode;
                String fieldName = specificFieldNode.fieldName().toSourceCode().trim();
                if (fieldName.equals("maxLength")) {
                    if (specificFieldNode.valueExpr().isEmpty()) continue;
                    length = ((ExpressionNode)specificFieldNode.valueExpr().get()).toSourceCode().trim();
                    continue;
                }
                if (!fieldName.equals("length") || specificFieldNode.valueExpr().isEmpty()) continue;
                length = ((ExpressionNode)specificFieldNode.valueExpr().get()).toSourceCode().trim();
            }
        }
        return sqlType + String.format("(%s)", length);
    }

    public static String getTypeNonArray(String field, SqlType sqlType, String datasource) throws BalException {
        if (sqlType != null) {
            switch (sqlType.getTypeName()) {
                case "DECIMAL": {
                    return "DECIMAL" + String.format("(%s,%s)", sqlType.getNumericPrecision(), sqlType.getNumericScale());
                }
                case "VARCHAR": {
                    return "VARCHAR" + String.format("(%s)", sqlType.getMaxLength());
                }
                case "CHAR": {
                    return "CHAR" + String.format("(%s)", sqlType.getMaxLength());
                }
            }
        }
        switch (SqlScriptUtils.removeSingleQuote(field)) {
            case "int": {
                return "INT";
            }
            case "boolean": {
                if (datasource.equals("mssql")) {
                    return "BIT";
                }
                return "BOOLEAN";
            }
            case "decimal": {
                if (datasource.equals("mssql")) {
                    return "DECIMAL" + String.format("(%s,%s)", 38, 30);
                }
                if (datasource.equals("postgresql")) {
                    return "DECIMAL" + String.format("(%s,%s)", 65, 30);
                }
                return "DECIMAL" + String.format("(%s,%s)", 65, 30);
            }
            case "float": {
                if (datasource.equals("mysql")) {
                    return "DOUBLE";
                }
                return "FLOAT";
            }
            case "time:Date": {
                return "DATE";
            }
            case "time:TimeOfDay": {
                return "TIME";
            }
            case "time:Utc": {
                if (datasource.equals("mssql")) {
                    return "DATETIME2";
                }
                return "TIMESTAMP";
            }
            case "time:Civil": {
                if (datasource.equals("mssql")) {
                    return "DATETIME2";
                }
                if (datasource.equals("postgresql")) {
                    return "TIMESTAMP";
                }
                return "DATETIME";
            }
            case "string": {
                return "VARCHAR";
            }
        }
        throw new BalException("couldn't find equivalent SQL type for the field type: " + field);
    }

    public static String getTypeArray(String field, String datasource) throws BalException {
        if ("byte".equals(field)) {
            if (datasource.equals("mssql")) {
                return "VARBINARY(MAX)";
            }
            if (datasource.equals("postgresql")) {
                return "BYTEA";
            }
            return "LONGBLOB";
        }
        throw new BalException("couldn't find equivalent SQL type for the field type: " + field);
    }

    private static String getEnumType(Enum enumValue, String fieldName, String datasource) {
        if (datasource.equals("mssql") || datasource.equals("postgresql") || datasource.equals("h2")) {
            int maxLength = 0;
            List<EnumMember> members = enumValue.getMembers();
            StringBuilder checkStringBuilder = new StringBuilder();
            for (int i = 0; i < members.size(); ++i) {
                EnumMember member = members.get(i);
                String value = member.getValue() != null ? member.getValue() : member.getIdentifier();
                checkStringBuilder.append(SINGLE_QUOTE);
                checkStringBuilder.append(value);
                checkStringBuilder.append(SINGLE_QUOTE);
                if (i < members.size() - 1) {
                    checkStringBuilder.append(COMMA_WITH_SPACE);
                }
                if (value.length() <= maxLength) continue;
                maxLength = value.length();
            }
            return String.format("VARCHAR(%s) CHECK (%s IN (%s))", maxLength, fieldName, checkStringBuilder);
        }
        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append(ENUM_START_SCRIPT);
        List<EnumMember> members = enumValue.getMembers();
        for (int i = 0; i < members.size(); ++i) {
            stringBuilder.append(SINGLE_QUOTE);
            EnumMember member = members.get(i);
            if (member.getValue() != null) {
                stringBuilder.append(member.getValue());
            } else {
                stringBuilder.append(member.getIdentifier());
            }
            stringBuilder.append(SINGLE_QUOTE);
            if (i >= members.size() - 1) continue;
            stringBuilder.append(COMMA_WITH_SPACE);
        }
        stringBuilder.append(ENUM_END_SCRIPT);
        return stringBuilder.toString();
    }

    private static String[] rearrangeScriptsWithReference(Set<String> tables, HashMap<String, List<String>> referenceTables, HashMap<String, List<String>> tableScripts) {
        ArrayList<String> tableOrder = new ArrayList<String>();
        for (Map.Entry<String, List<String>> entry : referenceTables.entrySet()) {
            if (tableOrder.isEmpty()) {
                tableOrder.add(SqlScriptUtils.removeSingleQuote(entry.getKey()));
                continue;
            }
            int firstIndex = 0;
            List<String> referenceTableNames = referenceTables.get(entry.getKey());
            for (String referenceTableName : referenceTableNames) {
                int index = tableOrder.indexOf(referenceTableName);
                if (firstIndex != 0 && index <= firstIndex || index < 0) continue;
                firstIndex = index + 1;
            }
            tableOrder.add(firstIndex, SqlScriptUtils.removeSingleQuote(entry.getKey()));
        }
        for (String tableName : tables) {
            if (tableOrder.contains(tableName)) continue;
            tableOrder.add(0, tableName);
        }
        int length = tables.size() * 2;
        int size = tableOrder.size();
        String[] tableScriptsInOrder = new String[length];
        for (int i = 0; i <= tableOrder.size() - 1; ++i) {
            List<String> script = tableScripts.get(SqlScriptUtils.removeSingleQuote((String)tableOrder.get(size - (i + 1))));
            tableScriptsInOrder[i] = script.get(0);
            tableScriptsInOrder[length - (i + 1)] = script.get(1);
        }
        return tableScriptsInOrder;
    }

    private static String escape(String name, String datasource) {
        if (datasource.equals("mssql")) {
            return "[" + name + "]";
        }
        if (datasource.equals("postgresql") || datasource.equals("h2")) {
            return "\"" + name + "\"";
        }
        return "`" + name + "`";
    }
}

