Spaces:
Runtime error
Runtime error
import sys | |
import gradio as gr | |
import torch | |
import transformers | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
sys.path.insert(0, './petals/') | |
from petals.client.remote_model import DistributedBloomForCausalLM | |
MODEL_NAME = "bigscience/test-bloomd-6b3" | |
# INITIAL_PEERS = ["/ip4/193.106.95.184/tcp/31000/p2p/QmSg7izCDtowVTACbUmWvEiQZNY4wgCQ9T9Doo66K59X6q"] | |
tokenizer_bloomd_6b3 = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME) | |
model_bloomd_6b3 = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, | |
# initial_peers=INITIAL_PEERS, | |
low_cpu_mem_usage=True, torch_dtype=torch.float32) | |
MODEL_NAME = "bigscience/bloom-petals" | |
tokenizer_bloomd = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME) | |
model_bloomd = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, | |
low_cpu_mem_usage=True, torch_dtype=torch.float32) | |
tokenizer_DialoGPT_small = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small") | |
model_DialoGPT_small = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small") | |
tokenizer_DialoGPT_medium = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") | |
model_DialoGPT_medium = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium") | |
tokenizer_DialoGPT_large = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large") | |
model_DialoGPT_large = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large") | |
def predict( | |
input_text, | |
history=None, | |
person_description=None, | |
number_of_new_tokens=1000, | |
model_name=None, | |
del_hist=None | |
): | |
if history is None or del_hist == 'delete history': | |
history = [] | |
if model_name == 'DialoGPT-small': | |
model = model_DialoGPT_small | |
tokenizer = tokenizer_DialoGPT_small | |
elif model_name == 'DialoGPT-medium': | |
model = model_DialoGPT_medium | |
tokenizer = tokenizer_DialoGPT_medium | |
elif model_name == 'DialoGPT-large': | |
model = model_DialoGPT_large | |
tokenizer = tokenizer_DialoGPT_large | |
elif model_name == 'test-bloomd-6b3': | |
model = model_bloomd_6b3 | |
tokenizer = tokenizer_bloomd_6b3 | |
elif model_name == 'bloom-petals': | |
model = model_bloomd | |
tokenizer = tokenizer_bloomd | |
else: | |
model = model_DialoGPT_medium | |
tokenizer = tokenizer_DialoGPT_medium | |
person_description_ids = tokenizer.encode(person_description + tokenizer.eos_token, return_tensors='pt') | |
new_user_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt') | |
bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1) | |
input_with_desc_ids = torch.cat([person_description_ids, bot_input_ids], dim=-1) | |
max_token_count = number_of_new_tokens + len(input_with_desc_ids[0]) | |
history = model.generate(input_with_desc_ids, max_length=max_token_count, | |
pad_token_id=tokenizer.eos_token_id).tolist() | |
history[0] = history[0][len(person_description_ids[0]):] | |
response = tokenizer.decode(history[0]).split("<|endoftext|>") | |
response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] | |
return response, history | |
gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Textbox(label='Input message', lines=1, placeholder="Enter your message..."), | |
"state", | |
gr.Textbox(label='Person Description', lines=2, placeholder="Enter a description of the person..."), | |
gr.Slider(label='Number of new tokens', minimum=2, maximum=100, value=10), | |
gr.Radio( | |
label='Model name', | |
choices=[ | |
'DialoGPT-small', | |
'DialoGPT-medium', | |
'DialoGPT-large', | |
'test-bloomd-6b3', | |
'bloom-petals', | |
] | |
), | |
gr.Radio( | |
label='Delete history', | |
value="Don't delete history", | |
choices=[ | |
'delete history', | |
"Don't delete history" | |
]), | |
], | |
outputs=[gr.Chatbot(label='History of the dialogue'), "state"], | |
).launch(), | |