import argparse |
from s5.utils.util import str2bool |
from s5.train import train |
from s5.dataloading import Datasets |
if __name__ == "__main__": |
parser = argparse.ArgumentParser() |
parser.add_argument("--USE_WANDB", type=str2bool, default=False, |
help="log with wandb?") |
parser.add_argument("--wandb_project", type=str, default=None, |
help="wandb project name") |
parser.add_argument("--wandb_entity", type=str, default=None, |
help="wandb entity name, e.g. username") |
parser.add_argument("--dir_name", type=str, default='./cache_dir', |
help="name of directory where data is cached") |
parser.add_argument("--dataset", type=str, choices=Datasets.keys(), |
default='mnist-classification', |
help="dataset name") |
parser.add_argument("--n_layers", type=int, default=6, |
help="Number of layers in the network") |
parser.add_argument("--d_model", type=int, default=128, |
help="Number of features, i.e. H, " |
"dimension of layer inputs/outputs") |
parser.add_argument("--ssm_size_base", type=int, default=256, |
help="SSM Latent size, i.e. P") |
parser.add_argument("--blocks", type=int, default=8, |
help="How many blocks, J, to initialize with") |
parser.add_argument("--C_init", type=str, default="trunc_standard_normal", |
choices=["trunc_standard_normal", "lecun_normal", "complex_normal"], |
help="Options for initialization of C: \\" |
"trunc_standard_normal: sample from trunc. std. normal then multiply by V \\ " \ |
"lecun_normal sample from lecun normal, then multiply by V\\ " \ |
"complex_normal: sample directly from complex standard normal") |
parser.add_argument("--discretization", type=str, default="zoh", choices=["zoh", "bilinear"]) |
parser.add_argument("--mode", type=str, default="pool", choices=["pool", "last"], |
help="options: (for classification tasks) \\" \ |
" pool: mean pooling \\" \ |
"last: take last element") |
parser.add_argument("--activation_fn", default="half_glu1", type=str, |
choices=["full_glu", "half_glu1", "half_glu2", "gelu"]) |
parser.add_argument("--conj_sym", type=str2bool, default=True, |
help="whether to enforce conjugate symmetry") |
parser.add_argument("--clip_eigs", type=str2bool, default=False, |
help="whether to enforce the left-half plane condition") |
parser.add_argument("--bidirectional", type=str2bool, default=False, |
help="whether to use bidirectional model") |
parser.add_argument("--dt_min", type=float, default=0.001, |
help="min value to sample initial timescale params from") |
parser.add_argument("--dt_max", type=float, default=0.1, |
help="max value to sample initial timescale params from") |
parser.add_argument("--prenorm", type=str2bool, default=True, |
help="True: use prenorm, False: use postnorm") |
parser.add_argument("--batchnorm", type=str2bool, default=True, |
help="True: use batchnorm, False: use layernorm") |
parser.add_argument("--bn_momentum", type=float, default=0.95, |
help="batchnorm momentum") |
parser.add_argument("--bsz", type=int, default=64, |
help="batch size") |
parser.add_argument("--epochs", type=int, default=100, |
help="max number of epochs") |
parser.add_argument("--early_stop_patience", type=int, default=1000, |
help="number of epochs to continue training when val loss plateaus") |
parser.add_argument("--ssm_lr_base", type=float, default=1e-3, |
help="initial ssm learning rate") |
parser.add_argument("--lr_factor", type=float, default=1, |
help="global learning rate = lr_factor*ssm_lr_base") |
parser.add_argument("--dt_global", type=str2bool, default=False, |
help="Treat timescale parameter as global parameter or SSM parameter") |
parser.add_argument("--lr_min", type=float, default=0, |
help="minimum learning rate") |
parser.add_argument("--cosine_anneal", type=str2bool, default=True, |
help="whether to use cosine annealing schedule") |
parser.add_argument("--warmup_end", type=int, default=1, |
help="epoch to end linear warmup") |
parser.add_argument("--lr_patience", type=int, default=1000000, |
help="patience before decaying learning rate for lr_decay_on_val_plateau") |
parser.add_argument("--reduce_factor", type=float, default=1.0, |
help="factor to decay learning rate for lr_decay_on_val_plateau") |
parser.add_argument("--p_dropout", type=float, default=0.0, |
help="probability of dropout") |
parser.add_argument("--weight_decay", type=float, default=0.05, |
help="weight decay value") |
parser.add_argument("--opt_config", type=str, default="standard", choices=['standard', |
'BandCdecay', |
'BfastandCdecay', |
'noBCdecay'], |
help="Opt configurations: \\ " \ |
"standard: no weight decay on B (ssm lr), weight decay on C (global lr) \\" \ |
"BandCdecay: weight decay on B (ssm lr), weight decay on C (global lr) \\" \ |
"BfastandCdecay: weight decay on B (global lr), weight decay on C (global lr) \\" \ |
"noBCdecay: no weight decay on B (ssm lr), no weight decay on C (ssm lr) \\") |
parser.add_argument("--jax_seed", type=int, default=1919, |
help="seed randomness") |
train(parser.parse_args()) |