Spaces:
Runtime error
Runtime error
import gradio as gr | |
from typing import List, Optional | |
from transformers import BertTokenizer, BartForConditionalGeneration | |
title = "HIT-TMG/dialogue-bart-large-chinese" | |
description = """ | |
This is a seq2seq model pre-trained on several Chinese dialogue datasets, from bart-large-chinese. | |
However it is just a simple demo for this pre-trained model. It's better to fine-tune it on downstream tasks for better performance \n | |
See some details of model card at https://huggingface.co/HIT-TMG/dialogue-bart-large-chinese . \n\n | |
Besides starting the conversation from scratch, you can also input the whole dialogue history utterance by utterance seperated by '[SEP]'. \n | |
""" | |
tokenizer = BertTokenizer.from_pretrained("HIT-TMG/dialogue-bart-large-chinese") | |
model = BartForConditionalGeneration.from_pretrained("HIT-TMG/dialogue-bart-large-chinese") | |
tokenizer.truncation_side = 'left' | |
max_length = 512 | |
examples = [ | |
["你有什么爱好吗"], | |
["你好。[SEP]嘿嘿你好,请问你最近在忙什么呢?[SEP]我最近养了一只狗狗,我在训练它呢。"] | |
] | |
def chat_func(input_utterance: str, history: Optional[List[str]] = None): | |
if history is not None: | |
history.extend(input_utterance.split(tokenizer.sep_token)) | |
else: | |
history = input_utterance.split(tokenizer.sep_token) | |
history_str = "对话历史:" + tokenizer.sep_token.join(history) | |
input_ids = tokenizer(history_str, | |
return_tensors='pt', | |
truncation=True, | |
max_length=max_length, | |
).input_ids | |
output_ids = model.generate(input_ids, | |
max_new_tokens=30, | |
top_k=32, | |
num_beams=4, | |
repetition_penalty=1.2, | |
no_repeat_ngram_size=4)[0] | |
response = tokenizer.decode(output_ids, skip_special_tokens=True) | |
history.append(response) | |
if len(history) % 2 == 0: | |
display_utterances = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)] | |
else: | |
display_utterances = [("", history[0])] + [(history[i], history[i + 1]) for i in range(1, len(history) - 1, 2)] | |
return display_utterances, history | |
demo = gr.Interface(fn=chat_func, | |
title=title, | |
description=description, | |
inputs=[gr.Textbox(lines=1, placeholder="Input current utterance"), "state"], | |
examples=examples, | |
outputs=["chatbot", "state"]) | |
if __name__ == "__main__": | |
demo.launch() | |
# def chat(history): | |
# history_prefix = "对话历史:" | |
# history = history_prefix + history | |
# | |
# outputs = tokenizer(history, | |
# return_tensors='pt', | |
# padding=True, | |
# truncation=True, | |
# max_length=512) | |
# | |
# input_ids = outputs.input_ids | |
# output_ids = model.generate(input_ids)[0] | |
# | |
# return tokenizer.decode(output_ids, skip_special_tokens=True) | |
# | |
# | |
# chatbot = gr.Chatbot().style(color_map=("green", "pink")) | |
# demo = gr.Interface( | |
# chat, | |
# inputs=gr.Textbox(lines=8, placeholder="输入你的对话历史(请以'[SEP]'作为每段对话的间隔)\nInput the dialogue history (Please split utterances by '[SEP]')"), | |
# title=title, | |
# description=description, | |
# outputs =["text"] | |
# ) | |
# | |
# | |
# if __name__ == "__main__": | |
# demo.launch() | |