Spaces:
Sleeping
Sleeping
File size: 5,313 Bytes
b6f0b52 3eb59a4 75829f5 e37eda0 b6f0b52 d9d0b05 f0e4f1b 82bfc51 887daae 82bfc51 d0ab6a9 cd60664 f0e4f1b d0ab6a9 cd60664 887daae 82bfc51 d0ab6a9 9e9d1c1 cd60664 d0ab6a9 cd60664 887daae 82bfc51 cd60664 2129665 cd60664 887daae 2129665 f0e4f1b 887daae 02a6269 fc3c978 2129665 b6f0b52 865d538 2129665 6dd2b20 82bfc51 c6acd31 624a8ad c6acd31 624a8ad b6f0b52 865d538 b6f0b52 c6acd31 b6f0b52 c6acd31 bcb1e04 c6acd31 bcb1e04 c6acd31 b6f0b52 c6acd31 a3c9c61 624a8ad 6dd2b20 a3c9c61 d9d0b05 a3c9c61 887daae f0e4f1b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import os
import streamlit as st
import pandas as pd
import sqlite3
import logging
import ast # For parsing string representations of lists
from langchain_community.chat_models import ChatOpenAI
from langchain_community.agent_toolkits.sql.base import create_sql_agent
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
# Initialize logging
logging.basicConfig(level=logging.INFO)
# Initialize conversation history
if 'history' not in st.session_state:
st.session_state.history = []
# OpenAI API key
openai_api_key = os.getenv("OPENAI_API_KEY")
# Check if the API key is set
if not openai_api_key:
st.error("OpenAI API key is not set. Please set the OPENAI_API_KEY environment variable.")
st.stop()
# Step 1: Upload CSV data file (or use default)
st.title("Enhanced Natural Language to SQL Query App")
st.write("Upload a CSV file to get started, or use the default dataset.")
csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
if csv_file is None:
data = pd.read_csv("default_data.csv") # Ensure this file exists
st.write("Using default_data.csv file.")
table_name = "default_table"
else:
data = pd.read_csv(csv_file)
table_name = csv_file.name.split('.')[0]
st.write(f"Data Preview ({csv_file.name}):")
st.dataframe(data.head())
# Step 2: Load CSV data into SQLite database
db_file = 'my_database.db'
conn = sqlite3.connect(db_file)
data.to_sql(table_name, conn, index=False, if_exists='replace')
conn.close()
# Create SQLDatabase instance
db = SQLDatabase.from_uri(f"sqlite:///{db_file}", include_tables=[table_name])
# Initialize the LLM
llm = ChatOpenAI(temperature=0, openai_api_key=openai_api_key)
# Initialize the SQL Agent
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
agent_executor = create_sql_agent(
llm=llm,
toolkit=toolkit,
verbose=True,
agent_executor_kwargs={"return_intermediate_steps": True}
)
# Step 4: Define the callback function
def process_input():
user_prompt = st.session_state['user_input']
if user_prompt:
try:
# Append user message to history
st.session_state.history.append({"role": "user", "content": user_prompt})
# Use the agent to get the response WITHOUT 'return_intermediate_steps'
with st.spinner("Processing..."):
response = agent_executor(user_prompt) # Removed 'return_intermediate_steps' here
# Extract the final answer and the data from intermediate steps
final_answer = response['output']
intermediate_steps = response['intermediate_steps']
# Initialize an empty list for SQL result
sql_result = []
# Find the SQL query result
for step in intermediate_steps:
if step[0].tool == 'sql_db_query':
# The result is a string representation of a list
sql_result = ast.literal_eval(step[1])
break
# Convert the result to a DataFrame for better formatting
if sql_result:
df_result = pd.DataFrame(sql_result)
sql_result_formatted = df_result.to_markdown(index=False)
else:
sql_result_formatted = "No results found."
# Include the data in the final answer
assistant_response = f"{final_answer}\n\n**Query Result:**\n{sql_result_formatted}"
# Append the assistant's response to the history
st.session_state.history.append({"role": "assistant", "content": assistant_response})
# Generate insights based on the response
insights_template = """
You are an expert data analyst. Based on the user's question and the response provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words.
User's Question: {question}
Response:
{response}
Concise Analysis:
"""
insights_prompt = PromptTemplate(template=insights_template, input_variables=['question', 'response'])
insights_chain = LLMChain(llm=llm, prompt=insights_prompt)
insights = insights_chain.run({'question': user_prompt, 'response': assistant_response})
# Append the assistant's insights to the history
st.session_state.history.append({"role": "assistant", "content": insights})
except Exception as e:
logging.error(f"An error occurred: {e}")
assistant_response = f"Error: {e}"
st.session_state.history.append({"role": "assistant", "content": assistant_response})
# Reset user input
st.session_state['user_input'] = ''
# Step 5: Display conversation history
for message in st.session_state.history:
if message['role'] == 'user':
st.markdown(f"**User:** {message['content']}")
elif message['role'] == 'assistant':
st.markdown(f"**Assistant:** {message['content']}")
# Input field
st.text_input("Enter your message:", key='user_input', on_change=process_input)
|