File size: 6,170 Bytes
b6f0b52
3eb59a4
 
75829f5
e37eda0
b6f0b52
 
 
 
 
 
d9d0b05
 
f0e4f1b
 
 
82bfc51
 
 
 
 
887daae
82bfc51
 
d0ab6a9
 
 
 
 
cd60664
e38120f
d0ab6a9
 
cd60664
 
887daae
82bfc51
d0ab6a9
9e9d1c1
cd60664
d0ab6a9
cd60664
 
 
580d1d7
 
 
 
887daae
82bfc51
 
cd60664
2129665
cd60664
887daae
2129665
f0e4f1b
887daae
02a6269
fc3c978
2129665
 
 
 
 
b6f0b52
865d538
2129665
6dd2b20
580d1d7
 
 
 
3cd3249
e7ab984
0769108
580d1d7
 
 
 
 
 
 
 
 
6dd2b20
82bfc51
580d1d7
c6acd31
 
 
 
 
 
580d1d7
c6acd31
580d1d7
b6f0b52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
865d538
b6f0b52
 
 
 
 
 
c6acd31
 
b6f0b52
c6acd31
 
bcb1e04
c6acd31
 
 
 
 
 
 
 
bcb1e04
c6acd31
 
 
b6f0b52
c6acd31
 
 
 
 
580d1d7
 
 
 
 
 
 
c6acd31
 
 
 
a3c9c61
6dd2b20
580d1d7
a3c9c61
 
 
 
d9d0b05
a3c9c61
887daae
f0e4f1b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import os 
import streamlit as st
import pandas as pd
import sqlite3
import logging
import ast  # For parsing string representations of lists

from langchain_community.chat_models import ChatOpenAI
from langchain_community.agent_toolkits.sql.base import create_sql_agent
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain

# Initialize logging
logging.basicConfig(level=logging.INFO)

# Initialize conversation history
if 'history' not in st.session_state:
    st.session_state.history = []

# OpenAI API key
openai_api_key = os.getenv("OPENAI_API_KEY")

# Check if the API key is set
if not openai_api_key:
    st.error("OpenAI API key is not set. Please set the OPENAI_API_KEY environment variable.")
    st.stop()

# Step 1: Upload CSV data file (or use default)
st.title("Business Data Insights Chatbot: Automating SQL Generation & Insights Extraction")
st.write("Upload a CSV file to get started, or use the default dataset.")

csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
if csv_file is None:
    data = pd.read_csv("default_data.csv")  # Ensure this file exists
    st.write("Using default_data.csv file.")
    table_name = "default_table"
else:
    data = pd.read_csv(csv_file)
    table_name = csv_file.name.split('.')[0]
    st.write(f"Data Preview ({csv_file.name}):")
    st.dataframe(data.head())

# Display column names
st.write("**Available Columns:**")
st.write(", ".join(data.columns.tolist()))

# Step 2: Load CSV data into SQLite database
db_file = 'my_database.db'
conn = sqlite3.connect(db_file)
data.to_sql(table_name, conn, index=False, if_exists='replace')
conn.close()

# Create SQLDatabase instance
db = SQLDatabase.from_uri(f"sqlite:///{db_file}", include_tables=[table_name])

# Initialize the LLM
llm = ChatOpenAI(temperature=0, openai_api_key=openai_api_key)

# Initialize the SQL Agent
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
agent_executor = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    verbose=True,
    agent_executor_kwargs={"return_intermediate_steps": True}
)

# Step 3: Sample Questions
st.write("**Sample Questions:**")
sample_questions = [
    "Summarize the data for me.",
    "Do you notice any correlations in the datasets?",
    "Can you offer any recommendations based on the datasets?",
    "Provide an analysis of some numbers across some categories."
]

def set_sample_question(question):
    st.session_state['user_input'] = question
    process_input()

for question in sample_questions:
    st.button(question, on_click=set_sample_question, args=(question,))

# Step 4: Define the callback function
def process_input():
    user_prompt = st.session_state.get('user_input', '')

    if user_prompt:
        try:
            # Append user message to history
            st.session_state.history.append({"role": "user", "content": user_prompt})

            # Use the agent to get the response
            with st.spinner("Processing..."):
                response = agent_executor(user_prompt)

            # Extract the final answer and the data from intermediate steps
            final_answer = response['output']
            intermediate_steps = response['intermediate_steps']

            # Initialize an empty list for SQL result
            sql_result = []

            # Find the SQL query result
            for step in intermediate_steps:
                if step[0].tool == 'sql_db_query':
                    # The result is a string representation of a list
                    sql_result = ast.literal_eval(step[1])
                    break

            # Convert the result to a DataFrame for better formatting
            if sql_result:
                df_result = pd.DataFrame(sql_result)
                sql_result_formatted = df_result.to_markdown(index=False)
            else:
                sql_result_formatted = "No results found."

            # Include the data in the final answer
            assistant_response = f"{final_answer}\n\n**Query Result:**\n{sql_result_formatted}"

            # Append the assistant's response to the history
            st.session_state.history.append({"role": "assistant", "content": assistant_response})

            # Generate insights based on the response
            insights_template = """
            You are an expert data analyst. Based on the user's question and the response provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words.

            User's Question: {question}

            Response:
            {response}

            Concise Analysis:
            """
            insights_prompt = PromptTemplate(template=insights_template, input_variables=['question', 'response'])
            insights_chain = LLMChain(llm=llm, prompt=insights_prompt)

            insights = insights_chain.run({'question': user_prompt, 'response': assistant_response})

            # Append the assistant's insights to the history
            st.session_state.history.append({"role": "assistant", "content": insights})
        except Exception as e:
            logging.error(f"An error occurred: {e}")

            # Check for specific errors related to missing columns
            if "no such column" in str(e).lower():
                assistant_response = "Error: One or more columns referenced do not exist in the data."
            else:
                assistant_response = f"Error: {e}"

            st.session_state.history.append({"role": "assistant", "content": assistant_response})

        # Reset user input
        st.session_state['user_input'] = ''

# Step 5: Display conversation history
st.write("## Conversation History")
for message in st.session_state.history:
    if message['role'] == 'user':
        st.markdown(f"**User:** {message['content']}")
    elif message['role'] == 'assistant':
        st.markdown(f"**Assistant:** {message['content']}")

# Input field
st.text_input("Enter your message:", key='user_input', on_change=process_input)