demo / options.py
ar-houwei-chou
demo
6aee98f
raw
history blame
2.7 kB
from argparse import ArgumentParser, Namespace
import helpers
def add_options(parser: ArgumentParser):
# fmt: off
# Dataset
parser.add_argument("--train-data", default="saved/processed/train-*.tar", type=str)
parser.add_argument("--valid-data", default="saved/processed/valid.tar", type=str)
parser.add_argument("--test-data", default="saved/processed/test.tar", type=str)
parser.add_argument("--text-vocab-path", default="bert-base-chinese", type=str, help="BertTokenizer used to preprocess the corpus")
parser.add_argument("--cond-vocab-path", default="./vocab.cond.model", type=str)
parser.add_argument("--num-workers", default=8, help="Number of data loaders", type=int)
parser.add_argument("--tokenize", default="zh", help="Tokenization method used to compute sacrebleu, diversity, and BERTScore, defaulted to Chinese", type=str)
# Model
parser.add_argument("--d-model", default=512, type=int)
parser.add_argument("--nhead", default=8, type=int)
parser.add_argument("--num-encoder-layers", default=6, type=int)
parser.add_argument("--num-decoder-layers", default=6, type=int)
parser.add_argument("--max-seq-len", default=256, type=int)
parser.add_argument("--mode", default="baseline", type=str, choices=[
helpers.BASELINE, helpers.KOBE_ATTRIBUTE, helpers.KOBE_KNOWLEDGE, helpers.KOBE_FULL])
# Training
parser.add_argument("--name", default="exp", type=str, help="expeirment name")
parser.add_argument("--gpu", default=1, type=int)
parser.add_argument("--grad-clip", default=1.0, type=float, help="clip threshold of gradients")
parser.add_argument("--epochs", default=30, type=int, help="number of epochs to train")
parser.add_argument("--patience", default=10, type=int, help="early stopping patience")
parser.add_argument("--lr", default=1, type=float, help="learning rate")
parser.add_argument("--dropout", default=0.1, type=float, help="dropout rate")
parser.add_argument("--batch-size", default=64, type=int)
parser.add_argument("--seed", default=42, type=int)
# Evaluation
parser.add_argument("--test", action="store_true", help="only do evaluation")
parser.add_argument("--load-file", required=False, type=str, help="path to the checkpoint (.ckpt) for evaluation")
parser.add_argument("--decoding-strategy", default="greedy", type=str, choices=["greedy", "nucleus"], help="Whether to use greedy decoding or nucleus sampling (https://arxiv.org/abs/1904.09751)")
# fmt: on
def add_args(args: Namespace):
args.text_vocab_size = helpers.get_bert_vocab_size(args.text_vocab_path)
args.cond_vocab_size = helpers.get_vocab_size(args.cond_vocab_path)