text2sql / app.py
nileshhanotia's picture
Update app.py
6a33bcd verified
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from functools import lru_cache
import json
import mysql.connector
from mysql.connector import Error
import os
import sys
from datetime import datetime
import time
import logging
import threading
# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
)
# Enable GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Database configuration
DB_CONFIG = {
'host': 'sql12.freemysqlhosting.net',
'database': 'sql12740625',
'user': 'sql12740625',
'password': 'QGG9kdrE4g',
'port': 3306,
'pool_size': 5,
'pool_reset_session': True
}
# Global variables for model and tokenizer
GLOBAL_MODEL = None
GLOBAL_TOKENIZER = None
db_connection_status = False
def initialize_model():
"""Initialize model and tokenizer globally"""
global GLOBAL_MODEL, GLOBAL_TOKENIZER
logging.info("Initializing model and tokenizer...")
st.write("Initializing model and tokenizer...")
start_time = time.time()
model_name_sql = "premai-io/prem-1B-SQL"
GLOBAL_TOKENIZER = AutoTokenizer.from_pretrained(model_name_sql)
GLOBAL_MODEL = AutoModelForCausalLM.from_pretrained(
model_name_sql,
torch_dtype=torch.float32, # Use float32 for CPU
).to(device)
# Set model to evaluation mode
GLOBAL_MODEL.eval()
logging.info(f"Model initialization took {time.time() - start_time:.2f} seconds")
def test_db_connection():
"""Test database connection with timeout"""
global db_connection_status
try:
logging.info("Testing database connection...")
connection = mysql.connector.connect(
**DB_CONFIG,
connect_timeout=10
)
if connection.is_connected():
db_info = connection.get_server_info()
cursor = connection.cursor()
cursor.execute("SELECT DATABASE();")
db_name = cursor.fetchone()[0]
cursor.close()
connection.close()
db_connection_status = True
logging.info(f"Successfully connected to MySQL Server version {db_info} - Database: {db_name}")
return True, f"Successfully connected to MySQL Server version {db_info}\nDatabase: {db_name}"
except Error as e:
db_connection_status = False
logging.error(f"Error connecting to MySQL database: {e}")
return False, f"Error connecting to MySQL database: {e}"
return False, "Unable to establish database connection"
def get_db_connection():
"""Get database connection from pool"""
logging.info("Getting database connection from pool...")
return mysql.connector.connect(**DB_CONFIG)
def execute_query(query):
"""Execute SQL query with timeout and connection pooling"""
logging.info(f"Executing query: {query}")
connection = None
try:
connection = get_db_connection()
cursor = connection.cursor(dictionary=True, buffered=True)
cursor.execute(query)
results = cursor.fetchall()
logging.info(f"Query executed successfully, retrieved {len(results)} records.")
return results
except Error as e:
logging.error(f"Error executing query: {e}")
return f"Error executing query: {e}"
finally:
if connection and connection.is_connected():
cursor.close()
connection.close()
logging.info("Database connection closed.")
def generate_sql(natural_language_query):
"""Generate SQL query with performance optimizations"""
logging.info(f"Generating SQL for query: {natural_language_query}")
try:
start_time = time.time()
schema_info = """
CREATE TABLE sales (
pizza_id DECIMAL(8,2) PRIMARY KEY,
order_id DECIMAL(8,2),
pizza_name_id VARCHAR(14),
quantity DECIMAL(4,2),
order_date DATE,
order_time VARCHAR(8),
unit_price DECIMAL(5,2),
total_price DECIMAL(5,2),
pizza_size VARCHAR(3),
pizza_category VARCHAR(7),
pizza_ingredients VARCHAR(97),
pizza_name VARCHAR(42)
);
"""
prompt = f"""### Task: Generate a SQL query to answer the following question.
### Database Schema:
{schema_info}
### Question: {natural_language_query}
### SQL Query:"""
inputs = GLOBAL_TOKENIZER(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512,
return_attention_mask=True
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = GLOBAL_MODEL.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=256,
temperature=0.1,
do_sample=True,
top_p=0.95,
num_return_sequences=1,
pad_token_id=GLOBAL_TOKENIZER.eos_token_id,
)
generated_query = GLOBAL_TOKENIZER.decode(outputs[0], skip_special_tokens=True)
sql_query = generated_query.split("### SQL Query:")[-1].strip()
logging.info(f"SQL generation took {time.time() - start_time:.2f} seconds")
return sql_query
except Exception as e:
logging.error(f"Error generating SQL query: {str(e)}")
return f"Error generating SQL query: {str(e)}"
def format_result(query_result):
"""Format query results efficiently"""
if isinstance(query_result, str) and "Error" in query_result:
logging.warning(f"Query result contains an error: {query_result}")
return query_result
if not query_result:
logging.info("No results found.")
return "No results found."
# Use list comprehension for better performance
if len(query_result) == 1:
return "\n".join(f"{k}: {v}" for k, v in query_result[0].items())
results = [f"Found {len(query_result)} results:\n"]
for i, row in enumerate(query_result[:5], 1):
results.append(f"Result {i}:")
results.extend(f"{k}: {v}" for k, v in row.items())
results.append("")
if len(query_result) > 5:
results.append(f"(Showing first 5 of {len(query_result)} results)")
return "\n".join(results)
def check_live_connection():
"""Check the database connection status periodically."""
while True:
test_db_connection()
time.sleep(10) # Check every 10 seconds
def main():
"""Main function with Streamlit UI components"""
st.title("Natural Language to SQL Query")
st.write("Ask questions about pizza sales data in plain English.")
# Start the live connection check in a separate thread
threading.Thread(target=check_live_connection, daemon=True).start()
# Test and display database connection status
if db_connection_status:
st.success("Database connection is live.")
else:
st.error("Database connection is not live.")
# Initialize model
initialize_model()
# Input field for natural language query
natural_language_query = st.text_input("Enter your question", placeholder="e.g., What were the total sales for each pizza category?")
if st.button("Generate and Execute Query"):
if natural_language_query:
# Generate SQL query
sql_query = generate_sql(natural_language_query)
st.write("Generated SQL Query:", sql_query)
# Execute the generated query
query_result = execute_query(sql_query)
formatted_result = format_result(query_result)
st.write("Query Result:")
st.code(json.dumps(query_result, indent=2))
st.write("Human-Readable Response:")
st.text(formatted_result)
else:
logging.warning("User did not enter a query.")
st.write("Please enter a query.")
if __name__ == "__main__":
main()