arithescientist commited on
Commit
f0e4f1b
1 Parent(s): d0ab6a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -183
app.py CHANGED
@@ -2,16 +2,31 @@ import os
2
  import streamlit as st
3
  import pandas as pd
4
  import sqlite3
5
- from langchain import OpenAI, LLMChain, PromptTemplate
6
- import sqlparse
7
  import logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  # Initialize conversation history
10
  if 'history' not in st.session_state:
11
  st.session_state.history = []
12
 
13
  # OpenAI API key (ensure it is securely stored)
14
- # You can set the API key in your environment variables or a .env file
15
  openai_api_key = os.getenv("OPENAI_API_KEY")
16
 
17
  # Check if the API key is set
@@ -20,7 +35,7 @@ if not openai_api_key:
20
  st.stop()
21
 
22
  # Step 1: Upload CSV data file (or use default)
23
- st.title("Natural Language to SQL Query App with Enhanced Insights")
24
  st.write("Upload a CSV file to get started, or use the default dataset.")
25
 
26
  csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
@@ -43,117 +58,64 @@ data.to_sql(table_name, conn, index=False, if_exists='replace')
43
  valid_columns = list(data.columns)
44
  st.write(f"Valid columns: {valid_columns}")
45
 
46
- # Step 3: Set up the LLM Chains
47
- # SQL Generation Chain
48
- sql_template = """
49
- You are an expert data scientist. Given a natural language question, the name of the table, and a list of valid columns, generate a valid SQL query that answers the question.
50
-
51
- Ensure that:
52
-
53
- - You only use the columns provided.
54
- - When performing string comparisons in the WHERE clause, make them case-insensitive by using 'COLLATE NOCASE' or the LOWER() function.
55
- - Do not use 'COLLATE NOCASE' in ORDER BY clauses unless sorting a string column.
56
- - Do not apply 'COLLATE NOCASE' to numeric columns.
57
-
58
- If the question is vague or open-ended and does not pertain to specific data retrieval, respond with "NO_SQL" to indicate that a SQL query should not be generated.
59
-
60
- Question: {question}
61
-
62
- Table name: {table_name}
63
-
64
- Valid columns: {columns}
65
-
66
- SQL Query:
 
 
 
 
 
 
 
67
  """
68
- sql_prompt = PromptTemplate(template=sql_template, input_variables=['question', 'table_name', 'columns'])
69
- llm = OpenAI(temperature=0, openai_api_key=openai_api_key, max_tokens = 180)
70
- sql_generation_chain = LLMChain(llm=llm, prompt=sql_prompt)
71
-
72
- # Insights Generation Chain
73
- insights_template = """
74
- You are an expert data scientist. Based on the user's question and the SQL query result provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words.
75
-
76
- User's Question: {question}
77
-
78
- SQL Query Result:
79
- {result}
80
-
81
- Concise Analysis (max 200 words):
82
- """
83
- insights_prompt = PromptTemplate(template=insights_template, input_variables=['question', 'result'])
84
- insights_chain = LLMChain(llm=llm, prompt=insights_prompt)
85
-
86
- # General Insights and Recommendations Chain
87
- general_insights_template = """
88
- You are an expert data scientist. Based on the entire dataset provided below, generate a concise analysis with key insights and recommendations. Limit the response to 150 words.
89
-
90
- Dataset Summary:
91
- {dataset_summary}
92
-
93
- Concise Analysis and Recommendations (max 150 words):
94
- """
95
- general_insights_prompt = PromptTemplate(template=general_insights_template, input_variables=['dataset_summary'])
96
- general_insights_chain = LLMChain(llm=llm, prompt=general_insights_prompt)
97
-
98
- # Optional: Clean up function to remove incorrect COLLATE NOCASE usage
99
- def clean_sql_query(query):
100
- """Removes incorrect usage of COLLATE NOCASE from the SQL query."""
101
- parsed = sqlparse.parse(query)
102
- statements = []
103
- for stmt in parsed:
104
- tokens = []
105
- idx = 0
106
- while idx < len(stmt.tokens):
107
- token = stmt.tokens[idx]
108
- if (token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'COLLATE'):
109
- # Check if the next token is 'NOCASE'
110
- next_token = stmt.tokens[idx + 2] if idx + 2 < len(stmt.tokens) else None
111
- if next_token and next_token.value.upper() == 'NOCASE':
112
- # Skip 'COLLATE' and 'NOCASE' tokens
113
- idx += 3 # Skip 'COLLATE', whitespace, 'NOCASE'
114
- continue
115
- tokens.append(token)
116
- idx += 1
117
- statements.append(''.join([str(t) for t in tokens]))
118
- return ' '.join(statements)
119
 
