File size: 5,628 Bytes
f729637 7e553cd 51d6ebd c9129a5 a849b40 06e742c a849b40 06e742c c9129a5 51d6ebd 6599416 6484cbc ebb49a0 e420f46 3dd6650 e420f46 6484cbc 4f6421c 6484cbc a849b40 fbf7e35 ef5bd87 fbf7e35 26fd56b fbf7e35 26fd56b fbf7e35 fe14f27 a849b40 6599416 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
---
license: apache-2.0
---
## few_shot_intent_gpt2_base
这个模型是基于 [uer/gpt2-chinese-cluecorpussmall](https://huggingface.co/uer/gpt2-chinese-cluecorpussmall) 模型在 [qgyd2021/few_shot_intent_sft](https://huggingface.co/datasets/qgyd2021/few_shot_intent_sft) 数据集上微调的结果.
(1)训练在(11000 steps)处 Early Stop。这相当于加载的 [qgyd2021/few_shot_intent_sft](https://huggingface.co/datasets/qgyd2021/few_shot_intent_sft) 数据集的 1 个 epoch 处。
(2)此处保存的是 checkpoint-6000 (6000 steps)的最优权重。这相当于原数据集的 0.63 个 epoch 处。
最终的模型大约是在训练了 0.6 个 epoch 时保存的结果。
你可以在此处体验该模型 [qgyd2021/gpt2_chat](https://huggingface.co/spaces/qgyd2021/gpt2_chat)。
### TensorBoard 数集
**Eval Loss** 见下图:
![eval_loss.jpg](docs/pictures/eval_loss.jpg)
**Learning rate** 见下图:
学习率从 2e-4 下降到 1.4e-4。
![learning_rate.jpg](docs/pictures/learning_rate.jpg)
### 讨论
(1)最优解在不到 1 个 epoch 处得到。
* 这可能说明 GPT2 模型大小,相对于任务复杂度来说太小了。
* 模型进入到局部最终解而无法跳出,应考虑使用较大的学习率,或更换学习率调度器。
(2)后续应考虑针对 prompt-response 中 response 部分进行训练。
* 即只优化 response 部分的损失以提升识别结果与 prompt 之间的注意力机制。当前的训练有可能只是使模型拟合了 few shot 数据的格式,而并没有拟合到意图识别的目的。
(3)模型使用中的体会。
* 如果在使用过程中,模型生成 response 不在 prompt 中给定的选项,这可能说明模型已经过拟合了。
* 如果模型生成 response 在 prompt 中,但答案不正确,则说明模型已学习到生成的表层模型,而没有学习到意图识别的目的。则建议在此模型基础上进一步优化 response 部分的损失。
### 其它
训练时加载数据集的代码
```python
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import json
from datasets import load_dataset
from datasets.download.download_manager import DownloadMode
from tqdm import tqdm
from project_settings import project_path
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", default="qgyd2021/few_shot_intent_sft", type=str)
parser.add_argument("--dataset_split", default=None, type=str)
parser.add_argument(
"--dataset_cache_dir",
default=(project_path / "hub_datasets").as_posix(),
type=str
)
parser.add_argument("--num_epochs", default=1, type=int)
parser.add_argument("--train_subset", default="train.jsonl", type=str)
parser.add_argument("--valid_subset", default="valid.jsonl", type=str)
args = parser.parse_args()
return args
def main():
args = get_args()
name_list = [
# "a_intent_prompt",
"amazon_massive_intent_en_us_prompt",
"amazon_massive_intent_zh_cn_prompt",
"atis_intents_prompt",
"banking77_prompt",
"bi_text11_prompt",
"bi_text27_prompt",
# "book6_prompt",
"carer_prompt",
"chatbots_prompt",
"chinese_news_title_prompt",
"cmid_4class_prompt",
"cmid_36class_prompt",
"coig_cqia_prompt",
"conv_intent_prompt",
"crosswoz_prompt",
"dmslots_prompt",
"dnd_style_intents_prompt",
"emo2019_prompt",
"finance21_prompt",
"ide_intent_prompt",
"intent_classification_prompt",
"jarvis_intent_prompt",
"mobile_assistant_prompt",
"mtop_intent_prompt",
"out_of_scope_prompt",
"ri_sawoz_domain_prompt",
"ri_sawoz_general_prompt",
"small_talk_prompt",
"smp2017_task1_prompt",
"smp2019_task1_domain_prompt",
"smp2019_task1_intent_prompt",
# "snips_built_in_intents_prompt",
"star_wars_prompt",
"suicide_intent_prompt",
"snips_built_in_intents_prompt",
"telemarketing_intent_cn_prompt",
"telemarketing_intent_en_prompt",
"vira_intents_prompt",
]
with open(args.train_subset, "w", encoding="utf-8") as f:
for _ in range(args.num_epochs):
for name in name_list:
print(name)
dataset = load_dataset(
path=args.dataset_path,
name=name,
split="train",
cache_dir=args.dataset_cache_dir,
download_mode=DownloadMode.FORCE_REDOWNLOAD,
ignore_verifications=True
)
for sample in tqdm(dataset):
row = json.dumps(sample, ensure_ascii=False)
f.write("{}\n".format(row))
with open(args.valid_subset, "w", encoding="utf-8") as f:
for _ in range(args.num_epochs):
for name in name_list:
print(name)
dataset = load_dataset(
path=args.dataset_path,
name=name,
split="test",
cache_dir=args.dataset_cache_dir,
download_mode=DownloadMode.FORCE_REDOWNLOAD,
ignore_verifications=True
)
for sample in tqdm(dataset):
row = json.dumps(sample, ensure_ascii=False)
f.write("{}\n".format(row))
return
if __name__ == '__main__':
main()
```
|