YanshekWoo's picture
update tokenizer class
ace8906
raw
history blame
No virus
3.63 kB
import gradio as gr
from typing import List, Optional
from transformers import AutoTokenizer, BertTokenizer, BartForConditionalGeneration
title = "HIT-TMG/dialogue-bart-large-chinese"
description = """
This is a seq2seq model pre-trained on several Chinese dialogue datasets, from bart-large-chinese.
However it is just a simple demo for this pre-trained model. It's better to fine-tune it on downstream tasks for better performance \n
See some details of model card at https://huggingface.co/HIT-TMG/dialogue-bart-large-chinese . \n\n
Besides starting the conversation from scratch, you can also input the whole dialogue history utterance by utterance seperated by '[SEP]'. \n
"""
# tokenizer = BertTokenizer.from_pretrained("HIT-TMG/dialogue-bart-large-chinese")
tokenizer = AutoTokenizer.from_pretrained("HIT-TMG/dialogue-bart-large-chinese")
model = BartForConditionalGeneration.from_pretrained("HIT-TMG/dialogue-bart-large-chinese")
tokenizer.truncation_side = 'left'
max_length = 512
examples = [
["你有什么爱好吗"],
["你好。[SEP]嘿嘿你好,请问你最近在忙什么呢?[SEP]我最近养了一只狗狗,我在训练它呢。"]
]
def chat_func(input_utterance: str, history: Optional[List[str]] = None):
if history is not None:
history.extend(input_utterance.split(tokenizer.sep_token))
else:
history = input_utterance.split(tokenizer.sep_token)
history_str = "对话历史:" + tokenizer.sep_token.join(history)
input_ids = tokenizer(history_str,
return_tensors='pt',
truncation=True,
max_length=max_length,
).input_ids
output_ids = model.generate(input_ids,
max_new_tokens=30,
top_k=32,
num_beams=4,
repetition_penalty=1.2,
no_repeat_ngram_size=4)[0]
response = tokenizer.decode(output_ids, skip_special_tokens=True)
history.append(response)
if len(history) % 2 == 0:
display_utterances = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)]
else:
display_utterances = [("", history[0])] + [(history[i], history[i + 1]) for i in range(1, len(history) - 1, 2)]
return display_utterances, history
demo = gr.Interface(fn=chat_func,
title=title,
description=description,
inputs=[gr.Textbox(lines=1, placeholder="Input current utterance"), "state"],
examples=examples,
outputs=["chatbot", "state"])
if __name__ == "__main__":
demo.launch()
# def chat(history):
# history_prefix = "对话历史:"
# history = history_prefix + history
#
# outputs = tokenizer(history,
# return_tensors='pt',
# padding=True,
# truncation=True,
# max_length=512)
#
# input_ids = outputs.input_ids
# output_ids = model.generate(input_ids)[0]
#
# return tokenizer.decode(output_ids, skip_special_tokens=True)
#
#
# chatbot = gr.Chatbot().style(color_map=("green", "pink"))
# demo = gr.Interface(
# chat,
# inputs=gr.Textbox(lines=8, placeholder="输入你的对话历史(请以'[SEP]'作为每段对话的间隔)\nInput the dialogue history (Please split utterances by '[SEP]')"),
# title=title,
# description=description,
# outputs =["text"]
# )
#
#
# if __name__ == "__main__":
# demo.launch()