/*
 * Decompiled with CFR 0.152.
 */
package ghidra.features.bsim.query.client;

import generic.lsh.vector.LSHVector;
import generic.lsh.vector.WeightedLSHCosineVectorFactory;
import ghidra.features.bsim.query.BSimPostgresDBConnectionManager;
import ghidra.features.bsim.query.BSimServerInfo;
import ghidra.features.bsim.query.FunctionDatabase;
import ghidra.features.bsim.query.LSHException;
import ghidra.features.bsim.query.client.AbstractSQLFunctionDatabase;
import ghidra.features.bsim.query.client.Configuration;
import ghidra.features.bsim.query.client.tables.CachedStatement;
import ghidra.features.bsim.query.description.FunctionDescription;
import ghidra.features.bsim.query.description.SignatureRecord;
import ghidra.features.bsim.query.description.VectorResult;
import ghidra.features.bsim.query.protocol.AdjustVectorIndex;
import ghidra.features.bsim.query.protocol.BSimQuery;
import ghidra.features.bsim.query.protocol.PasswordChange;
import ghidra.features.bsim.query.protocol.PrewarmRequest;
import ghidra.features.bsim.query.protocol.QueryNearestVector;
import ghidra.features.bsim.query.protocol.QueryResponseRecord;
import ghidra.features.bsim.query.protocol.ResponseAdjustIndex;
import ghidra.features.bsim.query.protocol.ResponseNearestVector;
import ghidra.features.bsim.query.protocol.ResponsePassword;
import ghidra.features.bsim.query.protocol.ResponsePrewarm;
import ghidra.features.bsim.query.protocol.SimilarityVectorResult;
import java.io.IOException;
import java.net.URL;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;

