|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
|
|
question_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap") |
|
question_model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap") |
|
|
|
sql_tokenizer = AutoTokenizer.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2") |
|
sql_model = AutoModelForSeq2SeqLM.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2") |
|
|
|
|
|
def generate_question(tables): |
|
|
|
table_str = ", ".join([f"{table}: ({', '.join(cols)})" for table, cols in tables.items()]) |
|
prompt = f"Generate a question based on the following table schema: {table_str}" |
|
|
|
|
|
input_ids = question_tokenizer(prompt, return_tensors="pt").input_ids |
|
output = question_model.generate(input_ids, num_beams=5, max_length=50) |
|
question = question_tokenizer.decode(output[0], skip_special_tokens=True) |
|
return question |
|
|
|
|
|
def prepare_sql_input(question, tables): |
|
table_str = ", ".join([f"{table}({', '.join(cols)})" for table, cols in tables.items()]) |
|
prompt = f"Convert the question and table schema into an SQL query. Tables: {table_str}. Question: {question}" |
|
|
|
input_ids = sql_tokenizer(prompt, max_length=512, return_tensors="pt").input_ids |
|
return input_ids |
|
|
|
|
|
def generate_sql(question, tables): |
|
input_data = prepare_sql_input(question, tables) |
|
input_data = input_data.to(sql_model.device) |
|
outputs = sql_model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512) |
|
sql_query = sql_tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return sql_query |
|
|
|
|
|
def main(): |
|
st.title("Multi-Model: Text to SQL and Question Generation") |
|
|
|
|
|
tables_input = st.text_area("Enter table schemas (in JSON format):", |
|
'{"people_name": ["id", "name"], "people_age": ["people_id", "age"]}') |
|
try: |
|
tables = eval(tables_input) |
|
except: |
|
tables = {} |
|
|
|
|
|
if tables: |
|
generated_question = generate_question(tables) |
|
st.write(f"Generated Question: {generated_question}") |
|
|
|
|
|
question = st.text_area("Enter your question (optional):", generated_question if tables else "") |
|
|
|
if st.button("Generate SQL Query"): |
|
if question and tables: |
|
sql_query = generate_sql(question, tables) |
|
st.write(f"Generated SQL Query: {sql_query}") |
|
else: |
|
st.write("Please enter both a question and table schemas.") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|