arithescientist commited on
Commit
6dd2b20
1 Parent(s): 9bff135

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -144
app.py CHANGED
@@ -3,13 +3,13 @@ import streamlit as st
3
  import pandas as pd
4
  import sqlite3
5
  import logging
6
- import json
7
  from langchain.agents.agent_toolkits import SQLDatabaseToolkit
 
 
8
  from langchain.sql_database import SQLDatabase
9
- from langchain.prompts import PromptTemplate
10
- from langchain.chains import LLMChain
11
- # Import ChatOpenAI from langchain_community
12
- from langchain_community.chat_models import ChatOpenAI
13
 
14
  # Initialize logging
15
  logging.basicConfig(level=logging.INFO)
@@ -20,8 +20,6 @@ if 'history' not in st.session_state:
20
 
21
  # OpenAI API key
22
  openai_api_key = os.getenv("OPENAI_API_KEY")
23
- # Alternatively, you can set your API key directly
24
- # openai_api_key = "YOUR_OPENAI_API_KEY"
25
 
26
  # Check if the API key is set
27
  if not openai_api_key:
@@ -54,122 +52,18 @@ engine = SQLDatabase.from_uri(f"sqlite:///{db_file}", include_tables=[table_name
54
  # Initialize the LLM
55
  llm = ChatOpenAI(temperature=0, openai_api_key=openai_api_key)
56
 
57
- # Step 3: Create the agent toolkit (not used directly in the layered approach but kept for completeness)
58
  toolkit = SQLDatabaseToolkit(db=engine, llm=llm)
59
 
60
- # Step 4: Define the layered functions
61
-
62
- # Layer 1: Understanding the Question
63
- def parse_user_question(question):
64
- parsing_prompt = f"""
65
- You are an assistant that extracts key information from user questions for SQL query generation.
66
- Given the following question, identify the relevant columns, tables, and any conditions or filters needed.
67
-
68
- Question: "{question}"
69
-
70
- Provide your answer in the following JSON format:
71
- {{
72
- "columns": [list of columns or aggregation functions],
73
- "table": "table_name",
74
- "conditions": "SQL WHERE clause conditions",
75
- "aggregation": "any aggregation functions needed",
76
- "group_by": [list of columns to group by],
77
- "order_by": "column to order by and direction (e.g., 'Total_Sales DESC')",
78
- "limit": "number of records to return"
79
- }}
80
-
81
- Answer:
82
- """
83
- # Use llm.predict instead of llm()
84
- response = llm.predict(parsing_prompt)
85
- try:
86
- parsed_query = json.loads(response)
87
- return parsed_query
88
- except json.JSONDecodeError as e:
89
- logging.error(f"JSON decoding error: {e}")
90
- return None
91
-
92
- # Layer 2: Generating the SQL Query
93
- def construct_sql_query(parsed_info):
94
- if not parsed_info:
95
- return None
96
-
97
- columns = ', '.join(parsed_info.get('columns', ['*']))
98
- table = parsed_info.get('table', table_name)
99
- conditions = parsed_info.get('conditions', '')
100
- group_by = parsed_info.get('group_by', [])
101
- order_by = parsed_info.get('order_by', '')
102
- limit = parsed_info.get('limit', '')
103
-
104
- sql_query = f"SELECT {columns} FROM {table}"
105
-
106
- if conditions:
107
- sql_query += f" WHERE {conditions}"
108
-
109
- if group_by:
110
- sql_query += f" GROUP BY {', '.join(group_by)}"
111
-
112
- if order_by:
113
- sql_query += f" ORDER BY {order_by}"
114
-
115
- if limit:
116
- sql_query += f" LIMIT {limit}"
117
-
118
- return sql_query
119
-
120
- # Layer 3: Executing the Query and Retrieving Data
121
- def execute_sql_query(sql_query):
122
- try:
123
- result = pd.read_sql_query(sql_query, conn)
124
- return result
125
- except Exception as e:
126
- logging.error(f"SQL execution error: {e}")
127
- return None
128
-
129
- # Layer 4: Formatting and Presenting the Results
130
- def display_results(result):
131
- if result is not None and not result.empty:
132
- st.session_state.history.append({"role": "assistant", "content": "Here are the results:"})
133
- st.session_state.history.append({"role": "assistant", "content": result.head(10)})
134
- else:
135
- assistant_response = "The query returned no results. Please try a different question."
136
- st.session_state.history.append({"role": "assistant", "content": assistant_response})
137
-
138
- # Layer 5: Generating Insights or Additional Analysis (Optional)
139
- def generate_insights(question, result):
140
- insights_template = """
141
- You are an expert data analyst. Based on the user's question and the SQL query result provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words.
142
-
143
- User's Question: {question}
144
-
145
- SQL Query Result:
146
- {result}
147
-
148
- Concise Analysis:
149
- """
150
- insights_prompt = PromptTemplate(template=insights_template, input_variables=['question', 'result'])
151
- insights_chain = LLMChain(llm=llm, prompt=insights_prompt)
152
-
153
- result_str = result.to_string(index=False)
154
- insights = insights_chain.run({'question': question, 'result': result_str})
155
-
156
- st.session_state.history.append({"role": "assistant", "content": insights})
157
-
158
- # Function to Generate Data Summary (for non-SQL responses)
159
- def generate_data_summary():
160
- summary_prompt = f"""
161
- You are an assistant that provides a summary of the dataset.
162
-
163
- Dataset Description:
164
- {data.describe(include='all').to_string()}
165
-
166
- Provide a concise summary of the dataset, highlighting key statistics and any notable observations.
167
- """
168
- # Use llm.predict instead of llm()
169
- summary = llm.predict(summary_prompt)
170
- return summary
171
-
172
- # Step 5: Define the callback function
173
  def process_input():
174
  user_prompt = st.session_state['user_input']
175
 
@@ -178,30 +72,58 @@ def process_input():
178
  # Append user message to history
179
  st.session_state.history.append({"role": "user", "content": user_prompt})
180
 
 
181
  with st.spinner("Processing..."):
182
- # Layer 1: Understand the question
183
- parsed_query = parse_user_question(user_prompt)
184
- logging.info(f"Parsed Query: {parsed_query}")
185
-
186
- if parsed_query and parsed_query.get('columns'):
187
- # Layer 2: Generate the SQL query
188
- sql_query = construct_sql_query(parsed_query)
189
- logging.info(f"Constructed SQL Query: {sql_query}")
190
-
191
- # Layer 3: Execute the SQL query and get the result
192
- result = execute_sql_query(sql_query)
193
-
194
- # Layer 4: Display the results
195
- display_results(result)
196
-
197
- # Layer 5: Generate insights (optional)
198
- if result is not None and not result.empty:
199
- generate_insights(user_prompt, result.head(10))
 
 
 
200
  else:
201
- # If no columns are identified, provide a summary
202
- summary = generate_data_summary()
203
- st.session_state.history.append({"role": "assistant", "content": summary})
204
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  except Exception as e:
206
  logging.error(f"An error occurred: {e}")
207
  assistant_response = f"Error: {e}"
@@ -210,7 +132,7 @@ def process_input():
210
  # Reset user input
211
  st.session_state['user_input'] = ''
212
 
213
- # Step 6: Display conversation history
214
  for message in st.session_state.history:
215
  if message['role'] == 'user':
216
  st.markdown(f"**User:** {message['content']}")
 
3
  import pandas as pd
4
  import sqlite3
5
  import logging
6
+ from langchain.agents import create_sql_agent
7
  from langchain.agents.agent_toolkits import SQLDatabaseToolkit
8
+ from langchain.agents.agent_types import AgentType
9
+ from langchain.llms import OpenAI
10
  from langchain.sql_database import SQLDatabase
11
+ from langchain.chat_models import ChatOpenAI
12
+ from langchain.evaluation import load_evaluator
 
 
13
 
14
  # Initialize logging
15
  logging.basicConfig(level=logging.INFO)
 
20
 
21
  # OpenAI API key
22
  openai_api_key = os.getenv("OPENAI_API_KEY")
 
 
23
 
24
  # Check if the API key is set
25
  if not openai_api_key:
 
52
  # Initialize the LLM
53
  llm = ChatOpenAI(temperature=0, openai_api_key=openai_api_key)
54
 
55
+ # Step 3: Create the agent
56
  toolkit = SQLDatabaseToolkit(db=engine, llm=llm)
57
 
58
+ sql_agent = create_sql_agent(
59
+ llm=llm,
60
+ toolkit=toolkit,
61
+ verbose=True,
62
+ agent_type=AgentType.OPENAI_FUNCTIONS,
63
+ max_iterations=5
64
+ )
65
+
66
+ # Step 4: Define the callback function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def process_input():
68
  user_prompt = st.session_state['user_input']
69
 
 
72
  # Append user message to history
73
  st.session_state.history.append({"role": "user", "content": user_prompt})
74
 
75
+ # Use the agent to generate the SQL query and get the response
76
  with st.spinner("Processing..."):
77
+ response = sql_agent.run(user_prompt)
78
+
79
+ # Check if the response contains a SQL query
80
+ if "```sql" in response:
81
+ # Extract the SQL query
82
+ start_index = response.find("```sql") + len("```sql")
83
+ end_index = response.find("```", start_index)
84
+ sql_query = response[start_index:end_index].strip()
85
+ else:
86
+ # If no SQL code is found, assume the entire response is the SQL query
87
+ sql_query = response.strip()
88
+
89
+ logging.info(f"Generated SQL Query: {sql_query}")
90
+
91
+ # Attempt to execute SQL query and handle exceptions
92
+ try:
93
+ result = pd.read_sql_query(sql_query, conn)
94
+
95
+ if result.empty:
96
+ assistant_response = "The query returned no results. Please try a different question."
97
+ st.session_state.history.append({"role": "assistant", "content": assistant_response})
98
  else:
99
+ # Limit the result to first 10 rows for display
100
+ result_display = result.head(10)
101
+ st.session_state.history.append({"role": "assistant", "content": "Here are the results:"})
102
+ st.session_state.history.append({"role": "assistant", "content": result_display})
103
+
104
+ # Generate insights based on the query result
105
+ insights_template = """
106
+ You are an expert data analyst. Based on the user's question and the SQL query result provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words.
107
+
108
+ User's Question: {question}
109
+
110
+ SQL Query Result:
111
+ {result}
112
+
113
+ Concise Analysis:
114
+ """
115
+ insights_prompt = PromptTemplate(template=insights_template, input_variables=['question', 'result'])
116
+ insights_chain = LLMChain(llm=llm, prompt=insights_prompt)
117
+
118
+ result_str = result_display.to_string(index=False)
119
+ insights = insights_chain.run({'question': user_prompt, 'result': result_str})
120
+
121
+ # Append the assistant's insights to the history
122
+ st.session_state.history.append({"role": "assistant", "content": insights})
123
+ except Exception as e:
124
+ logging.error(f"An error occurred during SQL execution: {e}")
125
+ assistant_response = f"Error executing SQL query: {e}"
126
+ st.session_state.history.append({"role": "assistant", "content": assistant_response})
127
  except Exception as e:
128
  logging.error(f"An error occurred: {e}")
129
  assistant_response = f"Error: {e}"
 
132
  # Reset user input
133
  st.session_state['user_input'] = ''
134
 
135
+ # Step 5: Display conversation history
136
  for message in st.session_state.history:
137
  if message['role'] == 'user':
138
  st.markdown(f"**User:** {message['content']}")