Spaces:
Runtime error
Runtime error
File size: 3,534 Bytes
47e97de a1a2817 d2ed16f f67046d e44ee1e 459ca51 02f0e9a e44ee1e d2ed16f 5a9366e d2ed16f d18f655 1a934e9 d18f655 5a9366e d18f655 5a9366e d18f655 5a9366e d2ed16f d18f655 e47f086 287e49e fd3761b 02f0e9a 287e49e 0e59ad2 de1ff8e ec80ab7 287e49e d18f655 02f0e9a d2ed16f d18f655 d2ed16f d18f655 a1a2817 1a934e9 d18f655 d2ed16f 4250d6a a1a2817 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import gradio as gr
from typing import List, Optional
from transformers import BertTokenizer, BartForConditionalGeneration
title = "HIT-TMG/dialogue-bart-large-chinese"
description = """
This is a seq2seq model pre-trained on several Chinese dialogue datasets, from bart-large-chinese.
However it is just a simple demo for this pre-trained model. It's better to fine-tune it on downstream tasks for better performance \n
See some details of model card at https://huggingface.co/HIT-TMG/dialogue-bart-large-chinese . \n\n
Besides starting the conversation from scratch, you can also input the whole dialogue history utterance by utterance seperated by '[SEP]'. \n
"""
tokenizer = BertTokenizer.from_pretrained("HIT-TMG/dialogue-bart-large-chinese")
model = BartForConditionalGeneration.from_pretrained("HIT-TMG/dialogue-bart-large-chinese")
tokenizer.truncation_side = 'left'
max_length = 512
examples = [
["你有什么爱好吗"],
["你好。[SEP]嘿嘿你好,请问你最近在忙什么呢?[SEP]我最近养了一只狗狗,我在训练它呢。"]
]
def chat_func(input_utterance: str, history: Optional[List[str]] = None):
if history is not None:
history.extend(input_utterance.split(tokenizer.sep_token))
else:
history = input_utterance.split(tokenizer.sep_token)
history_str = "对话历史:" + tokenizer.sep_token.join(history)
input_ids = tokenizer(history_str,
return_tensors='pt',
truncation=True,
max_length=max_length,
).input_ids
output_ids = model.generate(input_ids,
max_new_tokens=30,
top_k=32,
num_beams=4,
repetition_penalty=1.2,
no_repeat_ngram_size=4)[0]
response = tokenizer.decode(output_ids, skip_special_tokens=True)
history.append(response)
if len(history) % 2 == 0:
display_utterances = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)]
else:
display_utterances = [("", history[0])] + [(history[i], history[i + 1]) for i in range(1, len(history) - 1, 2)]
return display_utterances, history
demo = gr.Interface(fn=chat_func,
title=title,
description=description,
inputs=[gr.Textbox(lines=1, placeholder="Input current utterance"), "state"],
examples=examples,
outputs=["chatbot", "state"])
if __name__ == "__main__":
demo.launch()
# def chat(history):
# history_prefix = "对话历史:"
# history = history_prefix + history
#
# outputs = tokenizer(history,
# return_tensors='pt',
# padding=True,
# truncation=True,
# max_length=512)
#
# input_ids = outputs.input_ids
# output_ids = model.generate(input_ids)[0]
#
# return tokenizer.decode(output_ids, skip_special_tokens=True)
#
#
# chatbot = gr.Chatbot().style(color_map=("green", "pink"))
# demo = gr.Interface(
# chat,
# inputs=gr.Textbox(lines=8, placeholder="输入你的对话历史(请以'[SEP]'作为每段对话的间隔)\nInput the dialogue history (Please split utterances by '[SEP]')"),
# title=title,
# description=description,
# outputs =["text"]
# )
#
#
# if __name__ == "__main__":
# demo.launch()
|