SQL_Generation / app.py
daljeetsingh's picture
changes
f2a9e61
## https://www.kaggle.com/code/unravel/fine-tuning-of-a-sql-model
import spaces
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import gradio as gr
import torch
from transformers.utils import logging
from example_queries import small_query, long_query
logging.set_verbosity_info()
logger = logging.get_logger("transformers")
model_name='t5-small'
tokenizer = AutoTokenizer.from_pretrained(model_name)
original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
ft_model_name="daljeetsingh/sql_ft_t5small_kag" #"cssupport/t5-small-awesome-text-to-sql"
ft_model = AutoModelForSeq2SeqLM.from_pretrained(ft_model_name, torch_dtype=torch.bfloat16)
original_model.to('cuda')
ft_model.to('cuda')
@spaces.GPU
def translate_text(text):
prompt = f"{text}"
inputs = tokenizer(prompt, return_tensors='pt')
inputs = inputs.to('cuda')
try:
output = tokenizer.decode(
original_model.generate(
inputs["input_ids"],
max_new_tokens=200,
)[0],
skip_special_tokens=True
)
ft_output = tokenizer.decode(
ft_model.generate(
inputs["input_ids"],
max_new_tokens=200,
)[0],
skip_special_tokens=True
)
return [output, ft_output]
except Exception as e:
return f"Error: {str(e)}"
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
value=small_query,
lines=8,
placeholder="Enter prompt...",
label="Prompt"
)
submit_btn = gr.Button(value="Generate")
with gr.Column():
orig_output = gr.Textbox(label="OriginalModel", lines=2)
ft_output = gr.Textbox(label="FTModel", lines=8)
submit_btn.click(
translate_text, inputs=[prompt], outputs=[orig_output, ft_output], api_name=False
)
examples = gr.Examples(
examples=[
[small_query],
[long_query],
],
inputs=[prompt],
)
demo.launch(show_api=False, share=True, debug=True)