YanshekWoo commited on
Commit
d18f655
1 Parent(s): e590807

ADD dialogue chat

Browse files
Files changed (1) hide show
  1. app.py +62 -21
app.py CHANGED
@@ -1,44 +1,85 @@
1
  import gradio as gr
 
 
2
  from transformers import BertTokenizer, BartForConditionalGeneration
3
 
4
-
5
  title = "HIT-TMG/dialogue-bart-large-chinese"
6
  description = """
7
  This is a seq2seq model fine-tuned on several Chinese dialogue datasets, from bart-large-chinese.
8
  See some details of model card at https://huggingface.co/HIT-TMG/dialogue-bart-large-chinese .
9
-
10
- Input example: 可以 认识 一下 吗 ?[SEP]当然 可以 啦 , 你好 。[SEP]嘿嘿 你好 , 请问 你 最近 在 忙 什么 呢 ?[SEP]我 最近 养 了 一只 狗狗 , 我 在 训练 它 呢 。
11
  """
12
 
 
13
 
14
  tokenizer = BertTokenizer.from_pretrained("HIT-TMG/dialogue-bart-large-chinese")
15
  model = BartForConditionalGeneration.from_pretrained("HIT-TMG/dialogue-bart-large-chinese")
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- def chat(history):
19
- history_prefix = "对话历史:"
20
- history = history_prefix + history
 
 
21
 
22
- outputs = tokenizer(history,
23
- return_tensors='pt',
24
- padding=True,
25
- truncation=True,
26
- max_length=512)
 
27
 
28
- input_ids = outputs.input_ids
29
  output_ids = model.generate(input_ids)[0]
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- return tokenizer.decode(output_ids, skip_special_tokens=True)
32
 
33
 
34
- chatbot = gr.Chatbot().style(color_map=("green", "pink"))
35
- demo = gr.Interface(
36
- chat,
37
- inputs=gr.Textbox(lines=8, placeholder="输入你的对话历史(请以'[SEP]'作为每段对话的间隔)\nInput the dialogue history (Please split utterances by '[SEP]')"),
38
- title=title,
39
- description=description,
40
- outputs =["text"]
41
- )
42
 
43
 
44
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ import torch
3
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
4
  from transformers import BertTokenizer, BartForConditionalGeneration
5
 
 
6
  title = "HIT-TMG/dialogue-bart-large-chinese"
7
  description = """
8
  This is a seq2seq model fine-tuned on several Chinese dialogue datasets, from bart-large-chinese.
9
  See some details of model card at https://huggingface.co/HIT-TMG/dialogue-bart-large-chinese .
 
 
10
  """
11
 
12
+ # Input example: 可以 认识 一下 吗 ?[SEP]当然 可以 啦 , 你好 。[SEP]嘿嘿 你好 , 请问 你 最近 在 忙 什么 呢 ?[SEP]我 最近 养 了 一只 狗狗 , 我 在 训练 它 呢 。
13
 
14
  tokenizer = BertTokenizer.from_pretrained("HIT-TMG/dialogue-bart-large-chinese")
15
  model = BartForConditionalGeneration.from_pretrained("HIT-TMG/dialogue-bart-large-chinese")
16
 
17
+ tokenizer.truncation_side = 'left'
18
+ max_length = 512
19
+
20
+
21
+ # def chat(history):
22
+ # history_prefix = "对话历史:"
23
+ # history = history_prefix + history
24
+ #
25
+ # outputs = tokenizer(history,
26
+ # return_tensors='pt',
27
+ # padding=True,
28
+ # truncation=True,
29
+ # max_length=512)
30
+ #
31
+ # input_ids = outputs.input_ids
32
+ # output_ids = model.generate(input_ids)[0]
33
+ #
34
+ # return tokenizer.decode(output_ids, skip_special_tokens=True)
35
+ #
36
+ #
37
+ # chatbot = gr.Chatbot().style(color_map=("green", "pink"))
38
+ # demo = gr.Interface(
39
+ # chat,
40
+ # inputs=gr.Textbox(lines=8, placeholder="输入你的对话历史(请以'[SEP]'作为每段对话的间隔)\nInput the dialogue history (Please split utterances by '[SEP]')"),
41
+ # title=title,
42
+ # description=description,
43
+ # outputs =["text"]
44
+ # )
45
+ #
46
+ #
47
+ # if __name__ == "__main__":
48
+ # demo.launch()
49
+
50
 
51
+ def chat_func(input_utterance, history: Optional[List[str]] = None):
52
+ if history is not None:
53
+ history.append(input_utterance)
54
+ else:
55
+ history = [input_utterance]
56
 
57
+ history_str = "对话历史:" + tokenizer.sep_token.join(history)
58
+
59
+ input_ids = tokenizer(history_str,
60
+ return_tensors='pt',
61
+ truncation=True,
62
+ max_length=max_length).input_ids
63
 
 
64
  output_ids = model.generate(input_ids)[0]
65
+ response = tokenizer.decode(output_ids, skip_special_tokens=True)
66
+
67
+ history.append(response)
68
+
69
+ # # convert the tokens to text, and then split the responses into lines
70
+ # response = tokenizer.decode(history[0]).split("<|endoftext|>")
71
+ # # print('decoded_response-->>'+str(response))
72
+ # response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # convert to tuples of list
73
+ # # print('response-->>'+str(response))
74
+
75
+ display_utterances = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)]
76
 
77
+ return display_utterances, history
78
 
79
 
80
+ demo = gr.Interface(fn=chat_func,
81
+ inputs=["text", "state"],
82
+ outputs=["chatbot", "state"])
 
 
 
 
 
83
 
84
 
85
  if __name__ == "__main__":