File size: 3,534 Bytes
47e97de
a1a2817
d2ed16f
f67046d
e44ee1e
 
459ca51
 
02f0e9a
 
e44ee1e
d2ed16f
5a9366e
d2ed16f
 
 
d18f655
 
 
1a934e9
 
 
 
 
d18f655
5a9366e
d18f655
5a9366e
d18f655
5a9366e
d2ed16f
d18f655
 
 
 
 
e47f086
287e49e
fd3761b
02f0e9a
287e49e
0e59ad2
de1ff8e
ec80ab7
 
287e49e
d18f655
 
 
 
02f0e9a
 
 
 
 
d2ed16f
d18f655
d2ed16f
 
d18f655
a1a2817
 
 
1a934e9
d18f655
d2ed16f
 
 
4250d6a
a1a2817
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
from typing import List, Optional
from transformers import BertTokenizer, BartForConditionalGeneration

title = "HIT-TMG/dialogue-bart-large-chinese"
description = """
This is a seq2seq model pre-trained on several Chinese dialogue datasets, from bart-large-chinese. 
However it is just a simple demo for this pre-trained model. It's better to fine-tune it on downstream tasks for better performance \n
See some details of model card at https://huggingface.co/HIT-TMG/dialogue-bart-large-chinese . \n\n
Besides starting the conversation from scratch, you can also input the whole dialogue history utterance by utterance seperated by '[SEP]'. \n
"""


tokenizer = BertTokenizer.from_pretrained("HIT-TMG/dialogue-bart-large-chinese")
model = BartForConditionalGeneration.from_pretrained("HIT-TMG/dialogue-bart-large-chinese")

tokenizer.truncation_side = 'left'
max_length = 512

examples = [
    ["你有什么爱好吗"],
    ["你好。[SEP]嘿嘿你好,请问你最近在忙什么呢?[SEP]我最近养了一只狗狗,我在训练它呢。"]
]


def chat_func(input_utterance: str, history: Optional[List[str]] = None):
    if history is not None:
        history.extend(input_utterance.split(tokenizer.sep_token))
    else:
        history = input_utterance.split(tokenizer.sep_token)

    history_str = "对话历史:" + tokenizer.sep_token.join(history)

    input_ids = tokenizer(history_str,
                          return_tensors='pt',
                          truncation=True,
                          max_length=max_length,
                          ).input_ids

    output_ids = model.generate(input_ids,
                                max_new_tokens=30,
                                top_k=32,
                                num_beams=4,
                                repetition_penalty=1.2,
                                no_repeat_ngram_size=4)[0]

    response = tokenizer.decode(output_ids, skip_special_tokens=True)

    history.append(response)


    if len(history) % 2 == 0:
        display_utterances = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)]
    else:
        display_utterances = [("", history[0])] + [(history[i], history[i + 1]) for i in range(1, len(history) - 1, 2)]

    return display_utterances, history


demo = gr.Interface(fn=chat_func,
                    title=title,
                    description=description,
                    inputs=[gr.Textbox(lines=1, placeholder="Input current utterance"), "state"],
                    examples=examples,
                    outputs=["chatbot", "state"])


if __name__ == "__main__":
    demo.launch()

# def chat(history):
#     history_prefix = "对话历史:"
#     history = history_prefix + history
#
#     outputs = tokenizer(history,
#                         return_tensors='pt',
#                         padding=True,
#                         truncation=True,
#                         max_length=512)
#
#     input_ids = outputs.input_ids
#     output_ids = model.generate(input_ids)[0]
#
#     return tokenizer.decode(output_ids, skip_special_tokens=True)
#
#
# chatbot = gr.Chatbot().style(color_map=("green", "pink"))
# demo = gr.Interface(
#     chat,
#     inputs=gr.Textbox(lines=8, placeholder="输入你的对话历史(请以'[SEP]'作为每段对话的间隔)\nInput the dialogue history (Please split utterances by '[SEP]')"),
#     title=title,
#     description=description,
#     outputs =["text"]
# )
#
#
# if __name__ == "__main__":
#     demo.launch()