Spaces:
Sleeping
Sleeping
File size: 3,490 Bytes
e37eda0 3eb59a4 75829f5 75fd593 9918408 6a2a63a e37eda0 7815bdb 937d1f9 081eac3 cb5f50e 9918408 3eb59a4 5671d43 6a2a63a 5671d43 3eb59a4 6a2a63a 3eb59a4 9918408 36fba91 6a2a63a 1c7e913 75fd593 9918408 7815bdb 75fd593 9918408 75fd593 7815bdb 75fd593 7815bdb 75fd593 7815bdb 75fd593 937d1f9 75829f5 937d1f9 75fd593 081eac3 6a2a63a 75fd593 7815bdb 75fd593 081eac3 75fd593 081eac3 75fd593 081eac3 75fd593 081eac3 6a2a63a 937d1f9 e37eda0 6a2a63a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import os
import streamlit as st
import pandas as pd
import sqlite3
from langchain import OpenAI, LLMChain, PromptTemplate
from langchain_community.utilities import SQLDatabase
import sqlparse
import logging
from sql_metadata import Parser # Added import
# OpenAI API key (ensure it is securely stored)
openai_api_key = os.getenv("OPENAI_API_KEY")
# Step 1: Upload CSV data file (or use default)
csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
if csv_file is None:
data = pd.read_csv("default_data.csv") # Use default CSV if no file is uploaded
st.write("Using default data.csv file.")
else:
data = pd.read_csv(csv_file)
st.write(f"Data Preview ({csv_file.name}):")
st.dataframe(data.head())
# Step 2: Load CSV data into a persistent SQLite database
db_file = 'my_database.db'
conn = sqlite3.connect(db_file)
table_name = csv_file.name.split('.')[0] if csv_file else "default_table"
data.to_sql(table_name, conn, index=False, if_exists='replace')
# SQL table metadata (for validation and schema)
valid_columns = list(data.columns)
st.write(f"Valid columns: {valid_columns}")
# Step 3: Define SQL validation helpers
def validate_sql(query, valid_columns):
"""Validates the SQL query by ensuring it references only valid columns."""
parser = Parser(query)
columns_in_query = parser.columns
for column in columns_in_query:
if column not in valid_columns:
st.write(f"Invalid column detected: {column}")
return False
return True
def validate_sql_with_sqlparse(query):
"""Validates SQL syntax using sqlparse."""
parsed_query = sqlparse.parse(query)
return len(parsed_query) > 0
# Step 4: Set up the LLM Chain to generate SQL queries
template = """
You are an expert data scientist. Given a natural language question, the name of the table, and a list of valid columns, generate a valid SQL query that answers the question.
Question: {question}
Table name: {table_name}
Valid columns: {columns}
SQL Query:
"""
prompt = PromptTemplate(template=template, input_variables=['question', 'table_name', 'columns'])
sql_generation_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)
# Step 5: Generate SQL query based on user input
user_prompt = st.text_input("Enter your natural language prompt:")
if user_prompt:
try:
# Step 6: Adjust the logic to handle "what are the columns" query
if "columns" in user_prompt.lower():
# Custom logic to return columns
st.write(f"The columns are: {', '.join(valid_columns)}")
else:
columns = ', '.join(valid_columns)
generated_sql = sql_generation_chain.run({'question': user_prompt, 'table_name': table_name, 'columns': columns})
# Debug: Display generated SQL query for inspection
st.write(f"Generated SQL Query:\n{generated_sql}")
# Step 7: Validate SQL query
if not validate_sql_with_sqlparse(generated_sql):
st.write("Generated SQL is not valid.")
elif not validate_sql(generated_sql, valid_columns):
st.write("Generated SQL references invalid columns.")
else:
# Step 8: Execute SQL query
result = pd.read_sql_query(generated_sql, conn)
st.write("Query Results:")
st.dataframe(result)
except Exception as e:
logging.error(f"An error occurred: {e}")
st.write(f"Error: {e}")
|