import os import streamlit as st import pandas as pd import sqlite3 from transformers import pipeline import sqlparse import logging # Initialize conversation history if 'history' not in st.session_state: st.session_state.history = [] # Load a smaller and faster pre-trained model (distilgpt2) from Hugging Face llm = pipeline('text-generation', model='distilgpt2') # Using a smaller model for faster inference # Step 1: Upload CSV data file (or use default) st.title("Natural Language to SQL Query App with Enhanced Insights") st.write("Upload a CSV file to get started, or use the default dataset.") csv_file = st.file_uploader("Upload your CSV file", type=["csv"]) if csv_file is None: data = pd.read_csv("default_data.csv") # Ensure this file exists in your working directory st.write("Using default_data.csv file.") table_name = "default_table" else: data = pd.read_csv(csv_file) table_name = csv_file.name.split('.')[0] st.write(f"Data Preview ({csv_file.name}):") st.dataframe(data.head()) # Step 2: Load CSV data into a persistent SQLite database db_file = 'my_database.db' conn = sqlite3.connect(db_file, check_same_thread=False) # Allow connection across threads data.to_sql(table_name, conn, index=False, if_exists='replace') # SQL table metadata (for validation and schema) valid_columns = list(data.columns) st.write(f"Valid columns: {valid_columns}") # Function to generate SQL query using Hugging Face model def generate_sql_query(question, table_name, columns): # Simplified and direct prompt to focus on generating valid SQL prompt = f""" You are a SQL expert. Generate a SQL query using the columns: {columns}. Question: {question} Respond only with the SQL query. """ response = llm(prompt, max_new_tokens=50, truncation=True) # Ensure max tokens are reasonable return response[0]['generated_text'].strip() # Function to generate insights using Hugging Face model def generate_insights(question, result): prompt = f""" Based on the user's question and the SQL query result below, generate concise data insights: {result} """ response = llm(prompt, max_new_tokens=100, truncation=True) return response[0]['generated_text'].strip() # Function to classify user query as SQL or Insights def classify_query(question): prompt = f""" Classify the following question as 'SQL' or 'INSIGHTS': "{question}" """ response = llm(prompt, max_new_tokens=10, truncation=True) category = response[0]['generated_text'].strip().upper() return 'SQL' if 'SQL' in category else 'INSIGHTS' # Function to generate dataset summary def generate_dataset_summary(data): summary_template = f""" Provide a brief summary of the dataset: {data.head().to_string(index=False)} """ response = llm(summary_template, max_new_tokens=100, truncation=True) return response[0]['generated_text'].strip() # Function to validate if the generated SQL query is valid def is_valid_sql(query): sql_keywords = ["SELECT", "INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER"] return any(query.strip().upper().startswith(keyword) for keyword in sql_keywords) # Define the callback function def process_input(): user_prompt = st.session_state['user_input'] if user_prompt: try: # Append user message to history st.session_state.history.append({"role": "user", "content": user_prompt}) # Classify the user query category = classify_query(user_prompt) logging.info(f"User query classified as: {category}") if "COLUMNS" in user_prompt.upper(): assistant_response = f"The columns are: {', '.join(valid_columns)}" st.session_state.history.append({"role": "assistant", "content": assistant_response}) elif category == 'SQL': columns = ', '.join(valid_columns) generated_sql = generate_sql_query(user_prompt, table_name, columns) if generated_sql.upper() == "NO_SQL": # Handle cases where no SQL should be generated assistant_response = "Sure, let's discuss some general insights and recommendations based on the data." # Generate dataset summary dataset_summary = generate_dataset_summary(data) # Generate general insights and recommendations general_insights = generate_insights(user_prompt, dataset_summary) # Append the assistant's insights to the history st.session_state.history.append({"role": "assistant", "content": general_insights}) else: # Validate the SQL query if is_valid_sql(generated_sql): # Attempt to execute SQL query and handle exceptions try: result = pd.read_sql_query(generated_sql, conn) if result.empty: assistant_response = "The query returned no results. Please try a different question." st.session_state.history.append({"role": "assistant", "content": assistant_response}) else: # Convert the result to a string for the insights prompt result_str = result.head(10).to_string(index=False) # Limit to first 10 rows # Generate insights and recommendations based on the query result insights = generate_insights(user_prompt, result_str) # Append the assistant's insights to the history st.session_state.history.append({"role": "assistant", "content": insights}) # Append the result DataFrame to the history st.session_state.history.append({"role": "assistant", "content": result}) except Exception as e: logging.error(f"An error occurred during SQL execution: {e}") assistant_response = f"Error executing SQL query: {e}" st.session_state.history.append({"role": "assistant", "content": assistant_response}) else: # If generated text is not valid SQL, provide feedback to the user st.session_state.history.append({"role": "assistant", "content": "Generated text is not a valid SQL query. Please try rephrasing your question."}) else: # INSIGHTS category # Generate dataset summary dataset_summary = generate_dataset_summary(data) # Generate general insights and recommendations general_insights = generate_insights(user_prompt, dataset_summary) # Append the assistant's insights to the history st.session_state.history.append({"role": "assistant", "content": general_insights}) except Exception as e: logging.error(f"An error occurred: {e}") assistant_response = f"Error: {e}" st.session_state.history.append({"role": "assistant", "content": assistant_response}) # Reset the user_input in session state st.session_state['user_input'] = '' # Display the conversation history for message in st.session_state.history: if message['role'] == 'user': st.markdown(f"**User:** {message['content']}") elif message['role'] == 'assistant': if isinstance(message['content'], pd.DataFrame): st.markdown("**Assistant:** Query Results:") st.dataframe(message['content']) else: st.markdown(f"**Assistant:** {message['content']}") # Place the input field at the bottom with the callback st.text_input("Enter your message:", key='user_input', on_change=process_input)