Spaces:
Runtime error
Runtime error
File size: 826 Bytes
e71a2ba 0669a02 e71a2ba 0669a02 e71a2ba 0669a02 fd09410 0669a02 fd09410 0669a02 fd09410 0669a02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import sys
sys.path.insert(0, './petals/')
import torch
import transformers
import gradio as gr
from src.client.remote_model import DistributedBloomForCausalLM
MODEL_NAME = "bigscience/bloom-petals"
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME)
def inference(text, seq_length=1):
input_ids = tokenizer([text], return_tensors="pt").input_ids
output = model.generate(input_ids, max_new_tokens=seq_length)
return tokenizer.batch_decode(output)[0]
iface = gr.Interface(
fn=inference,
inputs=[
gr.Textbox(lines=10, label="Input text"),
gr.inputs.Slider(
minimum=0,
maximum=1000,
step=1,
default=42,
label="Sequence length for generation"
)
],
outputs="text"
)
iface.launch() |