arithescientist commited on
Commit
2129665
1 Parent(s): 68db37e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -9
app.py CHANGED
@@ -1,14 +1,14 @@
1
- import os
2
  import streamlit as st
3
  import pandas as pd
4
  import sqlite3
5
  import logging
6
- from langchain.llms import OpenAI
7
  from langchain.chat_models import ChatOpenAI
8
- from langchain.chains import SQLChain as SQLDatabaseChain # Updated import
 
 
9
  from langchain.prompts import PromptTemplate
10
  from langchain.chains import LLMChain
11
- from langchain.sql_database import SQLDatabase
12
 
13
  # Initialize logging
14
  logging.basicConfig(level=logging.INFO)
@@ -44,15 +44,21 @@ else:
44
  db_file = 'my_database.db'
45
  conn = sqlite3.connect(db_file)
46
  data.to_sql(table_name, conn, index=False, if_exists='replace')
 
47
 
48
  # Create SQLDatabase instance
49
- engine = SQLDatabase.from_uri(f"sqlite:///{db_file}", include_tables=[table_name])
50
 
51
  # Initialize the LLM
52
  llm = ChatOpenAI(temperature=0, openai_api_key=openai_api_key)
53
 
54
- # Initialize the SQLDatabaseChain
55
- sql_chain = SQLDatabaseChain(llm=llm, database=engine, verbose=True)
 
 
 
 
 
56
 
57
  # Step 4: Define the callback function
58
  def process_input():
@@ -63,9 +69,9 @@ def process_input():
63
  # Append user message to history
64
  st.session_state.history.append({"role": "user", "content": user_prompt})
65
 
66
- # Use the SQLDatabaseChain to get the response
67
  with st.spinner("Processing..."):
68
- response = sql_chain.run(user_prompt)
69
 
70
  # Append the assistant's response to the history
71
  st.session_state.history.append({"role": "assistant", "content": response})
 
1
+ import os
2
  import streamlit as st
3
  import pandas as pd
4
  import sqlite3
5
  import logging
 
6
  from langchain.chat_models import ChatOpenAI
7
+ from langchain.agents import create_sql_agent
8
+ from langchain.agents.agent_toolkits import SQLDatabaseToolkit
9
+ from langchain.sql_database import SQLDatabase
10
  from langchain.prompts import PromptTemplate
11
  from langchain.chains import LLMChain
 
12
 
13
  # Initialize logging
14
  logging.basicConfig(level=logging.INFO)
 
44
  db_file = 'my_database.db'
45
  conn = sqlite3.connect(db_file)
46
  data.to_sql(table_name, conn, index=False, if_exists='replace')
47
+ conn.close()
48
 
49
  # Create SQLDatabase instance
50
+ db = SQLDatabase.from_uri(f"sqlite:///{db_file}", include_tables=[table_name])
51
 
52
  # Initialize the LLM
53
  llm = ChatOpenAI(temperature=0, openai_api_key=openai_api_key)
54
 
55
+ # Initialize the SQL Agent
56
+ toolkit = SQLDatabaseToolkit(db=db, llm=llm)
57
+ agent_executor = create_sql_agent(
58
+ llm=llm,
59
+ toolkit=toolkit,
60
+ verbose=True
61
+ )
62
 
63
  # Step 4: Define the callback function
64
  def process_input():
 
69
  # Append user message to history
70
  st.session_state.history.append({"role": "user", "content": user_prompt})
71
 
72
+ # Use the agent to get the response
73
  with st.spinner("Processing..."):
74
+ response = agent_executor.run(user_prompt)
75
 
76
  # Append the assistant's response to the history
77
  st.session_state.history.append({"role": "assistant", "content": response})