File size: 4,164 Bytes
e37eda0
3eb59a4
 
75829f5
937d1f9
 
bc8a89c
1d9b999
 
 
 
e37eda0
6a2a63a
e37eda0
937d1f9
6a2a63a
937d1f9
5671d43
937d1f9
3eb59a4
5671d43
6a2a63a
5671d43
 
3eb59a4
6a2a63a
3eb59a4
 
6a2a63a
 
 
 
 
 
 
937d1f9
1c7e913
 
 
2599708
e37eda0
 
937d1f9
73b2770
1c7e913
 
 
 
73b2770
 
1c7e913
2599708
bc8a89c
e37eda0
937d1f9
 
3eb59a4
937d1f9
 
e37eda0
937d1f9
bc8a89c
6a2a63a
 
1c7e913
 
 
 
 
 
6a2a63a
 
 
 
 
 
 
bc8a89c
937d1f9
75829f5
937d1f9
a511dd2
 
 
 
 
 
e37eda0
6a2a63a
a511dd2
 
1c7e913
 
6a2a63a
 
a511dd2
6a2a63a
 
 
 
 
a511dd2
6a2a63a
 
 
 
937d1f9
e37eda0
6a2a63a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import streamlit as st
import pandas as pd
import sqlite3
import openai
from langchain import OpenAI
from langchain_community.agent_toolkits.sql.base import create_sql_agent
from langchain_community.utilities import SQLDatabase
from langchain_community.document_loaders import CSVLoader
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import OpenAIEmbeddings
from langchain.chains import RetrievalQA
import sqlparse
import logging

# OpenAI API key (ensure it is securely stored)
openai.api_key = os.getenv("OPENAI_API_KEY")

# Step 1: Upload CSV data file (or use default)
csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
if csv_file is None:
    data = pd.read_csv("default_data.csv")  # Use default CSV if no file is uploaded
    st.write("Using default data.csv file.")
else:
    data = pd.read_csv(csv_file)
    st.write(f"Data Preview ({csv_file.name}):")
    st.dataframe(data.head())

# Step 2: Load CSV data into SQLite database with dynamic table name
conn = sqlite3.connect(':memory:')  # Use an in-memory SQLite database
table_name = csv_file.name.split('.')[0] if csv_file else "default_table"
data.to_sql(table_name, conn, index=False, if_exists='replace')

# SQL table metadata (for validation and schema)
valid_columns = list(data.columns)

# Debug: Display valid columns for user to verify
st.write(f"Valid columns: {valid_columns}")

# Step 3: Set up the SQL Database for LangChain
db = SQLDatabase.from_uri('sqlite:///:memory:')
db.raw_connection = conn  # Use the in-memory connection for LangChain

# Step 4: Create the SQL agent with increased iteration and time limits
sql_agent = create_sql_agent(
    OpenAI(temperature=0),
    db=db,
    verbose=True,
    max_iterations=20,  # Increased iteration limit
    max_execution_time=90  # Set timeout limit to 90 seconds
)

# Step 5: Use FAISS with RAG for context retrieval
embeddings = OpenAIEmbeddings()
loader = CSVLoader(file_path=csv_file.name if csv_file else "default_data.csv")
documents = loader.load()

vector_store = FAISS.from_documents(documents, embeddings)
retriever = vector_store.as_retriever()
rag_chain = RetrievalQA.from_chain_type(llm=OpenAI(temperature=0), retriever=retriever)

# Step 6: Define SQL validation helpers
def validate_sql(query, valid_columns):
    """Validates the SQL query by ensuring it references only valid columns."""
    parsed = sqlparse.parse(query)
    for token in parsed[0].tokens:
        if token.ttype is None:  # If it's a column name
            column_name = str(token).strip()
            if column_name not in valid_columns:
                return False
    return True

def validate_sql_with_sqlparse(query):
    """Validates SQL syntax using sqlparse."""
    parsed_query = sqlparse.parse(query)
    return len(parsed_query) > 0

# Step 7: Generate SQL query based on user input and run it with LangChain SQL Agent
user_prompt = st.text_input("Enter your natural language prompt:")
if user_prompt:
    try:
        # Step 8: Add valid column names to the prompt
        column_hints = f" Use only these columns: {', '.join(valid_columns)}"
        prompt_with_columns = user_prompt + column_hints

        # Step 9: Retrieve context using RAG
        context = rag_chain.run(prompt_with_columns)
        st.write(f"Retrieved Context: {context}")

        # Step 10: Generate SQL query using SQL agent
        generated_sql = sql_agent.run(f"{prompt_with_columns} {context}")
        
        # Debug: Display generated SQL query for inspection
        st.write(f"Generated SQL Query: {generated_sql}")

        # Step 11: Validate SQL query
        if not validate_sql_with_sqlparse(generated_sql):
            st.write("Generated SQL is not valid.")
        elif not validate_sql(generated_sql, valid_columns):
            st.write("Generated SQL references invalid columns.")
        else:
            # Step 12: Execute SQL query
            result = pd.read_sql(generated_sql, conn)
            st.write("Query Results:")
            st.dataframe(result)

    except Exception as e:
        logging.error(f"An error occurred: {e}")
        st.write(f"Error: {e}")