arithescientist commited on
Commit
887daae
1 Parent(s): 6e8db38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -25
app.py CHANGED
@@ -3,13 +3,11 @@ import streamlit as st
3
  import pandas as pd
4
  import sqlite3
5
  import logging
6
- from langchain.agents import create_sql_agent, AgentType
7
  from langchain.agents.agent_toolkits import SQLDatabaseToolkit
8
  from langchain.llms import OpenAI
9
  from langchain.sql_database import SQLDatabase
10
- from langchain.prompts import (
11
- PromptTemplate,
12
- )
13
  from langchain.evaluation import load_evaluator
14
 
15
  # Initialize logging
@@ -19,7 +17,7 @@ logging.basicConfig(level=logging.INFO)
19
  if 'history' not in st.session_state:
20
  st.session_state.history = []
21
 
22
- # OpenAI API key (ensure it is securely stored)
23
  openai_api_key = os.getenv("OPENAI_API_KEY")
24
 
25
  # Check if the API key is set
@@ -33,7 +31,7 @@ st.write("Upload a CSV file to get started, or use the default dataset.")
33
 
34
  csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
35
  if csv_file is None:
36
- data = pd.read_csv("default_data.csv") # Ensure this file exists in your working directory
37
  st.write("Using default_data.csv file.")
38
  table_name = "default_table"
39
  else:
@@ -42,19 +40,19 @@ else:
42
  st.write(f"Data Preview ({csv_file.name}):")
43
  st.dataframe(data.head())
44
 
45
- # Step 2: Load CSV data into a persistent SQLite database
46
  db_file = 'my_database.db'
47
  conn = sqlite3.connect(db_file)
48
  data.to_sql(table_name, conn, index=False, if_exists='replace')
49
 
50
- # SQL table metadata (for validation and schema)
51
  valid_columns = list(data.columns)
52
  st.write(f"Valid columns: {valid_columns}")
53
 
54
- # Create SQLDatabase instance with custom table info
55
  engine = SQLDatabase.from_uri(f"sqlite:///{db_file}", include_tables=[table_name])
56
 
