File size: 17,081 Bytes
a03c9b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
# Copyright 2024 The YourMT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Please see the details in the LICENSE file.
import argparse
# from packaging.version import parse as VersionParse
import torch
from utils.data_modules import AMTDataModule
from utils.task_manager import TaskManager
from model.init_train import initialize_trainer, update_config
from model.ymt3 import YourMT3
from config.data_presets import data_preset_single_cfg, data_preset_multi_cfg
from config.vocabulary import program_vocab_presets
from utils.utils import str2bool
# yapf: disable
parser = argparse.ArgumentParser(description="YourMT3")
# General
parser.add_argument('exp_id', type=str, help='A unique identifier for the experiment is used to resume training. The "@" symbol can be used to load a specific checkpoint.')
parser.add_argument('-p', '--project', type=str, default='ymt3', help='project name')
parser.add_argument('-d', '--data-preset', type=str, default='musicnet_thickstun_ext_em', help='dataset preset (default=musicnet_thickstun_ext_em). See config/data.py for more options.')
# Intra-stem augmentation
parser.add_argument('-amp', '--random-amp-range', nargs=2, type=float, default=None, help='random amp range for audio augmentation (default=None, using default value defined in config.py). In command line, use -amp 0.6 1.2')
parser.add_argument('-iaug', '--stem-iaug-prob', type=float, default=None, help='intra-stem augmentation probability (default follow config.py). p=1.0 means no intra-stem augmentation (no stems are dropped)')
# Cross-stem augmentation policy
parser.add_argument('-xk', '--xaug-max-k', type=int, default=None, help='max number of external sources used for cross-stem augmentations. Default follows config.py.')
parser.add_argument('-xtau', '--xaug-tau', type=float, default=None, help='exponential decay rate for cross-stem augmentation. Default follows config.py')
parser.add_argument('-xalpha', '--xaug-alpha', type=float, default=None, help='shape parameter for Weibull distribution. set 1.0 for exponential. Default follows config.py')
parser.add_argument('-xnio', '--xaug-no-instr-overlap', type=str2bool, default=None, help='No instrument overlap flag. Default follows config.py')
parser.add_argument('-xndo', '--xaug-no-drum-overlap', type=str2bool, default=None, help='No drum overlap flag. Default follows config.py')
parser.add_argument('-xiaug', '--uhat-intra-stem_augment', type=str2bool, default=None, help='uhat intra-stem augmentation flag. Default follows config.py')
# Post-mix augmentation (post-mixing)
parser.add_argument('-ps', '--pitch-shift-range', nargs=2, type=int, default=None, help='pitch shift range in semitones (default=None). If None, default value defined in config.py. [0, 0] disables pitch shift. In command line, use -ps -2 2')
# Audio configurations
parser.add_argument('-ac', '--audio-codec', type=str, default=None, help='audio codec (default=None). {"spec", "melspec"}. If None, default value defined in config.py will be used.')
parser.add_argument('-hop', '--hop-length', type=int, default=None, help='hop length in frames (default=None). {128, 300} 128 for MT3, 300 for PerceiverTFIf None, default value defined in config.py will be used.')
parser.add_argument('-nmel', '--n-mels', type=int, default=None, help='number of mel bins (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-if', '--input-frames', type=int, default=None, help='number of audio frames for input segment (default=None). If None, default value defined in config.py will be used.')
# Model configurations
parser.add_argument('-sqr', '--sca-use-query-residual', type=str2bool, default=None, help='sca use query residual flag. Default follows config.py')
parser.add_argument('-enc', '--encoder-type', type=str, default=None, help="Encoder type. 't5' or 'perceiver-tf' or 'conformer'. Default is 't5', following config.py.")
parser.add_argument('-dec', '--decoder-type', type=str, default=None, help="Decoder type. 't5' or 'multi-t5'. Default is 't5', following config.py.")
parser.add_argument('-preenc', '--pre-encoder-type', type=str, default='default', help="Pre-encoder type. None or 'conv' or 'default' or 'conv1d_t' or 'conv1d_f' or 'hftt' or 'res3b_hftt'. By default, t5_enc:None, perceiver_tf_enc:conv, conformer:None")
parser.add_argument('-predec', '--pre-decoder-type', type=str, default='default', help="Pre-decoder type. {None, 'linear', 'conv1', 'mlp', 'group_linear'} or 'default'. Default is {'t5': None, 'perceiver-tf': 'linear', 'conformer': None}.")
parser.add_argument('-cout', '--conv-out-channels', type=int, default=None, help='Number of filters for pre-encoder conv layer. Default follows "model_cfg" of config.py.')
parser.add_argument('-tenc', '--task-cond-encoder', type=str2bool, default=True, help='task conditional encoder (default=True). True or False')
parser.add_argument('-tdec', '--task-cond-decoder', type=str2bool, default=True, help='task conditional decoder (default=True). True or False')
parser.add_argument('-df', '--d-feat', type=int, default=None, help='Audio feature will be projected to this dimension for Q,K,V of T5 or K,V of Perceiver (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-pt', '--pretrained', type=str2bool, default=False, help='pretrained T5(default=False). True or False')
parser.add_argument('-b', '--base-name', type=str, default="google/t5-v1_1-small", help='base model name (default="google/t5-v1_1-small")')
parser.add_argument('-edr', '--encoder-dropout-rate', type=float, default=None, help='encoder dropout rate (default=None). If None, default rate defined in config will be used.')
parser.add_argument('-ddr', '--decoder-dropout-rate', type=float, default=None, help='decoder dropout rate (default=None). If None, default rate defined in config will be used.')
parser.add_argument('-epe', '--encoder-position-encoding-type', type=str, default='default', help="Positional encoding type of encoder. By default, pre-defined PE for T5 or Perceiver-TF encoder in config.py. For T5: {'sinusoidal', 'trainable'}, conformer: {'rotary', 'trainable'}, Perceiver-TF: {'trainable', 'rope', 'alibi', 'alibit', 'None', '0', 'none', 'tkd', 'td', 'tk', 'kdt'}.")
parser.add_argument('-dpe', '--decoder-position-encoding-type', type=str, default='default', help="Positional encoding type of decoder. By default, pre-defined PE for T5 in config.py. {'sinusoidal', 'trainable'}.")
parser.add_argument('-twe', '--tie-word-embedding', type=str2bool, default=None, help='tie word embedding (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-el', '--event-length', type=int, default=None, help='event length (default=None). If None, default value defined in model cfg of config.py will be used.')
# Perceiver-TF configurations
parser.add_argument('-dl', '--d-latent', type=int, default=None, help='Latent dimension of Perceiver. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-nl', '--num-latents', type=int, default=None, help='Number of latents of Perceiver. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-dpm', '--perceiver-tf-d-model', type=int, default=None, help='Perceiver-TF d_model (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-npb', '--num-perceiver-tf-blocks', type=int, default=None, help='Number of blocks of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py.')
parser.add_argument('-npl', '--num-perceiver-tf-local-transformers-per-block', type=int, default=None, help='Number of local layers per block of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-npt', '--num-perceiver-tf-temporal-transformers-per-block', type=int, default=None, help='Number of temporal layers per block of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-atc', '--attention-to-channel', type=str2bool, default=None, help='Attention to channel flag of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-ln', '--layer-norm-type', type=str, default=None, help='Layer normalization type (default=None). {"layer_norm", "rms_norm"}. If None, default value defined in config.py will be used.')
parser.add_argument('-ff', '--ff-layer-type', type=str, default=None, help='Feed forward layer type (default=None). {"mlp", "moe", "gmlp"}. If None, default value defined in config.py will be used.')
parser.add_argument('-wf', '--ff-widening-factor', type=int, default=None, help='Feed forward layer widening factor for MLP/MoE/gMLP (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-nmoe', '--moe-num-experts', type=int, default=None, help='Number of experts for MoE (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-kmoe', '--moe-topk', type=int, default=None, help='Top-k for MoE (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-act', '--hidden-act', type=str, default=None, help='Hidden activation function (default=None). {"gelu", "silu", "relu", "tanh"}. If None, default value defined in config.py will be used.')
parser.add_argument('-rt', '--rotary-type', type=str, default=None, help='Rotary embedding type expressed in three letters. e.g. ppl: "pixel" for SCA and latents, "lang" for temporal transformer. If None, use config.')
parser.add_argument('-rk', '--rope-apply-to-keys', type=str2bool, default=None, help='Apply rope to keys (default=None). If None, use config.')
parser.add_argument('-rp', '--rope-partial-pe', type=str2bool, default=None, help='Whether to apply RoPE to partial positions (default=None). If None, use config.')
# Decoder configurations
parser.add_argument('-dff', '--decoder-ff-layer-type', type=str, default=None, help='Feed forward layer type of decoder (default=None). {"mlp", "moe", "gmlp"}. If None, default value defined in config.py will be used.')
parser.add_argument('-dwf', '--decoder-ff-widening-factor', type=int, default=None, help='Feed forward layer widening factor for decoder MLP/MoE/gMLP (default=None). If None, default value defined in config.py will be used.')
# Task and Evaluation configurations
parser.add_argument('-tk', '--task', type=str, default='mt3_full_plus', help='tokenizer type (default=gm_ext_plus). See config/task.py for more options.')
parser.add_argument('-epv', '--eval-program-vocab', type=str, default=None, help='evaluation program vocabulary (default=None). If None, default vocabulary of the data preset will be used.')
parser.add_argument('-w', '--write-model-output', type=str2bool, default=False, help='write model test output to file (default=False). True or False')
# Trainer configurations
parser.add_argument('-bsz', '--train-batch-size', nargs=2, type=int, default=None, help='train batch size for sub and local (default=None) per GPU. e.g. "-bsz 6 12". If None, default value defined in config.py will be used.')
parser.add_argument('-pr','--precision', type=str, default="bf16-mixed", help='precision (default="bf16-mixed") {32, 16, bf16, bf16-mixed}')
parser.add_argument('-st', '--strategy', type=str, default='auto', help='strategy (default=auto). auto or deepspeed or ddp')
parser.add_argument('-sb', '--sync-batchnorm', type=str2bool, default=False, help='sync batchnorm (default=True). True or False')
parser.add_argument('-se', '--train-num-samples-per-epoch', type=int, default=90000, help='number of samples per epoch (default=96000). If None, use the total number of files in multi datasets.')
parser.add_argument('-e', '--max-epochs', type=int, default=None, help='number of max epochs (default is None, which is 1000).')
parser.add_argument('-it', '--max-steps', type=int, default=-1, help='number of max steps (default is -1, disabled). This overrides the number of total steps defined in config.')
parser.add_argument('-vit', '--val-interval', type=int, default=None, help='validation interval (default=None). If None, use the check_val_every_n_epoch defined in config.py')
parser.add_argument('-lr', '--base-learning-rate', type=float, default=None, help='base learning rate (default is 1e-03 for AdamW, and auto for AdaFactor)')
parser.add_argument('-o', '--optimizer', type=str, default='AdamWScale', help='optimizer (default=AdamWScale) or AdaFactor or AdamW or CPUAdam or DAdaptAdam. Only check lowercase.')
parser.add_argument('-s', '--scheduler', type=str, default='cosine', help='scheduler name (default=legacy), constant or legacy or cosine')
parser.add_argument('-n', '--num-nodes', type=int, default=1, help='number of nodes (default=1)')
parser.add_argument('-g', '--num-gpus', type=str, default='auto', help='number of gpus (default="auto")')
parser.add_argument('-wb', '--wandb-mode', type=str, default=None, help='wandb mode for logging (default=None). "disabled" or "online" or "offline". If None, default value defined in config.py will be used.')
args = parser.parse_args()
# yapf: enable
if torch.__version__ >= "1.13":
torch.set_float32_matmul_precision("high")
# Initialize trainer
trainer, wandb_logger, dir_info, shared_cfg = initialize_trainer(args, stage='train')
# Update config with args, including augmentation settings
shared_cfg, audio_cfg, model_cfg = update_config(args, shared_cfg, stage='train')
def main():
# Data preset
if args.data_preset in data_preset_single_cfg:
# convert single preset into multi preset format
data_preset = {
"presets": [args.data_preset],
"eval_vocab": data_preset_single_cfg[args.data_preset]["eval_vocab"],
}
for k in data_preset_single_cfg[args.data_preset].keys():
if k in ["eval_drum_vocab", "add_pitch_class_metric"]:
data_preset[k] = data_preset_single_cfg[args.data_preset][k]
elif args.data_preset in data_preset_multi_cfg:
data_preset = data_preset_multi_cfg[args.data_preset]
else:
raise ValueError(f"Invalid data preset: {args.data_preset}")
# Task manager
tm = TaskManager(task_name=args.task, max_shift_steps=int(shared_cfg["TOKENIZER"]["max_shift_steps"]))
print(f"Task: {tm.task_name}, Max Shift Steps: {tm.max_shift_steps}")
# Vocabulary for validation
if args.eval_program_vocab != None:
eval_program_vocab = program_vocab_presets[args.eval_program_vocab]
else:
eval_program_vocab = data_preset["eval_vocab"]
eval_drum_vocab = data_preset.get("eval_drum_vocab", None)
dm = AMTDataModule(data_preset_multi=data_preset,
task_manager=tm,
train_num_samples_per_epoch=args.train_num_samples_per_epoch,
audio_cfg=audio_cfg,
**shared_cfg["AUGMENTATION"])
model = YourMT3(
audio_cfg=audio_cfg,
model_cfg=model_cfg,
shared_cfg=shared_cfg,
pretrained=args.pretrained,
optimizer_name="CPUAdam" if "offload" in args.strategy.lower() else args.optimizer,
scheduler_name=args.scheduler.lower(),
base_lr=float(args.base_learning_rate) if args.base_learning_rate != None else None,
max_steps=int(args.max_steps),
task_manager=tm, # tokenizer is a member of task_manager
eval_vocab=eval_program_vocab,
eval_drum_vocab=eval_drum_vocab,
write_output_dir=dir_info["lightning_dir"] if args.write_model_output else None,
add_pitch_class_metric=data_preset.get("add_pitch_class_metric", None))
# if VersionParse(torch.__version__) >= VersionParse("2.1"):
# model = torch.compile(model, mode="reduce-overhead")
# Logging config updated by args
if trainer.global_rank == 0:
wandb_logger.experiment.config.update({"audio_cfg": model.audio_cfg}, allow_val_change=True)
wandb_logger.experiment.config.update({"model_cfg": model.model_cfg}, allow_val_change=True)
wandb_logger.experiment.config.update(model.shared_cfg, allow_val_change=True)
wandb_logger.watch(model, log='gradients', log_freq=5000)
# last_ckpt_path can be None
if dir_info["last_ckpt_path"] is not None:
checkpoint = torch.load(dir_info["last_ckpt_path"])
state_dict = checkpoint['state_dict']
model.load_state_dict(state_dict, strict=False)
trainer.fit(model, datamodule=dm)
else:
trainer.fit(model, ckpt_path=dir_info["last_ckpt_path"], datamodule=dm)
if __name__ == "__main__":
main()
|