Spaces:
Sleeping
Sleeping
File size: 4,164 Bytes
e37eda0 3eb59a4 75829f5 937d1f9 bc8a89c 1d9b999 e37eda0 6a2a63a e37eda0 937d1f9 6a2a63a 937d1f9 5671d43 937d1f9 3eb59a4 5671d43 6a2a63a 5671d43 3eb59a4 6a2a63a 3eb59a4 6a2a63a 937d1f9 1c7e913 2599708 e37eda0 937d1f9 73b2770 1c7e913 73b2770 1c7e913 2599708 bc8a89c e37eda0 937d1f9 3eb59a4 937d1f9 e37eda0 937d1f9 bc8a89c 6a2a63a 1c7e913 6a2a63a bc8a89c 937d1f9 75829f5 937d1f9 a511dd2 e37eda0 6a2a63a a511dd2 1c7e913 6a2a63a a511dd2 6a2a63a a511dd2 6a2a63a 937d1f9 e37eda0 6a2a63a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import os
import streamlit as st
import pandas as pd
import sqlite3
import openai
from langchain import OpenAI
from langchain_community.agent_toolkits.sql.base import create_sql_agent
from langchain_community.utilities import SQLDatabase
from langchain_community.document_loaders import CSVLoader
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import OpenAIEmbeddings
from langchain.chains import RetrievalQA
import sqlparse
import logging
# OpenAI API key (ensure it is securely stored)
openai.api_key = os.getenv("OPENAI_API_KEY")
# Step 1: Upload CSV data file (or use default)
csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
if csv_file is None:
data = pd.read_csv("default_data.csv") # Use default CSV if no file is uploaded
st.write("Using default data.csv file.")
else:
data = pd.read_csv(csv_file)
st.write(f"Data Preview ({csv_file.name}):")
st.dataframe(data.head())
# Step 2: Load CSV data into SQLite database with dynamic table name
conn = sqlite3.connect(':memory:') # Use an in-memory SQLite database
table_name = csv_file.name.split('.')[0] if csv_file else "default_table"
data.to_sql(table_name, conn, index=False, if_exists='replace')
# SQL table metadata (for validation and schema)
valid_columns = list(data.columns)
# Debug: Display valid columns for user to verify
st.write(f"Valid columns: {valid_columns}")
# Step 3: Set up the SQL Database for LangChain
db = SQLDatabase.from_uri('sqlite:///:memory:')
db.raw_connection = conn # Use the in-memory connection for LangChain
# Step 4: Create the SQL agent with increased iteration and time limits
sql_agent = create_sql_agent(
OpenAI(temperature=0),
db=db,
verbose=True,
max_iterations=20, # Increased iteration limit
max_execution_time=90 # Set timeout limit to 90 seconds
)
# Step 5: Use FAISS with RAG for context retrieval
embeddings = OpenAIEmbeddings()
loader = CSVLoader(file_path=csv_file.name if csv_file else "default_data.csv")
documents = loader.load()
vector_store = FAISS.from_documents(documents, embeddings)
retriever = vector_store.as_retriever()
rag_chain = RetrievalQA.from_chain_type(llm=OpenAI(temperature=0), retriever=retriever)
# Step 6: Define SQL validation helpers
def validate_sql(query, valid_columns):
"""Validates the SQL query by ensuring it references only valid columns."""
parsed = sqlparse.parse(query)
for token in parsed[0].tokens:
if token.ttype is None: # If it's a column name
column_name = str(token).strip()
if column_name not in valid_columns:
return False
return True
def validate_sql_with_sqlparse(query):
"""Validates SQL syntax using sqlparse."""
parsed_query = sqlparse.parse(query)
return len(parsed_query) > 0
# Step 7: Generate SQL query based on user input and run it with LangChain SQL Agent
user_prompt = st.text_input("Enter your natural language prompt:")
if user_prompt:
try:
# Step 8: Add valid column names to the prompt
column_hints = f" Use only these columns: {', '.join(valid_columns)}"
prompt_with_columns = user_prompt + column_hints
# Step 9: Retrieve context using RAG
context = rag_chain.run(prompt_with_columns)
st.write(f"Retrieved Context: {context}")
# Step 10: Generate SQL query using SQL agent
generated_sql = sql_agent.run(f"{prompt_with_columns} {context}")
# Debug: Display generated SQL query for inspection
st.write(f"Generated SQL Query: {generated_sql}")
# Step 11: Validate SQL query
if not validate_sql_with_sqlparse(generated_sql):
st.write("Generated SQL is not valid.")
elif not validate_sql(generated_sql, valid_columns):
st.write("Generated SQL references invalid columns.")
else:
# Step 12: Execute SQL query
result = pd.read_sql(generated_sql, conn)
st.write("Query Results:")
st.dataframe(result)
except Exception as e:
logging.error(f"An error occurred: {e}")
st.write(f"Error: {e}")
|