GlastonR's picture
Update app.py
616d6ae verified
raw
history blame
2.93 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Load the models and tokenizers
question_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap")
question_model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap")
sql_tokenizer = AutoTokenizer.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2")
sql_model = AutoModelForSeq2SeqLM.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2")
# Function to generate a question based on a table schema
def generate_question(tables):
# Convert table schema to string
table_str = ", ".join([f"{table}: ({', '.join(cols)})" for table, cols in tables.items()])
prompt = f"Generate a question based on the following table schema: {table_str}"
# Tokenize input and generate question
input_ids = question_tokenizer(prompt, return_tensors="pt").input_ids
output = question_model.generate(input_ids, num_beams=5, max_length=50)
question = question_tokenizer.decode(output[0], skip_special_tokens=True)
return question
# Function to prepare input data for SQL generation
def prepare_sql_input(question, tables):
table_str = ", ".join([f"{table}({', '.join(cols)})" for table, cols in tables.items()])
prompt = f"Convert the question and table schema into an SQL query. Tables: {table_str}. Question: {question}"
input_ids = sql_tokenizer(prompt, max_length=512, return_tensors="pt").input_ids
return input_ids
# Inference function for SQL generation
def generate_sql(question, tables):
input_data = prepare_sql_input(question, tables)
input_data = input_data.to(sql_model.device)
outputs = sql_model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512)
sql_query = sql_tokenizer.decode(outputs[0], skip_special_tokens=True)
return sql_query
# Streamlit UI
def main():
st.title("Multi-Model: Text to SQL and Question Generation")
# Input table schema
tables_input = st.text_area("Enter table schemas (in JSON format):",
'{"people_name": ["id", "name"], "people_age": ["people_id", "age"]}')
try:
tables = eval(tables_input) # Convert string to dict safely
except:
tables = {}
# If tables are provided, generate a question
if tables:
generated_question = generate_question(tables)
st.write(f"Generated Question: {generated_question}")
# Input question manually if needed
question = st.text_area("Enter your question (optional):", generated_question if tables else "")
if st.button("Generate SQL Query"):
if question and tables:
sql_query = generate_sql(question, tables)
st.write(f"Generated SQL Query: {sql_query}")
else:
st.write("Please enter both a question and table schemas.")
if __name__ == "__main__":
main()