GenBIChatbot / app.py
Ari
Update app.py
7815bdb verified
raw
history blame
3.49 kB
import os
import streamlit as st
import pandas as pd
import sqlite3
from langchain import OpenAI, LLMChain, PromptTemplate
from langchain_community.utilities import SQLDatabase
import sqlparse
import logging
from sql_metadata import Parser # Added import
# OpenAI API key (ensure it is securely stored)
openai_api_key = os.getenv("OPENAI_API_KEY")
# Step 1: Upload CSV data file (or use default)
csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
if csv_file is None:
data = pd.read_csv("default_data.csv") # Use default CSV if no file is uploaded
st.write("Using default data.csv file.")
else:
data = pd.read_csv(csv_file)
st.write(f"Data Preview ({csv_file.name}):")
st.dataframe(data.head())
# Step 2: Load CSV data into a persistent SQLite database
db_file = 'my_database.db'
conn = sqlite3.connect(db_file)
table_name = csv_file.name.split('.')[0] if csv_file else "default_table"
data.to_sql(table_name, conn, index=False, if_exists='replace')
# SQL table metadata (for validation and schema)
valid_columns = list(data.columns)
st.write(f"Valid columns: {valid_columns}")
# Step 3: Define SQL validation helpers
def validate_sql(query, valid_columns):
"""Validates the SQL query by ensuring it references only valid columns."""
parser = Parser(query)
columns_in_query = parser.columns
for column in columns_in_query:
if column not in valid_columns:
st.write(f"Invalid column detected: {column}")
return False
return True
def validate_sql_with_sqlparse(query):
"""Validates SQL syntax using sqlparse."""
parsed_query = sqlparse.parse(query)
return len(parsed_query) > 0
# Step 4: Set up the LLM Chain to generate SQL queries
template = """
You are an expert data scientist. Given a natural language question, the name of the table, and a list of valid columns, generate a valid SQL query that answers the question.
Question: {question}
Table name: {table_name}
Valid columns: {columns}
SQL Query:
"""
prompt = PromptTemplate(template=template, input_variables=['question', 'table_name', 'columns'])
sql_generation_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)
# Step 5: Generate SQL query based on user input
user_prompt = st.text_input("Enter your natural language prompt:")
if user_prompt:
try:
# Step 6: Adjust the logic to handle "what are the columns" query
if "columns" in user_prompt.lower():
# Custom logic to return columns
st.write(f"The columns are: {', '.join(valid_columns)}")
else:
columns = ', '.join(valid_columns)
generated_sql = sql_generation_chain.run({'question': user_prompt, 'table_name': table_name, 'columns': columns})
# Debug: Display generated SQL query for inspection
st.write(f"Generated SQL Query:\n{generated_sql}")
# Step 7: Validate SQL query
if not validate_sql_with_sqlparse(generated_sql):
st.write("Generated SQL is not valid.")
elif not validate_sql(generated_sql, valid_columns):
st.write("Generated SQL references invalid columns.")
else:
# Step 8: Execute SQL query
result = pd.read_sql_query(generated_sql, conn)
st.write("Query Results:")
st.dataframe(result)
except Exception as e:
logging.error(f"An error occurred: {e}")
st.write(f"Error: {e}")