File size: 2,704 Bytes
6aee98f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from argparse import ArgumentParser, Namespace

import helpers


def add_options(parser: ArgumentParser):
    # fmt: off
    # Dataset
    parser.add_argument("--train-data", default="saved/processed/train-*.tar", type=str)
    parser.add_argument("--valid-data", default="saved/processed/valid.tar", type=str)
    parser.add_argument("--test-data", default="saved/processed/test.tar", type=str)
    parser.add_argument("--text-vocab-path", default="bert-base-chinese", type=str, help="BertTokenizer used to preprocess the corpus")
    parser.add_argument("--cond-vocab-path", default="./vocab.cond.model", type=str)
    parser.add_argument("--num-workers", default=8, help="Number of data loaders", type=int)
    parser.add_argument("--tokenize", default="zh", help="Tokenization method used to compute sacrebleu, diversity, and BERTScore, defaulted to Chinese", type=str)

    # Model
    parser.add_argument("--d-model", default=512, type=int)
    parser.add_argument("--nhead", default=8, type=int)
    parser.add_argument("--num-encoder-layers", default=6, type=int)
    parser.add_argument("--num-decoder-layers", default=6, type=int)
    parser.add_argument("--max-seq-len", default=256, type=int)
    parser.add_argument("--mode", default="baseline", type=str, choices=[
        helpers.BASELINE, helpers.KOBE_ATTRIBUTE, helpers.KOBE_KNOWLEDGE, helpers.KOBE_FULL])

    # Training
    parser.add_argument("--name", default="exp", type=str, help="expeirment name")
    parser.add_argument("--gpu", default=1, type=int)
    parser.add_argument("--grad-clip", default=1.0, type=float, help="clip threshold of gradients")
    parser.add_argument("--epochs", default=30, type=int, help="number of epochs to train")
    parser.add_argument("--patience", default=10, type=int, help="early stopping patience")
    parser.add_argument("--lr", default=1, type=float, help="learning rate")
    parser.add_argument("--dropout", default=0.1, type=float, help="dropout rate")
    parser.add_argument("--batch-size", default=64, type=int)
    parser.add_argument("--seed", default=42, type=int)

    # Evaluation
    parser.add_argument("--test", action="store_true", help="only do evaluation")
    parser.add_argument("--load-file", required=False, type=str, help="path to the checkpoint (.ckpt) for evaluation")
    parser.add_argument("--decoding-strategy", default="greedy", type=str, choices=["greedy", "nucleus"], help="Whether to use greedy decoding or nucleus sampling (https://arxiv.org/abs/1904.09751)")

    # fmt: on


def add_args(args: Namespace):
    args.text_vocab_size = helpers.get_bert_vocab_size(args.text_vocab_path)
    args.cond_vocab_size = helpers.get_vocab_size(args.cond_vocab_path)