import os import streamlit as st import pandas as pd import sqlite3 from langchain import LLMChain, PromptTemplate import sqlparse import logging # Import necessary modules from transformers and langchain import torch from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from langchain.llms import HuggingFacePipeline # Initialize conversation history if 'history' not in st.session_state: st.session_state['history'] = [] # Set up the Llama-2-7b-chat-hf model model_id = "meta-llama/Llama-2-7b-chat-hf" # Get your Hugging Face token (it's stored as a secret in your Space) hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN") if hf_token is None: st.error("Hugging Face API token is not set. Please set the HUGGINGFACEHUB_API_TOKEN secret in your Space.") st.stop() # Import torch import torch # Set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load the tokenizer and model with the token tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=hf_token) model = AutoModelForCausalLM.from_pretrained( model_id, use_auth_token=hf_token, device_map=None, # We'll set the device manually torch_dtype=torch.float32 # Use float32 to avoid half-precision issues ).to(device) # Create the text-generation pipeline with appropriate parameters pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512, temperature=0.1, repetition_penalty=1.1, do_sample=True, # Use sampling to introduce some randomness eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id, device=0 if torch.cuda.is_available() else -1 # Use GPU if available ) # Wrap the pipeline with HuggingFacePipeline for use in LangChain llm = HuggingFacePipeline(pipeline=pipe) # Step 1: Upload CSV data file (or use default) st.title("Natural Language to SQL Query App with Enhanced Insights") st.write("Upload a CSV file to get started, or use the default dataset.") csv_file = st.file_uploader("Upload your CSV file", type=["csv"]) if csv_file is None: data = pd.read_csv("default_data.csv") # Ensure this file exists in your working directory st.write("Using default_data.csv file.") table_name = "default_table" else: data = pd.read_csv(csv_file) table_name = csv_file.name.split('.')[0] st.write(f"Data Preview ({csv_file.name}):") st.dataframe(data.head()) # Step 2: Load CSV data into a persistent SQLite database db_file = 'my_database.db' conn = sqlite3.connect(db_file) data.to_sql(table_name, conn, index=False, if_exists='replace') # SQL table metadata (for validation and schema) valid_columns = list(data.columns) st.write(f"Valid columns: {valid_columns}") # Step 3: Set up the LLM Chains with adjusted prompts # SQL Generation Chain sql_template = """ [INST] <> You are an expert data scientist. <> Given a natural language question, the name of the table, and a list of valid columns, generate a valid SQL query that answers the question. Ensure that: - You only use the columns provided. - When performing string comparisons in the WHERE clause, make them case-insensitive by using 'COLLATE NOCASE' or the LOWER() function. - Do not use 'COLLATE NOCASE' in ORDER BY clauses unless sorting a string column. - Do not apply 'COLLATE NOCASE' to numeric columns. If the question is vague or open-ended and does not pertain to specific data retrieval, respond with "NO_SQL" to indicate that a SQL query should not be generated. Question: {question} Table name: {table_name} Valid columns: {columns} SQL Query: [/INST] """ sql_prompt = PromptTemplate(template=sql_template, input_variables=['question', 'table_name', 'columns']) sql_generation_chain = LLMChain(llm=llm, prompt=sql_prompt) # Insights Generation Chain insights_template = """ [INST] <> You are an expert data scientist. <> Based on the user's question and the SQL query result provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words. User's Question: {question} SQL Query Result: {result} Concise Analysis (max 200 words): [/INST] """ insights_prompt = PromptTemplate(template=insights_template, input_variables=['question', 'result']) insights_chain = LLMChain(llm=llm, prompt=insights_prompt) # General Insights and Recommendations Chain general_insights_template = """ [INST] <> You are an expert data scientist. <> Based on the entire dataset provided below, generate a concise analysis with key insights and recommendations. Limit the response to 150 words. Dataset Summary: {dataset_summary} Concise Analysis and Recommendations (max 150 words): [/INST] """ general_insights_prompt = PromptTemplate(template=general_insights_template, input_variables=['dataset_summary']) general_insights_chain = LLMChain(llm=llm, prompt=general_insights_prompt) # Optional: Clean up function to remove incorrect COLLATE NOCASE usage def clean_sql_query(query): """Removes incorrect usage of COLLATE NOCASE from the SQL query.""" parsed = sqlparse.parse(query) statements = [] for stmt in parsed: tokens = [] idx = 0 while idx < len(stmt.tokens): token = stmt.tokens[idx] if (token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'COLLATE'): # Check if the next token is 'NOCASE' next_token = stmt.tokens[idx + 2] if idx + 2 < len(stmt.tokens) else None if next_token and next_token.value.upper() == 'NOCASE': # Skip 'COLLATE' and 'NOCASE' tokens idx += 3 # Skip 'COLLATE', whitespace, 'NOCASE' continue tokens.append(token) idx += 1 statements.append(''.join([str(t) for t in tokens])) return ' '.join(statements) # Function to classify user query def classify_query(question): """Classify the user query as either 'SQL' or 'INSIGHTS'.""" classification_template = """ [INST] <> You are an AI assistant that classifies user queries into two categories: 'SQL' for specific data retrieval queries and 'INSIGHTS' for general analytical or recommendation queries. <> Determine the appropriate category for the following user question. Question: "{question}" Category (SQL/INSIGHTS): [/INST] """ classification_prompt = PromptTemplate(template=classification_template, input_variables=['question']) classification_chain = LLMChain(llm=llm, prompt=classification_prompt) category = classification_chain.run({'question': question}).strip().upper() if category.startswith('SQL'): return 'SQL' else: return 'INSIGHTS' # Function to generate dataset summary def generate_dataset_summary(data): """Generate a summary of the dataset for general insights.""" summary_template = """ [INST] <> You are an expert data scientist. <> Based on the dataset provided below, generate a concise summary that includes the number of records, number of columns, data types, and any notable features. Dataset: {data} Dataset Summary: [/INST] """ summary_prompt = PromptTemplate(template=summary_template, input_variables=['data']) summary_chain = LLMChain(llm=llm, prompt=summary_prompt) summary = summary_chain.run({'data': data.head().to_string(index=False)}) return summary # Define the callback function def process_input(): user_prompt = st.session_state['user_input'] if user_prompt: try: # Append user message to history st.session_state.history.append({"role": "user", "content": user_prompt}) # Classify the user query category = classify_query(user_prompt) logging.info(f"User query classified as: {category}") if "COLUMNS" in user_prompt.upper(): assistant_response = f"The columns are: {', '.join(valid_columns)}" st.session_state.history.append({"role": "assistant", "content": assistant_response}) elif category == 'SQL': columns = ', '.join(valid_columns) generated_sql = sql_generation_chain.run({ 'question': user_prompt, 'table_name': table_name, 'columns': columns }).strip() if generated_sql.upper() == "NO_SQL": # Handle cases where no SQL should be generated assistant_response = "Sure, let's discuss some general insights and recommendations based on the data." # Generate dataset summary dataset_summary = generate_dataset_summary(data) # Generate general insights and recommendations general_insights = general_insights_chain.run({ 'dataset_summary': dataset_summary }) # Append the assistant's insights to the history st.session_state.history.append({"role": "assistant", "content": general_insights}) else: # Clean the SQL query cleaned_sql = clean_sql_query(generated_sql) logging.info(f"Generated SQL Query: {cleaned_sql}") # Attempt to execute SQL query and handle exceptions try: result = pd.read_sql_query(cleaned_sql, conn) if result.empty: assistant_response = "The query returned no results. Please try a different question." st.session_state.history.append({"role": "assistant", "content": assistant_response}) else: # Convert the result to a string for the insights prompt result_str = result.head(10).to_string(index=False) # Limit to first 10 rows # Generate insights and recommendations based on the query result insights = insights_chain.run({ 'question': user_prompt, 'result': result_str }) # Append the assistant's insights to the history st.session_state.history.append({"role": "assistant", "content": insights}) # Append the result DataFrame to the history st.session_state.history.append({"role": "assistant", "content": result}) except Exception as e: logging.error(f"An error occurred during SQL execution: {e}") assistant_response = f"Error executing SQL query: {e}" st.session_state.history.append({"role": "assistant", "content": assistant_response}) else: # INSIGHTS category # Generate dataset summary dataset_summary = generate_dataset_summary(data) # Generate general insights and recommendations general_insights = general_insights_chain.run({ 'dataset_summary': dataset_summary }) # Append the assistant's insights to the history st.session_state.history.append({"role": "assistant", "content": general_insights}) except Exception as e: logging.error(f"An error occurred: {e}") assistant_response = f"Error: {e}" st.session_state.history.append({"role": "assistant", "content": assistant_response}) # Reset the user_input in session state st.session_state['user_input'] = '' # Display the conversation history for message in st.session_state.history: if message['role'] == 'user': st.markdown(f"**User:** {message['content']}") elif message['role'] == 'assistant': if isinstance(message['content'], pd.DataFrame): st.markdown("**Assistant:** Query Results:") st.dataframe(message['content']) else: st.markdown(f"**Assistant:** {message['content']}") # Place the input field at the bottom with the callback st.text_input("Enter your message:", key='user_input', on_change=process_input)