metadata
license: apache-2.0
few_shot_intent_gpt2
这个模型是基于 uer/gpt2-chinese-cluecorpussmall 模型在 qgyd2021/few_shot_intent_sft 数据集上微调的结果.
(1)训练在(11000 steps)处 Early Stop。这相当于加载的 qgyd2021/few_shot_intent_sft 数据集的 1 个 epoch 处。
(2)此处保存的是 checkpoint-6000 (6000 steps)的最优权重。这相当于原数据集的 0.63 个 epoch 处。
最终的模型大约是在训练了 0.6 个 epoch 时保存的结果。
你可以在此处体验该模型 qgyd2021/gpt2_chat。
TensorBoard 数集
Eval Loss 见下图:
Learning rate 见下图:
讨论
(1)最优解在不到 1 个 epoch 处得到。
这可能说明 GPT2 模型大小,相对于任务复杂度来说太小了。
模型进入到局部最终解而无法跳出,应考虑使用较大的学习率,或更换学习率调度器。
其它
训练时加载数据集的代码
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import json
from datasets import load_dataset
from datasets.download.download_manager import DownloadMode
from tqdm import tqdm
from project_settings import project_path
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", default="qgyd2021/few_shot_intent_sft", type=str)
parser.add_argument("--dataset_split", default=None, type=str)
parser.add_argument(
"--dataset_cache_dir",
default=(project_path / "hub_datasets").as_posix(),
type=str
)
parser.add_argument("--num_epochs", default=1, type=int)
parser.add_argument("--train_subset", default="train.jsonl", type=str)
parser.add_argument("--valid_subset", default="valid.jsonl", type=str)
args = parser.parse_args()
return args
def main():
args = get_args()
name_list = [
# "a_intent_prompt",
"amazon_massive_intent_en_us_prompt",
"amazon_massive_intent_zh_cn_prompt",
"atis_intents_prompt",
"banking77_prompt",
"bi_text11_prompt",
"bi_text27_prompt",
# "book6_prompt",
"carer_prompt",
"chatbots_prompt",
"chinese_news_title_prompt",
"cmid_4class_prompt",
"cmid_36class_prompt",
"coig_cqia_prompt",
"conv_intent_prompt",
"crosswoz_prompt",
"dmslots_prompt",
"dnd_style_intents_prompt",
"emo2019_prompt",
"finance21_prompt",
"ide_intent_prompt",
"intent_classification_prompt",
"jarvis_intent_prompt",
"mobile_assistant_prompt",
"mtop_intent_prompt",
"out_of_scope_prompt",
"ri_sawoz_domain_prompt",
"ri_sawoz_general_prompt",
"small_talk_prompt",
"smp2017_task1_prompt",
"smp2019_task1_domain_prompt",
"smp2019_task1_intent_prompt",
# "snips_built_in_intents_prompt",
"star_wars_prompt",
"suicide_intent_prompt",
"snips_built_in_intents_prompt",
"telemarketing_intent_cn_prompt",
"telemarketing_intent_en_prompt",
"vira_intents_prompt",
]
with open(args.train_subset, "w", encoding="utf-8") as f:
for _ in range(args.num_epochs):
for name in name_list:
print(name)
dataset = load_dataset(
path=args.dataset_path,
name=name,
split="train",
cache_dir=args.dataset_cache_dir,
download_mode=DownloadMode.FORCE_REDOWNLOAD,
ignore_verifications=True
)
for sample in tqdm(dataset):
row = json.dumps(sample, ensure_ascii=False)
f.write("{}\n".format(row))
with open(args.valid_subset, "w", encoding="utf-8") as f:
for _ in range(args.num_epochs):
for name in name_list:
print(name)
dataset = load_dataset(
path=args.dataset_path,
name=name,
split="test",
cache_dir=args.dataset_cache_dir,
download_mode=DownloadMode.FORCE_REDOWNLOAD,
ignore_verifications=True
)
for sample in tqdm(dataset):
row = json.dumps(sample, ensure_ascii=False)
f.write("{}\n".format(row))
return
if __name__ == '__main__':
main()