File size: 5,280 Bytes
e37eda0
3eb59a4
 
75829f5
937d1f9
9918408
 
 
 
 
 
e37eda0
6a2a63a
e37eda0
9918408
937d1f9
9918408
 
 
 
cb5f50e
9918408
 
 
 
 
 
 
 
 
 
 
 
 
cb5f50e
9918408
 
 
 
 
 
 
 
 
cb5f50e
9918408
3eb59a4
5671d43
6a2a63a
5671d43
 
3eb59a4
6a2a63a
3eb59a4
 
9918408
 
36fba91
 
6a2a63a
 
 
 
 
9918408
 
1c7e913
 
9918408
36fba91
9918408
937d1f9
9918408
1c7e913
9918408
1c7e913
 
73b2770
 
1c7e913
2599708
9918408
 
 
 
 
 
 
 
937d1f9
 
9918408
937d1f9
 
cb5f50e
6a2a63a
9918408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
937d1f9
75829f5
937d1f9
9918408
a511dd2
 
 
9918408
cb5f50e
e37eda0
6a2a63a
9918408
cb5f50e
9918408
 
6a2a63a
 
9918408
6a2a63a
 
 
 
 
9918408
6a2a63a
 
 
 
937d1f9
e37eda0
6a2a63a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import streamlit as st
import pandas as pd
import sqlite3
import openai
from langchain_openai import AzureChatOpenAI
from langchain_community.agent_toolkits.sql.base import create_sql_agent
from langchain_community.utilities import SQLDatabase
from langchain_community.document_loaders import CSVLoader
from langchain_community.vectorstores import FAISS
from langchain_openai.embeddings import AzureOpenAIEmbeddings
from langchain.chains import RetrievalQA
import sqlparse
import logging
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

# Set up API credentials and environment variables
api_key = os.getenv("OPENAI_API_KEY")
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
api_version = os.getenv("OPENAI_API_VERSION", "2023-05-15")  # Set a default if not provided
chat_model = os.getenv("CHAT_MODEL")
chat_deployment = os.getenv("CHAT_DEPLOYMENT")
embed_model = os.getenv("EMBED_MODEL")
embed_deployment = os.getenv("EMBED_DEPLOYMENT")

# Default to a specific endpoint if the environment variable is missing
if not azure_endpoint:
    azure_endpoint = "https://<your-azure-endpoint>.openai.azure.com"  # Replace with your actual endpoint

# OpenAI API key (ensure it is securely stored)
openai.api_key = api_key

# Initialize Azure OpenAI LLM (Language Model)
llm = AzureChatOpenAI(
    temperature=0,
    model=chat_model,
    deployment_name=chat_deployment,
    api_key=api_key,
    azure_endpoint=azure_endpoint,
    api_version=api_version
)

# Step 1: Upload CSV data file (or use default)
csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
if csv_file is None:
    data = pd.read_csv("default_data.csv")  # Use default CSV if no file is uploaded
    st.write("Using default data.csv file.")
else:
    data = pd.read_csv(csv_file)
    st.write(f"Data Preview ({csv_file.name}):")
    st.dataframe(data.head())

# Step 2: Load CSV data into a persistent SQLite database
# Use a persistent database file instead of in-memory to retain schema context
db_file = 'my_database.db'
conn = sqlite3.connect(db_file)
table_name = csv_file.name.split('.')[0] if csv_file else "default_table"
data.to_sql(table_name, conn, index=False, if_exists='replace')

# SQL table metadata (for validation and schema)
valid_columns = list(data.columns)

# Debug: Display valid columns for user to verify
st.write(f"Valid columns: {valid_columns}")

# Step 3: Set up the SQL Database for LangChain
db = SQLDatabase.from_uri(f'sqlite:///{db_file}')
db.raw_connection = conn  # Use the persistent database connection for LangChain

# Step 4: Create the SQL agent with increased iteration and time limits
sql_agent = create_sql_agent(
    llm,
    db=db,
    verbose=True,
    max_iterations=20,  # Increased iteration limit
    max_execution_time=90  # Set timeout limit to 90 seconds
)

# Step 5: Use FAISS with RAG for context retrieval
embeddings = AzureOpenAIEmbeddings(
    model=embed_model,
    deployment_name=embed_deployment,
    azure_endpoint=azure_endpoint,
    api_key=api_key,
    api_version=api_version
)
loader = CSVLoader(file_path=csv_file.name if csv_file else "default_data.csv")
documents = loader.load()

vector_store = FAISS.from_documents(documents, embeddings)
retriever = vector_store.as_retriever()
rag_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)

# Step 6: Define SQL validation helpers
def validate_sql(query, valid_columns):
    """Validates the SQL query by ensuring it references only valid columns."""
    parsed = sqlparse.parse(query)
    for token in parsed[0].tokens:
        if token.ttype is None:  # If it's a column name
            column_name = str(token).strip()
            if column_name not in valid_columns:
                return False
    return True

def validate_sql_with_sqlparse(query):
    """Validates SQL syntax using sqlparse."""
    parsed_query = sqlparse.parse(query)
    return len(parsed_query) > 0

# Step 7: Generate SQL query based on user input and run it with LangChain SQL Agent
user_prompt = st.text_input("Enter your natural language prompt:")
if user_prompt:
    try:
        # Step 8: Add valid column names to the prompt
        column_hints = f" Use only these columns: {', '.join(valid_columns)}"
        prompt_with_columns = user_prompt + column_hints

        # Step 9: Retrieve context using RAG
        context = rag_chain.run(prompt_with_columns)
        st.write(f"Retrieved Context: {context}")

        # Step 10: Generate SQL query using SQL agent
        generated_sql = sql_agent.run(f"{prompt_with_columns} {context}")
        
        # Debug: Display generated SQL query for inspection
        st.write(f"Generated SQL Query: {generated_sql}")

        # Step 11: Validate SQL query
        if not validate_sql_with_sqlparse(generated_sql):
            st.write("Generated SQL is not valid.")
        elif not validate_sql(generated_sql, valid_columns):
            st.write("Generated SQL references invalid columns.")
        else:
            # Step 12: Execute SQL query
            result = pd.read_sql(generated_sql, conn)
            st.write("Query Results:")
            st.dataframe(result)

    except Exception as e:
        logging.error(f"An error occurred: {e}")
        st.write(f"Error: {e}")