120
- # Function to classify user query
121
- def classify_query(question):
122
- """Classify the user query as either 'SQL' or 'INSIGHTS'."""
123
- classification_template = """
124
- You are an AI assistant that classifies user queries into two categories: 'SQL' for specific data retrieval queries and 'INSIGHTS' for general analytical or recommendation queries.
125
-
126
- Determine the appropriate category for the following user question.
127
-
128
- Question: "{question}"
129
-
130
- Category (SQL/INSIGHTS):
131
- """
132
- classification_prompt = PromptTemplate(template=classification_template, input_variables=['question'])
133
- classification_chain = LLMChain(llm=llm, prompt=classification_prompt)
134
- category = classification_chain.run({'question': question}).strip().upper()
135
- if category.startswith('SQL'):
136
- return 'SQL'
137
- else:
138
- return 'INSIGHTS'
139
-
140
- # Function to generate dataset summary
141
- def generate_dataset_summary(data):
142
- """Generate a summary of the dataset for general insights."""
143
- summary_template = """
144
- You are an expert data scientist. Based on the dataset provided below, generate a concise summary that includes the number of records, number of columns, data types, and any notable features.
145
-
146
- Dataset:
147
- {data}
148
-
149
- Dataset Summary:
150
- """
151
- summary_prompt = PromptTemplate(template=summary_template, input_variables=['data'])
152
- summary_chain = LLMChain(llm=llm, prompt=summary_prompt)
153
- summary = summary_chain.run({'data': data.head().to_string(index=False)})
154
- return summary
155
-
156
- # Define the callback function
157
  def process_input():
158
  user_prompt = st.session_state['user_input']
159
 
@@ -162,77 +124,69 @@ def process_input():
162
  # Append user message to history
163
  st.session_state.history.append({"role": "user", "content": user_prompt})
164
 
165
- # Classify the user query
166
- category = classify_query(user_prompt)
167
- logging.info(f"User query classified as: {category}")
168
-
169
- if "COLUMNS" in user_prompt.upper():
170
- assistant_response = f"The columns are: {', '.join(valid_columns)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
172
- elif category == 'SQL':
173
- columns = ', '.join(valid_columns)
174
- generated_sql = sql_generation_chain.run({
175
- 'question': user_prompt,
176
- 'table_name': table_name,
177
- 'columns': columns
178
- }).strip()
179
 
180
- if generated_sql.upper() == "NO_SQL":
181
- # Handle cases where no SQL should be generated
182
- assistant_response = "Sure, let's discuss some general insights and recommendations based on the data."
183
-
184
- # Generate dataset summary
185
- dataset_summary = generate_dataset_summary(data)
186
-
187
- # Generate general insights and recommendations
188
- general_insights = general_insights_chain.run({
189
- 'dataset_summary': dataset_summary
190
- })
191
-
192
- # Append the assistant's insights to the history
193
- st.session_state.history.append({"role": "assistant", "content": general_insights})
194
  else:
195
- # Clean the SQL query
196
- cleaned_sql = clean_sql_query(generated_sql)
197
- logging.info(f"Generated SQL Query: {cleaned_sql}")
198
-
199
- # Attempt to execute SQL query and handle exceptions
200
- try:
201
- result = pd.read_sql_query(cleaned_sql, conn)
202
-
203
- if result.empty:
204
- assistant_response = "The query returned no results. Please try a different question."
205
- st.session_state.history.append({"role": "assistant", "content": assistant_response})
206
- else:
207
- # Convert the result to a string for the insights prompt
208
- result_str = result.head(10).to_string(index=False) # Limit to first 10 rows
209
-
210
- # Generate insights and recommendations based on the query result
211
- insights = insights_chain.run({
212
- 'question': user_prompt,
213
- 'result': result_str
214
- })
215
-
216
- # Append the assistant's insights to the history
217
- st.session_state.history.append({"role": "assistant", "content": insights})
218
- # Append the result DataFrame to the history
219
- st.session_state.history.append({"role": "assistant", "content": result})
220
- except Exception as e:
221
- logging.error(f"An error occurred during SQL execution: {e}")
222
- assistant_response = f"Error executing SQL query: {e}"
223
- st.session_state.history.append({"role": "assistant", "content": assistant_response})
224
- else: # INSIGHTS category
225
- # Generate dataset summary
226
- dataset_summary = generate_dataset_summary(data)
227
-
228
- # Generate general insights and recommendations
229
- general_insights = general_insights_chain.run({
230
- 'dataset_summary': dataset_summary
231
- })
232
-
233
- # Append the assistant's insights to the history
234
- st.session_state.history.append({"role": "assistant", "content": general_insights})
235
-
236
  except Exception as e:
