arithescientist commited on
Commit
fc3c978
·
verified ·
1 Parent(s): d6689d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -20
app.py CHANGED
@@ -3,7 +3,7 @@ import streamlit as st
3
  import pandas as pd
4
  import sqlite3
5
  import logging
6
- from langchain.agents import create_sql_agent
7
  from langchain.agents.agent_toolkits import SQLDatabaseToolkit
8
  from langchain.llms import OpenAI
9
  from langchain.sql_database import SQLDatabase
@@ -65,20 +65,20 @@ engine = SQLDatabase.from_uri(f"sqlite:///{db_file}", include_tables=[table_name
65
  few_shot_examples = [
66
  {
67
  "input": "What is the total revenue for each category?",
68
- "query": f"SELECT category, SUM(revenue) FROM {table_name} GROUP BY category;"
69
  },
70
  {
71
  "input": "Show the top 5 products by sales.",
72
- "query": f"SELECT product_name, sales FROM {table_name} ORDER BY sales DESC LIMIT 5;"
73
  },
74
  {
75
  "input": "How many orders were placed in the last month?",
76
- "query": f"SELECT COUNT(*) FROM {table_name} WHERE order_date >= DATE('now', '-1 month');"
77
  }
78
  ]
79
 
80
  # Step 4: Define the prompt templates
81
- system_prefix = """
82
  You are an expert data analyst who can convert natural language questions into SQL queries.
83
  Follow these guidelines:
84
  1. Only use the columns and tables provided.
@@ -86,37 +86,53 @@ Follow these guidelines:
86
  3. Ensure string comparisons are case-insensitive.
87
  4. Do not execute queries that could be harmful or unethical.
88
  5. Provide clear and concise SQL queries.
 
 
 
 
 
 
89
  """
90
 
91
- few_shot_prompt = FewShotPromptTemplate(
92
- example_prompt=PromptTemplate.from_template("Question: {input}\nSQL Query: {query}"),
93
- examples=few_shot_examples,
94
- prefix=system_prefix,
95
- suffix="Question: {input}\nSQL Query:",
96
- input_variables=["input", "agent_scratchpad"]
97
- )
98
 
99
- # Step 5: Initialize the LLM and toolkit
100
  llm = ChatOpenAI(temperature=0, openai_api_key=openai_api_key)
 
 
 
 
101
  toolkit = SQLDatabaseToolkit(db=engine, llm=llm)
 
 
102
 
103
- # Step 6: Create the agent using 'zero-shot-react-description' agent type
104
  agent_prompt = ChatPromptTemplate.from_messages([
105
- SystemMessagePromptTemplate(prompt=few_shot_prompt),
106
  HumanMessagePromptTemplate.from_template("{input}"),
107
  MessagesPlaceholder(variable_name="agent_scratchpad")
108
  ])
109
 
 
 
 
 
110
  sql_agent = create_sql_agent(
111
  llm=llm,
112
  toolkit=toolkit,
113
  prompt=agent_prompt,
114
  verbose=True,
115
- agent_type="zero-shot-react-description",
116
  max_iterations=5
117
  )
118
 
119
- # Step 7: Define the callback function
120
  def process_input():
121
  user_prompt = st.session_state['user_input']
122
 
@@ -127,8 +143,14 @@ def process_input():
127
 
128
  # Use the agent to generate the SQL query
129
  with st.spinner("Generating SQL query..."):
130
- response = sql_agent.run(user_prompt)
131
-
 
 
 
 
 
 
132
  # Extract the SQL query from the agent's response
133
  sql_query = response.strip()
134
  logging.info(f"Generated SQL Query: {sql_query}")
@@ -177,7 +199,7 @@ def process_input():
177
  # Reset the user_input in session state
178
  st.session_state['user_input'] = ''
179
 
180
- # Step 8: Display the conversation history
181
  for message in st.session_state.history:
182
  if message['role'] == 'user':
183
  st.markdown(f"**User:** {message['content']}")
 
3
  import pandas as pd
4
  import sqlite3
5
  import logging
6
+ from langchain.agents import create_sql_agent, AgentType
7
  from langchain.agents.agent_toolkits import SQLDatabaseToolkit
8
  from langchain.llms import OpenAI
9
  from langchain.sql_database import SQLDatabase
 
65
  few_shot_examples = [
66
  {
67
  "input": "What is the total revenue for each category?",
68
+ "output": f"SELECT category, SUM(revenue) FROM {table_name} GROUP BY category;"
69
  },
70
  {
71
  "input": "Show the top 5 products by sales.",
72
+ "output": f"SELECT product_name, sales FROM {table_name} ORDER BY sales DESC LIMIT 5;"
73
  },
74
  {
75
  "input": "How many orders were placed in the last month?",
76
+ "output": f"SELECT COUNT(*) FROM {table_name} WHERE order_date >= DATE('now', '-1 month');"
77
  }
78
  ]
79
 
80
  # Step 4: Define the prompt templates
81
+ system_message = """
82
  You are an expert data analyst who can convert natural language questions into SQL queries.
83
  Follow these guidelines:
84
  1. Only use the columns and tables provided.
 
86
  3. Ensure string comparisons are case-insensitive.
87
  4. Do not execute queries that could be harmful or unethical.
88
  5. Provide clear and concise SQL queries.
89
+
90
+ Available tables and columns:
91
+ {table_info}
92
+
93
+ Use the following examples as a guide:
94
+ {few_shot_examples}
95
  """
96
 
97
+ # Prepare few-shot examples as a string
98
+ few_shot_str = ""
99
+ for ex in few_shot_examples:
100
+ few_shot_str += f"Q: {ex['input']}\nA: {ex['output']}\n\n"
101
+
102
+ # Prepare table information
103
+ table_info = f"Table: {table_name}\nColumns: {', '.join(valid_columns)}"
104
 
105
+ # Initialize the LLM
106
  llm = ChatOpenAI(temperature=0, openai_api_key=openai_api_key)
107
+
108
+ # Step 5: Create the agent
109
+
110
+ # Get the list of tools from the toolkit
111
  toolkit = SQLDatabaseToolkit(db=engine, llm=llm)
112
+ tools = toolkit.get_tools()
113
+ tool_names = [tool.name for tool in tools]
114
 
115
+ # Create the agent prompt
116
  agent_prompt = ChatPromptTemplate.from_messages([
117
+ SystemMessagePromptTemplate.from_template(system_message),
118
  HumanMessagePromptTemplate.from_template("{input}"),
119
  MessagesPlaceholder(variable_name="agent_scratchpad")
120
  ])
121
 
122
+ # Set input variables for the prompt
123
+ agent_prompt.input_variables = ["input", "agent_scratchpad", "table_info", "few_shot_examples"]
124
+
125
+ # Create the agent
126
  sql_agent = create_sql_agent(
127
  llm=llm,
128
  toolkit=toolkit,
129
  prompt=agent_prompt,
130
  verbose=True,
131
+ agent_type=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION,
132
  max_iterations=5
133
  )
134
 
135
+ # Step 6: Define the callback function
136
  def process_input():
137
  user_prompt = st.session_state['user_input']
138
 
 
143
 
144
  # Use the agent to generate the SQL query
145
  with st.spinner("Generating SQL query..."):
146
+ # Run the agent with the necessary inputs
147
+ response = sql_agent.run(
148
+ input=user_prompt,
149
+ table_info=table_info,
150
+ few_shot_examples=few_shot_str,
151
+ agent_scratchpad=""
152
+ )
153
+
154
  # Extract the SQL query from the agent's response
155
  sql_query = response.strip()
156
  logging.info(f"Generated SQL Query: {sql_query}")
 
199
  # Reset the user_input in session state
200
  st.session_state['user_input'] = ''
201
 
202
+ # Step 7: Display the conversation history
203
  for message in st.session_state.history:
204
  if message['role'] == 'user':
205
  st.markdown(f"**User:** {message['content']}")