import streamlit as st from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # Load the models and tokenizers question_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap") question_model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap") sql_tokenizer = AutoTokenizer.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2") sql_model = AutoModelForSeq2SeqLM.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2") # Function to generate a question based on a table schema def generate_question(tables): # Convert table schema to string table_str = ", ".join([f"{table}: ({', '.join(cols)})" for table, cols in tables.items()]) prompt = f"Generate a question based on the following table schema: {table_str}" # Tokenize input and generate question input_ids = question_tokenizer(prompt, return_tensors="pt").input_ids output = question_model.generate(input_ids, num_beams=5, max_length=50) question = question_tokenizer.decode(output[0], skip_special_tokens=True) return question # Function to prepare input data for SQL generation def prepare_sql_input(question, tables): table_str = ", ".join([f"{table}({', '.join(cols)})" for table, cols in tables.items()]) prompt = f"Convert the question and table schema into an SQL query. Tables: {table_str}. Question: {question}" input_ids = sql_tokenizer(prompt, max_length=512, return_tensors="pt").input_ids return input_ids # Inference function for SQL generation def generate_sql(question, tables): input_data = prepare_sql_input(question, tables) input_data = input_data.to(sql_model.device) outputs = sql_model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512) sql_query = sql_tokenizer.decode(outputs[0], skip_special_tokens=True) return sql_query # Streamlit UI def main(): st.title("Multi-Model: Text to SQL and Question Generation") # Input table schema tables_input = st.text_area("Enter table schemas (in JSON format):", '{"people_name": ["id", "name"], "people_age": ["people_id", "age"]}') try: tables = eval(tables_input) # Convert string to dict safely except: tables = {} # If tables are provided, generate a question if tables: generated_question = generate_question(tables) st.write(f"Generated Question: {generated_question}") # Input question manually if needed question = st.text_area("Enter your question (optional):", generated_question if tables else "") if st.button("Generate SQL Query"): if question and tables: sql_query = generate_sql(question, tables) st.write(f"Generated SQL Query: {sql_query}") else: st.write("Please enter both a question and table schemas.") if __name__ == "__main__": main()