qgyd2021's picture
Update README.md
fe14f27 verified
|
raw
history blame
4.65 kB
metadata
license: apache-2.0

few_shot_intent_gpt2

这个模型是基于 uer/gpt2-chinese-cluecorpussmall 模型在 qgyd2021/few_shot_intent_sft 数据集上微调的结果.

(1)训练在(11000 steps)处 Early Stop。这相当于加载的 qgyd2021/few_shot_intent_sft 数据集的 1 个 epoch 处。

(2)此处保存的是 checkpoint-6000 (6000 steps)的最优权重。这相当于原数据集的 0.63 个 epoch 处。

最终的模型大约是在训练了 0.6 个 epoch 时保存的结果。

你可以在此处体验该模型 qgyd2021/gpt2_chat

Eval Loss 见下图:

eval_loss.jpg

讨论

(1)最优解在不到 1 个 epoch 处得到。

这可能说明 GPT2 模型大小,相对于任务复杂度来说太小了。

其它

训练时加载数据集的代码

#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import json

from datasets import load_dataset
from datasets.download.download_manager import DownloadMode
from tqdm import tqdm

from project_settings import project_path


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_path", default="qgyd2021/few_shot_intent_sft", type=str)
    parser.add_argument("--dataset_split", default=None, type=str)
    parser.add_argument(
        "--dataset_cache_dir",
        default=(project_path / "hub_datasets").as_posix(),
        type=str
    )

    parser.add_argument("--num_epochs", default=1, type=int)

    parser.add_argument("--train_subset", default="train.jsonl", type=str)
    parser.add_argument("--valid_subset", default="valid.jsonl", type=str)
    args = parser.parse_args()
    return args


def main():
    args = get_args()

    name_list = [
        # "a_intent_prompt",
        "amazon_massive_intent_en_us_prompt",
        "amazon_massive_intent_zh_cn_prompt",
        "atis_intents_prompt",
        "banking77_prompt",
        "bi_text11_prompt",
        "bi_text27_prompt",
        # "book6_prompt",
        "carer_prompt",
        "chatbots_prompt",
        "chinese_news_title_prompt",
        "cmid_4class_prompt",
        "cmid_36class_prompt",
        "coig_cqia_prompt",
        "conv_intent_prompt",
        "crosswoz_prompt",
        "dmslots_prompt",
        "dnd_style_intents_prompt",
        "emo2019_prompt",
        "finance21_prompt",
        "ide_intent_prompt",
        "intent_classification_prompt",
        "jarvis_intent_prompt",
        "mobile_assistant_prompt",
        "mtop_intent_prompt",
        "out_of_scope_prompt",
        "ri_sawoz_domain_prompt",
        "ri_sawoz_general_prompt",
        "small_talk_prompt",
        "smp2017_task1_prompt",
        "smp2019_task1_domain_prompt",
        "smp2019_task1_intent_prompt",
        # "snips_built_in_intents_prompt",
        "star_wars_prompt",
        "suicide_intent_prompt",
        "snips_built_in_intents_prompt",
        "telemarketing_intent_cn_prompt",
        "telemarketing_intent_en_prompt",
        "vira_intents_prompt",
    ]

    with open(args.train_subset, "w", encoding="utf-8") as f:
        for _ in range(args.num_epochs):
            for name in name_list:
                print(name)
                dataset = load_dataset(
                    path=args.dataset_path,
                    name=name,
                    split="train",
                    cache_dir=args.dataset_cache_dir,
                    download_mode=DownloadMode.FORCE_REDOWNLOAD,
                    ignore_verifications=True
                )
                for sample in tqdm(dataset):
                    row = json.dumps(sample, ensure_ascii=False)
                    f.write("{}\n".format(row))

    with open(args.valid_subset, "w", encoding="utf-8") as f:
        for _ in range(args.num_epochs):
            for name in name_list:
                print(name)
                dataset = load_dataset(
                    path=args.dataset_path,
                    name=name,
                    split="test",
                    cache_dir=args.dataset_cache_dir,
                    download_mode=DownloadMode.FORCE_REDOWNLOAD,
                    ignore_verifications=True
                )
                for sample in tqdm(dataset):
                    row = json.dumps(sample, ensure_ascii=False)
                    f.write("{}\n".format(row))

    return


if __name__ == '__main__':
    main()