Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import gradio as gr | |
hf_token = os.getenv("HF_Token") | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
b_tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")#using small parameter version of model for faster inference on hf | |
b_model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m") | |
g_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b",use_auth_token = hf_token)#using small paramerter version of model for faster inference on hf | |
g_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b",use_auth_token = hf_token) | |
def Sentence_Commpletion(model_name, input): | |
if model_name == "Bloom": | |
tokenizer, model = b_tokenizer, b_model | |
inputid = tokenizer(input, return_tensors="pt") | |
outputs = model.generate(input.inputid, max_length=30, num_return_sequences=1) | |
elif model_name == "Gemma": | |
tokenizer, model = g_tokenizer, g_model | |
inputid = Tokenizer(input, return_tensors="pt") | |
outputs = Model.generate(**inputid, max_new_tokens=20) | |
return tokenizer.decode(outputs[0]) | |
interface = gr.Interface( | |
fn=Sentence_Commpletion, | |
inputs=[gr.Radio(["Bloom", "Gemma"], label="Choose model"), | |
gr.Textbox(placeholder="Enter sentece"),], | |
outputs="text", | |
title="Bloom vs Gemma Sentence completion",) | |
interface.launch() | |