GenBIChatbot / app.py
arithescientist's picture
Update app.py
624a8ad verified
raw
history blame
5.31 kB
import os
import streamlit as st
import pandas as pd
import sqlite3
import logging
import ast # For parsing string representations of lists
from langchain_community.chat_models import ChatOpenAI
from langchain_community.agent_toolkits.sql.base import create_sql_agent
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
# Initialize logging
logging.basicConfig(level=logging.INFO)
# Initialize conversation history
if 'history' not in st.session_state:
st.session_state.history = []
# OpenAI API key
openai_api_key = os.getenv("OPENAI_API_KEY")
# Check if the API key is set
if not openai_api_key:
st.error("OpenAI API key is not set. Please set the OPENAI_API_KEY environment variable.")
st.stop()
# Step 1: Upload CSV data file (or use default)
st.title("Enhanced Natural Language to SQL Query App")
st.write("Upload a CSV file to get started, or use the default dataset.")
csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
if csv_file is None:
data = pd.read_csv("default_data.csv") # Ensure this file exists
st.write("Using default_data.csv file.")
table_name = "default_table"
else:
data = pd.read_csv(csv_file)
table_name = csv_file.name.split('.')[0]
st.write(f"Data Preview ({csv_file.name}):")
st.dataframe(data.head())
# Step 2: Load CSV data into SQLite database
db_file = 'my_database.db'
conn = sqlite3.connect(db_file)
data.to_sql(table_name, conn, index=False, if_exists='replace')
conn.close()
# Create SQLDatabase instance
db = SQLDatabase.from_uri(f"sqlite:///{db_file}", include_tables=[table_name])
# Initialize the LLM
llm = ChatOpenAI(temperature=0, openai_api_key=openai_api_key)
# Initialize the SQL Agent
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
agent_executor = create_sql_agent(
llm=llm,
toolkit=toolkit,
verbose=True,
agent_executor_kwargs={"return_intermediate_steps": True}
)
# Step 4: Define the callback function
def process_input():
user_prompt = st.session_state['user_input']
if user_prompt:
try:
# Append user message to history
st.session_state.history.append({"role": "user", "content": user_prompt})
# Use the agent to get the response WITHOUT 'return_intermediate_steps'
with st.spinner("Processing..."):
response = agent_executor(user_prompt) # Removed 'return_intermediate_steps' here
# Extract the final answer and the data from intermediate steps
final_answer = response['output']
intermediate_steps = response['intermediate_steps']
# Initialize an empty list for SQL result
sql_result = []
# Find the SQL query result
for step in intermediate_steps:
if step[0].tool == 'sql_db_query':
# The result is a string representation of a list
sql_result = ast.literal_eval(step[1])
break
# Convert the result to a DataFrame for better formatting
if sql_result:
df_result = pd.DataFrame(sql_result)
sql_result_formatted = df_result.to_markdown(index=False)
else:
sql_result_formatted = "No results found."
# Include the data in the final answer
assistant_response = f"{final_answer}\n\n**Query Result:**\n{sql_result_formatted}"
# Append the assistant's response to the history
st.session_state.history.append({"role": "assistant", "content": assistant_response})
# Generate insights based on the response
insights_template = """
You are an expert data analyst. Based on the user's question and the response provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words.
User's Question: {question}
Response:
{response}
Concise Analysis:
"""
insights_prompt = PromptTemplate(template=insights_template, input_variables=['question', 'response'])
insights_chain = LLMChain(llm=llm, prompt=insights_prompt)
insights = insights_chain.run({'question': user_prompt, 'response': assistant_response})
# Append the assistant's insights to the history
st.session_state.history.append({"role": "assistant", "content": insights})
except Exception as e:
logging.error(f"An error occurred: {e}")
assistant_response = f"Error: {e}"
st.session_state.history.append({"role": "assistant", "content": assistant_response})
# Reset user input
st.session_state['user_input'] = ''
# Step 5: Display conversation history
for message in st.session_state.history:
if message['role'] == 'user':
st.markdown(f"**User:** {message['content']}")
elif message['role'] == 'assistant':
st.markdown(f"**Assistant:** {message['content']}")
# Input field
st.text_input("Enter your message:", key='user_input', on_change=process_input)