Spaces:
Sleeping
Sleeping
File size: 5,280 Bytes
e37eda0 3eb59a4 75829f5 937d1f9 9918408 e37eda0 6a2a63a e37eda0 9918408 937d1f9 9918408 cb5f50e 9918408 cb5f50e 9918408 cb5f50e 9918408 3eb59a4 5671d43 6a2a63a 5671d43 3eb59a4 6a2a63a 3eb59a4 9918408 36fba91 6a2a63a 9918408 1c7e913 9918408 36fba91 9918408 937d1f9 9918408 1c7e913 9918408 1c7e913 73b2770 1c7e913 2599708 9918408 937d1f9 9918408 937d1f9 cb5f50e 6a2a63a 9918408 937d1f9 75829f5 937d1f9 9918408 a511dd2 9918408 cb5f50e e37eda0 6a2a63a 9918408 cb5f50e 9918408 6a2a63a 9918408 6a2a63a 9918408 6a2a63a 937d1f9 e37eda0 6a2a63a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import os
import streamlit as st
import pandas as pd
import sqlite3
import openai
from langchain_openai import AzureChatOpenAI
from langchain_community.agent_toolkits.sql.base import create_sql_agent
from langchain_community.utilities import SQLDatabase
from langchain_community.document_loaders import CSVLoader
from langchain_community.vectorstores import FAISS
from langchain_openai.embeddings import AzureOpenAIEmbeddings
from langchain.chains import RetrievalQA
import sqlparse
import logging
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
# Set up API credentials and environment variables
api_key = os.getenv("OPENAI_API_KEY")
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
api_version = os.getenv("OPENAI_API_VERSION", "2023-05-15") # Set a default if not provided
chat_model = os.getenv("CHAT_MODEL")
chat_deployment = os.getenv("CHAT_DEPLOYMENT")
embed_model = os.getenv("EMBED_MODEL")
embed_deployment = os.getenv("EMBED_DEPLOYMENT")
# Default to a specific endpoint if the environment variable is missing
if not azure_endpoint:
azure_endpoint = "https://<your-azure-endpoint>.openai.azure.com" # Replace with your actual endpoint
# OpenAI API key (ensure it is securely stored)
openai.api_key = api_key
# Initialize Azure OpenAI LLM (Language Model)
llm = AzureChatOpenAI(
temperature=0,
model=chat_model,
deployment_name=chat_deployment,
api_key=api_key,
azure_endpoint=azure_endpoint,
api_version=api_version
)
# Step 1: Upload CSV data file (or use default)
csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
if csv_file is None:
data = pd.read_csv("default_data.csv") # Use default CSV if no file is uploaded
st.write("Using default data.csv file.")
else:
data = pd.read_csv(csv_file)
st.write(f"Data Preview ({csv_file.name}):")
st.dataframe(data.head())
# Step 2: Load CSV data into a persistent SQLite database
# Use a persistent database file instead of in-memory to retain schema context
db_file = 'my_database.db'
conn = sqlite3.connect(db_file)
table_name = csv_file.name.split('.')[0] if csv_file else "default_table"
data.to_sql(table_name, conn, index=False, if_exists='replace')
# SQL table metadata (for validation and schema)
valid_columns = list(data.columns)
# Debug: Display valid columns for user to verify
st.write(f"Valid columns: {valid_columns}")
# Step 3: Set up the SQL Database for LangChain
db = SQLDatabase.from_uri(f'sqlite:///{db_file}')
db.raw_connection = conn # Use the persistent database connection for LangChain
# Step 4: Create the SQL agent with increased iteration and time limits
sql_agent = create_sql_agent(
llm,
db=db,
verbose=True,
max_iterations=20, # Increased iteration limit
max_execution_time=90 # Set timeout limit to 90 seconds
)
# Step 5: Use FAISS with RAG for context retrieval
embeddings = AzureOpenAIEmbeddings(
model=embed_model,
deployment_name=embed_deployment,
azure_endpoint=azure_endpoint,
api_key=api_key,
api_version=api_version
)
loader = CSVLoader(file_path=csv_file.name if csv_file else "default_data.csv")
documents = loader.load()
vector_store = FAISS.from_documents(documents, embeddings)
retriever = vector_store.as_retriever()
rag_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)
# Step 6: Define SQL validation helpers
def validate_sql(query, valid_columns):
"""Validates the SQL query by ensuring it references only valid columns."""
parsed = sqlparse.parse(query)
for token in parsed[0].tokens:
if token.ttype is None: # If it's a column name
column_name = str(token).strip()
if column_name not in valid_columns:
return False
return True
def validate_sql_with_sqlparse(query):
"""Validates SQL syntax using sqlparse."""
parsed_query = sqlparse.parse(query)
return len(parsed_query) > 0
# Step 7: Generate SQL query based on user input and run it with LangChain SQL Agent
user_prompt = st.text_input("Enter your natural language prompt:")
if user_prompt:
try:
# Step 8: Add valid column names to the prompt
column_hints = f" Use only these columns: {', '.join(valid_columns)}"
prompt_with_columns = user_prompt + column_hints
# Step 9: Retrieve context using RAG
context = rag_chain.run(prompt_with_columns)
st.write(f"Retrieved Context: {context}")
# Step 10: Generate SQL query using SQL agent
generated_sql = sql_agent.run(f"{prompt_with_columns} {context}")
# Debug: Display generated SQL query for inspection
st.write(f"Generated SQL Query: {generated_sql}")
# Step 11: Validate SQL query
if not validate_sql_with_sqlparse(generated_sql):
st.write("Generated SQL is not valid.")
elif not validate_sql(generated_sql, valid_columns):
st.write("Generated SQL references invalid columns.")
else:
# Step 12: Execute SQL query
result = pd.read_sql(generated_sql, conn)
st.write("Query Results:")
st.dataframe(result)
except Exception as e:
logging.error(f"An error occurred: {e}")
st.write(f"Error: {e}")
|