Spaces:
Sleeping
Sleeping
Ari
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -2,33 +2,30 @@ import os
|
|
2 |
import streamlit as st
|
3 |
import pandas as pd
|
4 |
import sqlite3
|
5 |
-
import numpy as np # For numerical operations
|
6 |
from langchain import OpenAI, LLMChain, PromptTemplate
|
7 |
import sqlparse
|
8 |
import logging
|
9 |
-
from sklearn.linear_model import LinearRegression # For machine learning tasks
|
10 |
-
from sklearn.model_selection import train_test_split
|
11 |
-
from sklearn.metrics import mean_squared_error, r2_score
|
12 |
|
13 |
# Initialize conversation history
|
14 |
if 'history' not in st.session_state:
|
15 |
st.session_state.history = []
|
16 |
|
17 |
-
# Set up logging
|
18 |
-
logging.basicConfig(level=logging.ERROR)
|
19 |
-
|
20 |
# OpenAI API key (ensure it is securely stored)
|
|
|
21 |
openai_api_key = os.getenv("OPENAI_API_KEY")
|
22 |
|
23 |
-
#
|
24 |
-
|
25 |
-
|
|
|
26 |
|
27 |
# Step 1: Upload CSV data file (or use default)
|
28 |
-
st.title("
|
|
|
|
|
29 |
csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
|
30 |
if csv_file is None:
|
31 |
-
data = pd.read_csv("default_data.csv") #
|
32 |
st.write("Using default_data.csv file.")
|
33 |
else:
|
34 |
data = pd.read_csv(csv_file)
|
@@ -45,30 +42,16 @@ data.to_sql(table_name, conn, index=False, if_exists='replace')
|
|
45 |
valid_columns = list(data.columns)
|
46 |
st.write(f"Valid columns: {valid_columns}")
|
47 |
|
48 |
-
# Step 3:
|
49 |
-
|
50 |
-
def extract_code(response):
|
51 |
-
"""Extracts code enclosed between <CODE> and </CODE> tags."""
|
52 |
-
import re
|
53 |
-
pattern = r"<CODE>(.*?)</CODE>"
|
54 |
-
match = re.search(pattern, response, re.DOTALL)
|
55 |
-
if match:
|
56 |
-
return match.group(1).strip()
|
57 |
-
else:
|
58 |
-
return None
|
59 |
-
|
60 |
-
# Step 4: Set up the LLM Chain to generate SQL queries or Python code
|
61 |
template = """
|
62 |
-
You are an expert data scientist
|
|
|
|
|
63 |
|
64 |
-
|
65 |
-
-
|
66 |
-
-
|
67 |
-
-
|
68 |
-
- Ensure that you only use the columns provided.
|
69 |
-
- Do not include any import statements in the code.
|
70 |
-
- For case-insensitive string comparisons in SQL, use either 'LOWER(column) = LOWER(value)' or 'column = value COLLATE NOCASE', but do not use both together.
|
71 |
-
- Provide the code between <CODE> and </CODE> tags.
|
72 |
|
73 |
Question: {question}
|
74 |
|
@@ -76,12 +59,34 @@ Table name: {table_name}
|
|
76 |
|
77 |
Valid columns: {columns}
|
78 |
|
79 |
-
|
80 |
"""
|
81 |
prompt = PromptTemplate(template=template, input_variables=['question', 'table_name', 'columns'])
|
82 |
-
llm =
|
83 |
sql_generation_chain = LLMChain(llm=llm, prompt=prompt)
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
# Define the callback function
|
86 |
def process_input():
|
87 |
user_prompt = st.session_state['user_input']
|
@@ -96,59 +101,27 @@ def process_input():
|
|
96 |
st.session_state.history.append({"role": "assistant", "content": assistant_response})
|
97 |
else:
|
98 |
columns = ', '.join(valid_columns)
|
99 |
-
|
100 |
'question': user_prompt,
|
101 |
'table_name': table_name,
|
102 |
'columns': columns
|
103 |
})
|
104 |
|
105 |
-
#
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
assistant_response = f"Error executing SQL query: {e}"
|
121 |
-
st.session_state.history.append({"role": "assistant", "content": assistant_response})
|
122 |
-
else:
|
123 |
-
# It's Python code
|
124 |
-
st.write(f"Generated Python Code:\n{code}")
|
125 |
-
try:
|
126 |
-
# Prepare the local namespace
|
127 |
-
local_vars = {
|
128 |
-
'pd': pd,
|
129 |
-
'np': np,
|
130 |
-
'data': data.copy(),
|
131 |
-
'result': None,
|
132 |
-
'LinearRegression': LinearRegression,
|
133 |
-
'train_test_split': train_test_split,
|
134 |
-
'mean_squared_error': mean_squared_error,
|
135 |
-
'r2_score': r2_score
|
136 |
-
}
|
137 |
-
exec(code, {}, local_vars)
|
138 |
-
result = local_vars.get('result')
|
139 |
-
if result is not None:
|
140 |
-
assistant_response = "Result:"
|
141 |
-
st.session_state.history.append({"role": "assistant", "content": assistant_response})
|
142 |
-
st.session_state.history.append({"role": "assistant", "content": result})
|
143 |
-
else:
|
144 |
-
assistant_response = "Code executed successfully."
|
145 |
-
st.session_state.history.append({"role": "assistant", "content": assistant_response})
|
146 |
-
except Exception as e:
|
147 |
-
logging.error(f"An error occurred during code execution: {e}")
|
148 |
-
assistant_response = f"Error executing code: {e}"
|
149 |
-
st.session_state.history.append({"role": "assistant", "content": assistant_response})
|
150 |
-
else:
|
151 |
-
assistant_response = response.strip()
|
152 |
st.session_state.history.append({"role": "assistant", "content": assistant_response})
|
153 |
|
154 |
except Exception as e:
|
@@ -164,17 +137,11 @@ for message in st.session_state.history:
|
|
164 |
if message['role'] == 'user':
|
165 |
st.markdown(f"**User:** {message['content']}")
|
166 |
elif message['role'] == 'assistant':
|
167 |
-
|
168 |
-
|
169 |
-
st.
|
170 |
-
st.dataframe(content)
|
171 |
-
elif isinstance(content, (int, float)):
|
172 |
-
st.markdown(f"**Assistant:** {content}")
|
173 |
-
elif isinstance(content, dict):
|
174 |
-
st.markdown("**Assistant:** Here are the results:")
|
175 |
-
st.json(content)
|
176 |
else:
|
177 |
-
st.markdown(f"**Assistant:** {content}")
|
178 |
|
179 |
# Place the input field at the bottom with the callback
|
180 |
st.text_input("Enter your message:", key='user_input', on_change=process_input)
|
|
|
2 |
import streamlit as st
|
3 |
import pandas as pd
|
4 |
import sqlite3
|
|
|
5 |
from langchain import OpenAI, LLMChain, PromptTemplate
|
6 |
import sqlparse
|
7 |
import logging
|
|
|
|
|
|
|
8 |
|
9 |
# Initialize conversation history
|
10 |
if 'history' not in st.session_state:
|
11 |
st.session_state.history = []
|
12 |
|
|
|
|
|
|
|
13 |
# OpenAI API key (ensure it is securely stored)
|
14 |
+
# You can set the API key in your environment variables or a .env file
|
15 |
openai_api_key = os.getenv("OPENAI_API_KEY")
|
16 |
|
17 |
+
# Check if the API key is set
|
18 |
+
if not openai_api_key:
|
19 |
+
st.error("OpenAI API key is not set. Please set the OPENAI_API_KEY environment variable.")
|
20 |
+
st.stop()
|
21 |
|
22 |
# Step 1: Upload CSV data file (or use default)
|
23 |
+
st.title("Natural Language to SQL Query App")
|
24 |
+
st.write("Upload a CSV file to get started, or use the default dataset.")
|
25 |
+
|
26 |
csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
|
27 |
if csv_file is None:
|
28 |
+
data = pd.read_csv("default_data.csv") # Ensure this file exists in your working directory
|
29 |
st.write("Using default_data.csv file.")
|
30 |
else:
|
31 |
data = pd.read_csv(csv_file)
|
|
|
42 |
valid_columns = list(data.columns)
|
43 |
st.write(f"Valid columns: {valid_columns}")
|
44 |
|
45 |
+
# Step 3: Set up the LLM Chain to generate SQL queries
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
template = """
|
47 |
+
You are an expert data scientist. Given a natural language question, the name of the table, and a list of valid columns, generate a valid SQL query that answers the question.
|
48 |
+
|
49 |
+
Ensure that:
|
50 |
|
51 |
+
- You only use the columns provided.
|
52 |
+
- When performing string comparisons in the WHERE clause, make them case-insensitive by using 'COLLATE NOCASE' or the LOWER() function.
|
53 |
+
- Do not use 'COLLATE NOCASE' in ORDER BY clauses unless sorting a string column.
|
54 |
+
- Do not apply 'COLLATE NOCASE' to numeric columns.
|
|
|
|
|
|
|
|
|
55 |
|
56 |
Question: {question}
|
57 |
|
|
|
59 |
|
60 |
Valid columns: {columns}
|
61 |
|
62 |
+
SQL Query:
|
63 |
"""
|
64 |
prompt = PromptTemplate(template=template, input_variables=['question', 'table_name', 'columns'])
|
65 |
+
llm = OpenAI(temperature=0, openai_api_key=openai_api_key)
|
66 |
sql_generation_chain = LLMChain(llm=llm, prompt=prompt)
|
67 |
|
68 |
+
# Optional: Clean up function to remove incorrect COLLATE NOCASE usage
|
69 |
+
def clean_sql_query(query):
|
70 |
+
"""Removes incorrect usage of COLLATE NOCASE from the SQL query."""
|
71 |
+
parsed = sqlparse.parse(query)
|
72 |
+
statements = []
|
73 |
+
for stmt in parsed:
|
74 |
+
tokens = []
|
75 |
+
idx = 0
|
76 |
+
while idx < len(stmt.tokens):
|
77 |
+
token = stmt.tokens[idx]
|
78 |
+
if (token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'COLLATE'):
|
79 |
+
# Check if the next token is 'NOCASE'
|
80 |
+
next_token = stmt.tokens[idx + 2] if idx + 2 < len(stmt.tokens) else None
|
81 |
+
if next_token and next_token.value.upper() == 'NOCASE':
|
82 |
+
# Skip 'COLLATE' and 'NOCASE' tokens
|
83 |
+
idx += 3 # Skip 'COLLATE', whitespace, 'NOCASE'
|
84 |
+
continue
|
85 |
+
tokens.append(token)
|
86 |
+
idx += 1
|
87 |
+
statements.append(''.join([str(t) for t in tokens]))
|
88 |
+
return ' '.join(statements)
|
89 |
+
|
90 |
# Define the callback function
|
91 |
def process_input():
|
92 |
user_prompt = st.session_state['user_input']
|
|
|
101 |
st.session_state.history.append({"role": "assistant", "content": assistant_response})
|
102 |
else:
|
103 |
columns = ', '.join(valid_columns)
|
104 |
+
generated_sql = sql_generation_chain.run({
|
105 |
'question': user_prompt,
|
106 |
'table_name': table_name,
|
107 |
'columns': columns
|
108 |
})
|
109 |
|
110 |
+
# Debug: Display generated SQL query for inspection
|
111 |
+
st.write(f"Generated SQL Query:\n{generated_sql}")
|
112 |
+
|
113 |
+
# Clean the SQL query
|
114 |
+
generated_sql = clean_sql_query(generated_sql)
|
115 |
+
|
116 |
+
# Attempt to execute SQL query and handle exceptions
|
117 |
+
try:
|
118 |
+
result = pd.read_sql_query(generated_sql, conn)
|
119 |
+
assistant_response = f"Generated SQL Query:\n{generated_sql}"
|
120 |
+
st.session_state.history.append({"role": "assistant", "content": assistant_response})
|
121 |
+
st.session_state.history.append({"role": "assistant", "content": result})
|
122 |
+
except Exception as e:
|
123 |
+
logging.error(f"An error occurred during SQL execution: {e}")
|
124 |
+
assistant_response = f"Error executing SQL query: {e}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
st.session_state.history.append({"role": "assistant", "content": assistant_response})
|
126 |
|
127 |
except Exception as e:
|
|
|
137 |
if message['role'] == 'user':
|
138 |
st.markdown(f"**User:** {message['content']}")
|
139 |
elif message['role'] == 'assistant':
|
140 |
+
if isinstance(message['content'], pd.DataFrame):
|
141 |
+
st.markdown("**Assistant:** Query Results:")
|
142 |
+
st.dataframe(message['content'])
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
else:
|
144 |
+
st.markdown(f"**Assistant:** {message['content']}")
|
145 |
|
146 |
# Place the input field at the bottom with the callback
|
147 |
st.text_input("Enter your message:", key='user_input', on_change=process_input)
|