Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
import pandas as pd | |
import sqlite3 | |
import logging | |
import ast # For parsing string representations of lists | |
from langchain_community.chat_models import ChatOpenAI | |
from langchain_community.agent_toolkits.sql.base import create_sql_agent | |
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit | |
from langchain_community.utilities import SQLDatabase | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import LLMChain | |
# Initialize logging | |
logging.basicConfig(level=logging.INFO) | |
# Initialize conversation history | |
if 'history' not in st.session_state: | |
st.session_state.history = [] | |
# OpenAI API key | |
openai_api_key = os.getenv("OPENAI_API_KEY") | |
# Check if the API key is set | |
if not openai_api_key: | |
st.error("OpenAI API key is not set. Please set the OPENAI_API_KEY environment variable.") | |
st.stop() | |
# Step 1: Upload CSV data file (or use default) | |
st.title("Business Data Insights Chatbot: Automating SQL Generation & Insights Extraction") | |
st.write("Upload a CSV file to get started, or use the default dataset.") | |
csv_file = st.file_uploader("Upload your CSV file", type=["csv"]) | |
if csv_file is None: | |
data = pd.read_csv("default_data.csv") # Ensure this file exists | |
st.write("Using default_data.csv file.") | |
table_name = "default_table" | |
else: | |
data = pd.read_csv(csv_file) | |
table_name = csv_file.name.split('.')[0] | |
st.write(f"Data Preview ({csv_file.name}):") | |
st.dataframe(data.head()) | |
# Display column names | |
st.write("**Available Columns:**") | |
st.write(", ".join(data.columns.tolist())) | |
# Step 2: Load CSV data into SQLite database | |
db_file = 'my_database.db' | |
conn = sqlite3.connect(db_file) | |
data.to_sql(table_name, conn, index=False, if_exists='replace') | |
conn.close() | |
# Create SQLDatabase instance | |
db = SQLDatabase.from_uri(f"sqlite:///{db_file}", include_tables=[table_name]) | |
# Initialize the LLM | |
llm = ChatOpenAI(temperature=0, openai_api_key=openai_api_key) | |
# Initialize the SQL Agent | |
toolkit = SQLDatabaseToolkit(db=db, llm=llm) | |
agent_executor = create_sql_agent( | |
llm=llm, | |
toolkit=toolkit, | |
verbose=True, | |
agent_executor_kwargs={"return_intermediate_steps": True} | |
) | |
# Step 3: Sample Questions | |
st.write("**Sample Questions:**") | |
sample_questions = [ | |
"Summarize the data for me.", | |
"Do you notice any correlations in the datasets?", | |
"Can you offer any recommendations based on the datasets?", | |
"Provide an analysis of some numbers across some categories." | |
] | |
def set_sample_question(question): | |
st.session_state['user_input'] = question | |
process_input() | |
for question in sample_questions: | |
st.button(question, on_click=set_sample_question, args=(question,)) | |
# Step 4: Define the callback function | |
def process_input(): | |
user_prompt = st.session_state.get('user_input', '') | |
if user_prompt: | |
try: | |
# Append user message to history | |
st.session_state.history.append({"role": "user", "content": user_prompt}) | |
# Use the agent to get the response | |
with st.spinner("Processing..."): | |
response = agent_executor(user_prompt) | |
# Extract the final answer and the data from intermediate steps | |
final_answer = response['output'] | |
intermediate_steps = response['intermediate_steps'] | |
# Initialize an empty list for SQL result | |
sql_result = [] | |
# Find the SQL query result | |
for step in intermediate_steps: | |
if step[0].tool == 'sql_db_query': | |
# The result is a string representation of a list | |
sql_result = ast.literal_eval(step[1]) | |
break | |
# Convert the result to a DataFrame for better formatting | |
if sql_result: | |
df_result = pd.DataFrame(sql_result) | |
sql_result_formatted = df_result.to_markdown(index=False) | |
else: | |
sql_result_formatted = "No results found." | |
# Include the data in the final answer | |
assistant_response = f"{final_answer}\n\n**Query Result:**\n{sql_result_formatted}" | |
# Append the assistant's response to the history | |
st.session_state.history.append({"role": "assistant", "content": assistant_response}) | |
# Generate insights based on the response | |
insights_template = """ | |
You are an expert data analyst. Based on the user's question and the response provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words. | |
User's Question: {question} | |
Response: | |
{response} | |
Concise Analysis: | |
""" | |
insights_prompt = PromptTemplate(template=insights_template, input_variables=['question', 'response']) | |
insights_chain = LLMChain(llm=llm, prompt=insights_prompt) | |
insights = insights_chain.run({'question': user_prompt, 'response': assistant_response}) | |
# Append the assistant's insights to the history | |
st.session_state.history.append({"role": "assistant", "content": insights}) | |
except Exception as e: | |
logging.error(f"An error occurred: {e}") | |
# Check for specific errors related to missing columns | |
if "no such column" in str(e).lower(): | |
assistant_response = "Error: One or more columns referenced do not exist in the data." | |
else: | |
assistant_response = f"Error: {e}" | |
st.session_state.history.append({"role": "assistant", "content": assistant_response}) | |
# Reset user input | |
st.session_state['user_input'] = '' | |
# Step 5: Display conversation history | |
st.write("## Conversation History") | |
for message in st.session_state.history: | |
if message['role'] == 'user': | |
st.markdown(f"**User:** {message['content']}") | |
elif message['role'] == 'assistant': | |
st.markdown(f"**Assistant:** {message['content']}") | |
# Input field | |
st.text_input("Enter your message:", key='user_input', on_change=process_input) | |