YanshekWoo
commited on
Commit
•
3f4ede4
1
Parent(s):
0ae23a9
Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language:
|
3 |
+
- zh
|
4 |
+
thumbnail: "url to a thumbnail used in social sharing"
|
5 |
+
tags:
|
6 |
+
- bart-base-chinese
|
7 |
+
datasets:
|
8 |
+
- Chinese Persona Chat (CPC)
|
9 |
+
- LCCC
|
10 |
+
- Emotional STC (ESTC)
|
11 |
+
- KdConv
|
12 |
+
---
|
13 |
+
|
14 |
+
# dialogue-bart-base-chinese
|
15 |
+
This is a seq2seq model fine-tuned on several Chinese dialogue datasets, from bart-base-chinese.
|
16 |
+
|
17 |
+
|
18 |
+
# Datasets
|
19 |
+
We utilize 4 Chinese dialogue datasets from [LUGE](https://www.luge.ai/#/)
|
20 |
+
|
21 |
+
| | | |
|
22 |
+
| ---- | ---- | ---- |
|
23 |
+
| | Count | Domain |
|
24 |
+
| Chinese Persona Chat (CPC) | 23,000 | Open |
|
25 |
+
| LCCC | 11,987,759 | Open |
|
26 |
+
| Emotional STC (ESTC) | 899,207 | Open |
|
27 |
+
| KdConv | 3,000 | Movie, Music, Travel |
|
28 |
+
| | | |
|
29 |
+
|
30 |
+
|
31 |
+
# Data format
|
32 |
+
Input: `[CLS] 对话历史:<history> 知识:<knowledge> [SEP]`
|
33 |
+
|
34 |
+
Output: `[CLS] <response> [SEP]`
|
35 |
+
|
36 |
+
|
37 |
+
# Example
|
38 |
+
```python
|
39 |
+
from transformers import BertTokenizer, BartForConditionalGeneration
|
40 |
+
# Note that tokenizer is an object of BertTokenizer, instead of BartTokenizer
|
41 |
+
tokenizer = BertTokenizer.from_pretrained("HIT-TMG/dialogue-bart-base-chinese")
|
42 |
+
model = BartForConditionalGeneration.from_pretrained("HIT-TMG/dialogue-bart-base-chinese")
|
43 |
+
# an example from CPC dev data
|
44 |
+
history = ["可以 认识 一下 吗 ?", "当然 可以 啦 , 你好 。", "嘿嘿 你好 , 请问 你 最近 在 忙 什么 呢 ?", "我 最近 养 了 一只 狗狗 , 我 在 训练 它 呢 。"]
|
45 |
+
history_str = "对话历史:" + tokenizer.sep_token.join(history)
|
46 |
+
input_ids = tokenizer(history_str, return_tensors='pt').input_ids
|
47 |
+
output_ids = model.generate(input_ids)[0]
|
48 |
+
print(tokenizer.decode(output_ids))
|
49 |
+
```
|