Spaces:
Sleeping
Sleeping
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
import gradio as gr | |
from transformers import BlenderbotTokenizer | |
from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration, BlenderbotConfig | |
from transformers import BlenderbotTokenizerFast | |
import contextlib | |
#tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill") | |
#model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot-400M-distill") | |
#tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-3B") | |
mname = "facebook/blenderbot-3B" | |
#configuration = BlenderbotConfig.from_pretrained(mname) | |
tokenizer = BlenderbotTokenizerFast.from_pretrained(mname) | |
model = BlenderbotForConditionalGeneration.from_pretrained(mname) | |
#tokenizer = BlenderbotTokenizer.from_pretrained(mname) | |
#-----------new chat----------- | |
print(mname + 'model loaded') | |
def predict(input,history=[]): | |
history.append(input) | |
listToStr= '</s> <s>'.join([str(elem)for elem in history[len(history)-3:]]) | |
#print('listToStr -->',str(listToStr)) | |
input_ids = tokenizer([(listToStr)], return_tensors="pt",max_length=512,truncation=True) | |
next_reply_ids = model.generate(**input_ids,max_length=512, pad_token_id=tokenizer.eos_token_id) | |
response = tokenizer.batch_decode(next_reply_ids, skip_special_tokens=True)[0] | |
history.append(response) | |
response = [(history[i], history[i+1]) for i in range(0, len(history)-1, 2)] # convert to tuples of list | |
return response, history | |
demo = gr.Interface(fn=predict, inputs=["text",'state'], outputs=["chatbot",'state']) | |
demo.launch() |