/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.hologres.client.impl;

import com.alibaba.hologres.client.HoloConfig;
import com.alibaba.hologres.client.Put;
import com.alibaba.hologres.client.impl.FillPreparedStatementFunc;
import com.alibaba.hologres.client.impl.PreparedStatementWithBatchInfo;
import com.alibaba.hologres.client.model.Column;
import com.alibaba.hologres.client.model.HoloVersion;
import com.alibaba.hologres.client.model.Record;
import com.alibaba.hologres.client.model.TableName;
import com.alibaba.hologres.client.model.TableSchema;
import com.alibaba.hologres.client.model.WriteMode;
import com.alibaba.hologres.client.type.PGroaringbitmap;
import com.alibaba.hologres.client.utils.IdentifierUtil;
import com.alibaba.hologres.client.utils.Tuple;
import com.alibaba.hologres.client.utils.Tuple3;
import com.alibaba.hologres.org.postgresql.util.PSQLState;
import java.math.BigDecimal;
import java.sql.Array;
import java.sql.Connection;
import java.sql.Date;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Time;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PrimitiveIterator;
import java.util.function.BiFunction;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class UpsertStatementBuilder {
    public static final Logger LOGGER = LoggerFactory.getLogger(UpsertStatementBuilder.class);
    protected static final String DELIMITER_OR = " OR ";
    protected static final String DELIMITER_DOT = ", ";
    protected HoloConfig config;
    boolean enableDefaultValue;
    String defaultTimeStampText;
    boolean inputNumberAsEpochMsForDatetimeColumn;
    boolean inputStringAsEpochMsForDatetimeColumn;
    boolean removeU0000InTextColumnValue;
    SqlCache<Tuple<BitSet, BitSet>> insertCache = new SqlCache();
    Map<Tuple<TableSchema, TableName>, SqlTemplate> deleteCache = new HashMap<Tuple<TableSchema, TableName>, SqlTemplate>();
    boolean first = true;
    private static final int WARN_SKIP_COUNT = 10000;
    long warnCount = 10000L;
    private static final String MYSQL_0000 = "0000-00-00 00:00:00";

    public UpsertStatementBuilder(HoloConfig config) {
        this.config = config;
        this.enableDefaultValue = config.isEnableDefaultForNotNullColumn();
        this.defaultTimeStampText = config.getDefaultTimestampText();
        this.inputNumberAsEpochMsForDatetimeColumn = config.isInputNumberAsEpochMsForDatetimeColumn();
        this.inputStringAsEpochMsForDatetimeColumn = config.isInputStringAsEpochMsForDatetimeColumn();
        this.removeU0000InTextColumnValue = config.isRemoveU0000InTextColumnValue();
    }

    private SqlTemplate buildDeleteSqlTemplate(Tuple<TableSchema, TableName> tuple) {
        TableSchema schema = (TableSchema)tuple.l;
        TableName tableName = (TableName)tuple.r;
        this.first = true;
        StringBuilder sb = new StringBuilder();
        sb.append("delete from ").append(tableName.getFullName());
        sb.append(" where ");
        String header = sb.toString();
        sb.setLength(0);
        sb.append("(");
        for (int index : schema.getKeyIndex()) {
            if (!this.first) {
                sb.append(" and ");
            }
            this.first = false;
            sb.append(IdentifierUtil.quoteIdentifier(schema.getColumnSchema()[index].getName(), true)).append("=?");
        }
        sb.append(")");
        String rowText = sb.toString();
        int maxLevel = 32 - Integer.numberOfLeadingZeros(Short.MAX_VALUE / schema.getKeyIndex().length) - 1;
        return new SqlTemplate(header, null, rowText, DELIMITER_OR, maxLevel);
    }

    private SqlTemplate buildInsertSql(Tuple3<TableSchema, TableName, WriteMode> tuple, Tuple<BitSet, BitSet> input) {
        TableSchema schema = (TableSchema)tuple.l;
        TableName tableName = (TableName)tuple.m;
        WriteMode mode = (WriteMode)((Object)tuple.r);
        BitSet set = (BitSet)input.l;
        BitSet onlyInsertSet = (BitSet)input.r;
        StringBuilder sb = new StringBuilder();
        sb.append("insert into ").append(tableName.getFullName());
        sb.append("(");
        this.first = true;
        set.stream().forEach(index -> {
            if (!this.first) {
                sb.append(",");
            }
            this.first = false;
            sb.append(IdentifierUtil.quoteIdentifier(schema.getColumn(index).getName(), true));
        });
        sb.append(")");
        sb.append(" values ");
        String header = sb.toString();
        sb.setLength(0);
        sb.append("(");
        this.first = true;
        set.stream().forEach(index -> {
            if (!this.first) {
                sb.append(",");
            }
            this.first = false;
            sb.append("?");
            Column column = schema.getColumn(index);
            if (-7 == column.getType() && "bit".equals(column.getTypeName())) {
                sb.append("::bit(").append(column.getPrecision()).append(")");
            } else if (1111 == column.getType() && "varbit".equals(column.getTypeName())) {
                sb.append("::bit varying(").append(column.getPrecision()).append(")");
            }
        });
        sb.append(")");
        String rowText = sb.toString();
        sb.setLength(0);
        if (schema.getKeyIndex().length > 0) {
            sb.append(" on conflict (");
            this.first = true;
            for (int index2 : schema.getKeyIndex()) {
                if (!this.first) {
                    sb.append(",");
                }
                this.first = false;
                sb.append(IdentifierUtil.quoteIdentifier(schema.getColumnSchema()[index2].getName(), true));
            }
            sb.append(") do ");
            if (WriteMode.INSERT_OR_IGNORE == mode) {
                sb.append("nothing");
            } else {
                sb.append("update set ");
                this.first = true;
                set.stream().forEach(index -> {
                    if (!onlyInsertSet.get(index)) {
                        if (!this.first) {
                            sb.append(",");
                        }
                        this.first = false;
                        String columnName = IdentifierUtil.quoteIdentifier(schema.getColumnSchema()[index].getName(), true);
                        sb.append(columnName).append("=excluded.").append(columnName);
                    }
                });
            }
        }
        String tail = sb.toString();
        int maxLevel = 32 - Integer.numberOfLeadingZeros(Short.MAX_VALUE / set.cardinality()) - 1;
        SqlTemplate sqlTemplate = new SqlTemplate(header, tail, rowText, DELIMITER_DOT, maxLevel);
        LOGGER.debug("new sql:{}", (Object)sqlTemplate.getSql(maxLevel));
        return sqlTemplate;
    }

    private static String[] handleDefaultValue(String defaultValue) {
        String[] ret = defaultValue.split("::");
        if (ret.length == 1) {
            String[] temp = new String[]{ret[0], null};
            ret = temp;
        }
        if (ret[0].startsWith("'") && ret[0].endsWith("'") && ret[0].length() > 1) {
            ret[0] = ret[0].substring(1, ret[0].length() - 1);
        }
        return ret;
    }

    private void logWarnSeldom(String s2, Object ... obj) {
        if (++this.warnCount > 10000L) {
            LOGGER.warn(s2, obj);
            this.warnCount = 0L;
        }
    }

    private void fillDefaultValue(Record record, Column column, int i) {
        if (record.getObject(i) == null && !column.getAllowNull().booleanValue()) {
            if (column.isSerial()) {
                return;
            }
            if (column.getDefaultValue() != null) {
                String[] defaultValuePair = UpsertStatementBuilder.handleDefaultValue(String.valueOf(column.getDefaultValue()));
                String defaultValue = defaultValuePair[0];
                switch (column.getType()) {
                    case -5: 
                    case 4: 
                    case 5: {
                        record.setObject(i, Long.parseLong(defaultValue));
                        break;
                    }
                    case 6: 
                    case 8: {
                        record.setObject(i, Double.parseDouble(defaultValue));
                        break;
                    }
                    case 2: 
                    case 3: {
                        record.setObject(i, new BigDecimal(defaultValue));
                        break;
                    }
                    case -7: 
                    case 16: {
                        record.setObject(i, Boolean.valueOf(defaultValue));
                        break;
                    }
                    case 1: 
                    case 12: {
                        record.setObject(i, defaultValue);
                        break;
                    }
                    case 91: 
                    case 92: 
                    case 93: 
                    case 2013: {
                        if ("now()".equalsIgnoreCase(defaultValue) || "current_timestamp".equalsIgnoreCase(defaultValue)) {
                            record.setObject(i, new java.util.Date());
                            break;
                        }
                        record.setObject(i, defaultValue);
                        break;
                    }
                    default: {
                        this.logWarnSeldom("unsupported default type,{}({})", column.getType(), column.getTypeName());
                        break;
                    }
                }
            } else if (this.enableDefaultValue) {
                switch (column.getType()) {
                    case -5: 
                    case 4: 
                    case 5: {
                        record.setObject(i, 0L);
                        break;
                    }
                    case 6: 
                    case 8: {
                        record.setObject(i, 0.0);
                        break;
                    }
                    case 2: 
                    case 3: {
                        record.setObject(i, BigDecimal.ZERO);
                        break;
                    }
                    case -7: 
                    case 16: {
                        record.setObject(i, false);
                        break;
                    }
                    case 1: 
                    case 12: {
                        record.setObject(i, "");
                        break;
                    }
                    case 91: 
                    case 92: 
                    case 93: 
                    case 2013: {
                        if (this.defaultTimeStampText == null) {
                            record.setObject(i, new java.util.Date(0L));
                            break;
                        }
                        record.setObject(i, this.defaultTimeStampText);
                        break;
                    }
                    default: {
                        this.logWarnSeldom("unsupported default type,{}({})", column.getType(), column.getTypeName());
                    }
                }
            }
        }
    }

    private void fillNotSetValue(Record record, Column column, int i) {
        if (!record.isSet(i)) {
            if (column.isSerial()) {
                return;
            }
            record.setObject(i, null);
        }
    }

    private void handleArrayColumn(Connection conn, Record record, Column column, int index) throws SQLException {
        Object obj = record.getObject(index);
        if (null != obj && obj instanceof List) {
            List list = (List)obj;
            Array array = conn.createArrayOf(column.getTypeName().substring(1), list.toArray());
            record.setObject(index, array);
        } else if (obj != null && obj instanceof Object[]) {
            Array array = conn.createArrayOf(column.getTypeName().substring(1), (Object[])obj);
            record.setObject(index, array);
        }
    }

    public void prepareRecord(Connection conn, Record record, WriteMode mode) throws SQLException {
        try {
            for (int i = 0; i < record.getSize(); ++i) {
                Column column = record.getSchema().getColumn(i);
                if (record.getType() == Put.MutationType.INSERT && mode != WriteMode.INSERT_OR_UPDATE) {
                    this.fillDefaultValue(record, column, i);
                    this.fillNotSetValue(record, column, i);
                }
                if (column.getType() != 2003) continue;
                this.handleArrayColumn(conn, record, column, i);
            }
        }
        catch (Exception e) {
            throw new SQLException(PSQLState.INVALID_PARAMETER_VALUE.getState(), e);
        }
    }

    private String removeU0000(String in) {
        if (in != null && in.contains("\u0000")) {
            return in.replaceAll("\u0000", "");
        }
        return in;
    }

    private void fillPreparedStatement(PreparedStatement ps, int index, Object obj, Column column) throws SQLException {
        switch (column.getType()) {
            case 1111: {
                if (obj instanceof byte[] && "roaringbitmap".equalsIgnoreCase(column.getTypeName())) {
                    PGroaringbitmap binaryObject = new PGroaringbitmap();
                    byte[] bytes = (byte[])obj;
                    binaryObject.setByteValue(bytes, 0);
                    ps.setObject(index, (Object)binaryObject, column.getType());
                    break;
                }
                if ("varbit".equals(column.getTypeName())) {
                    ps.setString(index, obj == null ? null : String.valueOf(obj));
                    break;
                }
                ps.setObject(index, obj, column.getType());
                break;
            }
            case -16: 
            case 1: 
            case 12: {
                if (obj == null) {
                    ps.setNull(index, column.getType());
                    break;
                }
                ps.setObject(index, (Object)this.removeU0000(obj.toString()), column.getType());
                break;
            }
            case -7: {
                if ("bit".equals(column.getTypeName())) {
                    if (obj instanceof Boolean) {
                        ps.setString(index, (Boolean)obj != false ? "1" : "0");
                        break;
                    }
                    ps.setString(index, obj == null ? null : String.valueOf(obj));
                    break;
                }
                ps.setObject(index, obj, column.getType());
                break;
            }
            case 93: 
            case 2014: {
                if (obj instanceof Number && this.inputNumberAsEpochMsForDatetimeColumn) {
                    ps.setObject(index, (Object)new Timestamp(((Number)obj).longValue()), column.getType());
                    break;
                }
                if (obj instanceof String && this.inputStringAsEpochMsForDatetimeColumn) {
                    long l = 0L;
                    try {
                        l = Long.parseLong((String)obj);
                        ps.setObject(index, (Object)new Timestamp(l), column.getType());
                    }
                    catch (NumberFormatException e) {
                        if (MYSQL_0000.equals(obj)) {
                            ps.setObject(index, (Object)new Timestamp(0L), column.getType());
                            break;
                        }
                        ps.setObject(index, obj, column.getType());
                    }
                    break;
                }
                if (MYSQL_0000.equals(obj)) {
                    ps.setObject(index, (Object)new Timestamp(0L), column.getType());
                    break;
                }
                ps.setObject(index, obj, column.getType());
                break;
            }
            case 91: {
                if (obj instanceof Number && this.inputNumberAsEpochMsForDatetimeColumn) {
                    ps.setObject(index, (Object)new Date(((Number)obj).longValue()), column.getType());
                    break;
                }
                if (obj instanceof String && this.inputStringAsEpochMsForDatetimeColumn) {
                    long l = 0L;
                    try {
                        l = Long.parseLong((String)obj);
                        ps.setObject(index, (Object)new Date(l), column.getType());
                    }
                    catch (NumberFormatException e) {
                        if (MYSQL_0000.equals(obj)) {
                            ps.setObject(index, (Object)new Date(0L), column.getType());
                            break;
                        }
                        ps.setObject(index, obj, column.getType());
                    }
                    break;
                }
                if (MYSQL_0000.equals(obj)) {
                    ps.setObject(index, (Object)new Date(0L), column.getType());
                    break;
                }
                ps.setObject(index, obj, column.getType());
                break;
            }
            case 92: 
            case 2013: {
                if (obj instanceof Number && this.inputNumberAsEpochMsForDatetimeColumn) {
                    ps.setObject(index, (Object)new Time(((Number)obj).longValue()), column.getType());
                    break;
                }
                if (obj instanceof String && this.inputStringAsEpochMsForDatetimeColumn) {
                    long l = 0L;
                    try {
                        l = Long.parseLong((String)obj);
                        ps.setObject(index, (Object)new Time(l), column.getType());
                    }
                    catch (NumberFormatException e) {
                        if (MYSQL_0000.equals(obj)) {
                            ps.setObject(index, (Object)new Time(0L), column.getType());
                            break;
                        }
                        ps.setObject(index, obj, column.getType());
                    }
                    break;
                }
                if (MYSQL_0000.equals(obj)) {
                    ps.setObject(index, (Object)new Time(0L), column.getType());
                    break;
                }
                ps.setObject(index, obj, column.getType());
                break;
            }
            default: {
                ps.setObject(index, obj, column.getType());
            }
        }
    }

    private int fillPreparedStatementForInsert(PreparedStatement ps, int psIndex, Record record) throws SQLException {
        IntStream columnStream = record.getBitSet().stream();
        PrimitiveIterator.OfInt it = columnStream.iterator();
        while (it.hasNext()) {
            int index = it.next();
            Column column = record.getSchema().getColumn(index);
            this.fillPreparedStatement(ps, ++psIndex, record.getObject(index), column);
        }
        return psIndex;
    }

    protected void buildInsertStatement(Connection conn, HoloVersion version, TableSchema schema, TableName tableName, Tuple<BitSet, BitSet> columnSet, List<Record> recordList, List<PreparedStatementWithBatchInfo> list, WriteMode mode) throws SQLException {
        if (recordList.size() == 0) {
            return;
        }
        SqlTemplate sql = this.insertCache.computeIfAbsent(new Tuple3<TableSchema, TableName, WriteMode>(schema, tableName, mode), columnSet, this::buildInsertSql);
        this.fillPreparedStatement(conn, sql, list, recordList, Put.MutationType.INSERT, this::fillPreparedStatementForInsert);
    }

    private void fillPreparedStatement(Connection conn, SqlTemplate sqlTemplate, List<PreparedStatementWithBatchInfo> list, List<Record> recordList, Put.MutationType type, FillPreparedStatementFunc func) throws SQLException {
        int fullValueBlocksCount;
        int maxValueBlocks = 1 << sqlTemplate.maxLevel;
        int unprocessedBatchCount = recordList.size();
        int remainFullValueBlocksCount = fullValueBlocksCount = unprocessedBatchCount / maxValueBlocks;
        boolean first = true;
        int currentLevel = 0;
        int rows = 0;
        int psIndex = 0;
        PreparedStatementWithBatchInfo ps = null;
        boolean batchMode = false;
        long byteSize = 0L;
        int batchCount = 0;
        for (Record record : recordList) {
            if (first) {
                if (remainFullValueBlocksCount > 0) {
                    batchMode = fullValueBlocksCount > 1;
                    currentLevel = sqlTemplate.getMaxLevel();
                    if (ps == null) {
                        ps = new PreparedStatementWithBatchInfo(conn.prepareStatement(sqlTemplate.getSql(currentLevel)), fullValueBlocksCount > 1, type);
                        list.add(ps);
                    }
                    --remainFullValueBlocksCount;
                } else {
                    if (ps != null) {
                        ps.setByteSize(byteSize);
                        ps.setBatchCount(batchCount);
                        byteSize = 0L;
                        batchCount = 0;
                    }
                    batchMode = false;
                    currentLevel = 31 - Integer.numberOfLeadingZeros(unprocessedBatchCount);
                    ps = new PreparedStatementWithBatchInfo(conn.prepareStatement(sqlTemplate.getSql(currentLevel)), false, type);
                    list.add(ps);
                }
                first = false;
                rows = 1 << currentLevel;
                psIndex = 0;
                ++batchCount;
            }
            --unprocessedBatchCount;
            if (rows > 0) {
                psIndex = func.apply((PreparedStatement)ps.l, psIndex, record);
                byteSize += record.getByteSize();
                --rows;
            }
            if (rows != 0) continue;
            first = true;
            if (!batchMode) continue;
            ((PreparedStatement)ps.l).addBatch();
        }
        if (ps != null) {
            ps.setByteSize(byteSize);
            ps.setBatchCount(batchCount);
        }
    }

    private int fillPreparedStatementForDelete(PreparedStatement ps, int psIndex, Record record) throws SQLException {
        for (int index : record.getSchema().getKeyIndex()) {
            Column column = record.getSchema().getColumn(index);
            this.fillPreparedStatement(ps, ++psIndex, record.getObject(index), column);
        }
        return psIndex;
    }

    protected void buildDeleteStatement(Connection conn, HoloVersion version, TableSchema schema, TableName tableName, List<Record> recordList, List<PreparedStatementWithBatchInfo> list) throws SQLException {
        if (recordList.size() == 0) {
            return;
        }
        SqlTemplate sql = this.deleteCache.computeIfAbsent(new Tuple<TableSchema, TableName>(schema, tableName), this::buildDeleteSqlTemplate);
        this.fillPreparedStatement(conn, sql, list, recordList, Put.MutationType.DELETE, this::fillPreparedStatementForDelete);
    }

    public List<PreparedStatementWithBatchInfo> buildStatements(Connection conn, HoloVersion version, TableSchema schema, TableName tableName, Collection<Record> recordList, WriteMode mode) throws SQLException {
        ArrayList<Record> deleteRecordList = new ArrayList<Record>();
        HashMap<Tuple, List> insertRecordList = new HashMap<Tuple, List>();
        ArrayList<PreparedStatementWithBatchInfo> preparedStatementList = new ArrayList<PreparedStatementWithBatchInfo>();
        try {
            block14: for (Record record : recordList) {
                this.prepareRecord(conn, record, mode);
                switch (record.getType()) {
                    case DELETE: {
                        deleteRecordList.add(record);
                        continue block14;
                    }
                    case INSERT: {
                        insertRecordList.computeIfAbsent(new Tuple<BitSet, BitSet>(record.getBitSet(), record.getOnlyInsertColumnSet()), t -> new ArrayList()).add(record);
                        continue block14;
                    }
                }
                throw new SQLException("unsupported type:" + (Object)((Object)record.getType()) + " for record:" + record);
            }
            try {
                if (deleteRecordList.size() > 0) {
                    this.buildDeleteStatement(conn, version, schema, tableName, deleteRecordList, preparedStatementList);
                }
                for (Map.Entry entry : insertRecordList.entrySet()) {
                    this.buildInsertStatement(conn, version, schema, tableName, (Tuple)entry.getKey(), (List)entry.getValue(), preparedStatementList, mode);
                }
            }
            catch (SQLException e) {
                for (PreparedStatementWithBatchInfo psWithInfo : preparedStatementList) {
                    PreparedStatement ps = (PreparedStatement)psWithInfo.l;
                    if (null == ps) continue;
                    try {
                        ps.close();
                    }
                    catch (SQLException sQLException) {}
                }
                throw e;
            }
            ArrayList<PreparedStatementWithBatchInfo> e = preparedStatementList;
            return e;
        }
        catch (SQLException e) {
            throw e;
        }
        catch (Exception e) {
            throw new SQLException(e);
        }
        finally {
            if (this.insertCache.getSize() > 500) {
                this.insertCache.clear();
            }
            if (this.deleteCache.size() > 500) {
                this.deleteCache.clear();
            }
        }
    }

    class SqlTemplate {
        private final String header;
        private final String tail;
        private final String rowText;
        private final String delimiter;
        private final int maxLevel;
        String[] sqls;

        public SqlTemplate(String header, String tail, String rowText, String delimiter, int maxLevel) {
            this.header = header;
            this.tail = tail;
            this.rowText = rowText;
            this.delimiter = delimiter;
            this.maxLevel = maxLevel;
            this.sqls = new String[maxLevel + 1];
        }

        public String getSql(int level) {
            if (level >= this.sqls.length) {
                throw new RuntimeException(this + " max level is " + this.sqls.length + ", but input level is " + level);
            }
            if (this.sqls[level] == null) {
                StringBuilder sb = new StringBuilder();
                if (this.header != null) {
                    sb.append(this.header);
                }
                for (int i = 0; i < 1 << level; ++i) {
                    if (i > 0) {
                        sb.append(this.delimiter);
                    }
                    sb.append(this.rowText);
                }
                if (null != this.tail) {
                    sb.append(this.tail);
                }
                this.sqls[level] = sb.toString();
            }
            return this.sqls[level];
        }

        public int getMaxLevel() {
            return this.maxLevel;
        }

        public String toString() {
            return "SqlTemplate{header='" + this.header + '\'' + ", tail='" + this.tail + '\'' + ", rowText='" + this.rowText + '\'' + ", delimiter='" + this.delimiter + '\'' + '}';
        }
    }

    static class SqlCache<T> {
        Map<Tuple3<TableSchema, TableName, WriteMode>, Map<T, SqlTemplate>> cacheMap = new HashMap<Tuple3<TableSchema, TableName, WriteMode>, Map<T, SqlTemplate>>();
        int size = 0;

        SqlCache() {
        }

        public SqlTemplate computeIfAbsent(Tuple3<TableSchema, TableName, WriteMode> tuple, T t, BiFunction<Tuple3<TableSchema, TableName, WriteMode>, T, SqlTemplate> b) {
            Map subMap = this.cacheMap.computeIfAbsent(tuple, s2 -> new HashMap());
            return subMap.computeIfAbsent(t, bs -> {
                ++this.size;
                return (SqlTemplate)b.apply(tuple, bs);
            });
        }

        public int getSize() {
            return this.size;
        }

        public void clear() {
            this.cacheMap.clear();
        }
    }
}

