File size: 3,490 Bytes
e37eda0
3eb59a4
 
75829f5
75fd593
9918408
6a2a63a
e37eda0
7815bdb
937d1f9
081eac3
 
cb5f50e
9918408
3eb59a4
5671d43
6a2a63a
5671d43
 
3eb59a4
6a2a63a
3eb59a4
 
9918408
36fba91
 
6a2a63a
 
 
 
 
1c7e913
 
75fd593
9918408
 
7815bdb
 
 
 
 
75fd593
9918408
 
 
 
 
 
 
75fd593
 
7815bdb
75fd593
 
 
7815bdb
 
75fd593
 
 
 
7815bdb
75fd593
 
 
937d1f9
75829f5
937d1f9
75fd593
081eac3
 
 
6a2a63a
75fd593
7815bdb
75fd593
081eac3
75fd593
081eac3
75fd593
081eac3
 
 
 
 
75fd593
 
081eac3
 
6a2a63a
937d1f9
e37eda0
6a2a63a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import streamlit as st
import pandas as pd
import sqlite3
from langchain import OpenAI, LLMChain, PromptTemplate
from langchain_community.utilities import SQLDatabase
import sqlparse
import logging
from sql_metadata import Parser  # Added import

# OpenAI API key (ensure it is securely stored)
openai_api_key = os.getenv("OPENAI_API_KEY")

# Step 1: Upload CSV data file (or use default)
csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
if csv_file is None:
    data = pd.read_csv("default_data.csv")  # Use default CSV if no file is uploaded
    st.write("Using default data.csv file.")
else:
    data = pd.read_csv(csv_file)
    st.write(f"Data Preview ({csv_file.name}):")
    st.dataframe(data.head())

# Step 2: Load CSV data into a persistent SQLite database
db_file = 'my_database.db'
conn = sqlite3.connect(db_file)
table_name = csv_file.name.split('.')[0] if csv_file else "default_table"
data.to_sql(table_name, conn, index=False, if_exists='replace')

# SQL table metadata (for validation and schema)
valid_columns = list(data.columns)
st.write(f"Valid columns: {valid_columns}")

# Step 3: Define SQL validation helpers
def validate_sql(query, valid_columns):
    """Validates the SQL query by ensuring it references only valid columns."""
    parser = Parser(query)
    columns_in_query = parser.columns
    for column in columns_in_query:
        if column not in valid_columns:
            st.write(f"Invalid column detected: {column}")
            return False
    return True

def validate_sql_with_sqlparse(query):
    """Validates SQL syntax using sqlparse."""
    parsed_query = sqlparse.parse(query)
    return len(parsed_query) > 0

# Step 4: Set up the LLM Chain to generate SQL queries
template = """
You are an expert data scientist. Given a natural language question, the name of the table, and a list of valid columns, generate a valid SQL query that answers the question.

Question: {question}

Table name: {table_name}

Valid columns: {columns}

SQL Query:
"""
prompt = PromptTemplate(template=template, input_variables=['question', 'table_name', 'columns'])
sql_generation_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)

# Step 5: Generate SQL query based on user input
user_prompt = st.text_input("Enter your natural language prompt:")
if user_prompt:
    try:
        # Step 6: Adjust the logic to handle "what are the columns" query
        if "columns" in user_prompt.lower():
            # Custom logic to return columns
            st.write(f"The columns are: {', '.join(valid_columns)}")
        else:
            columns = ', '.join(valid_columns)
            generated_sql = sql_generation_chain.run({'question': user_prompt, 'table_name': table_name, 'columns': columns})

            # Debug: Display generated SQL query for inspection
            st.write(f"Generated SQL Query:\n{generated_sql}")

            # Step 7: Validate SQL query
            if not validate_sql_with_sqlparse(generated_sql):
                st.write("Generated SQL is not valid.")
            elif not validate_sql(generated_sql, valid_columns):
                st.write("Generated SQL references invalid columns.")
            else:
                # Step 8: Execute SQL query
                result = pd.read_sql_query(generated_sql, conn)
                st.write("Query Results:")
                st.dataframe(result)

    except Exception as e:
        logging.error(f"An error occurred: {e}")
        st.write(f"Error: {e}")