57
- # Step 3: Define the few-shot examples for the prompt
58
  few_shot_examples = [
59
  {
60
  "input": "What is the total revenue for each category?",
@@ -78,12 +76,15 @@ for ex in few_shot_examples:
78
  # Prepare table information
79
  table_info = f"Table: {table_name}\nColumns: {', '.join(valid_columns)}"
80
 
 
 
 
81
  # Step 4: Define the prompt template
82
  system_message = """
83
  You are an expert data analyst who can convert natural language questions into SQL queries.
84
 
85
- Available tools:
86
- {tool_descriptions}
87
 
88
  Follow these guidelines:
89
  1. Only use the columns and tables provided.
@@ -104,19 +105,16 @@ Question: {input}
104
  {agent_scratchpad}
105
  """
106
 
107
- # Initialize the LLM
108
- llm = OpenAI(temperature=0, openai_api_key=openai_api_key)
109
-
110
  # Step 5: Create the agent
111
  toolkit = SQLDatabaseToolkit(db=engine, llm=llm)
112
  tools = toolkit.get_tools()
113
- tool_names = [tool.name for tool in tools]
114
  tool_descriptions = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
115
 
116
  # Create the prompt
117
  agent_prompt = PromptTemplate(
118
  template=system_message,
119
- input_variables=["input", "agent_scratchpad", "table_info", "few_shot_examples", "tool_descriptions"]
120
  )
121
 
122
  # Create the agent
@@ -146,14 +144,15 @@ def process_input():
146
  table_info=table_info,
147
  few_shot_examples=few_shot_str,
148
  agent_scratchpad="",
149
- tool_descriptions=tool_descriptions
 
150
  )
151
 
152
  # Extract the SQL query from the agent's response
153
  sql_query = response.strip()
154
  logging.info(f"Generated SQL Query: {sql_query}")
155
 
156
- # Attempt to execute SQL query and handle exceptions
157
  try:
158
  result = pd.read_sql_query(sql_query, conn)
159
 
@@ -161,12 +160,12 @@ def process_input():
161
  assistant_response = "The query returned no results. Please try a different question."
162
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
163
  else:
164
- # Limit the result to first 10 rows for display
165
  result_display = result.head(10)
166
  st.session_state.history.append({"role": "assistant", "content": "Here are the results:"})
167
  st.session_state.history.append({"role": "assistant", "content": result_display})
168
 
169
- # Generate insights based on the query result
170
  insights_template = """
171
  You are an expert data analyst. Based on the user's question and the SQL query result provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words.
172
 
@@ -183,7 +182,7 @@ def process_input():
183
  result_str = result_display.to_string(index=False)
184
  insights = insights_chain.run({'question': user_prompt, 'result': result_str})
185
 
186
- # Append the assistant's insights to the history
187
  st.session_state.history.append({"role": "assistant", "content": insights})
188
  except Exception as e:
189
  logging.error(f"An error occurred during SQL execution: {e}")
@@ -194,10 +193,10 @@ def process_input():
194
  assistant_response = f"Error: {e}"
195
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
196
 
197
- # Reset the user_input in session state
198
  st.session_state['user_input'] = ''
199
 
200
- # Step 7: Display the conversation history
201
  for message in st.session_state.history:
202
  if message['role'] == 'user':
203
  st.markdown(f"**User:** {message['content']}")
@@ -208,5 +207,5 @@ for message in st.session_state.history:
208
  else:
209
  st.markdown(f"**Assistant:** {message['content']}")
210
 
211
- # Place the input field at the bottom with the callback
212
  st.text_input("Enter your message:", key='user_input', on_change=process_input)
 
3
  import pandas as pd
4
  import sqlite3
5
  import logging
6
+ from langchain.agents import create_sql_agent
7
  from langchain.agents.agent_toolkits import SQLDatabaseToolkit
8
  from langchain.llms import OpenAI
9
  from langchain.sql_database import SQLDatabase
10
+ from langchain.prompts import PromptTemplate
 
 
11
  from langchain.evaluation import load_evaluator
12
 
13
  # Initialize logging
 
17
  if 'history' not in st.session_state:
18
  st.session_state.history = []
19
 
20
+ # OpenAI API key
21
  openai_api_key = os.getenv("OPENAI_API_KEY")
22
 
23
  # Check if the API key is set
 
31
 
32
  csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
33
  if csv_file is None:
34
+ data = pd.read_csv("default_data.csv") # Ensure this file exists
35
  st.write("Using default_data.csv file.")
36
  table_name = "default_table"
37
  else:
 
40
  st.write(f"Data Preview ({csv_file.name}):")
41
  st.dataframe(data.head())
42
 
43
+ # Step 2: Load CSV data into SQLite database
44
  db_file = 'my_database.db'
45
  conn = sqlite3.connect(db_file)
46
  data.to_sql(table_name, conn, index=False, if_exists='replace')
47
 
48
+ # SQL table metadata
49
  valid_columns = list(data.columns)
50
  st.write(f"Valid columns: {valid_columns}")
51
 
52
+ # Create SQLDatabase instance
53
  engine = SQLDatabase.from_uri(f"sqlite:///{db_file}", include_tables=[table_name])
54
 
55
+ # Step 3: Define few-shot examples
56
  few_shot_examples = [
57
  {
58
  "input": "What is the total revenue for each category?",
 
76
  # Prepare table information
77
  table_info = f"Table: {table_name}\nColumns: {', '.join(valid_columns)}"
78
 
79
+ # Initialize the LLM
80
+ llm = OpenAI(temperature=0, openai_api_key=openai_api_key)
81
+
82
  # Step 4: Define the prompt template
83
  system_message = """
84
  You are an expert data analyst who can convert natural language questions into SQL queries.
85
 
86
+ You have access to the following tools:
87
+ {tools}
88
 
89
  Follow these guidelines:
90
  1. Only use the columns and tables provided.
 
105
  {agent_scratchpad}
106
  """
107
 
 
 
 
108
  # Step 5: Create the agent
109
  toolkit = SQLDatabaseToolkit(db=engine, llm=llm)
110
  tools = toolkit.get_tools()
111
+ tool_names = ", ".join([tool.name for tool in tools])
112
  tool_descriptions = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
113
 
114
  # Create the prompt
115
  agent_prompt = PromptTemplate(
116
  template=system_message,
117
+ input_variables=["input", "agent_scratchpad", "table_info", "few_shot_examples", "tools", "tool_names"]
118
  )
119
 
120
  # Create the agent
 
144
  table_info=table_info,
145
  few_shot_examples=few_shot_str,
146
  agent_scratchpad="",
147
+ tools=tool_descriptions,
148
+ tool_names=tool_names
149
  )
150
 
151
  # Extract the SQL query from the agent's response
152
  sql_query = response.strip()
153
  logging.info(f"Generated SQL Query: {sql_query}")
154
 
155
+ # Execute SQL query
156
  try:
157
  result = pd.read_sql_query(sql_query, conn)
158
 
 
160
  assistant_response = "The query returned no results. Please try a different question."
161
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
162
  else:
163
+ # Display results
164
  result_display = result.head(10)
165
  st.session_state.history.append({"role": "assistant", "content": "Here are the results:"})
166
  st.session_state.history.append({"role": "assistant", "content": result_display})
167
 
168
+ # Generate insights
169
  insights_template = """
170
  You are an expert data analyst. Based on the user's question and the SQL query result provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words.
171
 
 
182
  result_str = result_display.to_string(index=False)
183
  insights = insights_chain.run({'question': user_prompt, 'result': result_str})
184
 
185
+ # Append insights to history
186
  st.session_state.history.append({"role": "assistant", "content": insights})
187
  except Exception as e:
188
  logging.error(f"An error occurred during SQL execution: {e}")
 
193
  assistant_response = f"Error: {e}"
194
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
195
 
196
+ # Reset user input
197
  st.session_state['user_input'] = ''
198
 
199
+ # Step 7: Display conversation history
200
  for message in st.session_state.history:
201
  if message['role'] == 'user':
202
  st.markdown(f"**User:** {message['content']}")
 
207
  else:
208
  st.markdown(f"**Assistant:** {message['content']}")
209
 
210
+ # Input field
211
  st.text_input("Enter your message:", key='user_input', on_change=process_input)