237
  logging.error(f"An error occurred: {e}")
238
  assistant_response = f"Error: {e}"
@@ -241,7 +195,7 @@ def process_input():
241
  # Reset the user_input in session state
242
  st.session_state['user_input'] = ''
243
 
244
- # Display the conversation history
245
  for message in st.session_state.history:
246
  if message['role'] == 'user':
247
  st.markdown(f"**User:** {message['content']}")
@@ -253,4 +207,4 @@ for message in st.session_state.history:
253
  st.markdown(f"**Assistant:** {message['content']}")
254
 
255
  # Place the input field at the bottom with the callback
256
- st.text_input("Enter your message:", key='user_input', on_change=process_input)
 
2
  import streamlit as st
3
  import pandas as pd
4
  import sqlite3
 
 
5
  import logging
6
+ from langchain.agents import create_sql_agent
7
+ from langchain.agents.agent_toolkits import SQLDatabaseToolkit
8
+ from langchain.llms import OpenAI
9
+ from langchain.sql_database import SQLDatabase
10
+ from langchain.prompts import (
11
+ ChatPromptTemplate,
12
+ FewShotPromptTemplate,
13
+ PromptTemplate,
14
+ SystemMessagePromptTemplate,
15
+ HumanMessagePromptTemplate,
16
+ MessagesPlaceholder
17
+ )
18
+ from langchain.schema import HumanMessage
19
+ from langchain.chat_models import ChatOpenAI
20
+ from langchain.evaluation import load_evaluator
21
+
22
+ # Initialize logging
23
+ logging.basicConfig(level=logging.INFO)
24
 
25
  # Initialize conversation history
26
  if 'history' not in st.session_state:
27
  st.session_state.history = []
28
 
29
  # OpenAI API key (ensure it is securely stored)
 
30
  openai_api_key = os.getenv("OPENAI_API_KEY")
31
 
32
  # Check if the API key is set
 
35
  st.stop()
36
 
37
  # Step 1: Upload CSV data file (or use default)
38
+ st.title("Enhanced Natural Language to SQL Query App")
39
  st.write("Upload a CSV file to get started, or use the default dataset.")
40
 
41
  csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
 
58
  valid_columns = list(data.columns)
59
  st.write(f"Valid columns: {valid_columns}")
60
 
