Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from example_strings import example1, example2, example3 | |
template_str = """{table_schemas} | |
\n \n | |
{task_spec} | |
\n \n | |
{prompt} | |
\n \n | |
SELECT""" | |
def load_model(model_name: str): | |
tokenizer = AutoTokenizer.from_pretrained(f"NumbersStation/{model_name}") | |
model = AutoModelForCausalLM.from_pretrained(f"NumbersStation/{model_name}") | |
return tokenizer, model | |
def build_complete_prompt(table_schemas: str, task_spec: str, prompt: str) -> str: | |
return template_str.format(table_schemas=table_schemas, task_spec=task_spec, prompt=prompt) | |
def infer(table_schemas: str, task_spec: str, prompt: str, model_choice: str = "nsql-350M"): | |
tokenizer, model = load_model(model_choice) | |
input_text = build_complete_prompt(table_schemas, task_spec, prompt) | |
input_ids = tokenizer(input_text, return_tensors="pt").input_ids | |
generated_ids = model.generate(input_ids, max_length=500) | |
return (tokenizer.decode(generated_ids[0], skip_special_tokens=True)) | |
description = """The NSQL model family was published by [Numbers Station](https://www.numbersstation.ai/) and is available in three flavors: | |
- [nsql-6B](https://huggingface.co/NumbersStation/nsql-6B) | |
- [nsql-2B](https://huggingface.co/NumbersStation/nsql-2B) | |
- [nsql-350M]((https://huggingface.co/NumbersStation/nsql-350M)) | |
This demo let's you choose from all of them and provides the three examples you can also find in their model cards. | |
In general you should first provide the table schemas of the tables you have questions about and then prompt it with a natural language question. | |
The model will then generate a SQL query that you can run against your database. | |
""" | |
iface = gr.Interface( | |
title="Text to SQL with NSQL", | |
description=description, | |
fn=infer, | |
inputs=[gr.Text(label="Table schemas", placeholder="Insert your table schemas here"), | |
gr.Text(label="Specify Task", value="Using valid SQLite, answer the following questions for the tables provided above."), | |
gr.Text(label="Prompt", placeholder="Put your natural language prompt here"), | |
gr.Dropdown(["nsql-6B", "nsql-2B", "nsql-350M"], value="nsql-6B") | |
], | |
outputs="text", | |
examples=[example1, example2, example3]) | |
iface.launch() | |