NL2SQL4 / app.py
abhishek23HF's picture
Rename app (1).py to app.py
9572730
raw
history blame
No virus
1.47 kB
import streamlit as st
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "NL2SQL_BLOOMZ-3B"
HUGGING_FACE_USER_NAME = "abhishek23HF"
peft_model_id = f"{HUGGING_FACE_USER_NAME}/{model_name}"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, return_dict=True, load_in_8bit=False)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
# Load the Lora model
model = PeftModel.from_pretrained(model, peft_model_id)
# from IPython.display import display, Markdown
def make_inference(db_id, question):
batch = tokenizer(f"""
### INSTRUCTION\n
Below is a User Question for a SQL DATABASE. Your job is to write a SQL Query for the given question from the user for that particular Database.
\n\n
### DATABASE_ID:\n{db_id}\n
### USER QUESTION:\n{question}\n\n
### SQL QUERY:\n
""", return_tensors='pt')
with torch.cuda.amp.autocast():
output_tokens = model.generate(**batch, max_new_tokens=200)
return tokenizer.decode(output_tokens[0], skip_special_tokens=True)
# Create two text input boxes
text_input_db_id= st.text_input("DB ID")
text_input_question = st.text_input("User Query")
# make_inference(your_db_id_here, your_db_query_here)
# Display the text input boxes
if st.button('Submit'):
st.write(make_inference(text_input_db_id, text_input_question))