DialogSumLlama2 / app.py
shenoy's picture
Add application file and dependencies
8fd1ef9
raw
history blame
1.31 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers
from peft import PeftModel
# Quantization config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype="float16",
)
model_name = "TinyPixel/Llama-2-7B-bf16-sharded"
# loading the model with quantization config
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
trust_remote_code=True,
device_map='auto'
)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True , return_token_type_ids=False)
tokenizer.pad_token = tokenizer.eos_token
model = PeftModel.from_pretrained(model,"shenoy/DialogSumLlama2_qlora", device_map="auto")
#gradio fields
input_text = gr.inputs.Textbox(label="Input Text", type="text")
output_text = gr.outputs.Textbox(label="Output Text", type="text")
def predict(text):
inputs = tokenizer(text, return_tensors="pt")
outputs = model.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], max_new_tokens=100 ,repetition_penalty=1.2)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
#gradio interface
interface = gr.Interface(fn=predict, inputs=input_text, outputs=output_text)
interface.launch()