|
|
|
from transformers import pipeline |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM |
|
import gradio as gr |
|
import torch |
|
|
|
|
|
|
|
pipe = pipeline("conversational", model="facebook/blenderbot-400M-distill") |
|
|
|
title = "🤖AI ChatBot" |
|
description = "Building open-domain chatbots is a challenging area for machine learning research." |
|
examples = [["How are you?"]] |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill") |
|
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot-400M-distill") |
|
|
|
|
|
def predict(input, history=[]): |
|
|
|
new_user_input_ids = tokenizer.encode( |
|
input + tokenizer.eos_token, return_tensors="pt" |
|
) |
|
|
|
|
|
bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1) |
|
|
|
|
|
history = model.generate( |
|
bot_input_ids, max_length=4000, pad_token_id=tokenizer.eos_token_id |
|
).tolist() |
|
|
|
|
|
response = tokenizer.decode(history[0]).split("<|endoftext|>") |
|
|
|
response = [ |
|
(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2) |
|
] |
|
|
|
return response, history |
|
|
|
|
|
gr.Interface( |
|
fn=predict, |
|
title=title, |
|
description=description, |
|
examples=examples, |
|
inputs=["text", "state"], |
|
outputs=["chatbot", "state"], |
|
theme="finlaymacklon/boxy_violet", |
|
).launch(share=True) |
|
|