import os
import streamlit as st
import pandas as pd
import sqlite3
from langchain import LLMChain, PromptTemplate
import sqlparse
import logging
# Import necessary modules from transformers and langchain
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain.llms import HuggingFacePipeline
# Initialize conversation history
if 'history' not in st.session_state:
st.session_state['history'] = []
# Set up the Llama-2-7b-chat-hf model
model_id = "meta-llama/Llama-2-7b-chat-hf"
# Get your Hugging Face token (it's stored as a secret in your Space)
hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
if hf_token is None:
st.error("Hugging Face API token is not set. Please set the HUGGINGFACEHUB_API_TOKEN secret in your Space.")
st.stop()
# Import torch
import torch
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the tokenizer and model with the token
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=hf_token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
use_auth_token=hf_token,
device_map=None, # We'll set the device manually
torch_dtype=torch.float32 # Use float32 to avoid half-precision issues
).to(device)
# Create the text-generation pipeline with appropriate parameters
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
temperature=0.1,
repetition_penalty=1.1,
do_sample=True, # Use sampling to introduce some randomness
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
device=0 if torch.cuda.is_available() else -1 # Use GPU if available
)
# Wrap the pipeline with HuggingFacePipeline for use in LangChain
llm = HuggingFacePipeline(pipeline=pipe)
# Step 1: Upload CSV data file (or use default)
st.title("Natural Language to SQL Query App with Enhanced Insights")
st.write("Upload a CSV file to get started, or use the default dataset.")
csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
if csv_file is None:
data = pd.read_csv("default_data.csv") # Ensure this file exists in your working directory
st.write("Using default_data.csv file.")
table_name = "default_table"
else:
data = pd.read_csv(csv_file)
table_name = csv_file.name.split('.')[0]
st.write(f"Data Preview ({csv_file.name}):")
st.dataframe(data.head())
# Step 2: Load CSV data into a persistent SQLite database
db_file = 'my_database.db'
conn = sqlite3.connect(db_file)
data.to_sql(table_name, conn, index=False, if_exists='replace')
# SQL table metadata (for validation and schema)
valid_columns = list(data.columns)
st.write(f"Valid columns: {valid_columns}")
# Step 3: Set up the LLM Chains with adjusted prompts
# SQL Generation Chain
sql_template = """
[INST] <>
You are an expert data scientist.
<>
Given a natural language question, the name of the table, and a list of valid columns, generate a valid SQL query that answers the question.
Ensure that:
- You only use the columns provided.
- When performing string comparisons in the WHERE clause, make them case-insensitive by using 'COLLATE NOCASE' or the LOWER() function.
- Do not use 'COLLATE NOCASE' in ORDER BY clauses unless sorting a string column.
- Do not apply 'COLLATE NOCASE' to numeric columns.
If the question is vague or open-ended and does not pertain to specific data retrieval, respond with "NO_SQL" to indicate that a SQL query should not be generated.
Question: {question}
Table name: {table_name}
Valid columns: {columns}
SQL Query:
[/INST]
"""
sql_prompt = PromptTemplate(template=sql_template, input_variables=['question', 'table_name', 'columns'])
sql_generation_chain = LLMChain(llm=llm, prompt=sql_prompt)
# Insights Generation Chain
insights_template = """
[INST] <>
You are an expert data scientist.
<>
Based on the user's question and the SQL query result provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words.
User's Question: {question}
SQL Query Result:
{result}
Concise Analysis (max 200 words):
[/INST]
"""
insights_prompt = PromptTemplate(template=insights_template, input_variables=['question', 'result'])
insights_chain = LLMChain(llm=llm, prompt=insights_prompt)
# General Insights and Recommendations Chain
general_insights_template = """
[INST] <>
You are an expert data scientist.
<>
Based on the entire dataset provided below, generate a concise analysis with key insights and recommendations. Limit the response to 150 words.
Dataset Summary:
{dataset_summary}
Concise Analysis and Recommendations (max 150 words):
[/INST]
"""
general_insights_prompt = PromptTemplate(template=general_insights_template, input_variables=['dataset_summary'])
general_insights_chain = LLMChain(llm=llm, prompt=general_insights_prompt)
# Optional: Clean up function to remove incorrect COLLATE NOCASE usage
def clean_sql_query(query):
"""Removes incorrect usage of COLLATE NOCASE from the SQL query."""
parsed = sqlparse.parse(query)
statements = []
for stmt in parsed:
tokens = []
idx = 0
while idx < len(stmt.tokens):
token = stmt.tokens[idx]
if (token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'COLLATE'):
# Check if the next token is 'NOCASE'
next_token = stmt.tokens[idx + 2] if idx + 2 < len(stmt.tokens) else None
if next_token and next_token.value.upper() == 'NOCASE':
# Skip 'COLLATE' and 'NOCASE' tokens
idx += 3 # Skip 'COLLATE', whitespace, 'NOCASE'
continue
tokens.append(token)
idx += 1
statements.append(''.join([str(t) for t in tokens]))
return ' '.join(statements)
# Function to classify user query
def classify_query(question):
"""Classify the user query as either 'SQL' or 'INSIGHTS'."""
classification_template = """
[INST] <>
You are an AI assistant that classifies user queries into two categories: 'SQL' for specific data retrieval queries and 'INSIGHTS' for general analytical or recommendation queries.
<>
Determine the appropriate category for the following user question.
Question: "{question}"
Category (SQL/INSIGHTS):
[/INST]
"""
classification_prompt = PromptTemplate(template=classification_template, input_variables=['question'])
classification_chain = LLMChain(llm=llm, prompt=classification_prompt)
category = classification_chain.run({'question': question}).strip().upper()
if category.startswith('SQL'):
return 'SQL'
else:
return 'INSIGHTS'
# Function to generate dataset summary
def generate_dataset_summary(data):
"""Generate a summary of the dataset for general insights."""
summary_template = """
[INST] <>
You are an expert data scientist.
<>
Based on the dataset provided below, generate a concise summary that includes the number of records, number of columns, data types, and any notable features.
Dataset:
{data}
Dataset Summary:
[/INST]
"""
summary_prompt = PromptTemplate(template=summary_template, input_variables=['data'])
summary_chain = LLMChain(llm=llm, prompt=summary_prompt)
summary = summary_chain.run({'data': data.head().to_string(index=False)})
return summary
# Define the callback function
def process_input():
user_prompt = st.session_state['user_input']
if user_prompt:
try:
# Append user message to history
st.session_state.history.append({"role": "user", "content": user_prompt})
# Classify the user query
category = classify_query(user_prompt)
logging.info(f"User query classified as: {category}")
if "COLUMNS" in user_prompt.upper():
assistant_response = f"The columns are: {', '.join(valid_columns)}"
st.session_state.history.append({"role": "assistant", "content": assistant_response})
elif category == 'SQL':
columns = ', '.join(valid_columns)
generated_sql = sql_generation_chain.run({
'question': user_prompt,
'table_name': table_name,
'columns': columns
}).strip()
if generated_sql.upper() == "NO_SQL":
# Handle cases where no SQL should be generated
assistant_response = "Sure, let's discuss some general insights and recommendations based on the data."
# Generate dataset summary
dataset_summary = generate_dataset_summary(data)
# Generate general insights and recommendations
general_insights = general_insights_chain.run({
'dataset_summary': dataset_summary
})
# Append the assistant's insights to the history
st.session_state.history.append({"role": "assistant", "content": general_insights})
else:
# Clean the SQL query
cleaned_sql = clean_sql_query(generated_sql)
logging.info(f"Generated SQL Query: {cleaned_sql}")
# Attempt to execute SQL query and handle exceptions
try:
result = pd.read_sql_query(cleaned_sql, conn)
if result.empty:
assistant_response = "The query returned no results. Please try a different question."
st.session_state.history.append({"role": "assistant", "content": assistant_response})
else:
# Convert the result to a string for the insights prompt
result_str = result.head(10).to_string(index=False) # Limit to first 10 rows
# Generate insights and recommendations based on the query result
insights = insights_chain.run({
'question': user_prompt,
'result': result_str
})
# Append the assistant's insights to the history
st.session_state.history.append({"role": "assistant", "content": insights})
# Append the result DataFrame to the history
st.session_state.history.append({"role": "assistant", "content": result})
except Exception as e:
logging.error(f"An error occurred during SQL execution: {e}")
assistant_response = f"Error executing SQL query: {e}"
st.session_state.history.append({"role": "assistant", "content": assistant_response})
else: # INSIGHTS category
# Generate dataset summary
dataset_summary = generate_dataset_summary(data)
# Generate general insights and recommendations
general_insights = general_insights_chain.run({
'dataset_summary': dataset_summary
})
# Append the assistant's insights to the history
st.session_state.history.append({"role": "assistant", "content": general_insights})
except Exception as e:
logging.error(f"An error occurred: {e}")
assistant_response = f"Error: {e}"
st.session_state.history.append({"role": "assistant", "content": assistant_response})
# Reset the user_input in session state
st.session_state['user_input'] = ''
# Display the conversation history
for message in st.session_state.history:
if message['role'] == 'user':
st.markdown(f"**User:** {message['content']}")
elif message['role'] == 'assistant':
if isinstance(message['content'], pd.DataFrame):
st.markdown("**Assistant:** Query Results:")
st.dataframe(message['content'])
else:
st.markdown(f"**Assistant:** {message['content']}")
# Place the input field at the bottom with the callback
st.text_input("Enter your message:", key='user_input', on_change=process_input)