Spaces:
Running
Running
arithescientist
commited on
Commit
•
887daae
1
Parent(s):
6e8db38
Update app.py
Browse files
app.py
CHANGED
@@ -3,13 +3,11 @@ import streamlit as st
|
|
3 |
import pandas as pd
|
4 |
import sqlite3
|
5 |
import logging
|
6 |
-
from langchain.agents import create_sql_agent
|
7 |
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
|
8 |
from langchain.llms import OpenAI
|
9 |
from langchain.sql_database import SQLDatabase
|
10 |
-
from langchain.prompts import
|
11 |
-
PromptTemplate,
|
12 |
-
)
|
13 |
from langchain.evaluation import load_evaluator
|
14 |
|
15 |
# Initialize logging
|
@@ -19,7 +17,7 @@ logging.basicConfig(level=logging.INFO)
|
|
19 |
if 'history' not in st.session_state:
|
20 |
st.session_state.history = []
|
21 |
|
22 |
-
# OpenAI API key
|
23 |
openai_api_key = os.getenv("OPENAI_API_KEY")
|
24 |
|
25 |
# Check if the API key is set
|
@@ -33,7 +31,7 @@ st.write("Upload a CSV file to get started, or use the default dataset.")
|
|
33 |
|
34 |
csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
|
35 |
if csv_file is None:
|
36 |
-
data = pd.read_csv("default_data.csv") # Ensure this file exists
|
37 |
st.write("Using default_data.csv file.")
|
38 |
table_name = "default_table"
|
39 |
else:
|
@@ -42,19 +40,19 @@ else:
|
|
42 |
st.write(f"Data Preview ({csv_file.name}):")
|
43 |
st.dataframe(data.head())
|
44 |
|
45 |
-
# Step 2: Load CSV data into
|
46 |
db_file = 'my_database.db'
|
47 |
conn = sqlite3.connect(db_file)
|
48 |
data.to_sql(table_name, conn, index=False, if_exists='replace')
|
49 |
|
50 |
-
# SQL table metadata
|
51 |
valid_columns = list(data.columns)
|
52 |
st.write(f"Valid columns: {valid_columns}")
|
53 |
|
54 |
-
# Create SQLDatabase instance
|
55 |
engine = SQLDatabase.from_uri(f"sqlite:///{db_file}", include_tables=[table_name])
|
56 |
|
57 |
-
# Step 3: Define
|
58 |
few_shot_examples = [
|
59 |
{
|
60 |
"input": "What is the total revenue for each category?",
|
@@ -78,12 +76,15 @@ for ex in few_shot_examples:
|
|
78 |
# Prepare table information
|
79 |
table_info = f"Table: {table_name}\nColumns: {', '.join(valid_columns)}"
|
80 |
|
|
|
|
|
|
|
81 |
# Step 4: Define the prompt template
|
82 |
system_message = """
|
83 |
You are an expert data analyst who can convert natural language questions into SQL queries.
|
84 |
|
85 |
-
|
86 |
-
{
|
87 |
|
88 |
Follow these guidelines:
|
89 |
1. Only use the columns and tables provided.
|
@@ -104,19 +105,16 @@ Question: {input}
|
|
104 |
{agent_scratchpad}
|
105 |
"""
|
106 |
|
107 |
-
# Initialize the LLM
|
108 |
-
llm = OpenAI(temperature=0, openai_api_key=openai_api_key)
|
109 |
-
|
110 |
# Step 5: Create the agent
|
111 |
toolkit = SQLDatabaseToolkit(db=engine, llm=llm)
|
112 |
tools = toolkit.get_tools()
|
113 |
-
tool_names = [tool.name for tool in tools]
|
114 |
tool_descriptions = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
|
115 |
|
116 |
# Create the prompt
|
117 |
agent_prompt = PromptTemplate(
|
118 |
template=system_message,
|
119 |
-
input_variables=["input", "agent_scratchpad", "table_info", "few_shot_examples", "
|
120 |
)
|
121 |
|
122 |
# Create the agent
|
@@ -146,14 +144,15 @@ def process_input():
|
|
146 |
table_info=table_info,
|
147 |
few_shot_examples=few_shot_str,
|
148 |
agent_scratchpad="",
|
149 |
-
|
|
|
150 |
)
|
151 |
|
152 |
# Extract the SQL query from the agent's response
|
153 |
sql_query = response.strip()
|
154 |
logging.info(f"Generated SQL Query: {sql_query}")
|
155 |
|
156 |
-
#
|
157 |
try:
|
158 |
result = pd.read_sql_query(sql_query, conn)
|
159 |
|
@@ -161,12 +160,12 @@ def process_input():
|
|
161 |
assistant_response = "The query returned no results. Please try a different question."
|
162 |
st.session_state.history.append({"role": "assistant", "content": assistant_response})
|
163 |
else:
|
164 |
-
#
|
165 |
result_display = result.head(10)
|
166 |
st.session_state.history.append({"role": "assistant", "content": "Here are the results:"})
|
167 |
st.session_state.history.append({"role": "assistant", "content": result_display})
|
168 |
|
169 |
-
# Generate insights
|
170 |
insights_template = """
|
171 |
You are an expert data analyst. Based on the user's question and the SQL query result provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words.
|
172 |
|
@@ -183,7 +182,7 @@ def process_input():
|
|
183 |
result_str = result_display.to_string(index=False)
|
184 |
insights = insights_chain.run({'question': user_prompt, 'result': result_str})
|
185 |
|
186 |
-
# Append
|
187 |
st.session_state.history.append({"role": "assistant", "content": insights})
|
188 |
except Exception as e:
|
189 |
logging.error(f"An error occurred during SQL execution: {e}")
|
@@ -194,10 +193,10 @@ def process_input():
|
|
194 |
assistant_response = f"Error: {e}"
|
195 |
st.session_state.history.append({"role": "assistant", "content": assistant_response})
|
196 |
|
197 |
-
# Reset
|
198 |
st.session_state['user_input'] = ''
|
199 |
|
200 |
-
# Step 7: Display
|
201 |
for message in st.session_state.history:
|
202 |
if message['role'] == 'user':
|
203 |
st.markdown(f"**User:** {message['content']}")
|
@@ -208,5 +207,5 @@ for message in st.session_state.history:
|
|
208 |
else:
|
209 |
st.markdown(f"**Assistant:** {message['content']}")
|
210 |
|
211 |
-
#
|
212 |
st.text_input("Enter your message:", key='user_input', on_change=process_input)
|
|
|
3 |
import pandas as pd
|
4 |
import sqlite3
|
5 |
import logging
|
6 |
+
from langchain.agents import create_sql_agent
|
7 |
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
|
8 |
from langchain.llms import OpenAI
|
9 |
from langchain.sql_database import SQLDatabase
|
10 |
+
from langchain.prompts import PromptTemplate
|
|
|
|
|
11 |
from langchain.evaluation import load_evaluator
|
12 |
|
13 |
# Initialize logging
|
|
|
17 |
if 'history' not in st.session_state:
|
18 |
st.session_state.history = []
|
19 |
|
20 |
+
# OpenAI API key
|
21 |
openai_api_key = os.getenv("OPENAI_API_KEY")
|
22 |
|
23 |
# Check if the API key is set
|
|
|
31 |
|
32 |
csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
|
33 |
if csv_file is None:
|
34 |
+
data = pd.read_csv("default_data.csv") # Ensure this file exists
|
35 |
st.write("Using default_data.csv file.")
|
36 |
table_name = "default_table"
|
37 |
else:
|
|
|
40 |
st.write(f"Data Preview ({csv_file.name}):")
|
41 |
st.dataframe(data.head())
|
42 |
|
43 |
+
# Step 2: Load CSV data into SQLite database
|
44 |
db_file = 'my_database.db'
|
45 |
conn = sqlite3.connect(db_file)
|
46 |
data.to_sql(table_name, conn, index=False, if_exists='replace')
|
47 |
|
48 |
+
# SQL table metadata
|
49 |
valid_columns = list(data.columns)
|
50 |
st.write(f"Valid columns: {valid_columns}")
|
51 |
|
52 |
+
# Create SQLDatabase instance
|
53 |
engine = SQLDatabase.from_uri(f"sqlite:///{db_file}", include_tables=[table_name])
|
54 |
|
55 |
+
# Step 3: Define few-shot examples
|
56 |
few_shot_examples = [
|
57 |
{
|
58 |
"input": "What is the total revenue for each category?",
|
|
|
76 |
# Prepare table information
|
77 |
table_info = f"Table: {table_name}\nColumns: {', '.join(valid_columns)}"
|
78 |
|
79 |
+
# Initialize the LLM
|
80 |
+
llm = OpenAI(temperature=0, openai_api_key=openai_api_key)
|
81 |
+
|
82 |
# Step 4: Define the prompt template
|
83 |
system_message = """
|
84 |
You are an expert data analyst who can convert natural language questions into SQL queries.
|
85 |
|
86 |
+
You have access to the following tools:
|
87 |
+
{tools}
|
88 |
|
89 |
Follow these guidelines:
|
90 |
1. Only use the columns and tables provided.
|
|
|
105 |
{agent_scratchpad}
|
106 |
"""
|
107 |
|
|
|
|
|
|
|
108 |
# Step 5: Create the agent
|
109 |
toolkit = SQLDatabaseToolkit(db=engine, llm=llm)
|
110 |
tools = toolkit.get_tools()
|
111 |
+
tool_names = ", ".join([tool.name for tool in tools])
|
112 |
tool_descriptions = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
|
113 |
|
114 |
# Create the prompt
|
115 |
agent_prompt = PromptTemplate(
|
116 |
template=system_message,
|
117 |
+
input_variables=["input", "agent_scratchpad", "table_info", "few_shot_examples", "tools", "tool_names"]
|
118 |
)
|
119 |
|
120 |
# Create the agent
|
|
|
144 |
table_info=table_info,
|
145 |
few_shot_examples=few_shot_str,
|
146 |
agent_scratchpad="",
|
147 |
+
tools=tool_descriptions,
|
148 |
+
tool_names=tool_names
|
149 |
)
|
150 |
|
151 |
# Extract the SQL query from the agent's response
|
152 |
sql_query = response.strip()
|
153 |
logging.info(f"Generated SQL Query: {sql_query}")
|
154 |
|
155 |
+
# Execute SQL query
|
156 |
try:
|
157 |
result = pd.read_sql_query(sql_query, conn)
|
158 |
|
|
|
160 |
assistant_response = "The query returned no results. Please try a different question."
|
161 |
st.session_state.history.append({"role": "assistant", "content": assistant_response})
|
162 |
else:
|
163 |
+
# Display results
|
164 |
result_display = result.head(10)
|
165 |
st.session_state.history.append({"role": "assistant", "content": "Here are the results:"})
|
166 |
st.session_state.history.append({"role": "assistant", "content": result_display})
|
167 |
|
168 |
+
# Generate insights
|
169 |
insights_template = """
|
170 |
You are an expert data analyst. Based on the user's question and the SQL query result provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words.
|
171 |
|
|
|
182 |
result_str = result_display.to_string(index=False)
|
183 |
insights = insights_chain.run({'question': user_prompt, 'result': result_str})
|
184 |
|
185 |
+
# Append insights to history
|
186 |
st.session_state.history.append({"role": "assistant", "content": insights})
|
187 |
except Exception as e:
|
188 |
logging.error(f"An error occurred during SQL execution: {e}")
|
|
|
193 |
assistant_response = f"Error: {e}"
|
194 |
st.session_state.history.append({"role": "assistant", "content": assistant_response})
|
195 |
|
196 |
+
# Reset user input
|
197 |
st.session_state['user_input'] = ''
|
198 |
|
199 |
+
# Step 7: Display conversation history
|
200 |
for message in st.session_state.history:
|
201 |
if message['role'] == 'user':
|
202 |
st.markdown(f"**User:** {message['content']}")
|
|
|
207 |
else:
|
208 |
st.markdown(f"**Assistant:** {message['content']}")
|
209 |
|
210 |
+
# Input field
|
211 |
st.text_input("Enter your message:", key='user_input', on_change=process_input)
|