Bloom_vs_gemma / app.py
injilashah's picture
Update app.py
3f51521 verified
raw
history blame
1.44 kB
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
hf_token = os.getenv("HF_Token")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
b_tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")#using small parameter version of model for faster inference on hf
b_model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m")
g_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b",use_auth_token = hf_token)#using small paramerter version of model for faster inference on hf
g_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b",use_auth_token = hf_token)
def Sentence_Commpletion(model_name, input):
if model_name == "Bloom":
tokenizer, model = b_tokenizer, b_model
inputid = tokenizer(input, return_tensors="pt")
outputs = model.generate(input.inputid, max_length=30, num_return_sequences=1)
elif model_name == "Gemma":
tokenizer, model = g_tokenizer, g_model
inputid = Tokenizer(input, return_tensors="pt")
outputs = Model.generate(**inputid, max_new_tokens=20)
return tokenizer.decode(outputs[0])
interface = gr.Interface(
fn=Sentence_Commpletion,
inputs=[gr.Radio(["Bloom", "Gemma"], label="Choose model"),
gr.Textbox(placeholder="Enter sentece"),],
outputs="text",
title="Bloom vs Gemma Sentence completion",)
interface.launch()