GenBIChatbot / app.py
arithescientist's picture
Update app.py
e7ab984 verified
import os
import streamlit as st
import pandas as pd
import sqlite3
import logging
import ast # For parsing string representations of lists
from langchain_community.chat_models import ChatOpenAI
from langchain_community.agent_toolkits.sql.base import create_sql_agent
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
# Initialize logging
logging.basicConfig(level=logging.INFO)
# Initialize conversation history
if 'history' not in st.session_state:
st.session_state.history = []
# OpenAI API key
openai_api_key = os.getenv("OPENAI_API_KEY")
# Check if the API key is set
if not openai_api_key:
st.error("OpenAI API key is not set. Please set the OPENAI_API_KEY environment variable.")
st.stop()
# Step 1: Upload CSV data file (or use default)
st.title("Business Data Insights Chatbot: Automating SQL Generation & Insights Extraction")
st.write("Upload a CSV file to get started, or use the default dataset.")
csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
if csv_file is None:
data = pd.read_csv("default_data.csv") # Ensure this file exists
st.write("Using default_data.csv file.")
table_name = "default_table"
else:
data = pd.read_csv(csv_file)
table_name = csv_file.name.split('.')[0]
st.write(f"Data Preview ({csv_file.name}):")
st.dataframe(data.head())
# Display column names
st.write("**Available Columns:**")
st.write(", ".join(data.columns.tolist()))
# Step 2: Load CSV data into SQLite database
db_file = 'my_database.db'
conn = sqlite3.connect(db_file)
data.to_sql(table_name, conn, index=False, if_exists='replace')
conn.close()
# Create SQLDatabase instance
db = SQLDatabase.from_uri(f"sqlite:///{db_file}", include_tables=[table_name])
# Initialize the LLM
llm = ChatOpenAI(temperature=0, openai_api_key=openai_api_key)
# Initialize the SQL Agent
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
agent_executor = create_sql_agent(
llm=llm,
toolkit=toolkit,
verbose=True,
agent_executor_kwargs={"return_intermediate_steps": True}
)
# Step 3: Sample Questions
st.write("**Sample Questions:**")
sample_questions = [
"Summarize the data for me.",
"Do you notice any correlations in the datasets?",
"Can you offer any recommendations based on the datasets?",
"Provide an analysis of some numbers across some categories."
]
def set_sample_question(question):
st.session_state['user_input'] = question
process_input()
for question in sample_questions:
st.button(question, on_click=set_sample_question, args=(question,))
# Step 4: Define the callback function
def process_input():
user_prompt = st.session_state.get('user_input', '')
if user_prompt:
try:
# Append user message to history
st.session_state.history.append({"role": "user", "content": user_prompt})
# Use the agent to get the response
with st.spinner("Processing..."):
response = agent_executor(user_prompt)
# Extract the final answer and the data from intermediate steps
final_answer = response['output']
intermediate_steps = response['intermediate_steps']
# Initialize an empty list for SQL result
sql_result = []
# Find the SQL query result
for step in intermediate_steps:
if step[0].tool == 'sql_db_query':
# The result is a string representation of a list
sql_result = ast.literal_eval(step[1])
break
# Convert the result to a DataFrame for better formatting
if sql_result:
df_result = pd.DataFrame(sql_result)
sql_result_formatted = df_result.to_markdown(index=False)
else:
sql_result_formatted = "No results found."
# Include the data in the final answer
assistant_response = f"{final_answer}\n\n**Query Result:**\n{sql_result_formatted}"
# Append the assistant's response to the history
st.session_state.history.append({"role": "assistant", "content": assistant_response})
# Generate insights based on the response
insights_template = """
You are an expert data analyst. Based on the user's question and the response provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words.
User's Question: {question}
Response:
{response}
Concise Analysis:
"""
insights_prompt = PromptTemplate(template=insights_template, input_variables=['question', 'response'])
insights_chain = LLMChain(llm=llm, prompt=insights_prompt)
insights = insights_chain.run({'question': user_prompt, 'response': assistant_response})
# Append the assistant's insights to the history
st.session_state.history.append({"role": "assistant", "content": insights})
except Exception as e:
logging.error(f"An error occurred: {e}")
# Check for specific errors related to missing columns
if "no such column" in str(e).lower():
assistant_response = "Error: One or more columns referenced do not exist in the data."
else:
assistant_response = f"Error: {e}"
st.session_state.history.append({"role": "assistant", "content": assistant_response})
# Reset user input
st.session_state['user_input'] = ''
# Step 5: Display conversation history
st.write("## Conversation History")
for message in st.session_state.history:
if message['role'] == 'user':
st.markdown(f"**User:** {message['content']}")
elif message['role'] == 'assistant':
st.markdown(f"**Assistant:** {message['content']}")
# Input field
st.text_input("Enter your message:", key='user_input', on_change=process_input)