Spaces:
Running
Running
arithescientist
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -3,7 +3,7 @@ import streamlit as st
|
|
3 |
import pandas as pd
|
4 |
import sqlite3
|
5 |
import logging
|
6 |
-
from langchain.agents import create_sql_agent
|
7 |
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
|
8 |
from langchain.llms import OpenAI
|
9 |
from langchain.sql_database import SQLDatabase
|
@@ -65,20 +65,20 @@ engine = SQLDatabase.from_uri(f"sqlite:///{db_file}", include_tables=[table_name
|
|
65 |
few_shot_examples = [
|
66 |
{
|
67 |
"input": "What is the total revenue for each category?",
|
68 |
-
"
|
69 |
},
|
70 |
{
|
71 |
"input": "Show the top 5 products by sales.",
|
72 |
-
"
|
73 |
},
|
74 |
{
|
75 |
"input": "How many orders were placed in the last month?",
|
76 |
-
"
|
77 |
}
|
78 |
]
|
79 |
|
80 |
# Step 4: Define the prompt templates
|
81 |
-
|
82 |
You are an expert data analyst who can convert natural language questions into SQL queries.
|
83 |
Follow these guidelines:
|
84 |
1. Only use the columns and tables provided.
|
@@ -86,37 +86,53 @@ Follow these guidelines:
|
|
86 |
3. Ensure string comparisons are case-insensitive.
|
87 |
4. Do not execute queries that could be harmful or unethical.
|
88 |
5. Provide clear and concise SQL queries.
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
"""
|
90 |
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
)
|
98 |
|
99 |
-
#
|
100 |
llm = ChatOpenAI(temperature=0, openai_api_key=openai_api_key)
|
|
|
|
|
|
|
|
|
101 |
toolkit = SQLDatabaseToolkit(db=engine, llm=llm)
|
|
|
|
|
102 |
|
103 |
-
#
|
104 |
agent_prompt = ChatPromptTemplate.from_messages([
|
105 |
-
SystemMessagePromptTemplate(
|
106 |
HumanMessagePromptTemplate.from_template("{input}"),
|
107 |
MessagesPlaceholder(variable_name="agent_scratchpad")
|
108 |
])
|
109 |
|
|
|
|
|
|
|
|
|
110 |
sql_agent = create_sql_agent(
|
111 |
llm=llm,
|
112 |
toolkit=toolkit,
|
113 |
prompt=agent_prompt,
|
114 |
verbose=True,
|
115 |
-
agent_type=
|
116 |
max_iterations=5
|
117 |
)
|
118 |
|
119 |
-
# Step
|
120 |
def process_input():
|
121 |
user_prompt = st.session_state['user_input']
|
122 |
|
@@ -127,8 +143,14 @@ def process_input():
|
|
127 |
|
128 |
# Use the agent to generate the SQL query
|
129 |
with st.spinner("Generating SQL query..."):
|
130 |
-
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
# Extract the SQL query from the agent's response
|
133 |
sql_query = response.strip()
|
134 |
logging.info(f"Generated SQL Query: {sql_query}")
|
@@ -177,7 +199,7 @@ def process_input():
|
|
177 |
# Reset the user_input in session state
|
178 |
st.session_state['user_input'] = ''
|
179 |
|
180 |
-
# Step
|
181 |
for message in st.session_state.history:
|
182 |
if message['role'] == 'user':
|
183 |
st.markdown(f"**User:** {message['content']}")
|
|
|
3 |
import pandas as pd
|
4 |
import sqlite3
|
5 |
import logging
|
6 |
+
from langchain.agents import create_sql_agent, AgentType
|
7 |
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
|
8 |
from langchain.llms import OpenAI
|
9 |
from langchain.sql_database import SQLDatabase
|
|
|
65 |
few_shot_examples = [
|
66 |
{
|
67 |
"input": "What is the total revenue for each category?",
|
68 |
+
"output": f"SELECT category, SUM(revenue) FROM {table_name} GROUP BY category;"
|
69 |
},
|
70 |
{
|
71 |
"input": "Show the top 5 products by sales.",
|
72 |
+
"output": f"SELECT product_name, sales FROM {table_name} ORDER BY sales DESC LIMIT 5;"
|
73 |
},
|
74 |
{
|
75 |
"input": "How many orders were placed in the last month?",
|
76 |
+
"output": f"SELECT COUNT(*) FROM {table_name} WHERE order_date >= DATE('now', '-1 month');"
|
77 |
}
|
78 |
]
|
79 |
|
80 |
# Step 4: Define the prompt templates
|
81 |
+
system_message = """
|
82 |
You are an expert data analyst who can convert natural language questions into SQL queries.
|
83 |
Follow these guidelines:
|
84 |
1. Only use the columns and tables provided.
|
|
|
86 |
3. Ensure string comparisons are case-insensitive.
|
87 |
4. Do not execute queries that could be harmful or unethical.
|
88 |
5. Provide clear and concise SQL queries.
|
89 |
+
|
90 |
+
Available tables and columns:
|
91 |
+
{table_info}
|
92 |
+
|
93 |
+
Use the following examples as a guide:
|
94 |
+
{few_shot_examples}
|
95 |
"""
|
96 |
|
97 |
+
# Prepare few-shot examples as a string
|
98 |
+
few_shot_str = ""
|
99 |
+
for ex in few_shot_examples:
|
100 |
+
few_shot_str += f"Q: {ex['input']}\nA: {ex['output']}\n\n"
|
101 |
+
|
102 |
+
# Prepare table information
|
103 |
+
table_info = f"Table: {table_name}\nColumns: {', '.join(valid_columns)}"
|
104 |
|
105 |
+
# Initialize the LLM
|
106 |
llm = ChatOpenAI(temperature=0, openai_api_key=openai_api_key)
|
107 |
+
|
108 |
+
# Step 5: Create the agent
|
109 |
+
|
110 |
+
# Get the list of tools from the toolkit
|
111 |
toolkit = SQLDatabaseToolkit(db=engine, llm=llm)
|
112 |
+
tools = toolkit.get_tools()
|
113 |
+
tool_names = [tool.name for tool in tools]
|
114 |
|
115 |
+
# Create the agent prompt
|
116 |
agent_prompt = ChatPromptTemplate.from_messages([
|
117 |
+
SystemMessagePromptTemplate.from_template(system_message),
|
118 |
HumanMessagePromptTemplate.from_template("{input}"),
|
119 |
MessagesPlaceholder(variable_name="agent_scratchpad")
|
120 |
])
|
121 |
|
122 |
+
# Set input variables for the prompt
|
123 |
+
agent_prompt.input_variables = ["input", "agent_scratchpad", "table_info", "few_shot_examples"]
|
124 |
+
|
125 |
+
# Create the agent
|
126 |
sql_agent = create_sql_agent(
|
127 |
llm=llm,
|
128 |
toolkit=toolkit,
|
129 |
prompt=agent_prompt,
|
130 |
verbose=True,
|
131 |
+
agent_type=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION,
|
132 |
max_iterations=5
|
133 |
)
|
134 |
|
135 |
+
# Step 6: Define the callback function
|
136 |
def process_input():
|
137 |
user_prompt = st.session_state['user_input']
|
138 |
|
|
|
143 |
|
144 |
# Use the agent to generate the SQL query
|
145 |
with st.spinner("Generating SQL query..."):
|
146 |
+
# Run the agent with the necessary inputs
|
147 |
+
response = sql_agent.run(
|
148 |
+
input=user_prompt,
|
149 |
+
table_info=table_info,
|
150 |
+
few_shot_examples=few_shot_str,
|
151 |
+
agent_scratchpad=""
|
152 |
+
)
|
153 |
+
|
154 |
# Extract the SQL query from the agent's response
|
155 |
sql_query = response.strip()
|
156 |
logging.info(f"Generated SQL Query: {sql_query}")
|
|
|
199 |
# Reset the user_input in session state
|
200 |
st.session_state['user_input'] = ''
|
201 |
|
202 |
+
# Step 7: Display the conversation history
|
203 |
for message in st.session_state.history:
|
204 |
if message['role'] == 'user':
|
205 |
st.markdown(f"**User:** {message['content']}")
|