import os import streamlit as st import pandas as pd import sqlite3 import logging import ast # For parsing string representations of lists from langchain_community.chat_models import ChatOpenAI from langchain_community.agent_toolkits.sql.base import create_sql_agent from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit from langchain_community.utilities import SQLDatabase from langchain.prompts import PromptTemplate from langchain.chains import LLMChain # Initialize logging logging.basicConfig(level=logging.INFO) # Initialize conversation history if 'history' not in st.session_state: st.session_state.history = [] # OpenAI API key openai_api_key = os.getenv("OPENAI_API_KEY") # Check if the API key is set if not openai_api_key: st.error("OpenAI API key is not set. Please set the OPENAI_API_KEY environment variable.") st.stop() # Step 1: Upload CSV data file (or use default) st.title("Enhanced Natural Language to SQL Query App") st.write("Upload a CSV file to get started, or use the default dataset.") csv_file = st.file_uploader("Upload your CSV file", type=["csv"]) if csv_file is None: data = pd.read_csv("default_data.csv") # Ensure this file exists st.write("Using default_data.csv file.") table_name = "default_table" else: data = pd.read_csv(csv_file) table_name = csv_file.name.split('.')[0] st.write(f"Data Preview ({csv_file.name}):") st.dataframe(data.head()) # Step 2: Load CSV data into SQLite database db_file = 'my_database.db' conn = sqlite3.connect(db_file) data.to_sql(table_name, conn, index=False, if_exists='replace') conn.close() # Create SQLDatabase instance db = SQLDatabase.from_uri(f"sqlite:///{db_file}", include_tables=[table_name]) # Initialize the LLM llm = ChatOpenAI(temperature=0, openai_api_key=openai_api_key) # Initialize the SQL Agent toolkit = SQLDatabaseToolkit(db=db, llm=llm) agent_executor = create_sql_agent( llm=llm, toolkit=toolkit, verbose=True, agent_executor_kwargs={"return_intermediate_steps": True} ) # Step 4: Define the callback function def process_input(): user_prompt = st.session_state['user_input'] if user_prompt: try: # Append user message to history st.session_state.history.append({"role": "user", "content": user_prompt}) # Use the agent to get the response WITHOUT 'return_intermediate_steps' with st.spinner("Processing..."): response = agent_executor(user_prompt) # Removed 'return_intermediate_steps' here # Extract the final answer and the data from intermediate steps final_answer = response['output'] intermediate_steps = response['intermediate_steps'] # Initialize an empty list for SQL result sql_result = [] # Find the SQL query result for step in intermediate_steps: if step[0].tool == 'sql_db_query': # The result is a string representation of a list sql_result = ast.literal_eval(step[1]) break # Convert the result to a DataFrame for better formatting if sql_result: df_result = pd.DataFrame(sql_result) sql_result_formatted = df_result.to_markdown(index=False) else: sql_result_formatted = "No results found." # Include the data in the final answer assistant_response = f"{final_answer}\n\n**Query Result:**\n{sql_result_formatted}" # Append the assistant's response to the history st.session_state.history.append({"role": "assistant", "content": assistant_response}) # Generate insights based on the response insights_template = """ You are an expert data analyst. Based on the user's question and the response provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words. User's Question: {question} Response: {response} Concise Analysis: """ insights_prompt = PromptTemplate(template=insights_template, input_variables=['question', 'response']) insights_chain = LLMChain(llm=llm, prompt=insights_prompt) insights = insights_chain.run({'question': user_prompt, 'response': assistant_response}) # Append the assistant's insights to the history st.session_state.history.append({"role": "assistant", "content": insights}) except Exception as e: logging.error(f"An error occurred: {e}") assistant_response = f"Error: {e}" st.session_state.history.append({"role": "assistant", "content": assistant_response}) # Reset user input st.session_state['user_input'] = '' # Step 5: Display conversation history for message in st.session_state.history: if message['role'] == 'user': st.markdown(f"**User:** {message['content']}") elif message['role'] == 'assistant': st.markdown(f"**Assistant:** {message['content']}") # Input field st.text_input("Enter your message:", key='user_input', on_change=process_input)