61
+ # Create SQLDatabase instance with custom table info
62
+ engine = SQLDatabase.from_uri(f"sqlite:///{db_file}", include_tables=[table_name])
63
+
64
+ # Step 3: Define the few-shot examples for the prompt
65
+ few_shot_examples = [
66
+ {
67
+ "input": "What is the total revenue for each category?",
68
+ "query": f"SELECT category, SUM(revenue) FROM {table_name} GROUP BY category;"
69
+ },
70
+ {
71
+ "input": "Show the top 5 products by sales.",
72
+ "query": f"SELECT product_name, sales FROM {table_name} ORDER BY sales DESC LIMIT 5;"
73
+ },
74
+ {
75
+ "input": "How many orders were placed in the last month?",
76
+ "query": f"SELECT COUNT(*) FROM {table_name} WHERE order_date >= DATE('now', '-1 month');"
77
+ }
78
+ ]
79
+
80
+ # Step 4: Define the prompt templates
81
+ system_prefix = """
82
+ You are an expert data analyst who can convert natural language questions into SQL queries.
83
+ Follow these guidelines:
84
+ 1. Only use the columns and tables provided.
85
+ 2. Use appropriate SQL syntax for SQLite.
86
+ 3. Ensure string comparisons are case-insensitive.
87
+ 4. Do not execute queries that could be harmful or unethical.
88
+ 5. Provide clear and concise SQL queries.
89
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ few_shot_prompt = FewShotPromptTemplate(
92
+ example_prompt=PromptTemplate.from_template("Question: {input}\nSQL Query: {query}"),
93
+ examples=few_shot_examples,
94
+ prefix=system_prefix,
95
+ suffix="Question: {input}\nSQL Query:",
96
+ input_variables=["input"]
97
+ )
98
+
99
+ # Step 5: Initialize the LLM and toolkit
100
+ llm = ChatOpenAI(temperature=0, openai_api_key=openai_api_key)
101
+ toolkit = SQLDatabaseToolkit(db=engine, llm=llm)
102
+
103
+ # Step 6: Create the agent
104
+ agent_prompt = ChatPromptTemplate.from_messages([
105
+ SystemMessagePromptTemplate(prompt=few_shot_prompt),
106
+ HumanMessagePromptTemplate.from_template("{input}")
107
+ ])
108
+
109
+ sql_agent = create_sql_agent(
110
+ llm=llm,
111
+ toolkit=toolkit,
112
+ prompt=agent_prompt,
113
+ verbose=True,
114
+ agent_type="openai-functions",
115
+ max_iterations=5
116
+ )
117
+
118
+ # Step 7: Define the callback function
 
 
 
 
 
 
 
 
 
119
  def process_input():
120
  user_prompt = st.session_state['user_input']
121
 
 
124
  # Append user message to history
125
  st.session_state.history.append({"role": "user", "content": user_prompt})
126
 
127
+ # Use the agent to generate the SQL query
128
+ with st.spinner("Generating SQL query..."):
129
+ response = sql_agent.run(user_prompt)
130
+
131
+ # Check if the response contains SQL code
132
+ if "SELECT" in response.upper():
133
+ sql_query = response.strip()
134
+ logging.info(f"Generated SQL Query: {sql_query}")
135
+
136
+ # Attempt to execute SQL query and handle exceptions
137
+ try:
138
+ result = pd.read_sql_query(sql_query, conn)
139
+
140
+ if result.empty:
141
+ assistant_response = "The query returned no results. Please try a different question."
142
+ st.session_state.history.append({"role": "assistant", "content": assistant_response})
143
+ else:
144
+ # Limit the result to first 10 rows for display
145
+ result_display = result.head(10)
146
+ st.session_state.history.append({"role": "assistant", "content": "Here are the results:"})
147
+ st.session_state.history.append({"role": "assistant", "content": result_display})
148
+
149
+ # Generate insights based on the query result
150
+ insights_template = """
151
+ You are an expert data analyst. Based on the user's question and the SQL query result provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words.
152
+
153
+ User's Question: {question}
154
+
155
+ SQL Query Result:
156
+ {result}
157
+
158
+ Concise Analysis:
159
+ """
160
+ insights_prompt = PromptTemplate(template=insights_template, input_variables=['question', 'result'])
161
+ insights_chain = LLMChain(llm=llm, prompt=insights_prompt)
162
+
163
+ result_str = result_display.to_string(index=False)
164
+ insights = insights_chain.run({'question': user_prompt, 'result': result_str})
165
+
166
+ # Append the assistant's insights to the history
167
+ st.session_state.history.append({"role": "assistant", "content": insights})
168
+ except Exception as e:
169
+ logging.error(f"An error occurred during SQL execution: {e}")
170
+ assistant_response = f"Error executing SQL query: {e}"
171
+ st.session_state.history.append({"role": "assistant", "content": assistant_response})
172
+ else:
173
+ # Handle responses that do not contain SQL queries
174
+ assistant_response = response
175
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
 
 
 
 
 
 
 
176
 
177
+ # Evaluate the response for harmful content
178
+ try:
179
+ evaluator = load_evaluator("harmful_content", llm=llm)
180
+ eval_result = evaluator.evaluate_strings(
181
+ input=user_prompt,
182
+ prediction=response
183
+ )
184
+ if eval_result['flagged']:
185
+ st.warning("The assistant's response may not be appropriate.")
 
 
 
 
 
186
  else:
187
+ logging.info("Response evaluated as appropriate.")
188
+ except Exception as e:
189
+ logging.error(f"An error occurred during evaluation: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  except Exception as e:
191
  logging.error(f"An error occurred: {e}")
192
  assistant_response = f"Error: {e}"
 
195
  # Reset the user_input in session state
196
  st.session_state['user_input'] = ''
197
 
198
+ # Step 8: Display the conversation history
199
  for message in st.session_state.history:
200
  if message['role'] == 'user':
201
  st.markdown(f"**User:** {message['content']}")
 
207
  st.markdown(f"**Assistant:** {message['content']}")
208
 
209
  # Place the input field at the bottom with the callback
210
+ st.text_input("Enter your message:", key='user_input', on_change=process_input)