public final class PostgresFunctionDatabase
extends AbstractSQLFunctionDatabase<WeightedLSHCosineVectorFactory> {
    public static final int LAYOUT_VERSION = 6;
    private static final String DEFAULT_DATABASE_NAME = "postgres";
    private BSimPostgresDBConnectionManager.BSimPostgresDataSource postgresDs;
    private boolean asynchronous;
    private final CachedStatement<Statement> reusableStatement = new CachedStatement();
    private final CachedStatement<PreparedStatement> selectVectorByRowIdStatement = new CachedStatement();
    private final CachedStatement<PreparedStatement> selectNearestVectorStatement = new CachedStatement();

    public PostgresFunctionDatabase(URL postgresUrl, boolean async) {
        super(BSimPostgresDBConnectionManager.getDataSource(postgresUrl), FunctionDatabase.generateLSHVectorFactory(), 6);
        this.postgresDs = (BSimPostgresDBConnectionManager.BSimPostgresDataSource)this.ds;
        this.asynchronous = async;
    }

    @Override
    public void close() {
        this.reusableStatement.close();
        this.selectVectorByRowIdStatement.close();
        this.selectNearestVectorStatement.close();
        super.close();
    }

    private Statement getReusableStatement() throws SQLException {
        return this.reusableStatement.prepareIfNeeded(() -> this.initConnection().createStatement());
    }

    @Override
    protected void lockTablesForWrite() throws SQLException {
        String stmtstring = "LOCK TABLE exetable, desctable, vectable IN SHARE ROW EXCLUSIVE MODE";
        this.getReusableStatement().execute(stmtstring);
    }

    private void changePassword(Connection c, String username, char[] newPassword) throws SQLException {
        StringBuilder buffer = new StringBuilder();
        buffer.append("ALTER ROLE \"");
        buffer.append(username);
        buffer.append("\" WITH PASSWORD '");
        for (char ch : newPassword) {
            if (ch == '\'') {
                buffer.append(ch);
            }
            buffer.append(ch);
        }
        buffer.append('\'');
        try (Statement st = c.createStatement();){
            st.executeUpdate(buffer.toString());
            this.postgresDs.setPassword(username, newPassword);
        }
    }

    private void createVectorFunctions(Statement st) throws SQLException {
        st.executeUpdate("CREATE FUNCTION insert_vec(newvec lshvector,OUT ourhash BIGINT) AS $$ DECLARE  curs1 CURSOR (key BIGINT) FOR SELECT count FROM vectable WHERE id = key FOR UPDATE;  ourcount INTEGER; BEGIN  ourhash := lshvector_hash(newvec);  OPEN curs1( ourhash );  FETCH curs1 INTO ourcount;  IF FOUND THEN    UPDATE vectable SET count = ourcount + 1 WHERE CURRENT OF curs1;  ELSE    INSERT INTO vectable (id,count,vec) VALUES(ourhash,1,newvec);  END IF;  CLOSE curs1; END; $$ LANGUAGE plpgsql;");
        st.executeUpdate("CREATE FUNCTION remove_vec(vecid BIGINT,countdiff INTEGER) RETURNS INTEGER AS $$DECLARE  curs1 CURSOR (key BIGINT) FOR SELECT count FROM vectable WHERE id = key FOR UPDATE;  ourcount INTEGER;  rescode INTEGER;BEGIN  rescode = -1;  OPEN curs1( vecid );  FETCH curs1 INTO ourcount;  IF FOUND AND ourcount > countdiff THEN    UPDATE vectable SET count = ourcount - countdiff WHERE CURRENT OF curs1;    rescode = 0;  ELSIF FOUND THEN    DELETE FROM vectable WHERE CURRENT OF curs1;    rescode = 1;  END IF;  CLOSE curs1;  RETURN rescode;END;$$ LANGUAGE plpgsql;");
    }

    private void serverLoadWeights(Connection db) throws SQLException {
        try (Statement st = db.createStatement();
             ResultSet rs = st.executeQuery("SELECT lsh_load()");){
            while (rs.next()) {
            }
        }
    }

    @Override
    protected void initializeDatabase(Configuration config) throws SQLException {
        Connection db = this.initConnection();
        this.serverLoadWeights(db);
        try (Statement st = db.createStatement();){
            if (this.asynchronous) {
                st.executeUpdate("SET SESSION synchronous_commit TO OFF");
            } else {
                st.executeUpdate("SET SESSION synchronous_commit to ON");
            }
        }
        super.initializeDatabase(config);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    protected void generateRawDatabase() throws SQLException {
        BSimServerInfo serverInfo = this.postgresDs.getServerInfo();
        BSimServerInfo defaultServerInfo = new BSimServerInfo(BSimServerInfo.DBType.postgres, serverInfo.getServerName(), serverInfo.getPort(), DEFAULT_DATABASE_NAME);
        String createdbstring = "CREATE DATABASE \"" + serverInfo.getDBName() + "\"";
        BSimPostgresDBConnectionManager.BSimPostgresDataSource defaultDs = BSimPostgresDBConnectionManager.getDataSource(defaultServerInfo);
        try (Connection db = defaultDs.getConnection();
             Statement st = db.createStatement();){
            st.executeUpdate(createdbstring);
            this.postgresDs.initializeFrom(defaultDs);
        }
        finally {
            defaultDs.dispose();
        }
    }

    @Override
    protected void createDatabase(Configuration config) throws SQLException {
        try {
            super.createDatabase(config);
            Connection db = super.initConnection();
            try (Statement st = db.createStatement();){
                st.executeUpdate("CREATE EXTENSION IF NOT EXISTS lshvector");
                st.executeUpdate("CREATE TABLE vectable(id BIGINT UNIQUE,count INTEGER,vec lshvector)");
                st.executeUpdate("CREATE INDEX vectable_vec_idx ON vectable USING gin (vec gin_lshvector_ops)");
                this.createVectorFunctions(st);
                st.executeUpdate("REVOKE ALL ON SCHEMA PUBLIC FROM PUBLIC");
                st.executeUpdate("GRANT USAGE ON SCHEMA PUBLIC TO PUBLIC");
                st.executeUpdate("GRANT SELECT ON ALL TABLES IN SCHEMA PUBLIC TO PUBLIC");
                st.executeUpdate("GRANT USAGE ON ALL SEQUENCES IN SCHEMA PUBLIC TO PUBLIC");
                this.serverLoadWeights(db);
                if (this.asynchronous) {
                    st.executeUpdate("SET SESSION synchronous_commit TO OFF");
                } else {
                    st.executeUpdate("SET SESSION synchronous_commit to ON");
                }
            }
        }
        catch (SQLException err) {
            throw new SQLException("Could not create database: " + err.getMessage());
        }
    }

    private void dropIndex(Connection c) throws SQLException {
        try (Statement st = c.createStatement();){
            st.execute("DROP INDEX vectable_vec_idx");
        }
    }

    private void rebuildIndex(Connection c) throws SQLException {
        try (Statement st = c.createStatement();
             ResultSet rs = st.executeQuery("SELECT lsh_reload()");){
            st.execute("SET maintenance_work_mem TO '2GB'");
            st.execute("CREATE INDEX vectable_vec_idx ON vectable USING gin (vec gin_lshvector_ops)");
        }
    }

    private int preWarm(Connection c, int mainIndex, int secondaryIndex, int vectors) throws SQLException {
        try (Statement st = c.createStatement();){
            ResultSet rs;
            String queryString;
            int res = -1;
            st.execute("CREATE EXTENSION IF NOT EXISTS pg_prewarm");
            if (mainIndex != 0) {
                queryString = mainIndex == 1 ? "SELECT pg_prewarm('vectable_vec_idx','read')" : "SELECT pg_prewarm('vectable_vec_idx')";
                rs = st.executeQuery(queryString);
                try {
                    if (rs.next()) {
                        res = rs.getInt(1);
                        while (rs.next()) {
                        }
                    }
                }
                finally {
                    if (rs != null) {
                        rs.close();
                    }
                }
            }
            if (secondaryIndex != 0) {
                queryString = secondaryIndex == 1 ? "SELECT pg_prewarm('vectable_id_key','read')" : "SELECT pg_prewarm('vectable_id_key')";
                rs = st.executeQuery(queryString);
                try {
                    while (rs.next()) {
                    }
                }
                finally {
                    if (rs != null) {
                        rs.close();
                    }
                }
            }
            if (vectors != 0) {
                queryString = vectors == 1 ? "SELECT pg_prewarm('vectable','read')" : "SELECT pg_prewarm('vectable')";
                rs = st.executeQuery(queryString);
                try {
                    while (rs.next()) {
                    }
                }
                finally {
                    if (rs != null) {
                        rs.close();
                    }
                }
            }
            st.execute("DROP EXTENSION pg_prewarm");
            int n = res;
            return n;
        }
    }

    @Override
    protected long storeSignatureRecord(SignatureRecord sigrec) throws SQLException {
        String sql = "SELECT insert_vec( '" + sigrec.getLSHVector().saveSQL() + "')";
        try (ResultSet rs = this.getReusableStatement().executeQuery(sql);){
            if (!rs.next()) {
                throw new SQLException("Did not get vector id after insertion");
            }
            long l = rs.getLong(1);
            return l;
        }
    }

    @Override
    protected int deleteVectors(long id, int countdiff) throws SQLException {
        int res = -100;
        String sql = "SELECT remove_vec( " + Long.toString(id) + "," + Integer.toString(countdiff) + ")";
        try (ResultSet rs = this.getReusableStatement().executeQuery(sql);){
            if (!rs.next()) {
                throw new SQLException("Did not get result code after deletion");
            }
            res = rs.getInt(1);
        }
        return res;
    }

    @Override
    protected int queryNearestVector(List<VectorResult> resultset, LSHVector vec, double simthresh, double sigthresh, int max) throws SQLException {
        PreparedStatement s = this.selectNearestVectorStatement.prepareIfNeeded(() -> this.initConnection().prepareStatement("WITH const(cvec) AS (VALUES( lshvector_in( CAST( ? AS cstring) ) ) ), comp AS ( SELECT id,count,cvec,vec,lshvector_compare(cvec,vec) AS cfunc FROM const,vectable        WHERE cvec % vec) SELECT id,count,(comp.cfunc).sim,(comp.cfunc).sig,vec FROM comp WHERE (comp.cfunc).sim > ? AND (comp.cfunc).sig > ? ORDER BY (comp.cfunc).sim DESC LIMIT ?"));
        s.setString(1, vec.saveSQL());
        s.setDouble(2, simthresh);
        s.setDouble(3, sigthresh);
        s.setInt(4, max);
        int total = 0;
        try (ResultSet rs = s.executeQuery();){
            while (rs.next()) {
                VectorResult curres = new VectorResult();
                resultset.add(curres);
                curres.vectorid = rs.getLong(1);
                curres.hitcount = rs.getInt(2);
                curres.sim = rs.getDouble(3);
                curres.signif = rs.getDouble(4);
                String vecstring = rs.getString(5);
                try {
                    curres.vec = ((WeightedLSHCosineVectorFactory)this.vectorFactory).restoreVectorFromSql(vecstring);
                }
                catch (IOException e) {
                    throw new SQLException(e.getMessage());
                }
                total += curres.hitcount;
            }
            int n = total;
            return n;
        }
    }

    @Override
    protected void queryNearestVector(QueryNearestVector query) throws SQLException {
        ResponseNearestVector response = query.nearresponse;
        response.totalvec = 0;
        response.totalmatch = 0;
        response.uniquematch = 0;
        int vectormax = query.vectormax;
        if (vectormax == 0) {
            vectormax = 2000000;
        }
        Iterator<FunctionDescription> iter = query.manage.listAllFunctions();
        while (iter.hasNext()) {
            LSHVector thevec;
            double len2;
            FunctionDescription frec = iter.next();
            SignatureRecord srec = frec.getSignatureRecord();
            if (srec == null || (len2 = ((WeightedLSHCosineVectorFactory)this.vectorFactory).getSelfSignificance(thevec = srec.getLSHVector())) < query.signifthresh) continue;
            ++response.totalvec;
            ArrayList<VectorResult> resultset = new ArrayList<VectorResult>();
            this.queryNearestVector(resultset, thevec, query.thresh, query.signifthresh, vectormax);
            if (resultset.isEmpty()) continue;
            SimilarityVectorResult simres = new SimilarityVectorResult(frec);
            simres.addNotes(resultset);
            response.totalmatch += simres.getTotalCount();
            if (simres.getTotalCount() == 1) {
                ++response.uniquematch;
            }
            response.result.add(simres);
        }
    }

    @Override
    protected VectorResult queryVectorId(long id) throws SQLException {
        PreparedStatement s = this.selectVectorByRowIdStatement.prepareIfNeeded(() -> this.initConnection().prepareStatement("SELECT id,count,vec FROM vectable WHERE id = ?"));
        s.setLong(1, id);
        try (ResultSet rs = s.executeQuery();){
            VectorResult rowres;
            if (!rs.next()) {
                throw new SQLException("Bad vectable rowid");
            }
            try {
                rowres = new VectorResult();
                rowres.vectorid = rs.getLong(1);
                rowres.hitcount = rs.getInt(2);
                rowres.vec = ((WeightedLSHCosineVectorFactory)this.vectorFactory).restoreVectorFromSql(rs.getString(3));
            }
            catch (IOException e) {
                throw new SQLException(e.getMessage());
            }
            VectorResult vectorResult = rowres;
            return vectorResult;
        }
    }

    @Override
    public String getUserName() {
        return this.postgresDs.getUserName();
    }

    @Override
    public void setUserName(String userName) {
        if (this.postgresDs.getStatus() == FunctionDatabase.Status.Ready) {
            throw new IllegalStateException("Connection has already been established");
        }
        this.postgresDs.setPreferredUserName(userName);
    }

    @Override
    public QueryResponseRecord doQuery(BSimQuery<?> query, Connection c) throws SQLException, LSHException, FunctionDatabase.DatabaseNonFatalException {
        if (query instanceof PrewarmRequest) {
            PrewarmRequest q = (PrewarmRequest)query;
            this.fdbPrewarm(q, c);
        } else if (query instanceof PasswordChange) {
            PasswordChange q = (PasswordChange)query;
            this.fdbPasswordChange(q, c);
        } else if (query instanceof AdjustVectorIndex) {
            AdjustVectorIndex q = (AdjustVectorIndex)query;
            this.fdbAdjustVectorIndex(q, c);
        } else {
            return super.doQuery(query, c);
        }
        return query.getResponse();
    }

    private void fdbAdjustVectorIndex(AdjustVectorIndex query, Connection c) throws SQLException {
        ResponseAdjustIndex response = query.adjustresponse;
        response.success = false;
        if (query.doRebuild) {
            this.rebuildIndex(c);
        } else {
            this.dropIndex(c);
        }
        response.success = true;
    }

    private void fdbPrewarm(PrewarmRequest request, Connection c) throws SQLException {
        ResponsePrewarm response = request.prewarmresponse;
        response.blockCount = this.preWarm(c, request.mainIndexConfig, request.secondaryIndexConfig, request.vectorTableConfig);
    }

    private void fdbPasswordChange(PasswordChange query, Connection c) throws LSHException {
        ResponsePassword response = query.passwordResponse;
        if (query.username == null) {
            throw new LSHException("Missing username for password change");
        }
        if (query.newPassword == null || query.newPassword.length == 0) {
            throw new LSHException("No password provided");
        }
        response.changeSuccessful = true;
        response.errorMessage = null;
        try {
            this.changePassword(c, query.username, query.newPassword);
        }
        catch (SQLException e) {
            response.changeSuccessful = false;
            response.errorMessage = e.getMessage();
        }
    }

    @Override
    public String formatBitAndSQL(String v1, String v2) {
        return "(" + v1 + " & " + v2 + ")";
    }

    static {
        Logger postgresLogger = Logger.getLogger("org.postgresql.Driver");
        postgresLogger.setLevel(Level.FINEST);
    }
}

