arithescientist's picture
Update app.py
45afb27 verified
import os
import streamlit as st
import pandas as pd
import sqlite3
from transformers import pipeline
import sqlparse
import logging
# Initialize conversation history
if 'history' not in st.session_state:
st.session_state.history = []
# Load a smaller and faster pre-trained model (distilgpt2) from Hugging Face
llm = pipeline('text-generation', model='distilgpt2') # Using a smaller model for faster inference
# Step 1: Upload CSV data file (or use default)
st.title("Natural Language to SQL Query App with Enhanced Insights")
st.write("Upload a CSV file to get started, or use the default dataset.")
csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
if csv_file is None:
data = pd.read_csv("default_data.csv") # Ensure this file exists in your working directory
st.write("Using default_data.csv file.")
table_name = "default_table"
else:
data = pd.read_csv(csv_file)
table_name = csv_file.name.split('.')[0]
st.write(f"Data Preview ({csv_file.name}):")
st.dataframe(data.head())
# Step 2: Load CSV data into a persistent SQLite database
db_file = 'my_database.db'
conn = sqlite3.connect(db_file, check_same_thread=False) # Allow connection across threads
data.to_sql(table_name, conn, index=False, if_exists='replace')
# SQL table metadata (for validation and schema)
valid_columns = list(data.columns)
st.write(f"Valid columns: {valid_columns}")
# Function to generate SQL query using Hugging Face model
def generate_sql_query(question, table_name, columns):
# Simplified and direct prompt to focus on generating valid SQL
prompt = f"""
You are a SQL expert. Generate a SQL query using the columns:
{columns}.
Question: {question}
Respond only with the SQL query.
"""
response = llm(prompt, max_new_tokens=50, truncation=True) # Ensure max tokens are reasonable
return response[0]['generated_text'].strip()
# Function to generate insights using Hugging Face model
def generate_insights(question, result):
prompt = f"""
Based on the user's question and the SQL query result below, generate concise data insights:
{result}
"""
response = llm(prompt, max_new_tokens=100, truncation=True)
return response[0]['generated_text'].strip()
# Function to classify user query as SQL or Insights
def classify_query(question):
prompt = f"""
Classify the following question as 'SQL' or 'INSIGHTS':
"{question}"
"""
response = llm(prompt, max_new_tokens=10, truncation=True)
category = response[0]['generated_text'].strip().upper()
return 'SQL' if 'SQL' in category else 'INSIGHTS'
# Function to generate dataset summary
def generate_dataset_summary(data):
summary_template = f"""
Provide a brief summary of the dataset:
{data.head().to_string(index=False)}
"""
response = llm(summary_template, max_new_tokens=100, truncation=True)
return response[0]['generated_text'].strip()
# Function to validate if the generated SQL query is valid
def is_valid_sql(query):
sql_keywords = ["SELECT", "INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER"]
return any(query.strip().upper().startswith(keyword) for keyword in sql_keywords)
# Define the callback function
def process_input():
user_prompt = st.session_state['user_input']
if user_prompt:
try:
# Append user message to history
st.session_state.history.append({"role": "user", "content": user_prompt})
# Classify the user query
category = classify_query(user_prompt)
logging.info(f"User query classified as: {category}")
if "COLUMNS" in user_prompt.upper():
assistant_response = f"The columns are: {', '.join(valid_columns)}"
st.session_state.history.append({"role": "assistant", "content": assistant_response})
elif category == 'SQL':
columns = ', '.join(valid_columns)
generated_sql = generate_sql_query(user_prompt, table_name, columns)
if generated_sql.upper() == "NO_SQL":
# Handle cases where no SQL should be generated
assistant_response = "Sure, let's discuss some general insights and recommendations based on the data."
# Generate dataset summary
dataset_summary = generate_dataset_summary(data)
# Generate general insights and recommendations
general_insights = generate_insights(user_prompt, dataset_summary)
# Append the assistant's insights to the history
st.session_state.history.append({"role": "assistant", "content": general_insights})
else:
# Validate the SQL query
if is_valid_sql(generated_sql):
# Attempt to execute SQL query and handle exceptions
try:
result = pd.read_sql_query(generated_sql, conn)
if result.empty:
assistant_response = "The query returned no results. Please try a different question."
st.session_state.history.append({"role": "assistant", "content": assistant_response})
else:
# Convert the result to a string for the insights prompt
result_str = result.head(10).to_string(index=False) # Limit to first 10 rows
# Generate insights and recommendations based on the query result
insights = generate_insights(user_prompt, result_str)
# Append the assistant's insights to the history
st.session_state.history.append({"role": "assistant", "content": insights})
# Append the result DataFrame to the history
st.session_state.history.append({"role": "assistant", "content": result})
except Exception as e:
logging.error(f"An error occurred during SQL execution: {e}")
assistant_response = f"Error executing SQL query: {e}"
st.session_state.history.append({"role": "assistant", "content": assistant_response})
else:
# If generated text is not valid SQL, provide feedback to the user
st.session_state.history.append({"role": "assistant", "content": "Generated text is not a valid SQL query. Please try rephrasing your question."})
else: # INSIGHTS category
# Generate dataset summary
dataset_summary = generate_dataset_summary(data)
# Generate general insights and recommendations
general_insights = generate_insights(user_prompt, dataset_summary)
# Append the assistant's insights to the history
st.session_state.history.append({"role": "assistant", "content": general_insights})
except Exception as e:
logging.error(f"An error occurred: {e}")
assistant_response = f"Error: {e}"
st.session_state.history.append({"role": "assistant", "content": assistant_response})
# Reset the user_input in session state
st.session_state['user_input'] = ''
# Display the conversation history
for message in st.session_state.history:
if message['role'] == 'user':
st.markdown(f"**User:** {message['content']}")
elif message['role'] == 'assistant':
if isinstance(message['content'], pd.DataFrame):
st.markdown("**Assistant:** Query Results:")
st.dataframe(message['content'])
else:
st.markdown(f"**Assistant:** {message['content']}")
# Place the input field at the bottom with the callback
st.text_input("Enter your message:", key='user_input', on_change=process_input)