File size: 5,628 Bytes
f729637
 
 
7e553cd
51d6ebd
c9129a5
 
a849b40
06e742c
a849b40
06e742c
c9129a5
 
51d6ebd
6599416
 
 
6484cbc
 
ebb49a0
e420f46
3dd6650
e420f46
 
6484cbc
 
4f6421c
 
6484cbc
 
 
 
 
a849b40
 
 
 
fbf7e35
ef5bd87
fbf7e35
26fd56b
 
 
fbf7e35
26fd56b
fbf7e35
 
 
 
 
fe14f27
a849b40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6599416
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
---
license: apache-2.0
---
## few_shot_intent_gpt2_base

这个模型是基于 [uer/gpt2-chinese-cluecorpussmall](https://huggingface.co/uer/gpt2-chinese-cluecorpussmall) 模型在 [qgyd2021/few_shot_intent_sft](https://huggingface.co/datasets/qgyd2021/few_shot_intent_sft) 数据集上微调的结果. 

(1)训练在(11000 steps)处 Early Stop。这相当于加载的 [qgyd2021/few_shot_intent_sft](https://huggingface.co/datasets/qgyd2021/few_shot_intent_sft) 数据集的 1 个 epoch 处。

(2)此处保存的是 checkpoint-6000 (6000 steps)的最优权重。这相当于原数据集的 0.63 个 epoch 处。


最终的模型大约是在训练了 0.6 个 epoch 时保存的结果。

你可以在此处体验该模型 [qgyd2021/gpt2_chat](https://huggingface.co/spaces/qgyd2021/gpt2_chat)。


### TensorBoard 数集

**Eval Loss** 见下图:

![eval_loss.jpg](docs/pictures/eval_loss.jpg)


**Learning rate** 见下图:

学习率从 2e-4 下降到 1.4e-4。

![learning_rate.jpg](docs/pictures/learning_rate.jpg)




### 讨论

(1)最优解在不到 1 个 epoch 处得到。

* 这可能说明 GPT2 模型大小,相对于任务复杂度来说太小了。

* 模型进入到局部最终解而无法跳出,应考虑使用较大的学习率,或更换学习率调度器。

(2)后续应考虑针对 prompt-response 中 response 部分进行训练。

* 即只优化 response 部分的损失以提升识别结果与 prompt 之间的注意力机制。当前的训练有可能只是使模型拟合了 few shot 数据的格式,而并没有拟合到意图识别的目的。

(3)模型使用中的体会。

* 如果在使用过程中,模型生成 response 不在 prompt 中给定的选项,这可能说明模型已经过拟合了。

* 如果模型生成 response 在 prompt 中,但答案不正确,则说明模型已学习到生成的表层模型,而没有学习到意图识别的目的。则建议在此模型基础上进一步优化 response 部分的损失。



### 其它

训练时加载数据集的代码
```python
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import json

from datasets import load_dataset
from datasets.download.download_manager import DownloadMode
from tqdm import tqdm

from project_settings import project_path


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_path", default="qgyd2021/few_shot_intent_sft", type=str)
    parser.add_argument("--dataset_split", default=None, type=str)
    parser.add_argument(
        "--dataset_cache_dir",
        default=(project_path / "hub_datasets").as_posix(),
        type=str
    )

    parser.add_argument("--num_epochs", default=1, type=int)

    parser.add_argument("--train_subset", default="train.jsonl", type=str)
    parser.add_argument("--valid_subset", default="valid.jsonl", type=str)
    args = parser.parse_args()
    return args


def main():
    args = get_args()

    name_list = [
        # "a_intent_prompt",
        "amazon_massive_intent_en_us_prompt",
        "amazon_massive_intent_zh_cn_prompt",
        "atis_intents_prompt",
        "banking77_prompt",
        "bi_text11_prompt",
        "bi_text27_prompt",
        # "book6_prompt",
        "carer_prompt",
        "chatbots_prompt",
        "chinese_news_title_prompt",
        "cmid_4class_prompt",
        "cmid_36class_prompt",
        "coig_cqia_prompt",
        "conv_intent_prompt",
        "crosswoz_prompt",
        "dmslots_prompt",
        "dnd_style_intents_prompt",
        "emo2019_prompt",
        "finance21_prompt",
        "ide_intent_prompt",
        "intent_classification_prompt",
        "jarvis_intent_prompt",
        "mobile_assistant_prompt",
        "mtop_intent_prompt",
        "out_of_scope_prompt",
        "ri_sawoz_domain_prompt",
        "ri_sawoz_general_prompt",
        "small_talk_prompt",
        "smp2017_task1_prompt",
        "smp2019_task1_domain_prompt",
        "smp2019_task1_intent_prompt",
        # "snips_built_in_intents_prompt",
        "star_wars_prompt",
        "suicide_intent_prompt",
        "snips_built_in_intents_prompt",
        "telemarketing_intent_cn_prompt",
        "telemarketing_intent_en_prompt",
        "vira_intents_prompt",
    ]

    with open(args.train_subset, "w", encoding="utf-8") as f:
        for _ in range(args.num_epochs):
            for name in name_list:
                print(name)
                dataset = load_dataset(
                    path=args.dataset_path,
                    name=name,
                    split="train",
                    cache_dir=args.dataset_cache_dir,
                    download_mode=DownloadMode.FORCE_REDOWNLOAD,
                    ignore_verifications=True
                )
                for sample in tqdm(dataset):
                    row = json.dumps(sample, ensure_ascii=False)
                    f.write("{}\n".format(row))

    with open(args.valid_subset, "w", encoding="utf-8") as f:
        for _ in range(args.num_epochs):
            for name in name_list:
                print(name)
                dataset = load_dataset(
                    path=args.dataset_path,
                    name=name,
                    split="test",
                    cache_dir=args.dataset_cache_dir,
                    download_mode=DownloadMode.FORCE_REDOWNLOAD,
                    ignore_verifications=True
                )
                for sample in tqdm(dataset):
                    row = json.dumps(sample, ensure_ascii=False)
                    f.write("{}\n".format(row))

    return


if __name__ == '__main__':
    main()

```