import gradio as gr import mysql.connector from mysql.connector import Error from transformers import AutoModelForCausalLM, AutoTokenizer # Load the model and tokenizer model_name = "premai-io/prem-1B-SQL" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) def generate_sql(natural_language_query): """Generate SQL query from natural language.""" # Define your schema information schema_info = """ CREATE TABLE sales ( pizza_id DECIMAL(8,2) PRIMARY KEY, order_id DECIMAL(8,2), pizza_name_id VARCHAR(14), quantity DECIMAL(4,2), order_date DATE, order_time VARCHAR(8), unit_price DECIMAL(5,2), total_price DECIMAL(5,2), pizza_size VARCHAR(3), pizza_category VARCHAR(7), pizza_ingredients VARCHAR(97), pizza_name VARCHAR(42) ); """ # Construct the prompt prompt = f"""### Task: Generate a SQL query to answer the following question. ### Database Schema: {schema_info} ### Question: {natural_language_query} ### SQL Query:""" # Tokenize and generate inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device) outputs = model.generate( inputs["input_ids"], max_length=512, temperature=0.1, do_sample=True, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id ) # Decode and clean up the response generated_query = tokenizer.decode(outputs[0], skip_special_tokens=True) sql_query = generated_query.split("### SQL Query:")[-1].strip() return sql_query def main(): # Gradio interface setup iface = gr.Interface( fn=generate_sql, inputs="text", outputs="text", title="Natural Language to SQL Query Generator", description="Enter a natural language query to generate the corresponding SQL query." ) iface.launch() if __name__ == "__main__": main()