/*
 * Decompiled with CFR 0.152.
 */
package com.arun.bhardwaj.agents;

import com.arun.bhardwaj.utility.DbUtility;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

public class SQLAgent {
    private final Connection connection;
    private final HttpClient httpClient;
    private final String ollamaUrl = "http://localhost:11434/api/generate";
    private final String model = "llama3.2:latest";

    public SQLAgent(Connection connection) {
        this.connection = connection;
        this.httpClient = HttpClient.newHttpClient();
    }

    public List<Map<String, Object>> ask(String userQuestion) throws Exception {
        String schema = this.loadSchema();
        String prompt = this.buildPrompt(schema, userQuestion);
        String rawResponse = this.callLLM(prompt);
        String sql = this.extractSQL(rawResponse);
        System.out.println("Generated SQL is " + sql);
        this.validateSQL(sql);
        return this.executeSQL(sql);
    }

    private String buildPrompt(String schema, String question) {
        return "You are an expert SQLite SQL generator.\n\nRULES:\n- Return ONLY valid SQLite SQL\n- Do NOT explain\n- Only SELECT queries allowed\n- Always use LIMIT 100\n\nDATABASE SCHEMA:\n%s\n\nQUESTION:\n%s\n".formatted(schema, question);
    }

    private String callLLM(String prompt) throws Exception {
        String body = "{\n  \"model\": \"%s\",\n  \"prompt\": \"%s\",\n  \"stream\": false\n}\n".formatted("llama3.2:latest", this.escape(prompt));
        HttpRequest request = HttpRequest.newBuilder().uri(URI.create("http://localhost:11434/api/generate")).header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(body)).build();
        HttpResponse<String> response = this.httpClient.send(request, HttpResponse.BodyHandlers.ofString());
        return response.body();
    }

    private String extractSQL(String response) {
        int start = response.indexOf("\"response\":\"") + 12;
        int end = response.lastIndexOf("\"");
        String sql = response.substring(start, end);
        return sql.replace("\\n", " ").trim();
    }

    private void validateSQL(String sql) {
        String lower = sql.toLowerCase();
        if (!lower.startsWith("select")) {
            throw new SecurityException("Only SELECT queries allowed");
        }
        List<String> blocked = List.of("delete", "drop", "update", "insert", "alter", "truncate");
        for (String keyword : blocked) {
            if (!lower.contains(keyword)) continue;
            throw new SecurityException("Blocked SQL keyword: " + keyword);
        }
    }

    private List<Map<String, Object>> executeSQL(String sql) throws Exception {
        ArrayList<Map<String, Object>> results = new ArrayList<Map<String, Object>>();
        try (PreparedStatement ps = this.connection.prepareStatement(sql.substring(0, sql.indexOf("\"")));
             ResultSet rs = ps.executeQuery();){
            ResultSetMetaData meta = rs.getMetaData();
            int columnCount = meta.getColumnCount();
            while (rs.next()) {
                LinkedHashMap<String, Object> row = new LinkedHashMap<String, Object>();
                for (int i = 1; i <= columnCount; ++i) {
                    row.put(meta.getColumnName(i), rs.getObject(i));
                }
                results.add(row);
            }
        }
        return results;
    }

    private String loadSchema() throws Exception {
        StringBuilder schema = new StringBuilder();
        Statement stmt = this.connection.createStatement();
        ResultSet tables = stmt.executeQuery("SELECT name FROM sqlite_master WHERE type='table'");
        while (tables.next()) {
            String table = tables.getString("name");
            schema.append("\nTable ").append(table).append(":\n");
            ResultSet cols = stmt.executeQuery("PRAGMA table_info(" + table + ")");
            while (cols.next()) {
                schema.append(" - ").append(cols.getString("name")).append(" (").append(cols.getString("type")).append(")\n");
            }
        }
        return schema.toString();
    }

    private String escape(String text) {
        return text.replace("\"", "\\\"").replace("\n", "\\n");
    }

    public static void main(String[] args) {
        Connection conn = null;
        try {
            conn = DbUtility.getConnection();
            SQLAgent agent = new SQLAgent(conn);
            List<Map<String, Object>> result = agent.ask("Can you tell me contct number of Arun Bhardwaj ");
            result.forEach(System.out::println);
        }
        catch (SQLException e) {
            throw new RuntimeException(e);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}

