mimbres commited on
Commit
a03c9b4
·
1 Parent(s): 7888f4e
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +1 -1
  2. amt/src/.coverage +0 -0
  3. amt/src/.coveragerc +5 -0
  4. amt/src/config/.DS_Store +0 -0
  5. amt/src/config/config.py +272 -0
  6. amt/src/config/data_presets.py +811 -0
  7. amt/src/config/task.py +119 -0
  8. amt/src/config/vocabulary.py +384 -0
  9. amt/src/extras/.DS_Store +0 -0
  10. amt/src/extras/Dockerfile +18 -0
  11. amt/src/extras/check_drum_channel_slakh.py +24 -0
  12. amt/src/extras/dataset_mutable_var_sanity_check.py +81 -0
  13. amt/src/extras/datasets_eval_testing.py +42 -0
  14. amt/src/extras/demo_cross_augmentation.py +69 -0
  15. amt/src/extras/demo_intra_augmentation.py +52 -0
  16. amt/src/extras/download_mirst500.py +50 -0
  17. amt/src/extras/fig/label_smooth_interval_of_interest.png +0 -0
  18. amt/src/extras/fig/pitchshift_benchnmark.png +0 -0
  19. amt/src/extras/fig/pitchshift_stretch_and_resampler_process_time.png +0 -0
  20. amt/src/extras/inspecting_slakh_bass.py +34 -0
  21. amt/src/extras/install_deepspeed.md +28 -0
  22. amt/src/extras/label_smoothing.py +67 -0
  23. amt/src/extras/multi_channel_seqlen_stats.py +177 -0
  24. amt/src/extras/npy_speed_benchmark.py +187 -0
  25. amt/src/extras/perceivertf_inspect.py +640 -0
  26. amt/src/extras/perceivertf_multi_inspect.py +778 -0
  27. amt/src/extras/pitch_shift_benchmark.py +167 -0
  28. amt/src/extras/remove_silence_musicnet_midi.py +32 -0
  29. amt/src/extras/rotary_positional_embedding.py +191 -0
  30. amt/src/extras/run_spleeter_mir1k.sh +17 -0
  31. amt/src/extras/run_spleeter_mirst500.sh +13 -0
  32. amt/src/extras/run_spleeter_mirst500_cmedia.sh +13 -0
  33. amt/src/extras/swap_channel.py +122 -0
  34. amt/src/extras/t5_dev.py +41 -0
  35. amt/src/extras/t5perceiver.py +443 -0
  36. amt/src/extras/unimax_sampler/README.md +45 -0
  37. amt/src/extras/unimax_sampler/demo.py +15 -0
  38. amt/src/extras/unimax_sampler/unimax_sampler.py +168 -0
  39. amt/src/install_dataset.py +285 -0
  40. amt/src/model/RoPE/RoPE.py +306 -0
  41. amt/src/model/conformer_helper.py +169 -0
  42. amt/src/model/conformer_mod.py +439 -0
  43. amt/src/model/conv_block.py +217 -0
  44. amt/src/model/ff_layer.py +238 -0
  45. amt/src/model/init_train.py +281 -0
  46. amt/src/model/lm_head.py +40 -0
  47. amt/src/model/lr_scheduler.py +91 -0
  48. amt/src/model/ops.py +111 -0
  49. amt/src/model/optimizers.py +218 -0
  50. amt/src/model/perceiver_helper.py +290 -0
.gitignore CHANGED
@@ -1,2 +1,2 @@
1
- amt/
2
  examples/
 
1
+ amt/logs/
2
  examples/
amt/src/.coverage ADDED
Binary file (53.2 kB). View file
 
amt/src/.coveragerc ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [run]
2
+ omit =
3
+ train.py
4
+ test.py
5
+ install*.py
amt/src/config/.DS_Store ADDED
Binary file (6.15 kB). View file
 
amt/src/config/config.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """config.py"""
2
+ import numpy as np
3
+ # yapf: disable
4
+ """
5
+ audio_cfg:
6
+ - Used by 'ymt3' to create a spectrogram layer.
7
+ - Input shape of model is determined by audio_cfg.
8
+ - 'train.py' arguments can override these defaults.
9
+ """
10
+ audio_cfg = {
11
+ # Overwrittable by args in train.py
12
+ "codec": "melspec", # {melspec, spec} melspec for MT3, spec for PerceiverTF
13
+ "hop_length": 128, # {128, 300} 128 for MT3, 300 for PerceiverTF
14
+ # Shared audio parameters
15
+ "audio_backend": "torchaudio", # {torchaudio, nnAudio}
16
+ "sample_rate": 16000,
17
+ "input_frames": 32767, # number of input frames (~=2.048 s), determining in-/output shape of front layers.
18
+ "n_fft": 2048,
19
+ "n_mels": 512, # only for melspec
20
+ "f_min": 50.0,
21
+ "f_max": 8000.0,
22
+ } # TODO: currently dataloader is not updated by "input_frames"
23
+
24
+ """
25
+ model_cfg:
26
+ - Encoder type dictates use of T5_CFG or PERCEIVER_TF_CFG.
27
+ - 'train.py' arguments can override these defaults.
28
+ """
29
+ model_cfg = {
30
+ "encoder_type": "t5", # {"t5", "perceiver-tf", "conformer"}
31
+ "decoder_type": "t5", # {"t5", "multi-t5"}
32
+ "pre_encoder_type": "default", # {None, "default", "conv", "conv1d", "conv2d_avpt"} by default, t5:None, perceiver:conv.
33
+ "pre_encoder_type_default": {"t5": None, "perceiver-tf": "conv", "conformer": None},
34
+ "pre_decoder_type": "default", # {None, 'linear', 'conv1', 'mlp', 'group_linear'} see model/projection_layer.py
35
+ "pre_decoder_type_default": { # [enc_type][dec_type]
36
+ "t5": {"t5": None,},
37
+ "perceiver-tf": {"t5": "linear", "multi-t5": "mc_shared_linear"},
38
+ "conformer": {"t5": None,},
39
+ },
40
+ "conv_out_channels": 128, # number of filters for 'conv' pre_encoder. Otherwise ignored.
41
+ "t5_basename": "google/t5-v1_1-small",
42
+ "pretrained": False, # bool, if True, load pretrained weights from t5_basename. Mismatched layers are ignored.
43
+ "use_task_conditional_encoder": True, # True by default, but default task is None. So not activated by default.
44
+ "use_task_conditional_decoder": True, # True by default, but default task is None. So not activated by default.
45
+ "d_feat": "auto", # Input audio feature dimension for encoder. Automatically inferred by audio_cfg and existence of pre_encoders.
46
+ "tie_word_embeddings": True, # If True, weights of embed_tokens and lm_head are tied for stabilizing gradients.
47
+ "vocab_size": "auto", # int or "auto", automatically inferred by task manager.
48
+ "num_max_positions": "auto", # int or "auto". Length of positional encoding. Automatically inferred by "feat_length", "event_length" and task_manager.max_task_token_length.
49
+ # 'vocab_size', 'tie_word_embeddings' and 'num_max_positions' are auto-copied to encoder and decoder configs in the below.
50
+ "encoder": {
51
+ "t5": {
52
+ "d_model": 512, # Hidden size of T5 encoder.
53
+ "num_heads": 6,
54
+ "num_layers": 8,
55
+ "dropout_rate": 0.05,
56
+ "position_encoding_type": "sinusoidal", # {'sinusoidal', 'trainable'}.
57
+ "ff_widening_factor": 2, # wideening factor for MLP/MoE layers. Default is 2 in T5.
58
+ "ff_layer_type": "t5_gmlp", # {'t5_gmlp', 'moe', 'mlp', 'gmlp'}. 'moe' for mixture of experts, 'mlp' for standard transformer dense layer, 'gmlp' for simple gated MLP.
59
+ },
60
+ "perceiver-tf": {
61
+ "num_latents": 24, # number of latents in Perceiver. 24 in perceiver-tf paper.
62
+ "d_latent": 128, # latent dimension of Perceiver. 128 in perceiver-tf paper.
63
+ "d_model": "q", # int or "q" or "kv". Inner-dim of sca and local/temporal self-att.
64
+ # "q" follows "latent_dim". "kv" follows "d_feat". Best practice is to inc-/decrease 'd_latent', instead of 'd_model'.
65
+ "num_blocks": 3, # number of Perceiver-TF blocks in encoder. L in the paper.
66
+ "num_local_transformers_per_block": 2, # N in the paper.
67
+ "num_temporal_transformers_per_block": 2, # M in the paper.
68
+ "sca_use_query_residual": False,
69
+ "dropout_rate": 0.1,
70
+ "position_encoding_type": "trainable", # {'trainable', 'rotary', 'alibi', 'alibit', None, 'tkd','td', 'tk', 'kdt'}. alibit is alibi with trainable slopes.
71
+ "attention_to_channel": True, # Whether to use channel attention in sca.
72
+ "layer_norm_type": "layer_norm", # {'layer_norm', 'rms_norm'}
73
+ "ff_layer_type": "mlp", # {'moe', 'mlp', gmlp}. 'moe' for mixture of experts, 'mlp' for standard transformer dense layer, 'gmlp' for simple gated MLP.
74
+ "ff_widening_factor": 1, # wideening factor for MLP/MoE layers. Default is 1.
75
+ "moe_num_experts": 4, # number of experts in MoE layer. Default is 4. Disabled if ff_layer_type is not 'moe'.
76
+ "moe_topk": 2, # top-k routing in MoE layer. Default is 2. Disabled if ff_layer_type is not 'moe'.
77
+ "hidden_act": 'gelu', # activation function in MLP/MoE layer. Default is 'gelu'. {'gelu', 'silu', 'relu'}
78
+ "rotary_type_sca": "pixel", # {'l'|'lang', 'p'|'pixel'}. Default is 'pixel'.
79
+ "rotary_type_latent": "pixel", # {'l'|'lang', 'p'|'pixel'}. Default is 'pixel'.
80
+ "rotary_type_temporal": "lang", # {'l'|'lang', 'p'|'pixel'}. Default is 'lang'.
81
+ "rotary_apply_to_keys": False, # Whether to apply rotary to keys. Default is False.
82
+ "rotary_partial_pe": False, # Whether to use partial positional encoding. Default is False.
83
+ },
84
+ "conformer": {
85
+ "d_model": 512, # Hidden size of T5 encoder.
86
+ "intermediate_size": 512, # or 2048. size of the intermediate feed forward layer in each T5Block
87
+ "num_heads": 8,
88
+ "num_layers": 8,
89
+ "dropout_rate": 0.1,
90
+ "layerdrop": 0.1, # see https://arxiv.org/abs/1909.11556
91
+ "position_encoding_type": "rotary", # {'rotary', 'relative'}.
92
+ "conv_dim": (512, 512, 512, 512, 512, 512, 512),
93
+ "conv_stride": (5, 2, 2, 2, 2, 2, 2),
94
+ "conv_kernel": (10, 3, 3, 3, 3, 3, 3),
95
+ "conv_depthwise_kernel_size": 31,
96
+ },
97
+
98
+ },
99
+ "decoder": {
100
+ "t5": {
101
+ "d_model": 512, # Hidden size of T5 encoder. If encoder has lower dim, it is projected to this dim for enc-dec cross att.
102
+ "num_heads": 6,
103
+ "num_layers": 8,
104
+ "dropout_rate": 0.05,
105
+ "position_encoding_type": "sinusoidal", # {'sinusoidal', 'trainable'}.
106
+ "ff_widening_factor": 2, # wideening factor for MLP/MoE layers. Default is 2 in T5.
107
+ "ff_layer_type": "t5_gmlp", # {'t5_gmlp', 'moe', 'mlp', 'gmlp'}. 'moe' for mixture of experts, 'mlp' for standard transformer dense layer, 'gmlp' for simple gated MLP.
108
+ },
109
+ "multi-t5": {
110
+ "d_model": 512, # Hidden size of T5 encoder. Recommended: {256 or 512}
111
+ "num_heads": 6,
112
+ "num_layers": 8,
113
+ "dropout_rate": 0.05,
114
+ "position_encoding_type": "sinusoidal", # {'sinusoidal', 'trainable'}.
115
+ "ff_widening_factor": 2, # wideening factor for MLP/MoE layers. Default is 2 in T5.
116
+ "ff_layer_type": "t5_gmlp", # {'t5_gmlp', 'moe', 'mlp', 'gmlp'}. 'moe' for mixture of experts, 'mlp' for standard transformer dense layer, 'gmlp' for simple gated MLP.
117
+ "num_channels": 13,
118
+ },
119
+ },
120
+ "feat_length": "auto", # Input audio feature length for encoder. Automatically inferred by audio_cfg.
121
+ # mt3: 256 time steps
122
+ "event_length": 1024, # max length of event tokens excluding task tokens <-- 128 for multi-t5
123
+ "init_factor": 1.0, # initialization factor for embedding layers
124
+ }
125
+
126
+ # yapf: enable
127
+ shared_cfg = {
128
+ "PATH": {
129
+ "data_home": "../../data", # path to the data directory. If using relative path, it is relative to /src directory.
130
+ },
131
+ "BSZ": { # global batch size is local_bsz * n_GPUs in DDP mode
132
+ "train_sub": 12, #20, # sub-batch size is per CPU worker
133
+ "train_local": 24, #40, # local batch size is per GPU in DDP mode
134
+ "validation": 64, # validation batch size is per GPU in DDP mode
135
+ "test": 64,
136
+ },
137
+ "AUGMENTATION": {
138
+ "train_random_amp_range": [0.8, 1.1], # min and max amplitude scaling factor
139
+ "train_stem_iaug_prob": 0.7, # probability of stem activation in intra-stem augmentation
140
+ "train_stem_xaug_policy": {
141
+ "max_k": 3,
142
+ "tau": 0.3,
143
+ "alpha": 1.0,
144
+ "max_subunit_stems": 12, # the number of subunit stems to be reduced to this number of stems
145
+ "p_include_singing": None, # NOT IMPLEMENTED; probability of including singing for cross augmented examples. if None, use base probaility.
146
+ "no_instr_overlap": True,
147
+ "no_drum_overlap": True,
148
+ "uhat_intra_stem_augment": True,
149
+ },
150
+ "train_pitch_shift_range": [-2, 2], # [min, max] in semitones. None or [0, 0] for no pitch shift.
151
+ },
152
+ "DATAIO": { # do not set `shuffle` here.
153
+ "num_workers": 4, # num_worker is per GPU in DDP mode
154
+ "prefetch_factor": 2, #2,
155
+ "pin_memory": True,
156
+ "persistent_workers": False,
157
+ },
158
+ "CHECKPOINT": {
159
+ "save_top_k": 4, # max top k checkpoints to save
160
+ "monitor": 'validation/macro_onset_f',
161
+ "mode": 'max',
162
+ # "every_n_epochs": 20, # only working when check_val_every_n_epoch is 0
163
+ "save_last": True, # save last model
164
+ "filename": "{epoch}-{step}",
165
+ },
166
+ "TRAINER": { # do not coverwrite args in this section
167
+ "limit_train_batches": 1.0, # How much of training dataset to check (float = fraction, int = num_batches)
168
+ "limit_val_batches": 1.0,
169
+ "limit_test_batches": 1.0,
170
+ "gradient_clip_val": 1.0, # {0 or None} means don't clip.
171
+ "accumulate_grad_batches": 1, #1, # Accumulates grads every k batches. If set to 1, no effect.
172
+ "check_val_every_n_epoch": 1, #5, 1 for very large dataset such as EGMD
173
+ "num_sanity_val_steps": 0,
174
+ },
175
+ "WANDB": {
176
+ "save_dir": "../logs",
177
+ "cache_dir": "../logs/.wandb_cache",
178
+ "resume": "allow",
179
+ "anonymous": "allow", # {never, allow, must}
180
+ "mode": "online", # {online, offline, disabled}
181
+ },
182
+ "LR_SCHEDULE": {
183
+ # "scheduler_type": "cosine", # {legacy, cosine, constant}
184
+ "warmup_steps": 1000, # only for cosine scheduler, legacy scheduler follows T5's legacy schedule
185
+ "total_steps": 100000, # argparser of train.py can overwrite this
186
+ "final_cosine": 1e-5, # only for cosine scheduler
187
+ },
188
+ "TOKENIZER": {
189
+ "max_shift_steps": "auto", # max number of shift steps in the model. (int) or "auto". If "auto", it is set by audio_cfg["input_frames"] and shift_steps_ms. 206 with default setup.
190
+ "shift_step_ms": 10, # shift step in ms
191
+ },
192
+ }
193
+
194
+ T5_BASE_CFG = {
195
+ "google/t5-v1_1-small": {
196
+ "architectures": ["T5ForConditionalGeneration"],
197
+ "d_ff":
198
+ 1024, # size of the intermediate feed forward layer in each T5Block. Can be overwrten by ff_widening_factor in model_cfg.
199
+ "d_kv": 64, # d_kv has to be equal to d_model // num_heads.
200
+ # "d_model": 512, # encoder hiddnen size, defined by model_cfg
201
+ "decoder_start_token_id": 0,
202
+ "dense_act_fn": "gelu_new",
203
+ # "dropout_rate": 0.05, # can be overwritten by args in ymt3
204
+ "eos_token_id": 1,
205
+ "feed_forward_proj": "gated-gelu",
206
+ "initializer_factor": 1.0,
207
+ "is_encoder_decoder": True,
208
+ "is_gated_act": True,
209
+ "layer_norm_epsilon": 1e-06,
210
+ "model_type": "t5",
211
+ # "num_decoder_layers": 8, # defined by model_cfg
212
+ # "num_heads": 6, # defined by model_cfg
213
+ # "num_layers": 8, # defined by model_cfg
214
+ "output_past": True,
215
+ "pad_token_id": 0,
216
+ "relative_attention_num_buckets": 32,
217
+ # "tie_word_embeddings": True,
218
+ "use_cache": True,
219
+ # "vocab_size": 1391 # vocab_size is automatically set by the task manager...
220
+ },
221
+ "google/t5-efficient-small": {
222
+ "architectures": ["T5ForConditionalGeneration"],
223
+ "d_ff": 2048,
224
+ "d_kv": 64,
225
+ "d_model": 512,
226
+ "decoder_start_token_id": 0,
227
+ "dropout_rate": 0.1,
228
+ "eos_token_id": 1,
229
+ "feed_forward_proj": "relu",
230
+ "initializer_factor": 1.0,
231
+ "is_encoder_decoder": True,
232
+ "layer_norm_epsilon": 1e-06,
233
+ "model_type": "t5",
234
+ "num_decoder_layers": 6,
235
+ "num_heads": 8,
236
+ "num_layers": 6,
237
+ "pad_token_id": 0,
238
+ "relative_attention_num_buckets": 32,
239
+ "torch_dtype": "float32",
240
+ "transformers_version": "4.17.0.dev0",
241
+ "use_cache": True,
242
+ },
243
+ }
244
+
245
+ # yapf: enable
246
+ DEEPSPEED_CFG = {
247
+ "zero_allow_untested_optimizer": True,
248
+ "optimizer": {
249
+ "type": "adam",
250
+ "params": {
251
+ "lr": 1e-4,
252
+ "betas": [0.998, 0.999],
253
+ "eps": 1e-3,
254
+ "weight_decay": 0.001,
255
+ "adam_w_mode": True,
256
+ }
257
+ },
258
+ "scheduler": {
259
+ "type": "WarmupLR",
260
+ "params": {
261
+ "last_batch_iteration": -1,
262
+ "warmup_min_lr": 0,
263
+ "warmup_max_lr": 3e-5,
264
+ "warmup_num_steps": 100,
265
+ },
266
+ },
267
+ "zero_optimization": {
268
+ "stage": 0, #0,1,2,3
269
+ # "offload_optimizer":
270
+ # False, # Enable Offloading optimizer state/calculation to the host CPU
271
+ },
272
+ }
amt/src/config/data_presets.py ADDED
@@ -0,0 +1,811 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ data.py:
2
+ Data presets for training and evaluation.
3
+
4
+ Single Presets:
5
+ musicnet_mt3
6
+ musicnet_em
7
+ musicnet_thickstun
8
+ slakh
9
+ guitarset
10
+ ...
11
+
12
+ Multi Presets:
13
+ all_mmegs
14
+ ...
15
+
16
+ """
17
+ from config.vocabulary import *
18
+ from config.vocabulary import drum_vocab_presets, program_vocab_presets
19
+ from utils.utils import deduplicate_splits, merge_splits, merge_vocab
20
+
21
+ data_preset_single_cfg = {
22
+ "musicnet_mt3": {
23
+ "eval_vocab": [MUSICNET_INSTR_CLASS],
24
+ "dataset_name": "musicnet",
25
+ "train_split": "train_mt3",
26
+ "validation_split": "validation_mt3_acoustic",
27
+ "test_split": "test_mt3_acoustic",
28
+ "has_stem": False,
29
+ },
30
+ "musicnet_mt3_synth_only": { # sanity-check
31
+ "eval_vocab": [MUSICNET_INSTR_CLASS],
32
+ "dataset_name": "musicnet",
33
+ "train_split": "train_mt3_synth",
34
+ "validation_split": "validation_mt3_synth",
35
+ "test_split": "test_mt3_acoustic",
36
+ "has_stem": False,
37
+ },
38
+ "musicnet_mt3_em": {
39
+ "eval_vocab": [MUSICNET_INSTR_CLASS],
40
+ "dataset_name": "musicnet",
41
+ "train_split": "train_mt3_em",
42
+ "validation_split": "validation_mt3_em",
43
+ "test_split": "test_mt3_em",
44
+ "has_stem": False,
45
+ },
46
+ "musicnet_thickstun": { # exp4
47
+ "eval_vocab": [MUSICNET_INSTR_CLASS],
48
+ "dataset_name": "musicnet",
49
+ "train_split": "train_thickstun",
50
+ "validation_split": "test_thickstun",
51
+ "test_split": "test_thickstun",
52
+ "has_stem": False,
53
+ },
54
+ "musicnet_thickstun_em": { # NOTE: this is not the use of external 'synth' in the paper, but the use of 'synth' within the dataset
55
+ "eval_vocab": [MUSICNET_INSTR_CLASS],
56
+ "dataset_name": "musicnet",
57
+ "train_split": "train_thickstun_em",
58
+ "validation_split": "test_thickstun_em",
59
+ "test_split": "test_thickstun_em",
60
+ "has_stem": False,
61
+ },
62
+ "musicnet_thickstun_ext": { # exp4
63
+ "eval_vocab": [MUSICNET_INSTR_CLASS],
64
+ "dataset_name": "musicnet",
65
+ "train_split": "train_thickstun",
66
+ "validation_split": "test_thickstun_ext",
67
+ "test_split": "test_thickstun_ext",
68
+ "has_stem": False,
69
+ },
70
+ "musicnet_thickstun_ext_em": { # NOTE: this is not the use of external 'synth' in the paper, but the use of 'synth' within the dataset
71
+ "eval_vocab": [MUSICNET_INSTR_CLASS],
72
+ "dataset_name": "musicnet",
73
+ "train_split": "train_thickstun_em",
74
+ "validation_split": "test_thickstun_ext_em",
75
+ "test_split": "test_thickstun_ext_em",
76
+ "has_stem": False,
77
+ },
78
+ "maps_default": {
79
+ "eval_vocab": [PIANO_SOLO_CLASS],
80
+ "dataset_name": "maps",
81
+ "train_split": "train",
82
+ "validation_split": "test",
83
+ "test_split": "test",
84
+ "has_stem": False,
85
+ },
86
+ "maps_all": {
87
+ "eval_vocab": [None],
88
+ "dataset_name": "maps",
89
+ "train_split": "all",
90
+ "validation_split": None,
91
+ "test_split": None,
92
+ "has_stem": False,
93
+ },
94
+ "maestro": {
95
+ "eval_vocab": [PIANO_SOLO_CLASS],
96
+ "dataset_name": "maestro",
97
+ "train_split": "train",
98
+ "validation_split": "validation",
99
+ "test_split": "test",
100
+ "has_stem": False,
101
+ },
102
+ "maestro_final": {
103
+ "eval_vocab": [PIANO_SOLO_CLASS],
104
+ "dataset_name": "maestro",
105
+ "train_split": merge_splits(["train", "validation"], dataset_name="maestro"),
106
+ "validation_split": "test",
107
+ "test_split": "test",
108
+ "has_stem": False,
109
+ },
110
+ "guitarset": { # 4 random players for train, 1 for valid, and 1 for test
111
+ "eval_vocab": [GUITAR_SOLO_CLASS],
112
+ "dataset_name": "guitarset",
113
+ "train_split": "train",
114
+ "validation_split": "validation",
115
+ "test_split": "test",
116
+ "has_stem": False,
117
+ },
118
+ "guitarset_pshift": { # guitarset + pitch shift
119
+ "eval_vocab": [GUITAR_SOLO_CLASS],
120
+ "dataset_name": "guitarset",
121
+ "train_split": "train_pshift",
122
+ "validation_split": "validation",
123
+ "test_split": "test",
124
+ "has_stem": False,
125
+ },
126
+ "guitarset_progression": { # progression 1 and 2 as train, progression 3 as test
127
+ "eval_vocab": [GUITAR_SOLO_CLASS],
128
+ "dataset_name": "guitarset",
129
+ "train_split": merge_splits(["progression_1", "progression_2"], dataset_name="guitarset"),
130
+ "validation_split": "progression_3",
131
+ "test_split": "progression_3",
132
+ "has_stem": False,
133
+ },
134
+ "guitarset_progression_pshift": { # guuitarset_progression + pitch shift
135
+ "eval_vocab": [GUITAR_SOLO_CLASS],
136
+ "dataset_name": "guitarset",
137
+ "train_split": merge_splits(["progression_1_pshift", "progression_2_pshift"], dataset_name="guitarset"),
138
+ "validation_split": "progression_3",
139
+ "test_split": "progression_3",
140
+ "has_stem": False,
141
+ },
142
+ "guitarset_minus_bn": { # guuitarset_style + pitch shift
143
+ "eval_vocab": [GUITAR_SOLO_CLASS],
144
+ "dataset_name": "guitarset",
145
+ "train_split": merge_splits(["Funk_pshift", "SS_pshift", "Jazz_pshift", "Rock_pshift"],
146
+ dataset_name="guitarset"),
147
+ "validation_split": "BN",
148
+ "test_split": "BN",
149
+ "has_stem": False,
150
+ },
151
+ "guitarset_minus_funk": { # guuitarset_style + pitch shift
152
+ "eval_vocab": [GUITAR_SOLO_CLASS],
153
+ "dataset_name": "guitarset",
154
+ "train_split": merge_splits(["BN_pshift", "SS_pshift", "Jazz_pshift", "Rock_pshift"],
155
+ dataset_name="guitarset"),
156
+ "validation_split": "Funk",
157
+ "test_split": "Funk",
158
+ "has_stem": False,
159
+ },
160
+ "guitarset_minus_ss": { # guuitarset_style + pitch shift
161
+ "eval_vocab": GUITAR_SOLO_CLASS,
162
+ "dataset_name": "guitarset",
163
+ "train_split": merge_splits(["BN_pshift", "Funk_pshift", "Jazz_pshift", "Rock_pshift"],
164
+ dataset_name="guitarset"),
165
+ "validation_split": "SS",
166
+ "test_split": "SS",
167
+ "has_stem": False,
168
+ },
169
+ "guitarset_minus_jazz": { # guuitarset_style + pitch shift
170
+ "eval_vocab": [GUITAR_SOLO_CLASS],
171
+ "dataset_name": "guitarset",
172
+ "train_split": merge_splits(["BN_pshift", "Funk_pshift", "SS_pshift", "Rock_pshift"],
173
+ dataset_name="guitarset"),
174
+ "validation_split": "Jazz",
175
+ "test_split": "Jazz",
176
+ "has_stem": False,
177
+ },
178
+ "guitarset_minus_rock": { # guuitarset_style + pitch shift
179
+ "eval_vocab": [GUITAR_SOLO_CLASS],
180
+ "dataset_name": "guitarset",
181
+ "train_split": merge_splits(["BN_pshift", "Funk_pshift", "SS_pshift", "Jazz_pshift"],
182
+ dataset_name="guitarset"),
183
+ "validation_split": "Rock",
184
+ "test_split": "Rock",
185
+ "has_stem": False,
186
+ },
187
+ "guitarset_all": {
188
+ "eval_vocab": [None],
189
+ "dataset_name": "guitarset",
190
+ "train_split": "all",
191
+ "validation_split": None,
192
+ "test_split": None,
193
+ "has_stem": False,
194
+ },
195
+ "enstdrums_dtp": {
196
+ "eval_vocab": [None],
197
+ "eval_drum_vocab": drum_vocab_presets["ksh"],
198
+ "dataset_name": "enstdrums",
199
+ "train_split": merge_splits(["drummer_1_dtp", "drummer_2_dtp", "drummer_1_dtp", "drummer_2_dtp"], dataset_name="enstdrums"),
200
+ "validation_split": "drummer_1_dtp", # for sanity check
201
+ "test_split": "drummer_3_dtp",
202
+ "has_stem": False,
203
+ },
204
+ "enstdrums_dtm": {
205
+ "eval_vocab": [None],
206
+ "eval_drum_vocab": drum_vocab_presets["ksh"],
207
+ "dataset_name": "enstdrums",
208
+ "train_split": merge_splits(["drummer_1_dtm", "drummer_2_dtm", "drummer_1_dtp", "drummer_2_dtp"], dataset_name="enstdrums"),
209
+ "validation_split": "drummer_3_dtm_r2", # 0.6 * drum
210
+ "test_split": "drummer_3_dtm_r1", # 0.75 * drum
211
+ "has_stem": True,
212
+ },
213
+ "enstdrums_random_dtm": { # single dataset training as a denoising ADT model
214
+ "eval_vocab": [None],
215
+ "eval_drum_vocab": drum_vocab_presets["ksh"],
216
+ "dataset_name": "enstdrums",
217
+ "train_split": "train_dtm",
218
+ "validation_split": "validation_dtm",
219
+ "test_split": "test_dtm",
220
+ "has_stem": True,
221
+ },
222
+ "enstdrums_random": { # multi dataset training with random split of 70:15:15
223
+ "eval_vocab": [None],
224
+ "eval_drum_vocab": drum_vocab_presets["ksh"],
225
+ "dataset_name": "enstdrums",
226
+ "train_split": "train_dtp",
227
+ "validation_split": "test_dtm",
228
+ "test_split": "test_dtm",
229
+ "has_stem": True,
230
+ },
231
+ "enstdrums_random_plus_dtd": { # multi dataset training plus dtd
232
+ "eval_vocab": [None],
233
+ "eval_drum_vocab": drum_vocab_presets["ksh"],
234
+ "dataset_name": "enstdrums",
235
+ "train_split": merge_splits(["train_dtp", "all_dtd"], dataset_name="enstdrums"),
236
+ "validation_split": "test_dtm",
237
+ "test_split": "test_dtm",
238
+ "has_stem": True,
239
+ },
240
+ "mir_st500": {
241
+ "eval_vocab": [SINGING_SOLO_CLASS],
242
+ "dataset_name": "mir_st500",
243
+ "train_split": "train_stem",
244
+ "validation_split": "test",
245
+ "test_split": "test",
246
+ "has_stem": True,
247
+ },
248
+ "mir_st500_voc": {
249
+ "eval_vocab": [SINGING_SOLO_CLASS],
250
+ "dataset_name": "mir_st500",
251
+ "train_split": "train_vocal",
252
+ "validation_split": "test_vocal",
253
+ "test_split": "test_vocal",
254
+ "has_stem": False,
255
+ },
256
+ "mir_st500_voc_debug": { # using train_vocal for test (for debugging)
257
+ "eval_vocab": [SINGING_SOLO_CLASS],
258
+ "dataset_name": "mir_st500",
259
+ "train_split": "train_vocal",
260
+ "validation_split": "test_vocal",
261
+ "test_split": "train_vocal",
262
+ "has_stem": False,
263
+ },
264
+ "slakh": {
265
+ "eval_vocab": [GM_INSTR_CLASS],
266
+ "eval_drum_vocab": drum_vocab_presets["gm"],
267
+ "dataset_name": "slakh",
268
+ "train_split": "train",
269
+ "validation_split": "validation",
270
+ "test_split": "test",
271
+ "has_stem": True,
272
+ },
273
+ "slakh_final": {
274
+ "eval_vocab": [GM_INSTR_CLASS],
275
+ "eval_drum_vocab": drum_vocab_presets["gm"],
276
+ "dataset_name": "slakh",
277
+ "train_split": merge_splits(["train", "validation"], dataset_name="slakh"),
278
+ "validation_split": "test",
279
+ "test_split": "test",
280
+ "has_stem": True,
281
+ },
282
+ "rwc_pop_bass": {
283
+ "eval_vocab": [BASS_SOLO_CLASS],
284
+ "add_pitch_class_metric": ["Bass"],
285
+ "dataset_name": "rwc_pop",
286
+ "train_split": None,
287
+ "validation_split": "bass",
288
+ "test_split": "bass",
289
+ "has_stem": False,
290
+ },
291
+ "rwc_pop_full": {
292
+ "eval_vocab": [GM_INSTR_CLASS_PLUS],
293
+ "add_pitch_class_metric": list(GM_INSTR_CLASS_PLUS.keys()),
294
+ "dataset_name": "rwc_pop",
295
+ "train_split": None,
296
+ "validation_split": "full",
297
+ "test_split": "full",
298
+ "has_stem": False,
299
+ },
300
+ "egmd": {
301
+ "eval_vocab": [None],
302
+ "eval_drum_vocab": drum_vocab_presets["ksh"],
303
+ "dataset_name": "egmd",
304
+ "train_split": "train",
305
+ "validation_split": "validation",
306
+ "test_split": "test_reduced", # EGMD has 5000+ test files, so we reudce it to 200 files to save time
307
+ # "train_limit_num_files": 4402, #8804, # 17608, # limit the number of files for training to random choice of half.
308
+ "has_stem": False,
309
+ },
310
+ "urmp": {
311
+ "eval_vocab": [GM_INSTR_CLASS],
312
+ "dataset_name": "urmp",
313
+ "train_split": "train",
314
+ "validation_split": "test",
315
+ "test_split": "test",
316
+ "has_stem": True,
317
+ },
318
+ "cmedia": {
319
+ "eval_vocab": [SINGING_SOLO_CLASS],
320
+ "dataset_name": "cmedia",
321
+ "train_split": "train_stem",
322
+ "validation_split": "train",
323
+ "test_split": "train",
324
+ "has_stem": True,
325
+ },
326
+ "cmedia_voc": {
327
+ "eval_vocab": [SINGING_SOLO_CLASS],
328
+ "dataset_name": "cmedia",
329
+ "train_split": "train_vocal",
330
+ "validation_split": "train_vocal",
331
+ "test_split": "train_vocal",
332
+ "has_stem": False,
333
+ },
334
+ "idmt_smt_bass": {
335
+ "eval_vocab": [BASS_SOLO_CLASS],
336
+ "dataset_name": "idmt_smt_bass",
337
+ "train_split": "train",
338
+ "validation_split": "validation",
339
+ "test_split": "validation",
340
+ "has_stem": False,
341
+ },
342
+ "geerdes": { # full mix dataset for evaluation
343
+ "eval_vocab": [GM_INSTR_CLASS_PLUS],
344
+ "dataset_name": "geerdes",
345
+ "train_split": None,
346
+ "validation_split": None,
347
+ "test_split": "all",
348
+ "has_stem": False,
349
+ },
350
+ "geerdes_sep": { # Using vocal/accomp separation for evalutation
351
+ "eval_vocab": [GM_INSTR_CLASS_PLUS],
352
+ "dataset_name": "geerdes",
353
+ "train_split": None,
354
+ "validation_split": None,
355
+ "test_split": "all_sep",
356
+ "has_stem": False,
357
+ },
358
+ "geerdes_half": { # Using half dataset for train/val
359
+ "eval_vocab": [GM_INSTR_CLASS_PLUS],
360
+ "dataset_name": "geerdes",
361
+ "train_split": "train",
362
+ "validation_split": "validation",
363
+ "test_split": "validation",
364
+ "has_stem": False,
365
+ },
366
+ "geerdes_half_sep": { # Using half dataset with vocal/accomp separation for train/val
367
+ "eval_vocab": [GM_INSTR_CLASS_PLUS],
368
+ "dataset_name": "geerdes",
369
+ "train_split": "train_sep",
370
+ "validation_split": "validation_sep",
371
+ "test_split": "validation_sep",
372
+ "has_stem": False,
373
+ },
374
+ }
375
+
376
+ data_preset_multi_cfg = {
377
+ "musicnet_mt3_em_synth_plus_maps": {
378
+ "presets": ["musicnet_mt3_em_synth", "maps_all"],
379
+ "weights": [0.6, 0.4],
380
+ "eval_vocab": [MUSICNET_INSTR_CLASS],
381
+ },
382
+ "musicnet_em_synth_table2_plus_maps": {
383
+ "presets": ["musicnet_em_synth_table2", "maps_all"],
384
+ "weights": [0.6, 0.4],
385
+ "eval_vocab": [MUSICNET_INSTR_CLASS],
386
+ },
387
+ "musicnet_em_synth_table2_plus_maps_multi": {
388
+ "presets": ["musicnet_em_synth_table2", "maps_default"],
389
+ "weights": [0.6, 0.4],
390
+ "eval_vocab": [MUSICNET_INSTR_CLASS],
391
+ },
392
+ "guitarset_progression_plus_maps": {
393
+ "presets": ["guitarset_progression", "maps_all"],
394
+ "weights": [0.5, 0.5],
395
+ "eval_vocab": [GUITAR_SOLO_CLASS],
396
+ },
397
+ "guitarset_pshift_plus_maps": {
398
+ "presets": ["guitarset_pshift", "maps_default"],
399
+ "weights": [0.6, 0.4],
400
+ "eval_vocab": [merge_vocab([GUITAR_SOLO_CLASS, PIANO_SOLO_CLASS])],
401
+ },
402
+ "guitarset_pshift_plus_musicnet_thick": {
403
+ "presets": ["guitarset_pshift", "musicnet_thickstun_em"],
404
+ "weights": [0.5, 0.5],
405
+ "eval_vocab": [merge_vocab([GUITAR_SOLO_CLASS, PIANO_SOLO_CLASS])],
406
+ },
407
+ "multi_sanity_check": {
408
+ "presets": ["musicnet_mt3_synth_only", "musicnet_mt3_synth_only"],
409
+ "weights": [0.6, 0.4],
410
+ "eval_vocab": [MUSICNET_INSTR_CLASS],
411
+ },
412
+ "all_mmegs": {
413
+ "presets": [
414
+ "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp", "guitarset_pshift"
415
+ ],
416
+ "weights": [0.2, 0.2, 0.2, 0.2, 0.2],
417
+ "eval_vocab": [None] * 5, # None means instrument-agnostic F1 for each dataset
418
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
419
+ "val_max_num_files": 20, # max 20 files per dataset
420
+ "test_max_num_files": None,
421
+ },
422
+ "all_gt_cv0": {
423
+ "presets": [
424
+ "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp", "guitarset_minus_bn"
425
+ ],
426
+ "weights": [0.2, 0.2, 0.2, 0.2, 0.2],
427
+ "eval_vocab": [None] * 5, # None means instrument-agnostic F1 for each dataset
428
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
429
+ "val_max_num_files": 20, # max 20 files per dataset
430
+ "test_max_num_files": None,
431
+ },
432
+ "all_gt_cv1": {
433
+ "presets": [
434
+ "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp",
435
+ "guitarset_minus_funk"
436
+ ],
437
+ "weights": [0.2, 0.2, 0.2, 0.2, 0.2],
438
+ "eval_vocab": [None] * 5, # None means instrument-agnostic F1 for each dataset
439
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
440
+ "val_max_num_files": 20, # max 20 files per dataset
441
+ "test_max_num_files": None,
442
+ },
443
+ "all_gt_cv2": {
444
+ "presets": [
445
+ "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp", "guitarset_minus_ss"
446
+ ],
447
+ "weights": [0.2, 0.2, 0.2, 0.2, 0.2],
448
+ "eval_vocab": [None] * 5, # None means instrument-agnostic F1 for each dataset
449
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
450
+ "val_max_num_files": 20, # max 20 files per dataset
451
+ "test_max_num_files": None,
452
+ },
453
+ "all_gt_cv3": {
454
+ "presets": [
455
+ "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp",
456
+ "guitarset_minus_rock"
457
+ ],
458
+ "weights": [0.2, 0.2, 0.2, 0.2, 0.2],
459
+ "eval_vocab": [None] * 5, # None means instrument-agnostic F1 for each dataset
460
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
461
+ "val_max_num_files": 20, # max 20 files per dataset
462
+ "test_max_num_files": None,
463
+ },
464
+ "all_gt_cv4": {
465
+ "presets": [
466
+ "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp",
467
+ "guitarset_minus_jazz"
468
+ ],
469
+ "weights": [0.2, 0.2, 0.2, 0.2, 0.2],
470
+ "eval_vocab": [None] * 5, # None means instrument-agnostic F1 for each dataset
471
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
472
+ "val_max_num_files": 20, # max 20 files per dataset
473
+ "test_max_num_files": None,
474
+ },
475
+ "all_enstdrums_random": {
476
+ "presets": [
477
+ "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_random", "guitarset"
478
+ ],
479
+ "weights": [0.2, 0.2, 0.2, 0.2, 0.2],
480
+ "eval_vocab": [None] * 5, # None means instrument-agnostic F1 for each dataset
481
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
482
+ "val_max_num_files": 20, # max 20 files per dataset
483
+ "test_max_num_files": None,
484
+ },
485
+ "all_plus_egmd": {
486
+ "presets": [
487
+ "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_random_plus_dtd",
488
+ "guitarset", "egmd"
489
+ ],
490
+ "weights": [0.2, 0.2, 0.2, 0.1, 0.1, 0.2],
491
+ "eval_vocab": [None] * 6, # None means instrument-agnostic F1 for each dataset
492
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
493
+ "val_max_num_files": 20, # max 20 files per dataset
494
+ "test_max_num_files": None,
495
+ },
496
+ "all_dtp_egmd": {
497
+ "presets": [
498
+ "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp", "guitarset", "egmd"
499
+ ],
500
+ "weights": [0.2, 0.2, 0.2, 0.1, 0.1, 0.2],
501
+ "eval_vocab": [None] * 6, # None means instrument-agnostic F1 for each dataset
502
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
503
+ "val_max_num_files": 20, # max 20 files per dataset
504
+ "test_max_num_files": None,
505
+ },
506
+ "all_weighted_slakh": {
507
+ "presets": [
508
+ "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp", "guitarset_pshift", "egmd"
509
+ ],
510
+ "weights": [0.5, 0.1, 0.1, 0.05, 0.05, 0.2],
511
+ "eval_vocab": [None] * 6, # None means instrument-agnostic F1 for each dataset
512
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
513
+ "val_max_num_files": 20, # max 20 files per dataset
514
+ "test_max_num_files": None,
515
+ },
516
+ "all_weighted_mt3": { # for comparison with MT3
517
+ "presets": [
518
+ "slakh", "musicnet_mt3", "mir_st500_voc", "enstdrums_dtp",
519
+ "guitarset_progression_pshift", "egmd"
520
+ ],
521
+ "weights": [0.5, 0.1, 0.1, 0.05, 0.05, 0.2],
522
+ "eval_vocab": [None] * 6, # None means instrument-agnostic F1 for each dataset
523
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
524
+ "val_max_num_files": 20, # max 20 files per dataset
525
+ "test_max_num_files": None,
526
+ },
527
+ "all_weighted_mt3_em": { # musicnet_mt3_em
528
+ "presets": [
529
+ "slakh", "musicnet_mt3_em", "mir_st500_voc", "enstdrums_dtp",
530
+ "guitarset_progression_pshift", "egmd"
531
+ ],
532
+ "weights": [0.5, 0.1, 0.1, 0.05, 0.05, 0.2],
533
+ "eval_vocab": [None] * 6, # None means instrument-agnoßstic F1 for each dataset
534
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
535
+ "val_max_num_files": 20, # max 20 files per dataset
536
+ "test_max_num_files": None,
537
+ },
538
+ "all_urmp": {
539
+ "presets": [
540
+ "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp",
541
+ "guitarset_pshift", "egmd", "urmp"
542
+ ],
543
+ "weights": [0.5, 0.2, 0.1, 0.05, 0.05, 0.05, 0.1],
544
+ "eval_vocab": [None] * 7, # None means instrument-agnostic F1 for each dataset
545
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
546
+ "val_max_num_files": 20, # max 20 files per dataset
547
+ "test_max_num_files": None,
548
+ },
549
+ "all_urmp_mt3": { # for comparison with MT3 including URMP
550
+ "presets": [
551
+ "slakh", "musicnet_mt3", "mir_st500_voc", "enstdrums_dtp",
552
+ "guitarset_progression", "egmd", "urmp"
553
+ ],
554
+ "weights": [0.5, 0.2, 0.1, 0.05, 0.05, 0.0125, 0.1],
555
+ "eval_vocab": [None] * 7, # None means instrument-agnostic F1 for each dataset
556
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
557
+ "val_max_num_files": 20, # max 20 files per dataset
558
+ "test_max_num_files": None,
559
+ },
560
+ "all_urmp_mt3_em": { # musicnet_mt3_em including URMP
561
+ "presets": [
562
+ "slakh", "musicnet_mt3_em", "mir_st500_voc", "enstdrums_dtp",
563
+ "guitarset_progression", "egmd", "urmp"
564
+ ],
565
+ "weights": [0.5, 0.2, 0.1, 0.05, 0.05, 0.0125, 0.1],
566
+ "eval_vocab": [None] * 7, # None means instrument-agnostic F1 for each dataset
567
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
568
+ "val_max_num_files": 20, # max 20 files per dataset
569
+ "test_max_num_files": None,
570
+ },
571
+ "all_maestro": { # including Mestro and URMP
572
+ "presets": [
573
+ "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp",
574
+ "guitarset_pshift", "egmd", "urmp", "maestro"
575
+ ],
576
+ "weights": [0.5, 0.1, 0.125, 0.075, 0.025, 0.01, 0.1, 0.1],
577
+ "eval_vocab": [None] * 8, # None means instrument-agnostic F1 for each dataset
578
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
579
+ "val_max_num_files": 20, # max 20 files per dataset
580
+ "test_max_num_files": None,
581
+ },
582
+ "all_maestro_mt3": { # for comparison with MT3 including URMP
583
+ "presets": [
584
+ "slakh", "musicnet_mt3", "mir_st500_voc", "enstdrums_dtp",
585
+ "guitarset_progression", "egmd", "urmp", "maestro"
586
+ ],
587
+ "weights": [0.5, 0.1, 0.1, 0.05, 0.05, 0.0125, 0.1, 0.1],
588
+ "eval_vocab": [None] * 8, # None means instrument-agnostic F1 for each dataset
589
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
590
+ "val_max_num_files": 20, # max 20 files per dataset
591
+ "test_max_num_files": None,
592
+ },
593
+ "all_maestro_mt3_em": { # musicnet_mt3_em including URMP
594
+ "presets": [
595
+ "slakh", "musicnet_mt3_em", "mir_st500_voc", "enstdrums_dtp",
596
+ "guitarset_progression", "egmd", "urmp", "maestro"
597
+ ],
598
+ "weights": [0.5, 0.1, 0.1, 0.05, 0.05, 0.0125, 0.1, 0.1],
599
+ "eval_vocab": [None] * 8, # None means instrument-agnostic F1 for each dataset
600
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
601
+ "val_max_num_files": 20, # max 20 files per dataset
602
+ "test_max_num_files": None,
603
+ },
604
+ "singing_v1": { # slakh + mir_st500 without spleeter
605
+ "presets": ["slakh", "mir_st500"],
606
+ "weights": [0.8, 0.2],
607
+ "eval_vocab": [None, SINGING_SOLO_CLASS], # None means instrument-agnostic F1 for each dataset
608
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
609
+ "val_max_num_files": 20, # max 20 files per dataset
610
+ "test_max_num_files": None,
611
+ },
612
+ "all_singing_v1": { # for singing-only task
613
+ "presets": [
614
+ "slakh", "musicnet_thickstun_em", "mir_st500_stem", "enstdrums_dtp",
615
+ "guitarset_pshift", "egmd", "urmp", "maestro"
616
+ ],
617
+ "weights": [0.5, 0.1, 0.1, 0.05, 0.05, 0.0125, 0.1, 0.1],
618
+ "eval_vocab": [None, None, SINGING_SOLO_CLASS, None, None, None, None, None], # None means instrument-agnostic F1 for each dataset
619
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
620
+ "val_max_num_files": 20, # max 20 files per dataset
621
+ "test_max_num_files": None,
622
+ },
623
+ "all_singing_drum_v1": { # for singing-only and drum-only tasks
624
+ "presets": [
625
+ "slakh", "musicnet_thickstun_em", "mir_st500_stem", "enstdrums_dtm",
626
+ "guitarset_pshift", "egmd", "urmp", "maestro"
627
+ ],
628
+ "weights": [0.5, 0.1, 0.1, 0.05, 0.05, 0.0125, 0.1, 0.1],
629
+ "eval_vocab": [None, None, SINGING_SOLO_CLASS, None, None, None, None, None], # None means instrument-agnostic F1 for each dataset
630
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
631
+ "val_max_num_files": 20, # max 20 files per dataset
632
+ "test_max_num_files": None,
633
+ },
634
+ "all_cross": { # including Mestro and URMP
635
+ "presets": [
636
+ "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp",
637
+ "guitarset_pshift", "egmd", "urmp", "maestro"
638
+ ],
639
+ "weights": [0.5, 0.1, 0.125, 0.075, 0.025, 0.01, 0.1, 0.1],
640
+ "eval_vocab": [None, None, SINGING_SOLO_CLASS, None, None, None, None, None], # None means instrument-agnostic F1 for each dataset
641
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
642
+ "val_max_num_files": 20, # max 20 files per dataset
643
+ "test_max_num_files": None,
644
+ },
645
+ "all_cross_rebal": { # rebalanced for cross-augment, using spleeter
646
+ "presets": [
647
+ "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp",
648
+ "guitarset_pshift", "egmd", "urmp", "maestro"
649
+ ],
650
+ "weights": [0.4, 0.15, 0.15, 0.075, 0.025, 0.01, 0.1, 0.1],
651
+ "eval_vocab": [None, None, SINGING_SOLO_CLASS, None, None, None, None, None], # None means instrument-agnostic F1 for each dataset
652
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
653
+ "val_max_num_files": 20, # max 20 files per dataset
654
+ "test_max_num_files": None,
655
+ },
656
+ "all_cross_rebal2": { # rebalanced for cross-augment, using spleeter
657
+ "presets": [
658
+ "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp",
659
+ "guitarset_pshift", "egmd", "urmp", "maestro"
660
+ ],
661
+ "weights": [0.275, 0.19, 0.19, 0.1, 0.025, 0.02, 0.1, 0.1],
662
+ "eval_vocab": [None, None, SINGING_SOLO_CLASS, None, None, None, None, None], # None means instrument-agnostic F1 for each dataset
663
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
664
+ "val_max_num_files": 20, # max 20 files per dataset
665
+ "test_max_num_files": None,
666
+ },
667
+ "all_cross_rebal4": { # rebalanced for cross-augment, using spleeter
668
+ "presets": [
669
+ "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp",
670
+ "guitarset_pshift", "egmd", "urmp", "maestro"
671
+ ],
672
+ "weights": [0.258, 0.19, 0.2, 0.125, 0.022, 0.005, 0.1, 0.1],
673
+ "eval_vocab": [None, None, SINGING_SOLO_CLASS, None, None, None, None, None], # None means instrument-agnostic F1 for each dataset
674
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
675
+ "val_max_num_files": 20, # max 20 files per dataset
676
+ "test_max_num_files": None,
677
+ },
678
+ "all_cross_rebal5": { # rebalanced for cross-augment, using spleeter
679
+ "presets": [
680
+ "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp",
681
+ "guitarset_pshift", "egmd", "urmp", "maestro"
682
+ ],
683
+ "weights": [0.295, 0.19, 0.24, 0.05, 0.02, 0.005, 0.1, 0.1],
684
+ "eval_vocab": [None, None, SINGING_SOLO_CLASS, None, None, None, None, None], # None means instrument-agnostic F1 for each dataset
685
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
686
+ "val_max_num_files": 20, # max 20 files per dataset
687
+ "test_max_num_files": None,
688
+ },
689
+ "all_cross_stem": { # accomp stem for sub-task learning + rebalanced for cross-augment
690
+ "presets": [
691
+ "slakh", "musicnet_thickstun_em", "mir_st500_stem", "enstdrums_dtm",
692
+ "guitarset_pshift", "egmd", "urmp", "maestro"
693
+ ],
694
+ "weights": [0.4, 0.15, 0.15, 0.075, 0.025, 0.01, 0.1, 0.1],
695
+ "eval_vocab": [None, None, SINGING_SOLO_CLASS, None, None, None, None, None], # None means instrument-agnostic F1 for each dataset
696
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
697
+ "val_max_num_files": 20, # max 20 files per dataset
698
+ "test_max_num_files": None,
699
+ },
700
+ "all_cross_stem_rebal3": { # accomp stem for sub-task learning + rebalanced for cross-augment
701
+ "presets": [
702
+ "slakh", "musicnet_thickstun_em", "mir_st500_stem", "enstdrums_dtm",
703
+ "guitarset_pshift", "egmd", "urmp", "maestro"
704
+ ],
705
+ "weights": [0.265, 0.18, 0.21, 0.1, 0.025, 0.02, 0.1, 0.1],
706
+ "eval_vocab": [None, None, SINGING_SOLO_CLASS, None, None, None, None, None], # None means instrument-agnostic F1 for each dataset
707
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
708
+ "val_max_num_files": 20, # max 20 files per dataset
709
+ "test_max_num_files": None,
710
+ },
711
+ "all_cross_v6": { # +cmeida +idmt_smt_bass
712
+ "presets": [
713
+ "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp",
714
+ "guitarset", "egmd", "urmp", "maestro", "idmt_smt_bass", "cmedia_voc",
715
+ ],
716
+ "weights": [0.295, 0.19, 0.19, 0.05, 0.01, 0.005, 0.1, 0.1, 0.01, 0.05],
717
+ "eval_vocab": [None, None, SINGING_SOLO_CLASS, None, None, None, None, None, BASS_SOLO_CLASS, SINGING_SOLO_CLASS], # None means instrument-agnostic F1 for each dataset
718
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
719
+ "val_max_num_files": 20, # max 20 files per dataset
720
+ "test_max_num_files": None,
721
+ },
722
+ "all_cross_v6_geerdes": { # +geerdes_half
723
+ "presets": [
724
+ "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp",
725
+ "guitarset", "egmd", "urmp", "maestro", "idmt_smt_bass", "cmedia_voc",
726
+ "geerdes_half", "geerdes_half_sep"
727
+ ],
728
+ "weights": [0.295, 0.19, 0.19, 0.05, 0.01, 0.005, 0.075, 0.075, 0.01, 0.05, 0.025, 0.025],
729
+ "eval_vocab": [None, None, SINGING_SOLO_CLASS, None, None, None, None, None, BASS_SOLO_CLASS,
730
+ SINGING_SOLO_CLASS, GM_INSTR_CLASS_PLUS, GM_INSTR_CLASS_PLUS], # None means instrument-agnostic F1 for each dataset
731
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
732
+ "val_max_num_files": 20, # max 20 files per dataset
733
+ "test_max_num_files": None,
734
+ },
735
+ "all_cross_v6_geerdes_rebal": { # +geerdes_half
736
+ "presets": [
737
+ "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp",
738
+ "guitarset", "egmd", "urmp", "maestro", "idmt_smt_bass", "cmedia_voc",
739
+ "geerdes_half", "geerdes_half_sep"
740
+ ],
741
+ "weights": [0.245, 0.175, 0.19, 0.05, 0.01, 0.005, 0.075, 0.05, 0.01, 0.05, 0.075, 0.075],
742
+ "eval_vocab": [None, None, SINGING_SOLO_CLASS, None, None, None, None, None, BASS_SOLO_CLASS,
743
+ SINGING_SOLO_CLASS, GM_INSTR_EXT_CLASS_PLUS, GM_INSTR_EXT_CLASS_PLUS], # None means instrument-agnostic F1 for each dataset
744
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
745
+ "val_max_num_files": 20, # max 20 files per dataset
746
+ "test_max_num_files": None,
747
+ },
748
+ "all_cross_v7": {
749
+ "presets": [
750
+ "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp",
751
+ "guitarset_progression_pshift", "egmd", "urmp", "maestro", "idmt_smt_bass", "cmedia_voc",
752
+ ],
753
+ "weights": [0.295, 0.19, 0.191, 0.05, 0.01, 0.004, 0.1, 0.1, 0.01, 0.05],
754
+ "eval_vocab": [None, None, SINGING_SOLO_CLASS, None, None, None, None, None, BASS_SOLO_CLASS, SINGING_SOLO_CLASS], # None means instrument-agnostic F1 for each dataset
755
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
756
+ "val_max_num_files": 20, # max 20 files per dataset
757
+ "test_max_num_files": None,
758
+ },
759
+ "all_cross_final": {
760
+ "presets": [
761
+ "slakh_final", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp",
762
+ "guitarset_progression_pshift", "egmd", "urmp", "maestro_final", "idmt_smt_bass", "cmedia_voc",
763
+ ],
764
+ "weights": [0.295, 0.19, 0.191, 0.05, 0.01, 0.004, 0.1, 0.1, 0.01, 0.05],
765
+ "eval_vocab": [None, None, SINGING_SOLO_CLASS, None, None, None, None, None, BASS_SOLO_CLASS, SINGING_SOLO_CLASS], # None means instrument-agnostic F1 for each dataset
766
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
767
+ "val_max_num_files": 20, # max 20 files per dataset
768
+ "test_max_num_files": None,
769
+ },
770
+ "all_eval_final": { # The final evaluation set
771
+ "presets": [
772
+ "slakh", "musicnet_thickstun", "musicnet_thickstun_em", "musicnet_thickstun_ext",
773
+ "musicnet_thickstun_ext_em", "mir_st500_voc", "mir_st500", "enstdrums_dtp",
774
+ "enstdrums_dtm", "guitarset_progression_pshift", "rwc_pop_bass", "maestro", "urmp",
775
+ "maps_default", "rwc_pop_full", # "geerdes", "geerdes_sep",
776
+ ],
777
+ "eval_vocab": [
778
+ GM_INSTR_CLASS, MUSICNET_INSTR_CLASS, MUSICNET_INSTR_CLASS, MUSICNET_INSTR_CLASS,
779
+ MUSICNET_INSTR_CLASS, SINGING_SOLO_CLASS, SINGING_SOLO_CLASS, None,
780
+ None, None, BASS_SOLO_CLASS, PIANO_SOLO_CLASS, GM_INSTR_CLASS,
781
+ PIANO_SOLO_CLASS, GM_INSTR_CLASS_PLUS, # GM_INSTR_CLASS_PLUS, GM_INSTR_CLASS_PLUS
782
+ ],
783
+ "eval_drum_vocab": drum_vocab_presets["ksh"],
784
+ },
785
+ "geerdes_eval": { # Geerdes evaluation sets for models trained without Geerdes.
786
+ "presets": ["geerdes_sep", "geerdes"],
787
+ "eval_vocab": [GM_INSTR_CLASS_PLUS, GM_INSTR_CLASS_PLUS],
788
+ "eval_drum_vocab": drum_vocab_presets["gm"],
789
+ },
790
+ "geerdes_half_eval": { # Geerdes evaluation sets for models trained with Geerdes-half
791
+ "presets": ["geerdes_half_sep", "geerdes_half"],
792
+ "eval_vocab": [GM_INSTR_CLASS_PLUS, GM_INSTR_CLASS_PLUS],
793
+ "eval_drum_vocab": drum_vocab_presets["gm"],
794
+ },
795
+ "minimal": { # slakh + mir_st500 with spleeter
796
+ "presets": ["slakh", "mir_st500_voc"],
797
+ "weights": [0.8, 0.2],
798
+ "eval_vocab": [None, SINGING_SOLO_CLASS], # None means instrument-agnostic F1 for each dataset
799
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
800
+ "val_max_num_files": 20, # max 20 files per dataset
801
+ "test_max_num_files": None,
802
+ },
803
+ "singing_debug": { # slakh + mir_st500 with spleeter
804
+ "presets": ["mir_st500_voc_debug"],
805
+ "weights": [1.0],
806
+ "eval_vocab": [SINGING_SOLO_CLASS], # None means instrument-agnostic F1 for each dataset
807
+ "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric
808
+ "val_max_num_files": 20, # max 20 files per dataset
809
+ "test_max_num_files": None,
810
+ },
811
+ }
amt/src/config/task.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The YourMT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Please see the details in the LICENSE file.
10
+ """task.py"""
11
+ from config.vocabulary import *
12
+ from utils.note_event_dataclasses import Event
13
+
14
+ task_cfg = {
15
+ "mt3_midi": { # 11 classes + drum class
16
+ "name": "mt3_midi",
17
+ "train_program_vocab": program_vocab_presets["mt3_midi"],
18
+ "train_drum_vocab": drum_vocab_presets["gm"],
19
+ },
20
+ "mt3_midi_plus": { # 11 classes + singing + drum class
21
+ "name": "mt3_midi_plus",
22
+ "train_program_vocab": program_vocab_presets["mt3_midi_plus"],
23
+ "train_drum_vocab": drum_vocab_presets["gm"],
24
+ },
25
+ "mt3_full": { # 34 classes (except drums) as in MT3 paper
26
+ "name": "mt3_full",
27
+ "train_program_vocab": program_vocab_presets["mt3_full"],
28
+ "train_drum_vocab": drum_vocab_presets["gm"],
29
+ },
30
+ "mt3_full_plus": { # 34 classes (except drums) as in MT3 paper + singing + drum class
31
+ "name": "mt3_full_plus",
32
+ "train_program_vocab": program_vocab_presets["mt3_full_plus"],
33
+ "train_drum_vocab": drum_vocab_presets["gm"],
34
+ },
35
+ "gm_ext_plus": { # 13 classes + singing + chorus (except drums)
36
+ "name": "gm_ext_plus",
37
+ "train_program_vocab": program_vocab_presets["gm_ext_plus"],
38
+ "train_drum_vocab": drum_vocab_presets["gm"],
39
+ },
40
+ "singing_v1": {
41
+ "name": "singing",
42
+ "train_program_vocab": program_vocab_presets["mt3_full_plus"],
43
+ "train_drum_vocab": drum_vocab_presets["gm"],
44
+ "subtask_tokens": ["task", "transcribe_singing", "transcribe_all"],
45
+ "ignore_decoding_tokens": ["task", "transcribe_singing", "transcribe_all"],
46
+ "max_task_token_length": 2,
47
+ "eval_subtask_prefix": {
48
+ "default": [Event("transcribe_all", 0), Event("task", 0)],
49
+ "singing-only": [Event("transcribe_singing", 0),
50
+ Event("task", 0)],
51
+ }
52
+ },
53
+ "singing_drum_v1": {
54
+ "name": "singing_drum",
55
+ "train_program_vocab": program_vocab_presets["mt3_full_plus"],
56
+ "train_drum_vocab": drum_vocab_presets["gm"],
57
+ "subtask_tokens": ["task", "transcribe_singing", "transcribe_drum", "transcribe_all"],
58
+ "ignore_decoding_tokens": [
59
+ "task", "transcribe_singing", "transcribe_drum", "transcribe_all"
60
+ ],
61
+ "max_task_token_length": 2,
62
+ "eval_subtask_prefix": {
63
+ "default": [Event("transcribe_all", 0), Event("task", 0)],
64
+ "singing-only": [Event("transcribe_singing", 0),
65
+ Event("task", 0)],
66
+ "drum-only": [Event("transcribe_drum", 0),
67
+ Event("task", 0)],
68
+ }
69
+ },
70
+ "mc13": { # multi-channel decoding task of {11 classes + drums + singing}
71
+ "name": "mc13",
72
+ "train_program_vocab": program_vocab_presets["gm_plus"],
73
+ "train_drum_vocab": drum_vocab_presets["gm"],
74
+ "num_decoding_channels": len(program_vocab_presets["gm_plus"]) + 1, # 13
75
+ "max_note_token_length_per_ch": 512, # multi-channel decoding exclusive parameter
76
+ "mask_loss_strategy": None, # multi-channel decoding exclusive parameter
77
+ },
78
+ "mc13_256": { # multi-channel decoding task of {11 classes + drums + singing}
79
+ "name": "mc13_256",
80
+ "train_program_vocab": program_vocab_presets["gm_plus"],
81
+ "train_drum_vocab": drum_vocab_presets["gm"],
82
+ "num_decoding_channels": len(program_vocab_presets["gm_plus"]) + 1, # 13
83
+ "max_note_token_length_per_ch": 256, # multi-channel decoding exclusive parameter
84
+ "mask_loss_strategy": None, # multi-channel decoding exclusive parameter
85
+ },
86
+ "mc13_full_plus": { # multi-channel decoding task of {34 classes + drums + singing & chorus} mapped to 13 channels
87
+ "name": "mc13_full_plus",
88
+ "train_program_vocab": program_vocab_presets["mt3_full_plus"],
89
+ "train_drum_vocab": drum_vocab_presets["gm"],
90
+ "program2channel_vocab_source": program_vocab_presets["gm_plus"],
91
+ "num_decoding_channels": 13,
92
+ "max_note_token_length_per_ch": 512, # multi-channel decoding exclusive parameter
93
+ "mask_loss_strategy": None, # multi-channel decoding exclusive parameter
94
+ },
95
+ "mc13_full_plus_256": { # multi-channel decoding task of {34 classes + drums + singing & chorus} mapped to 13 channels
96
+ "name": "mc13_full_plus_256",
97
+ "train_program_vocab": program_vocab_presets["mt3_full_plus"],
98
+ "train_drum_vocab": drum_vocab_presets["gm"],
99
+ "program2channel_vocab_source": program_vocab_presets["gm_plus"],
100
+ "num_decoding_channels": 13,
101
+ "max_note_token_length_per_ch": 256, # multi-channel decoding exclusive parameter
102
+ "mask_loss_strategy": None, # multi-channel decoding exclusive parameter
103
+ },
104
+ "exc_v1": {
105
+ "name": "exclusive",
106
+ "train_program_vocab": program_vocab_presets["mt3_full_plus"],
107
+ "train_drum_vocab": drum_vocab_presets["gm"],
108
+ "subtask_tokens": ["transcribe", "all", ":"],
109
+ # "ignore_decoding_tokens": [
110
+ # "task", "transcribe_singing", "transcribe_drum", "transcribe_all"
111
+ # ],
112
+ # "max_task_token_length": 2,
113
+ "ignore_decoding_tokens_from_and_to": ["transcribe", ":"],
114
+ "eval_subtask_prefix": { # this is the main task that transcribe all instruments
115
+ "default": [Event("transcribe", 0), Event("all", 0), Event(":", 0)],
116
+ },
117
+ "shuffle_subtasks": True,
118
+ },
119
+ }
amt/src/config/vocabulary.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The YourMT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Please see the details in the LICENSE file.
10
+ """vocabulary.py
11
+
12
+ Vocabulary for instrument classes. Vocabulary can be used as train_vocab
13
+ or test_vocab in data_presets.py or train.py arguments.
14
+
15
+ - When it is used as train_vocab, it maps the instrument classes to the first
16
+ program number of the class. For example, if you use 'GM_INSTR_CLASS' as
17
+ train_vocab, then the program number of 'Piano' is [0,1,2,3,4,5,6,7]. These
18
+ program numbers are trained as program [0] in the model.
19
+
20
+ - When it is used as eval_vocab, any program number in the instrument class
21
+ is considered as correct.
22
+
23
+
24
+ MUSICNET_INSTR_CLASS: 3 classes used for MusicNet benchmark
25
+ GM_INSTR_CLASS: equivalent to 'MIDI Class' defined by MT3.
26
+ GM_INSTR_CLASS_PLUS: GM_INSTR_CLASS + singing voice
27
+ GM_INSTR_FULL: 128 GM instruments, which is extended from 'MT3_FULL'
28
+ MT3_FULL: this matches the class names in Table 3 of MT3 paper
29
+ ENST_DRUM_NOTES: 20 drum notes used in ENST dataset
30
+ GM_DRUM_NOTES: 45 GM drum notes with percussions
31
+
32
+ Program 128 is reserved for 'drum' internally.
33
+ Program 129 is reserved for 'unannotated', internally.
34
+ Program 100 is reserved for 'singing voice (melody)' in GM_INSTR_CLASS_PLUS.
35
+ Program 101 is reserved for 'singing voice (chorus)' in GM_INSTR_CLASS_PLUS.
36
+
37
+
38
+ """
39
+ # yapf: disable
40
+ import numpy as np
41
+
42
+ PIANO_SOLO_CLASS = {
43
+ "Piano": np.arange(0, 8),
44
+ }
45
+
46
+ GUITAR_SOLO_CLASS = {
47
+ "Guitar": np.arange(24, 32),
48
+ }
49
+
50
+ SINGING_SOLO_CLASS = {
51
+ "Singing Voice": [100, 101],
52
+ }
53
+
54
+ SINGING_CHORUS_SEP_CLASS = {
55
+ "Singing Voice": [100],
56
+ "Singing Voice (chorus)": [101],
57
+ }
58
+
59
+ BASS_SOLO_CLASS = {
60
+ "Bass": np.arange(32, 40),
61
+ }
62
+
63
+ MUSICNET_INSTR_CLASS = {
64
+ "Piano": np.arange(0, 8),
65
+ "Strings": np.arange(40, 52), # Solo strings + ensemble strings
66
+ "Winds": np.arange(64, 80), # Reed + Pipe
67
+ }
68
+
69
+ GM_INSTR_CLASS = {
70
+ "Piano": np.arange(0, 8),
71
+ "Chromatic Percussion": np.arange(8, 16),
72
+ "Organ": np.arange(16, 24),
73
+ "Guitar": np.arange(24, 32),
74
+ "Bass": np.arange(32, 40),
75
+ "Strings": np.arange(40, 56), # Strings + Ensemble
76
+ # "Strings": np.arange(40, 48),
77
+ # "Ensemble": np.arange(48, 56),
78
+ "Brass": np.arange(56, 64),
79
+ "Reed": np.arange(64, 72),
80
+ "Pipe": np.arange(72, 80),
81
+ "Synth Lead": np.arange(80, 88),
82
+ "Synth Pad": np.arange(88, 96),
83
+ }
84
+
85
+ GM_INSTR_CLASS_PLUS = GM_INSTR_CLASS.copy()
86
+ GM_INSTR_CLASS_PLUS["Singing Voice"] = [100, 101]
87
+
88
+ GM_INSTR_EXT_CLASS = { # Best for enjoyable MIDI file generation
89
+ "Acoustic Piano": [0, 1, 3, 6, 7],
90
+ "Electric Piano": [2, 4, 5],
91
+ "Chromatic Percussion": np.arange(8, 16),
92
+ "Organ": np.arange(16, 24),
93
+ "Guitar (clean)": np.arange(24, 28),
94
+ "Guitar (distortion)": [30, 28, 29, 31], # np.arange(28, 32),
95
+ "Bass": [33, 32, 34, 35, 36, 37, 38, 39], # np.arange(32, 40),
96
+ "Strings": [48, 40, 41, 42, 43, 44, 45, 46, 47, 49, 50, 51, 52, 53, 54, 55], # np.arange(40, 56),
97
+ "Brass": np.arange(56, 64),
98
+ "Reed": np.arange(64, 72),
99
+ "Pipe": np.arange(72, 80),
100
+ "Synth Lead": np.arange(80, 88),
101
+ "Synth Pad": np.arange(88, 96),
102
+ }
103
+ GM_INSTR_EXT_CLASS_PLUS = GM_INSTR_EXT_CLASS.copy()
104
+ GM_INSTR_EXT_CLASS_PLUS["Singing Voice"] = [100]
105
+ GM_INSTR_EXT_CLASS_PLUS["Singing Voice (chorus)"] = [101]
106
+
107
+ GM_INSTR_FULL = {
108
+ "Acoustic Grand Piano": [0],
109
+ "Bright Acoustic Piano": [1],
110
+ "Electric Grand Piano": [2],
111
+ "Honky-tonk Piano": [3],
112
+ "Electric Piano 1": [4],
113
+ "Electric Piano 2": [5],
114
+ "Harpsichord": [6],
115
+ "Clavinet": [7],
116
+ "Celesta": [8],
117
+ "Glockenspiel": [9],
118
+ "Music Box": [10],
119
+ "Vibraphone": [11],
120
+ "Marimba": [12],
121
+ "Xylophone": [13],
122
+ "Tubular Bells": [14],
123
+ "Dulcimer": [15],
124
+ "Drawbar Organ": [16],
125
+ "Percussive Organ": [17],
126
+ "Rock Organ": [18],
127
+ "Church Organ": [19],
128
+ "Reed Organ": [20],
129
+ "Accordion": [21],
130
+ "Harmonica": [22],
131
+ "Tango Accordion": [23],
132
+ "Acoustic Guitar (nylon)": [24],
133
+ "Acoustic Guitar (steel)": [25],
134
+ "Electric Guitar (jazz)": [26],
135
+ "Electric Guitar (clean)": [27],
136
+ "Electric Guitar (muted)": [28],
137
+ "Overdriven Guitar": [29],
138
+ "Distortion Guitar": [30],
139
+ "Guitar Harmonics": [31],
140
+ "Acoustic Bass": [32],
141
+ "Electric Bass (finger)": [33],
142
+ "Electric Bass (pick)": [34],
143
+ "Fretless Bass": [35],
144
+ "Slap Bass 1": [36],
145
+ "Slap Bass 2": [37],
146
+ "Synth Bass 1": [38],
147
+ "Synth Bass 2": [39],
148
+ "Violin": [40],
149
+ "Viola": [41],
150
+ "Cello": [42],
151
+ "Contrabass": [43],
152
+ "Tremolo Strings": [44],
153
+ "Pizzicato Strings": [45],
154
+ "Orchestral Harp": [46],
155
+ "Timpani": [47],
156
+ "String Ensemble 1": [48],
157
+ "String Ensemble 2": [49],
158
+ "Synth Strings 1": [50],
159
+ "Synth Strings 2": [51],
160
+ "Choir Aahs": [52],
161
+ "Voice Oohs": [53],
162
+ "Synth Choir": [54],
163
+ "Orchestra Hit": [55],
164
+ "Trumpet": [56],
165
+ "Trombone": [57],
166
+ "Tuba": [58],
167
+ "Muted Trumpet": [59],
168
+ "French Horn": [60],
169
+ "Brass Section": [61],
170
+ "Synth Brass 1": [62],
171
+ "Synth Brass 2": [63],
172
+ "Soprano Sax": [64],
173
+ "Alto Sax": [65],
174
+ "Tenor Sax": [66],
175
+ "Baritone Sax": [67],
176
+ "Oboe": [68],
177
+ "English Horn": [69],
178
+ "Bassoon": [70],
179
+ "Clarinet": [71],
180
+ "Piccolo": [72],
181
+ "Flute": [73],
182
+ "Recorder": [74],
183
+ "Pan Flute": [75],
184
+ "Bottle Blow": [76],
185
+ "Shakuhachi": [77],
186
+ "Whistle": [78],
187
+ "Ocarina": [79],
188
+ "Lead 1 (square)": [80],
189
+ "Lead 2 (sawtooth)": [81],
190
+ "Lead 3 (calliope)": [82],
191
+ "Lead 4 (chiff)": [83],
192
+ "Lead 5 (charang)": [84],
193
+ "Lead 6 (voice)": [85],
194
+ "Lead 7 (fifths)": [86],
195
+ "Lead 8 (bass + lead)": [87],
196
+ "Pad 1 (new age)": [88],
197
+ "Pad 2 (warm)": [89],
198
+ "Pad 3 (polysynth)": [90],
199
+ "Pad 4 (choir)": [91],
200
+ "Pad 5 (bowed)": [92],
201
+ "Pad 6 (metallic)": [93],
202
+ "Pad 7 (halo)": [94],
203
+ "Pad 8 (sweep)": [95],
204
+ # "FX 1 (rain)": [96],
205
+ # "FX 2 (soundtrack)": [97],
206
+ # "FX 3 (crystal)": [98],
207
+ # "FX 4 (atmosphere)": [99],
208
+ # "FX 5 (brightness)": [100],
209
+ # "FX 6 (goblins)": [101],
210
+ # "FX 7 (echoes)": [102],
211
+ # "FX 8 (sci-fi)": [103],
212
+ # "Sitar": [104],
213
+ # "Banjo": [105],
214
+ # "Shamisen": [106],
215
+ # "Koto": [107],
216
+ # "Kalimba": [108],
217
+ # "Bagpipe": [109],
218
+ # "Fiddle": [110],
219
+ # "Shanai": [111],
220
+ # "Tinkle Bell": [112],
221
+ # "Agogo": [113],
222
+ # "Steel Drums": [114],
223
+ # "Woodblock": [115],
224
+ # "Taiko Drum": [116],
225
+ # "Melodic Tom": [117],
226
+ # "Synth Drum": [118],
227
+ # "Reverse Cymbal": [119],
228
+ # "Guitar Fret Noise": [120],
229
+ # "Breath Noise": [121],
230
+ # "Seashore": [122],
231
+ # "Bird Tweet": [123],
232
+ # "Telephone Ring": [124],
233
+ # "Helicopter": [125],
234
+ # "Applause": [126],
235
+ # "Gunshot": [127]
236
+ }
237
+
238
+ MT3_FULL = { # this matches the class names in Table 3 of MT3 paper
239
+ "Acoustic Piano": [0, 1, 3, 6, 7],
240
+ "Electric Piano": [2, 4, 5],
241
+ "Chromatic Percussion": np.arange(8, 16),
242
+ "Organ": np.arange(16, 24),
243
+ "Acoustic Guitar": np.arange(24, 26),
244
+ "Clean Electric Guitar": np.arange(26, 29),
245
+ "Distorted Electric Guitar": np.arange(29, 32),
246
+ "Acoustic Bass": [32, 35],
247
+ "Electric Bass": [33, 34, 36, 37, 38, 39],
248
+ "Violin": [40],
249
+ "Viola": [41],
250
+ "Cello": [42],
251
+ "Contrabass": [43],
252
+ "Orchestral Harp": [46],
253
+ "Timpani": [47],
254
+ "String Ensemble": [48, 49, 44, 45],
255
+ "Synth Strings": [50, 51],
256
+ "Choir and Voice": [52, 53, 54],
257
+ "Orchestra Hit": [55],
258
+ "Trumpet": [56, 59],
259
+ "Trombone": [57],
260
+ "Tuba": [58],
261
+ "French Horn": [60],
262
+ "Brass Section": [61, 62, 63],
263
+ "Soprano/Alto Sax": [64, 65],
264
+ "Tenor Sax": [66],
265
+ "Baritone Sax": [67],
266
+ "Oboe": [68],
267
+ "English Horn": [69],
268
+ "Bassoon": [70],
269
+ "Clarinet": [71],
270
+ "Pipe": [73, 72, 74, 75, 76, 77, 78, 79],
271
+ "Synth Lead": np.arange(80, 88),
272
+ "Synth Pad": np.arange(88, 96),
273
+ }
274
+
275
+ MT3_FULL_PLUS = MT3_FULL.copy()
276
+ MT3_FULL_PLUS["Singing Voice"] = [100]
277
+ MT3_FULL_PLUS["Singing Voice (chorus)"] = [101]
278
+
279
+ ENST_DRUM_NOTES = {
280
+ "bd": [36], # Kick Drum
281
+ "sd": [38], # Snare Drum
282
+ "sweep": [0], # Brush sweep
283
+ "sticks": [1], # Sticks
284
+ "rs": [2], # Rim shot
285
+ "cs": [37], # X-stick
286
+ "chh": [42], # Closed Hi-Hat
287
+ "ohh": [46], # Open Hi-Hat
288
+ "cb": [56], # Cowbell
289
+ "c": [3], # Other Cymbals
290
+ "lmt": [47], # Low Mid Tom
291
+ "mt": [48], # Mid Tom
292
+ "mtr": [58], # Mid Tom Rim
293
+ "lt": [45], # Low Tom
294
+ "ltr": [50], # Low Tom Rim
295
+ "lft": [41], # Low Floor Tom
296
+ "rc": [51], # Ride Cymbal
297
+ "ch": [52], # Chinese Cymbal
298
+ "cr": [49], # Crash Cymbal
299
+ "spl": [55], # Splash Cymbal
300
+ }
301
+
302
+ EGMD_DRUM_NOTES = {
303
+ "Kick Drum": [36], # Listed by order of most common annotation
304
+ "Snare X-stick": [37], # Snare X-Stick, https://youtu.be/a2KFrrKaoYU?t=80
305
+ "Snare Drum": [38], # Snare (head) and Electric Snare
306
+ "Closed Hi-Hat": [42, 44, 22], # 44 is pedal hi-hat
307
+ "Open Hi-Hat": [46, 26],
308
+ "Cowbell": [56],
309
+ "High Floor Tom": [43],
310
+ "Low Floor Tom": [41], # Lowest Tom
311
+ "Low Tom": [45],
312
+ "Low-Mid Tom": [47],
313
+ "Mid Tom": [48],
314
+ "Low Tom (Rim)": [50], # TD-17: 47, 50, 58
315
+ "Mid Tom (Rim)": [58],
316
+ # "Ride Cymbal": [51, 53, 59],
317
+ "Ride": [51],
318
+ "Ride (Bell)": [53], # https://youtu.be/b94hZoM5s3k?t=323
319
+ "Ride (Edge)": [59],
320
+ "Chinese Cymbal": [52],
321
+ "Crash Cymbal": [49, 57],
322
+ "Splash Cymbal": [55],
323
+ }
324
+
325
+ # Inspired by Roland TD-17 MIDI note map, https://rolandus.zendesk.com/hc/en-us/articles/360005173411-TD-17-Default-Factory-MIDI-Note-Map
326
+ GM_DRUM_NOTES = {
327
+ "Kick Drum": [36, 35], # Listed by order of most common annotation
328
+ "Snare X-stick": [37, 2], # Snare X-Stick, https://youtu.be/a2KFrrKaoYU?t=80
329
+ "Snare Drum": [38, 40], # Snare (head) and Electric Snare
330
+ "Closed Hi-Hat": [42, 44, 22], # 44 is pedal hi-hat
331
+ "Open Hi-Hat": [46, 26],
332
+ "Cowbell": [56],
333
+ "High Floor Tom": [43],
334
+ "Low Floor Tom": [41], # Lowest Tom
335
+ "Low Tom": [45],
336
+ "Low-Mid Tom": [47],
337
+ "Mid Tom": [48],
338
+ "Low Tom (Rim)": [50], # TD-17: 47, 50, 58
339
+ "Mid Tom (Rim)": [58],
340
+ # "Ride Cymbal": [51, 53, 59],
341
+ "Ride": [51],
342
+ "Ride (Bell)": [53], # https://youtu.be/b94hZoM5s3k?t=323
343
+ "Ride (Edge)": [59],
344
+ "Chinese Cymbal": [52],
345
+ "Crash Cymbal": [49, 57],
346
+ "Splash Cymbal": [55],
347
+ }
348
+
349
+ KICK_SNARE_HIHAT = {
350
+ "Kick Drum": [36, 35],
351
+ "Snare Drum": [38, 40],
352
+ # "Snare Drum + X-Stick": [38, 40, 37, 2],
353
+ # "Snare X-stick": [37, 2], # Snare X-Stick, https://youtu.be/a2KFrrKaoYU?t=80
354
+ "Hi-Hat": [42, 44, 46, 22, 26],
355
+ # "Ride Cymbal": [51, 53, 59],
356
+ # "Hi-Hat + Ride": [42, 44, 46, 22, 26, 51, 53, 59],
357
+ # "HiHat + all Cymbals": [42, 44, 46, 22, 26, 51, 53, 59, 52, 49, 57, 55],
358
+ # "Kick Drum + Low Tom": [36, 35, 45],
359
+ # "All Cymbal": [51, 53, 59, 52, 49, 57, 55]
360
+ # "all": np.arange(30, 60)
361
+ }
362
+
363
+ drum_vocab_presets = {
364
+ "gm": GM_DRUM_NOTES,
365
+ "egmd": EGMD_DRUM_NOTES,
366
+ "enst": ENST_DRUM_NOTES,
367
+ "ksh": KICK_SNARE_HIHAT,
368
+ "kshr": {
369
+ "Kick Drum": [36, 35],
370
+ "Snare Drum": [38, 40],
371
+ "Hi-Hat": [42, 44, 46, 22, 26, 51, 53, 59],
372
+ }
373
+ }
374
+
375
+ program_vocab_presets = {
376
+ "gm_full": GM_INSTR_FULL, # 96 classes (except drums)
377
+ "mt3_full": MT3_FULL, # 34 classes (except drums) as in MT3 paper
378
+ "mt3_midi": GM_INSTR_CLASS, # 11 classes (except drums) as in MT3 paper
379
+ "mt3_midi_plus": GM_INSTR_CLASS_PLUS, # 11 classes + singing (except drums)
380
+ "mt3_full_plus": MT3_FULL_PLUS, # 34 classes (except drums) mt3_full + singing (except drums)
381
+ "gm": GM_INSTR_CLASS, # 11 classes (except drums)
382
+ "gm_plus": GM_INSTR_CLASS_PLUS, # 11 classes + singing (except drums)
383
+ "gm_ext_plus": GM_INSTR_EXT_CLASS_PLUS, # 13 classes + singing + chorus (except drums)
384
+ }
amt/src/extras/.DS_Store ADDED
Binary file (10.2 kB). View file
 
amt/src/extras/Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel
2
+ LABEL maintainer="https://github.com/mimbres/YourMT3"
3
+
4
+ ENV TZ=Europe/London \
5
+ DEBIAN_FRONTEND=noninteractive
6
+ RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
7
+
8
+ RUN apt-get update
9
+ ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
10
+
11
+ RUN apt-get update --fix-missing && apt-get install -y wget curl \
12
+ nano git ffmpeg sox tmux htop
13
+ RUN pip3 install --upgrade pip
14
+ RUN pip3 install mirdata mido git+https://github.com/craffel/mir_eval.git \
15
+ matplotlib lightning>=2.0.2 pytest-timeout pytest deprecated librosa \
16
+ einops transformers wandb
17
+
18
+ CMD [ "/bin/bash" ]
amt/src/extras/check_drum_channel_slakh.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.mirdata_dev.datasets import slakh16k
2
+
3
+
4
+ def check_drum_channel_slakh(data_home: str):
5
+ ds = slakh16k.Dataset(data_home, version='default')
6
+ for track_id in ds.track_ids:
7
+ is_drum = ds.track(track_id).is_drum
8
+ midi = MidiFile(ds.track(track_id).midi_path)
9
+ cnt = 0
10
+ for msg in midi:
11
+ if 'note' in msg.type:
12
+ if is_drum and (msg.channel != 9):
13
+ print('found drum track with channel != 9 in track_id: ',
14
+ track_id)
15
+ if not is_drum and (msg.channel == 9):
16
+ print(
17
+ 'found non-drum track with channel == 9 in track_id: ',
18
+ track_id)
19
+ if is_drum and (msg.channel == 9):
20
+ cnt += 1
21
+ if cnt > 0:
22
+ print(f'found {cnt} notes in drum track with ch 9 in track_id: ',
23
+ track_id)
24
+ return
amt/src/extras/dataset_mutable_var_sanity_check.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ for n in range(1000):
2
+ sampled_data = ds.__getitem__(n)
3
+
4
+ a = deepcopy(sampled_data['note_event_segments'])
5
+ b = deepcopy(sampled_data['note_event_segments'])
6
+
7
+ for (note_events, tie_note_events, start_time) in list(zip(*b.values())):
8
+ note_events = pitch_shift_note_events(note_events, 2)
9
+ tie_note_events = pitch_shift_note_events(tie_note_events, 2)
10
+
11
+ # compare
12
+ for i, (note_events, tie_note_events, start_time) in enumerate(list(zip(*b.values()))):
13
+ for j, ne in enumerate(note_events):
14
+ if ne.is_drum is False:
15
+ if ne.pitch != a['note_events'][i][j].pitch + 2:
16
+ print(i, j)
17
+ assert ne.pitch == a['note_events'][i][j].pitch + 2
18
+
19
+ for k, tne in enumerate(tie_note_events):
20
+ assert tne.pitch == a['tie_note_events'][i][k].pitch + 2
21
+
22
+ print('test {} passed'.format(n))
23
+
24
+
25
+ def assert_note_events_almost_equal(actual_note_events,
26
+ predicted_note_events,
27
+ ignore_time=False,
28
+ ignore_activity=True,
29
+ delta=5.1e-3):
30
+ """
31
+ Asserts that the given lists of Note instances are equal up to a small
32
+ floating-point tolerance, similar to `assertAlmostEqual` of `unittest`.
33
+ Tolerance is 5e-3 by default, which is 5 ms for 100 ticks-per-second.
34
+
35
+ If `ignore_time` is True, then the time field is ignored. (useful for
36
+ comparing tie note events, default is False)
37
+
38
+ If `ignore_activity` is True, then the activity field is ignored (default
39
+ is True).
40
+ """
41
+ assert len(actual_note_events) == len(predicted_note_events)
42
+ for j, (actual_note_event,
43
+ predicted_note_event) in enumerate(zip(actual_note_events, predicted_note_events)):
44
+ if ignore_time is False:
45
+ assert abs(actual_note_event.time - predicted_note_event.time) <= delta
46
+ assert actual_note_event.is_drum == predicted_note_event.is_drum
47
+ if actual_note_event.is_drum is False and predicted_note_event.is_drum is False:
48
+ assert actual_note_event.program == predicted_note_event.program
49
+ assert actual_note_event.pitch == predicted_note_event.pitch
50
+ assert actual_note_event.velocity == predicted_note_event.velocity
51
+ if ignore_activity is False:
52
+ assert actual_note_event.activity == predicted_note_event.activity
53
+
54
+
55
+ cache_old = deepcopy(dict(ds.cache))
56
+ for n in range(500):
57
+ sampled_data = ds.__getitem__(n)
58
+ cache_new = ds.cache
59
+ cnt = 0
60
+ for k, v in cache_new.items():
61
+ if k in cache_old:
62
+ cnt += 1
63
+ assert (cache_new[k]['programs'] == cache_old[k]['programs']).all()
64
+ assert (cache_new[k]['is_drum'] == cache_old[k]['is_drum']).all()
65
+ assert (cache_new[k]['has_stems'] == cache_old[k]['has_stems'])
66
+ assert (cache_new[k]['has_unannotated'] == cache_old[k]['has_unannotated'])
67
+ assert (cache_new[k]['audio_array'] == cache_old[k]['audio_array']).all()
68
+
69
+ for nes_new, nes_old in zip(cache_new[k]['note_event_segments']['note_events'],
70
+ cache_old[k]['note_event_segments']['note_events']):
71
+ assert_note_events_almost_equal(nes_new, nes_old)
72
+
73
+ for tnes_new, tnes_old in zip(cache_new[k]['note_event_segments']['tie_note_events'],
74
+ cache_old[k]['note_event_segments']['tie_note_events']):
75
+ assert_note_events_almost_equal(tnes_new, tnes_old, ignore_time=True)
76
+
77
+ for s_new, s_old in zip(cache_new[k]['note_event_segments']['start_times'],
78
+ cache_old[k]['note_event_segments']['start_times']):
79
+ assert s_new == s_old
80
+ cache_old = deepcopy(dict(ds.cache))
81
+ print(n, cnt)
amt/src/extras/datasets_eval_testing.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.datasets_eval import AudioFileDataset
2
+ from torch.utils.data import DataLoader
3
+ import pytorch_lightning as pl
4
+
5
+
6
+ def test():
7
+
8
+ ds = AudioFileDataset()
9
+ dl = DataLoader(
10
+ ds, batch_size=None, collate_fn=lambda k: k
11
+ ) # empty collate_fn is required to use mixed types.
12
+
13
+ for x, y in dl:
14
+ break
15
+
16
+ class MyModel(pl.LightningModule):
17
+
18
+ def __init__(self, **kwargs):
19
+ super().__init__()
20
+
21
+ def forward(self, x):
22
+ return x
23
+
24
+ def training_step(self, batch, batch_idx):
25
+ return 0
26
+
27
+ def validation_step(self, batch, batch_idx):
28
+ print(batch)
29
+ return 0
30
+
31
+ def train_dataloader(self):
32
+ return dl
33
+
34
+ def val_dataloader(self):
35
+ return dl
36
+
37
+ def configure_optimizers(self):
38
+ return None
39
+
40
+ model = MyModel()
41
+ trainer = pl.Trainer()
42
+ trainer.validate(model)
amt/src/extras/demo_cross_augmentation.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The YourMT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Please see the details in the LICENSE file.
10
+ from typing import Dict, Tuple
11
+ from copy import deepcopy
12
+ import soundfile as sf
13
+ import torch
14
+ from utils.data_modules import AMTDataModule
15
+ from config.data_presets import data_preset_single_cfg, data_preset_multi_cfg
16
+ from utils.augment import intra_stem_augment_processor
17
+
18
+
19
+ def get_ds(data_preset_multi: Dict, train_num_samples_per_epoch: int = 90000):
20
+ dm = AMTDataModule(data_preset_multi=data_preset_multi, train_num_samples_per_epoch=train_num_samples_per_epoch)
21
+ dm.setup('fit')
22
+ dl = dm.train_dataloader()
23
+ ds = dl.flattened[0].dataset
24
+ return ds
25
+
26
+
27
+ def debug_func(num_segments: int = 10):
28
+ sampled_data, sampled_ids = ds._get_rand_segments_from_cache(num_segments)
29
+ ux_sampled_data, _ = ds._get_rand_segments_from_cache(ux_count_sum, False, sampled_ids)
30
+ s = deepcopy(sampled_data)
31
+ intra_stem_augment_processor(sampled_data, submix_audio=False)
32
+
33
+
34
+ def gen_audio(index: int = 0):
35
+ # audio_arr: (b, 1, nframe), note_token_arr: (b, l), task_token_arr: (b, task_l)
36
+ audio_arr, note_token_arr, task_token_arr = ds.__getitem__(index)
37
+
38
+ # merge all the segments into one audio file
39
+ audio = audio_arr.permute(0, 2, 1).reshape(-1).squeeze().numpy()
40
+
41
+ # save the audio file
42
+ sf.write('xaug_demo_audio.wav', audio, 16000, subtype='PCM_16')
43
+
44
+
45
+ data_preset_multi = data_preset_multi_cfg["all_cross_rebal5"]
46
+ ds = get_ds(data_preset_multi)
47
+ ds.random_amp_range = [0.8, 1.1]
48
+ ds.stem_xaug_policy = {
49
+ "max_k": 5,
50
+ "tau": 0.3,
51
+ "alpha": 1.0,
52
+ "max_subunit_stems": 12,
53
+ "no_instr_overlap": True,
54
+ "no_drum_overlap": True,
55
+ "uhat_intra_stem_augment": True,
56
+ }
57
+ gen_audio(3)
58
+
59
+ # for k in ds.cache.keys():
60
+ # arr = ds.cache[k]['audio_array']
61
+ # arr = np.sum(arr, axis=1).reshape(-1)
62
+ # # sf.write(f'xxx/{k}.wav', arr, 16000, subtype='PCM_16')
63
+ # if np.min(arr) > -0.5:
64
+ # print(k)
65
+
66
+ # arr = ds.cache[52]['audio_array']
67
+ # for i in range(arr.shape[1]):
68
+ # a = arr[:, i, :].reshape(-1)
69
+ # sf.write(f'xxx52/52_{i}.wav', a, 16000, subtype='PCM_16')
amt/src/extras/demo_intra_augmentation.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The YourMT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Please see the details in the LICENSE file.
10
+ import numpy as np
11
+ import torch
12
+ import json
13
+ import soundfile as sf
14
+ from utils.datasets_train import get_cache_data_loader
15
+
16
+
17
+ def get_filelist(track_id: int) -> dict:
18
+ filelist = '../../data/yourmt3_indexes/slakh_train_file_list.json'
19
+ with open(filelist, 'r') as f:
20
+ fl = json.load(f)
21
+ new_filelist = dict()
22
+ for key, value in fl.items():
23
+ if int(key) == track_id:
24
+ new_filelist[0] = value
25
+ return new_filelist
26
+
27
+
28
+ def get_ds(track_id: int, random_amp_range: list = [1., 1.], stem_aug_prob: float = 0.8):
29
+ filelist = get_filelist(track_id)
30
+ dl = get_cache_data_loader(filelist,
31
+ 'train',
32
+ 1,
33
+ 1,
34
+ random_amp_range=random_amp_range,
35
+ stem_aug_prob=stem_aug_prob,
36
+ shuffle=False)
37
+ ds = dl.dataset
38
+ return ds
39
+
40
+
41
+ def gen_audio(track_id: int, n_segments: int = 30, random_amp_range: list = [1., 1.], stem_aug_prob: float = 0.8):
42
+ ds = get_ds(track_id, random_amp_range, stem_aug_prob)
43
+ audio = []
44
+ for i in range(n_segments):
45
+ audio.append(ds.__getitem__(0)[0])
46
+ # audio.append(ds.__getitem__(i)[0])
47
+
48
+ audio = torch.concat(audio, dim=2).numpy()[0, 0, :]
49
+ sf.write('audio.wav', audio, 16000, subtype='PCM_16')
50
+
51
+
52
+ gen_audio(1, 20)
amt/src/extras/download_mirst500.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ from pytube import YouTube
5
+
6
+
7
+ def downloadMp3(yt, idx, askPath=0):
8
+ # extract only audio
9
+ video = yt.streams.filter(only_audio=True).first()
10
+
11
+ destination = 'mp3File'
12
+ # check for destination to save file
13
+ if (askPath == 1):
14
+ print("Enter the destination (leave blank for default dir mp3File)")
15
+ destination = str(input(">> ")) or 'mp3File'
16
+
17
+ # download the file
18
+ out_file = video.download(output_path=destination)
19
+
20
+ # save the file
21
+ # base, ext = os.path.splitext(out_file)
22
+ dir_path, file_base = os.path.split(out_file)
23
+
24
+ new_file = os.path.join(dir_path, f'{idx}.mp3')
25
+ os.rename(out_file, new_file)
26
+ # result of success
27
+ print(yt.title + " has been successfully downloaded.")
28
+
29
+
30
+ MISSING_FILE_IDS = [
31
+ 16, 26, 33, 38, 40, 50, 53, 55, 60, 81, 82, 98, 107, 122, 126, 127, 129, 141, 145, 150, 172,
32
+ 201, 205, 206, 215, 216, 221, 226, 232, 240, 243, 245, 255, 257, 267, 273, 278, 279, 285, 287,
33
+ 291, 304, 312, 319, 321, 325, 329, 332, 333, 336, 337, 342, 359, 375, 402, 417, 438, 445, 454,
34
+ 498
35
+ ]
36
+
37
+ data_link_file = '../../../data/mir_St500_yourmt3_16k/MIR-ST500_20210206/MIR-ST500_link.json'
38
+ data_link = json.load(open(data_link_file, 'r'))
39
+ download_fail = []
40
+
41
+ for i in MISSING_FILE_IDS:
42
+ print(f'Downloading {i}...')
43
+ yt = YouTube(data_link[str(i)])
44
+ try:
45
+ downloadMp3(yt, idx=i)
46
+ except:
47
+ download_fail.append(i)
48
+ print(f'Failed to download {i}.')
49
+
50
+ print(f'Failed to download {len(download_fail)} files: {download_fail}')
amt/src/extras/fig/label_smooth_interval_of_interest.png ADDED
amt/src/extras/fig/pitchshift_benchnmark.png ADDED
amt/src/extras/fig/pitchshift_stretch_and_resampler_process_time.png ADDED
amt/src/extras/inspecting_slakh_bass.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mirdata
2
+ from utils.mirdata_dev.datasets import slakh16k
3
+
4
+ ds = slakh16k.Dataset(data_home='../../data', version='2100-yourmt3-16k')
5
+ mtrack_ids = ds.mtrack_ids
6
+
7
+ # Collect plugin names
8
+ plugin_names = set()
9
+ cnt = 0
10
+ for mtrack_id in mtrack_ids:
11
+ mtrack = ds.multitrack(mtrack_id)
12
+ for track_id in mtrack.track_ids:
13
+ track = ds.track(track_id)
14
+ if track.instrument.lower() == 'bass':
15
+ if track.plugin_name == 'upright_bass.nkm':
16
+ print(f'{str(cnt)}: {track_id}: {track.plugin_name}')
17
+ # if track.plugin_name not in plugin_names:
18
+ # plugin_names.add(track.plugin_name)
19
+ # print(f'{str(cnt)}: {track_id}: {track.plugin_name}')
20
+ # cnt += 1
21
+ """
22
+ 0: Track00001-S03: scarbee_rickenbacker_bass_palm_muted.nkm
23
+ 1: Track00002-S01: classic_bass.nkm
24
+ 2: Track00004-S01: scarbee_rickenbacker_bass.nkm
25
+ 3: Track00005-S04: scarbee_jay_bass_both.nkm
26
+ 4: Track00006-S03: pop_bass.nkm
27
+ 5: Track00008-S00: scarbee_pre_bass.nkm
28
+ 6: Track00013-S00: jazz_upright.nkm
29
+ 7: Track00014-S01: funk_bass.nkm
30
+ 8: Track00016-S01: scarbee_mm_bass.nkm
31
+ 9: Track00024-S07: upright_bass.nkm
32
+ 10: Track00027-S03: scarbee_jay_bass_slap_both.nkm
33
+ 11: Track00094-S08: upright_bass2.nkm
34
+ """
amt/src/extras/install_deepspeed.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ # not required on pytorch 2.0:latest container
4
+ pip install cupy-cuda11x -f https://pip.cupy.dev/aarch64
5
+
6
+ apt-get update
7
+ apt-get install git
8
+ apt-get install libaio-dev
9
+
10
+ DS_BUILD_OPS=1 pip install deepspeed
11
+ ds_report
12
+
13
+
14
+ pip install deepspeed==0.7.7
15
+
16
+ git clone https://github.com/NVIDIA/apex
17
+ cd apex
18
+ pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
19
+
20
+ In case you have trouble building apex from source we recommend using the NGC containers
21
+ from here which come with a pre-built PyTorch and apex release.
22
+
23
+ nvcr.io/nvidia/pytorch:23.01-py3
24
+
25
+ pip install deepspeed, pip install transformers[deepspeed]
26
+ https://www.deepspeed.ai/docs/config-json/#autotuning
27
+
28
+ """
amt/src/extras/label_smoothing.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+
5
+ a = torch.signal.windows.gaussian(11, sym=True, std=3)
6
+ plt.plot(a)
7
+
8
+
9
+ def gaussian_smoothing(y_hot, mu=5, sigma=0.865):
10
+ """
11
+ y_hot: one-hot encoded array
12
+ """
13
+ #sigma = np.sqrt(np.abs(np.log(0.05) / ((4 - mu)**2))) / 2
14
+
15
+ # Generate index array
16
+ i = np.arange(len(y_hot))
17
+
18
+ # Gaussian function
19
+ y_smooth = np.exp(-(i - mu)**2 / (2 * sigma**2))
20
+
21
+ # Normalize the resulting array
22
+ y_smooth /= y_smooth.sum()
23
+ return y_smooth, sigma
24
+
25
+
26
+ # y_ls = (1 - α) * y_hot + α / K, where K is the number of classes, alpha is the smoothing parameter
27
+
28
+ y_hot = torch.zeros(11)
29
+ y_hot[5] = 1
30
+ plt.plot(y_hot, 'b.-')
31
+
32
+ alpha = 0.3
33
+ y_ls = (1 - alpha) * y_hot + alpha / 10
34
+ plt.plot(y_ls, 'r.-')
35
+
36
+ y_gs, std = gaussian_smoothing(y_hot, A=0.5)
37
+ plt.plot(y_gs, 'g.-')
38
+
39
+ y_gst_a, std = gaussian_smoothing(y_hot, A=0.5, mu=5.5)
40
+ plt.plot(y_gst_a, 'y.-')
41
+
42
+ y_gst_b, std = gaussian_smoothing(y_hot, A=0.5, mu=5.8)
43
+ plt.plot(y_gst_b, 'c.-')
44
+
45
+ plt.legend([
46
+ 'y_hot', 'label smoothing' + '\n' + '(alpha=0.3)',
47
+ 'gaussian smoothing' + '\n' + 'for interval of interest' + '\n' + 'mu=5',
48
+ 'gaussian smoothing' + '\n' + 'mu=5.5', 'gaussian smoothing' + '\n' + 'mu=5.8'
49
+ ])
50
+
51
+ plt.grid()
52
+ plt.xticks(np.arange(11), np.arange(0, 110, 10))
53
+ plt.xlabel('''Time (ms)
54
+ original (quantized) one hot label:
55
+ [0,0,0,0,0,1,0,0,0,0,0]
56
+ \n
57
+ label smooting is defined as:
58
+ y_ls = (1 - α) * y_hot + α / K,
59
+ where K is the number of classes, α is the smoothing parameter
60
+ \n
61
+ gaussian smoothing for the interval (± 10ms) of interest:
62
+ y_gs = A * exp(-(i - mu)**2 / (2 * sigma**2))
63
+ with sigma = 0.865 an mu = 5
64
+ \n
65
+ gaussian smoothing with unqunatized target timing:
66
+ mu = 5.5 for 55ms target timing
67
+ ''')
amt/src/extras/multi_channel_seqlen_stats.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The YourMT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Please see the details in the LICENSE file.
10
+ from typing import Dict, Tuple
11
+ from copy import deepcopy
12
+ from collections import Counter
13
+ import numpy as np
14
+ import torch
15
+ from utils.data_modules import AMTDataModule
16
+ from utils.task_manager import TaskManager
17
+ from config.data_presets import data_preset_single_cfg, data_preset_multi_cfg
18
+ from utils.augment import intra_stem_augment_processor
19
+
20
+
21
+ def get_ds(data_preset_multi: Dict, task_name: str, train_num_samples_per_epoch: int = 90000):
22
+ tm = TaskManager(task_name=task_name)
23
+ tm.max_note_token_length_per_ch = 1024 # only to check the max length
24
+ dm = AMTDataModule(data_preset_multi=data_preset_multi,
25
+ task_manager=tm,
26
+ train_num_samples_per_epoch=train_num_samples_per_epoch)
27
+ dm.setup('fit')
28
+ dl = dm.train_dataloader()
29
+ ds = dl.flattened[0].dataset
30
+ return ds
31
+
32
+
33
+ data_preset_multi = data_preset_multi_cfg["all_cross_v6"]
34
+ task_name = "mc13" # "mt3_full_plus"
35
+ ds = get_ds(data_preset_multi, task_name=task_name)
36
+ ds.random_amp_range = [0.8, 1.1]
37
+ ds.stem_xaug_policy = {
38
+ "max_k": 5,
39
+ "tau": 0.3,
40
+ "alpha": 1.0,
41
+ "max_subunit_stems": 12,
42
+ "no_instr_overlap": True,
43
+ "no_drum_overlap": True,
44
+ "uhat_intra_stem_augment": True,
45
+ }
46
+
47
+ length_all = []
48
+ for i in range(40000):
49
+ if i % 5000 == 0:
50
+ print(i)
51
+ audio_arr, note_token_arr, task_totken_arr, pshift_steps = ds.__getitem__(i)
52
+ lengths = torch.sum(note_token_arr != 0, dim=2).flatten().cpu().tolist()
53
+ length_all.extend(lengths)
54
+
55
+ length_all = np.asarray(length_all)
56
+
57
+ # stats
58
+ empty_sequence = np.sum(length_all < 3) / len(length_all) * 100
59
+ print("empty_sequences:", f"{empty_sequence:.2f}", "%")
60
+
61
+ mean_except_empty = np.mean(length_all[length_all > 2])
62
+ print("mean_except_empty:", mean_except_empty)
63
+
64
+ median_except_empty = np.median(length_all[length_all > 2])
65
+ print("median_except_empty:", median_except_empty)
66
+
67
+ ch_less_than_768 = np.sum(length_all < 768) / len(length_all) * 100
68
+ print("ch_less_than_768:", f"{ch_less_than_768:.2f}", "%")
69
+
70
+ ch_larger_than_512 = np.sum(length_all > 512) / len(length_all) * 100
71
+ print("ch_larger_than_512:", f"{ch_larger_than_512:.6f}", "%")
72
+
73
+ ch_larger_than_256 = np.sum(length_all > 256) / len(length_all) * 100
74
+ print("ch_larger_than_256:", f"{ch_larger_than_256:.6f}", "%")
75
+
76
+ ch_larger_than_128 = np.sum(length_all > 128) / len(length_all) * 100
77
+ print("ch_larger_than_128:", f"{ch_larger_than_128:.6f}", "%")
78
+
79
+ ch_larger_than_64 = np.sum(length_all > 64) / len(length_all) * 100
80
+ print("ch_larger_than_64:", f"{ch_larger_than_64:.6f}", "%")
81
+
82
+ song_length_all = length_all.reshape(-1, 13)
83
+ song_larger_than_512 = 0
84
+ song_larger_than_256 = 0
85
+ song_larger_than_128 = 0
86
+ song_larger_than_64 = 0
87
+ for l in song_length_all:
88
+ if np.sum(l > 512) > 0:
89
+ song_larger_than_512 += 1
90
+ if np.sum(l > 256) > 0:
91
+ song_larger_than_256 += 1
92
+ if np.sum(l > 128) > 0:
93
+ song_larger_than_128 += 1
94
+ if np.sum(l > 64) > 0:
95
+ song_larger_than_64 += 1
96
+ num_songs = len(song_length_all)
97
+ print("song_larger_than_512:", f"{song_larger_than_512/num_songs*100:.4f}", "%")
98
+ print("song_larger_than_256:", f"{song_larger_than_256/num_songs*100:.4f}", "%")
99
+ print("song_larger_than_128:", f"{song_larger_than_128/num_songs*100:.4f}", "%")
100
+ print("song_larger_than_64:", f"{song_larger_than_64/num_songs*100:.4f}", "%")
101
+
102
+ instr_dict = {
103
+ 0: "Piano",
104
+ 1: "Chromatic Percussion",
105
+ 2: "Organ",
106
+ 3: "Guitar",
107
+ 4: "Bass",
108
+ 5: "Strings + Ensemble",
109
+ 6: "Brass",
110
+ 7: "Reed",
111
+ 8: "Pipe",
112
+ 9: "Synth Lead",
113
+ 10: "Synth Pad",
114
+ 11: "Singing",
115
+ 12: "Drums",
116
+ }
117
+ cnt_larger_than_512 = Counter()
118
+ for i in np.where(length_all > 512)[0] % 13:
119
+ cnt_larger_than_512[i] += 1
120
+ print("larger_than_512:")
121
+ for k, v in cnt_larger_than_512.items():
122
+ print(f" - {instr_dict[k]}: {v}")
123
+
124
+ cnt_larger_than_256 = Counter()
125
+ for i in np.where(length_all > 256)[0] % 13:
126
+ cnt_larger_than_256[i] += 1
127
+ print("larger_than_256:")
128
+ for k, v in cnt_larger_than_256.items():
129
+ print(f" - {instr_dict[k]}: {v}")
130
+
131
+ cnt_larger_than_128 = Counter()
132
+ for i in np.where(length_all > 128)[0] % 13:
133
+ cnt_larger_than_128[i] += 1
134
+ print("larger_than_128:")
135
+ for k, v in cnt_larger_than_128.items():
136
+ print(f" - {instr_dict[k]}: {v}")
137
+ """
138
+ empty_sequences: 91.06 %
139
+ mean_except_empty: 36.68976799156269
140
+ median_except_empty: 31.0
141
+ ch_less_than_768: 100.00 %
142
+ ch_larger_than_512: 0.000158 %
143
+ ch_larger_than_256: 0.015132 %
144
+ ch_larger_than_128: 0.192061 %
145
+ ch_larger_than_64: 0.661260 %
146
+ song_larger_than_512: 0.0021 %
147
+ song_larger_than_256: 0.1926 %
148
+ song_larger_than_128: 2.2280 %
149
+ song_larger_than_64: 6.1033 %
150
+
151
+ larger_than_512:
152
+ - Guitar: 7
153
+ - Strings + Ensemble: 3
154
+ larger_than_256:
155
+ - Piano: 177
156
+ - Guitar: 680
157
+ - Strings + Ensemble: 79
158
+ - Organ: 2
159
+ - Chromatic Percussion: 11
160
+ - Bass: 1
161
+ - Synth Lead: 2
162
+ - Brass: 1
163
+ - Reed: 5
164
+ larger_than_128:
165
+ - Guitar: 4711
166
+ - Strings + Ensemble: 1280
167
+ - Piano: 5548
168
+ - Bass: 211
169
+ - Synth Pad: 22
170
+ - Pipe: 18
171
+ - Chromatic Percussion: 55
172
+ - Synth Lead: 22
173
+ - Organ: 75
174
+ - Reed: 161
175
+ - Brass: 45
176
+ - Drums: 11
177
+ """
amt/src/extras/npy_speed_benchmark.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tasks.utils.event_codec import Event, EventRange
3
+ from tasks.utils import event_codec
4
+
5
+ ec = event_codec.Codec(
6
+ max_shift_steps=1000, # this means 0,1,...,1000
7
+ steps_per_second=100,
8
+ event_ranges=[
9
+ EventRange('pitch', min_value=0, max_value=127),
10
+ EventRange('velocity', min_value=0, max_value=1),
11
+ EventRange('tie', min_value=0, max_value=0),
12
+ EventRange('program', min_value=0, max_value=127),
13
+ EventRange('drum', min_value=0, max_value=127),
14
+ ],
15
+ )
16
+
17
+ events = [
18
+ Event(type='shift', value=0), # actually not needed
19
+ Event(type='shift', value=1), # 10 ms shift
20
+ Event(type='shift', value=1000), # 10 s shift
21
+ Event(type='pitch', value=0), # lowest pitch 8.18 Hz
22
+ Event(type='pitch', value=60), # C4 or 261.63 Hz
23
+ Event(type='pitch', value=127), # highest pitch G9 or 12543.85 Hz
24
+ Event(type='velocity', value=0), # lowest velocity)
25
+ Event(type='velocity', value=1), # lowest velocity)
26
+ Event(type='tie', value=0), # tie
27
+ Event(type='program', value=0), # program
28
+ Event(type='program', value=127), # program
29
+ Event(type='drum', value=0), # drum
30
+ Event(type='drum', value=127), # drum
31
+ ]
32
+
33
+ events = events * 100
34
+ tokens = [ec.encode_event(e) for e in events]
35
+ tokens = np.array(tokens, dtype=np.int16)
36
+
37
+ import csv
38
+ # Save events to a CSV file
39
+ with open('events.csv', 'w', newline='') as file:
40
+ writer = csv.writer(file)
41
+ for event in events:
42
+ writer.writerow([event.type, event.value])
43
+
44
+ # Load events from a CSV file
45
+ with open('events.csv', 'r') as file:
46
+ reader = csv.reader(file)
47
+ events2 = [Event(row[0], int(row[1])) for row in reader]
48
+
49
+
50
+ import json
51
+ # Save events to a JSON file
52
+ with open('events.json', 'w') as file:
53
+ json.dump([event.__dict__ for event in events], file)
54
+
55
+ # Load events from a JSON file
56
+ with open('events.json', 'r') as file:
57
+ events = [Event(**event_dict) for event_dict in json.load(file)]
58
+
59
+
60
+
61
+
62
+ """----------------------------"""
63
+ # Write the tokens to a npy file
64
+ import numpy as np
65
+ np.save('tokens.npy', tokens)
66
+
67
+ def t_npy():
68
+ t = np.load('tokens.npy', allow_pickle=True) # allow pickle doesn't affect speed
69
+
70
+ os.makedirs('temp', exist_ok=True)
71
+ for i in range(2400):
72
+ np.save(f'temp/tokens{i}.npy', tokens)
73
+
74
+ def t_npy2400():
75
+ for i in range(2400):
76
+ t = np.load(f'temp/tokens{i}.npy')
77
+ def t_npy2400_take200():
78
+ for i in range(200):
79
+ t = np.load(f'temp/tokens{i}.npy')
80
+
81
+ import shutil
82
+ shutil.rmtree('temp', ignore_errors=True)
83
+
84
+ # Write the 2400 tokens to a single npy file
85
+ data = dict()
86
+ for i in range(2400):
87
+ data[f'arr{i}'] = tokens.copy()
88
+ np.save(f'tokens_2400x.npy', data)
89
+ def t_npy2400single():
90
+ t = np.load('tokens_2400x.npy', allow_pickle=True).item()
91
+
92
+ def t_mmap2400single():
93
+ t = np.load('tokens_2400x.npy', mmap_mode='r')
94
+
95
+ # Write the tokens to a npz file
96
+ np.savez('tokens.npz', arr0=tokens)
97
+ def t_npz():
98
+ npz_file = np.load('tokens.npz')
99
+ tt = npz_file['arr0']
100
+
101
+ data = dict()
102
+ for i in range(2400):
103
+ data[f'arr{i}'] = tokens
104
+ np.savez('tokens.npz', **data )
105
+ def t_npz2400():
106
+ npz_file = np.load('tokens.npz')
107
+ for i in range(2400):
108
+ tt = npz_file[f'arr{i}']
109
+
110
+ def t_npz2400_take200():
111
+ npz_file = np.load('tokens.npz')
112
+ # npz_file.files
113
+ for i in range(200):
114
+ tt = npz_file[f'arr{i}']
115
+
116
+
117
+ # Write the tokens to a txt file
118
+ with open('tokens.txt', 'w') as file:
119
+ file.write(' '.join(map(str, tokens)))
120
+
121
+ def t_txt():
122
+ # Read the tokens from the file
123
+ with open('tokens.txt', 'r') as file:
124
+ t = list(map(int, file.read().split()))
125
+ t = np.array(t)
126
+
127
+
128
+ # Write the tokens to a CSV file
129
+ with open('tokens.csv', 'w', newline='') as file:
130
+ writer = csv.writer(file)
131
+ writer.writerow(tokens)
132
+
133
+ def t_csv():
134
+ # Read the tokens from the CSV file
135
+ with open('tokens.csv', 'r') as file:
136
+ reader = csv.reader(file)
137
+ t = list(map(int, next(reader)))
138
+ t = np.array(t)
139
+
140
+
141
+ # Write the tokens to a JSON file
142
+ with open('tokens.json', 'w') as file:
143
+ json.dump(tokens, file)
144
+
145
+ def t_json():
146
+ # Read the tokens from the JSON file
147
+ with open('tokens.json', 'r') as file:
148
+ t = json.load(file)
149
+ t = np.array(t)
150
+
151
+ with open('tokens_2400x.json', 'w') as file:
152
+ json.dump(data, file)
153
+
154
+ def t_json2400single():
155
+ # Read the tokens from the JSON file
156
+ with open('tokens_2400x.json', 'r') as file:
157
+ t = json.load(file)
158
+
159
+ def t_mmap():
160
+ t = np.load('tokens.npy', mmap_mode='r')
161
+
162
+ # Write the tokens to bytes file
163
+
164
+
165
+
166
+
167
+ np.savetxt('tokens.ntxt', tokens)
168
+ def t_ntxt():
169
+ t = np.loadtxt('tokens.ntxt').astype(np.int32)
170
+
171
+ %timeit t_npz() # 139 us
172
+ %timeit t_mmap() # 3.12 ms
173
+ %timeit t_npy() # 87.8 us
174
+ %timeit t_txt() # 109 152 us
175
+ %timeit t_csv() # 145 190 us
176
+ %timeit t_json() # 72.8 119 us
177
+ %timeit t_ntxt() # 878 us
178
+
179
+ %timeit t_npy2400() # 212 ms; 2400 files in a folder
180
+ %timeit t_npz2400() # 296 ms; uncompreesed 1000 arrays in a single file
181
+
182
+ %timeit t_npy2400_take200() # 17.4 ms; 25 Mb
183
+ %timeit t_npz2400_take200() # 28.8 ms; 3.72 ms for 10 arrays; 25 Mb
184
+ %timeit t_npy2400single() # 4 ms; frozen dictionary containing 2400 arrays; 6.4 Mb; int16
185
+ %timeit t_mmap2400single() # dictionary is not supported
186
+ %timeit t_json2400single() # 175 ms; 17 Mb
187
+ # 2400 files from 100ms hop for 4 minutes
amt/src/extras/perceivertf_inspect.py ADDED
@@ -0,0 +1,640 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def l2_normalize(matrix):
7
+ """
8
+ L2 Normalize the matrix along its rows.
9
+
10
+ Parameters:
11
+ matrix (numpy.ndarray): The input matrix.
12
+
13
+ Returns:
14
+ numpy.ndarray: The L2 normalized matrix.
15
+ """
16
+ l2_norms = np.linalg.norm(matrix, axis=1, keepdims=True)
17
+ normalized_matrix = matrix / l2_norms
18
+ return normalized_matrix
19
+
20
+
21
+ def z_normalize(matrix):
22
+ """
23
+ Z-normalize the matrix along its rows (mean=0 and std=1).
24
+ Z-normalization is also known as "standardization", and derives from z-score.
25
+ Z = (X - mean) / std
26
+ Z-nomarlized, each row has mean=0 and std=1.
27
+
28
+ Parameters:
29
+ matrix (numpy.ndarray): The input matrix.
30
+
31
+ Returns:
32
+ numpy.ndarray: The Z normalized matrix.
33
+ """
34
+ mean = np.mean(matrix, axis=1, keepdims=True)
35
+ std = np.std(matrix, axis=1, keepdims=True)
36
+ normalized_matrix = (matrix - mean) / std
37
+ return normalized_matrix
38
+
39
+
40
+ def l2_normalize_tensors(tensor_tuple):
41
+ """
42
+ Applies L2 normalization on the last two dimensions for each tensor in a tuple.
43
+
44
+ Parameters:
45
+ tensor_tuple (tuple of torch.Tensor): A tuple containing N tensors, each of shape (1, k, 30, 30).
46
+
47
+ Returns:
48
+ tuple of torch.Tensor: A tuple containing N L2-normalized tensors.
49
+ """
50
+ normalized_tensors = []
51
+ for tensor in tensor_tuple:
52
+ # Ensure the tensor is a floating-point type
53
+ tensor = tensor.float()
54
+
55
+ # Calculate L2 norm on the last two dimensions, keeping the dimensions using keepdim=True
56
+ l2_norm = torch.linalg.norm(tensor, dim=(-2, -1), keepdim=True)
57
+
58
+ # Apply L2 normalization
59
+ normalized_tensor = tensor / (
60
+ l2_norm + 1e-7) # Small value to avoid division by zero
61
+
62
+ normalized_tensors.append(normalized_tensor)
63
+
64
+ return tuple(normalized_tensors)
65
+
66
+
67
+ def z_normalize_tensors(tensor_tuple):
68
+ """
69
+ Applies Z-normalization on the last two dimensions for each tensor in a tuple.
70
+
71
+ Parameters:
72
+ tensor_tuple (tuple of torch.Tensor): A tuple containing N tensors, each of shape (1, k, 30, 30).
73
+
74
+ Returns:
75
+ tuple of torch.Tensor: A tuple containing N Z-normalized tensors.
76
+ """
77
+ normalized_tensors = []
78
+ for tensor in tensor_tuple:
79
+ # Ensure the tensor is a floating-point type
80
+ tensor = tensor.float()
81
+
82
+ # Calculate mean and std on the last two dimensions
83
+ mean = tensor.mean(dim=(-2, -1), keepdim=True)
84
+ std = tensor.std(dim=(-2, -1), keepdim=True)
85
+
86
+ # Apply Z-normalization
87
+ normalized_tensor = (tensor - mean) / (
88
+ std + 1e-7) # Small value to avoid division by zero
89
+
90
+ normalized_tensors.append(normalized_tensor)
91
+
92
+ return tuple(normalized_tensors)
93
+
94
+
95
+ def apply_temperature_to_attention_tensors(tensor_tuple, temperature=1.0):
96
+ """
97
+ Applies temperature scaling to the attention weights in each tensor in a tuple.
98
+
99
+ Parameters:
100
+ tensor_tuple (tuple of torch.Tensor): A tuple containing N tensors,
101
+ each of shape (1, k, 30, 30).
102
+ temperature (float): Temperature parameter to control the sharpness
103
+ of the attention weights. Default is 1.0.
104
+
105
+ Returns:
106
+ tuple of torch.Tensor: A tuple containing N tensors with scaled attention weights.
107
+ """
108
+ scaled_attention_tensors = []
109
+
110
+ for tensor in tensor_tuple:
111
+ # Ensure the tensor is a floating-point type
112
+ tensor = tensor.float()
113
+
114
+ # Flatten the last two dimensions
115
+ flattened_tensor = tensor.reshape(1, tensor.shape[1],
116
+ -1) # Modified line here
117
+
118
+ # Apply temperature scaling and softmax along the last dimension
119
+ scaled_attention = flattened_tensor / temperature
120
+ scaled_attention = F.softmax(scaled_attention, dim=-1)
121
+
122
+ # Reshape to original shape
123
+ scaled_attention = scaled_attention.view_as(tensor)
124
+
125
+ scaled_attention_tensors.append(scaled_attention)
126
+
127
+ return tuple(scaled_attention_tensors)
128
+
129
+
130
+ def shorten_att(tensor_tuple, length=30):
131
+ shortend_tensors = []
132
+ for tensor in tensor_tuple:
133
+ shortend_tensors.append(tensor[:, :, :length, :length])
134
+ return tuple(shortend_tensors)
135
+
136
+
137
+ def keep_top_k(matrix, k=6):
138
+ """
139
+ Keep only the top k values in each row, set the rest to 0.
140
+
141
+ Parameters:
142
+ matrix (numpy.ndarray): The input matrix.
143
+ k (int): The number of top values to keep in each row.
144
+
145
+ Returns:
146
+ numpy.ndarray: The transformed matrix.
147
+ """
148
+ topk_indices_per_row = np.argpartition(matrix, -k, axis=1)[:, -k:]
149
+ result_matrix = np.zeros_like(matrix)
150
+
151
+ for i in range(matrix.shape[0]):
152
+ result_matrix[i, topk_indices_per_row[i]] = matrix[
153
+ i, topk_indices_per_row[i]]
154
+ return result_matrix
155
+
156
+
157
+ def test_case_forward_enc_perceiver_tf_dec_t5():
158
+ import torch
159
+ from model.ymt3 import YourMT3
160
+ from config.config import audio_cfg, model_cfg, shared_cfg
161
+ model_cfg["encoder_type"] = "perceiver-tf"
162
+ model_cfg["encoder"]["perceiver-tf"]["attention_to_channel"] = True
163
+ model_cfg["encoder"]["perceiver-tf"]["num_latents"] = 24
164
+ model_cfg["decoder_type"] = "t5"
165
+ model_cfg["pre_decoder_type"] = "default"
166
+
167
+ audio_cfg["codec"] = "spec"
168
+ audio_cfg["hop_length"] = 300
169
+ model = YourMT3(audio_cfg=audio_cfg, model_cfg=model_cfg)
170
+ model.eval()
171
+
172
+ # x = torch.randn(2, 1, 32767)
173
+ # labels = torch.randint(0, 400, (2, 1024), requires_grad=False)
174
+
175
+ # # forward
176
+ # output = model.forward(x, labels)
177
+
178
+ # # inference
179
+ # result = model.inference(x, None)
180
+
181
+ # display latents
182
+ checkpoint = torch.load(
183
+ "../logs/ymt3/ptf_all_cross_rebal5_spec300_xk2_amp0811_edr_005_attend_c_full_plus_b52/checkpoints/model.ckpt",
184
+ map_location="cpu")
185
+ state_dict = checkpoint['state_dict']
186
+ new_state_dict = {
187
+ k: v
188
+ for k, v in state_dict.items() if 'pitchshift' not in k
189
+ }
190
+ model.load_state_dict(new_state_dict, strict=False)
191
+
192
+ latents = model.encoder.latent_array.latents.detach().numpy()
193
+ import matplotlib.pyplot as plt
194
+ import numpy as np
195
+ from sklearn.metrics.pairwise import cosine_similarity
196
+ cos = cosine_similarity(latents)
197
+
198
+ from utils.data_modules import AMTDataModule
199
+ from einops import rearrange
200
+ dm = AMTDataModule(data_preset_multi={"presets": ["slakh"]})
201
+ dm.setup("test")
202
+ dl = dm.test_dataloader()
203
+ ds = list(dl.values())[0].dataset
204
+ audio, notes, tokens, _ = ds.__getitem__(7)
205
+ x = audio[[16], ::]
206
+ label = tokens[[16], :]
207
+ # spectrogram
208
+ x_spec = model.spectrogram(x)
209
+ plt.imshow(x_spec[0].detach().numpy().T, aspect='auto', origin='lower')
210
+ plt.title("spectrogram")
211
+ plt.xlabel('time step')
212
+ plt.ylabel('frequency bin')
213
+ plt.show()
214
+ x_conv = model.pre_encoder(x_spec)
215
+ # Create a larger figure
216
+ plt.figure(
217
+ figsize=(15,
218
+ 10)) # Adjust these numbers as needed for width and height
219
+ plt.subplot(2, 4, 1)
220
+ plt.imshow(x_spec[0].detach().numpy().T, aspect='auto', origin='lower')
221
+ plt.title("spectrogram")
222
+ plt.xlabel('time step')
223
+ plt.ylabel('frequency bin')
224
+ plt.subplot(2, 4, 2)
225
+ plt.imshow(x_conv[0][:, :, 0].detach().numpy().T,
226
+ aspect='auto',
227
+ origin='lower')
228
+ plt.title("conv(spec), ch=0")
229
+ plt.xlabel('time step')
230
+ plt.ylabel('F')
231
+ plt.subplot(2, 4, 3)
232
+ plt.imshow(x_conv[0][:, :, 42].detach().numpy().T,
233
+ aspect='auto',
234
+ origin='lower')
235
+ plt.title("ch=42")
236
+ plt.xlabel('time step')
237
+ plt.ylabel('F')
238
+ plt.subplot(2, 4, 4)
239
+ plt.imshow(x_conv[0][:, :, 80].detach().numpy().T,
240
+ aspect='auto',
241
+ origin='lower')
242
+ plt.title("ch=80")
243
+ plt.xlabel('time step')
244
+ plt.ylabel('F')
245
+ plt.subplot(2, 4, 5)
246
+ plt.imshow(x_conv[0][:, :, 11].detach().numpy().T,
247
+ aspect='auto',
248
+ origin='lower')
249
+ plt.title("ch=11")
250
+ plt.xlabel('time step')
251
+ plt.ylabel('F')
252
+ plt.subplot(2, 4, 6)
253
+ plt.imshow(x_conv[0][:, :, 20].detach().numpy().T,
254
+ aspect='auto',
255
+ origin='lower')
256
+ plt.title("ch=20")
257
+ plt.xlabel('time step')
258
+ plt.ylabel('F')
259
+ plt.subplot(2, 4, 7)
260
+ plt.imshow(x_conv[0][:, :, 77].detach().numpy().T,
261
+ aspect='auto',
262
+ origin='lower')
263
+ plt.title("ch=77")
264
+ plt.xlabel('time step')
265
+ plt.ylabel('F')
266
+ plt.subplot(2, 4, 8)
267
+ plt.imshow(x_conv[0][:, :, 90].detach().numpy().T,
268
+ aspect='auto',
269
+ origin='lower')
270
+ plt.title("ch=90")
271
+ plt.xlabel('time step')
272
+ plt.ylabel('F')
273
+ plt.tight_layout()
274
+ plt.show()
275
+
276
+ # encoding
277
+ output = model.encoder(inputs_embeds=x_conv,
278
+ output_hidden_states=True,
279
+ output_attentions=True)
280
+ enc_hs_all, att, catt = output["hidden_states"], output[
281
+ "attentions"], output["cross_attentions"]
282
+ enc_hs_last = enc_hs_all[2]
283
+
284
+ # enc_hs: time-varying encoder hidden state
285
+ plt.subplot(2, 3, 1)
286
+ plt.imshow(enc_hs_all[0][0][:, :, 21].detach().numpy().T)
287
+ plt.title('ENC_HS B0, d21')
288
+ plt.colorbar(orientation='horizontal')
289
+ plt.ylabel('latent k')
290
+ plt.xlabel('t')
291
+ plt.subplot(2, 3, 4)
292
+ plt.imshow(enc_hs_all[0][0][:, :, 127].detach().numpy().T)
293
+ plt.colorbar(orientation='horizontal')
294
+ plt.title('B0, d127')
295
+ plt.ylabel('latent k')
296
+ plt.xlabel('t')
297
+ plt.subplot(2, 3, 2)
298
+ plt.imshow(enc_hs_all[1][0][:, :, 21].detach().numpy().T)
299
+ plt.colorbar(orientation='horizontal')
300
+ plt.title('B1, d21')
301
+ plt.ylabel('latent k')
302
+ plt.xlabel('t')
303
+ plt.subplot(2, 3, 5)
304
+ plt.imshow(enc_hs_all[1][0][:, :, 127].detach().numpy().T)
305
+ plt.colorbar(orientation='horizontal')
306
+ plt.title('B1, d127')
307
+ plt.ylabel('latent k')
308
+ plt.xlabel('t')
309
+ plt.subplot(2, 3, 3)
310
+ plt.imshow(enc_hs_all[2][0][:, :, 21].detach().numpy().T)
311
+ plt.colorbar(orientation='horizontal')
312
+ plt.title('B2, d21')
313
+ plt.ylabel('latent k')
314
+ plt.xlabel('t')
315
+ plt.subplot(2, 3, 6)
316
+ plt.imshow(enc_hs_all[2][0][:, :, 127].detach().numpy().T)
317
+ plt.colorbar(orientation='horizontal')
318
+ plt.title('B2, d127')
319
+ plt.ylabel('latent k')
320
+ plt.xlabel('t')
321
+ plt.tight_layout()
322
+ plt.show()
323
+
324
+ enc_hs_proj = model.pre_decoder(enc_hs_last)
325
+ plt.imshow(enc_hs_proj[0].detach().numpy())
326
+ plt.title(
327
+ 'ENC_HS_PROJ: linear projection of encoder output, which is used for enc-dec cross attention'
328
+ )
329
+ plt.colorbar(orientation='horizontal')
330
+ plt.ylabel('latent k')
331
+ plt.xlabel('d')
332
+ plt.show()
333
+
334
+ plt.subplot(221)
335
+ plt.imshow(enc_hs_all[2][0][0, :, :].detach().numpy(), aspect='auto')
336
+ plt.title('enc_hs, t=0')
337
+ plt.ylabel('latent k')
338
+ plt.xlabel('d')
339
+ plt.subplot(222)
340
+ plt.imshow(enc_hs_all[2][0][10, :, :].detach().numpy(), aspect='auto')
341
+ plt.title('enc_hs, t=10')
342
+ plt.ylabel('latent k')
343
+ plt.xlabel('d')
344
+ plt.subplot(223)
345
+ plt.imshow(enc_hs_all[2][0][20, :, :].detach().numpy(), aspect='auto')
346
+ plt.title('enc_hs, t=20')
347
+ plt.ylabel('latent k')
348
+ plt.xlabel('d')
349
+ plt.subplot(224)
350
+ plt.imshow(enc_hs_all[2][0][30, :, :].detach().numpy(), aspect='auto')
351
+ plt.title('enc_hs, t=30')
352
+ plt.ylabel('latent k')
353
+ plt.xlabel('d')
354
+ plt.tight_layout()
355
+ plt.show()
356
+
357
+ # enc_hs correlation: which dim has most unique info?
358
+ plt.subplot(1, 3, 1)
359
+ a = rearrange(enc_hs_last, '1 t k d -> t (k d)').detach().numpy()
360
+ plt.imshow(cosine_similarity(a))
361
+ plt.title("enc hs, t x t cos_sim")
362
+ plt.subplot(1, 3, 2)
363
+ b = rearrange(enc_hs_last, '1 t k d -> k (t d)').detach().numpy()
364
+ plt.imshow(cosine_similarity(b))
365
+ plt.title("enc hs, k x k cos_sim")
366
+ plt.subplot(1, 3, 3)
367
+ c = rearrange(enc_hs_last, '1 t k d -> d (k t)').detach().numpy()
368
+ plt.imshow(cosine_similarity(c))
369
+ plt.title("cross att, d x d cos_sim")
370
+ plt.tight_layout()
371
+ plt.show()
372
+
373
+ # enc latent
374
+ plt.imshow(model.encoder.latent_array.latents.detach().numpy())
375
+ plt.title('latent array')
376
+ plt.xlabel('d')
377
+ plt.ylabel('latent k')
378
+ plt.show()
379
+
380
+ # enc Spectral Cross Attention: (T x head x K x D). How latent K attends to conv channel C?
381
+ plt.subplot(311)
382
+ plt.imshow(
383
+ torch.sum(torch.sum(catt[0][0], axis=0), axis=0).detach().numpy())
384
+ plt.title('block=0')
385
+ plt.ylabel('latent k')
386
+ plt.xlabel('conv channel')
387
+ plt.subplot(312)
388
+ plt.imshow(
389
+ torch.sum(torch.sum(catt[1][0], axis=0), axis=0).detach().numpy())
390
+ plt.title('block=1')
391
+ plt.ylabel('latent k')
392
+ plt.xlabel('conv channel')
393
+ plt.subplot(313)
394
+ plt.imshow(
395
+ torch.sum(torch.sum(catt[2][0], axis=0), axis=0).detach().numpy())
396
+ plt.title('block=2')
397
+ plt.ylabel('latent k')
398
+ plt.xlabel('conv channel')
399
+ plt.tight_layout()
400
+ plt.show()
401
+ # enc Latent Self-attention: How latent K attends to K?
402
+ plt.subplot(231)
403
+ plt.imshow(torch.sum(torch.sum(att[0][0], axis=1),
404
+ axis=0).detach().numpy(),
405
+ origin='upper')
406
+ plt.title('B0L0')
407
+ plt.xlabel('latent k')
408
+ plt.ylabel('latent k')
409
+ plt.subplot(234)
410
+ plt.imshow(torch.sum(torch.sum(att[0][1], axis=1),
411
+ axis=0).detach().numpy(),
412
+ origin='upper')
413
+ plt.title('B0L1')
414
+ plt.xlabel('latent k')
415
+ plt.ylabel('latent k')
416
+ plt.subplot(232)
417
+ plt.imshow(torch.sum(torch.sum(att[1][0], axis=1),
418
+ axis=0).detach().numpy(),
419
+ origin='upper')
420
+ plt.title('B1L0')
421
+ plt.xlabel('latent k')
422
+ plt.ylabel('latent k')
423
+ plt.subplot(235)
424
+ plt.imshow(torch.sum(torch.sum(att[1][1], axis=1),
425
+ axis=0).detach().numpy(),
426
+ origin='upper')
427
+ plt.title('B1L1')
428
+ plt.xlabel('latent k')
429
+ plt.ylabel('latent k')
430
+ plt.subplot(233)
431
+ plt.imshow(torch.sum(torch.sum(att[2][0], axis=1),
432
+ axis=0).detach().numpy(),
433
+ origin='upper')
434
+ plt.title('B2L0')
435
+ plt.xlabel('latent k')
436
+ plt.ylabel('latent k')
437
+ plt.subplot(236)
438
+ plt.imshow(torch.sum(torch.sum(att[2][1], axis=1),
439
+ axis=0).detach().numpy(),
440
+ origin='upper')
441
+ plt.title('B2L1')
442
+ plt.xlabel('latent k')
443
+ plt.ylabel('latent k')
444
+ plt.tight_layout()
445
+ plt.show()
446
+ # Time varying, different head for latent self-attention
447
+ plt.subplot(231)
448
+ plt.imshow(att[0][0][30, 3, :, :].detach().numpy())
449
+ plt.title('B0L0, t=30, Head=3')
450
+ plt.colorbar(orientation='horizontal')
451
+ plt.xlabel('k')
452
+ plt.ylabel('k')
453
+ plt.subplot(234)
454
+ plt.imshow(att[0][1][30, 3, :, :].detach().numpy())
455
+ plt.title('B0L1, t=30, Head=3')
456
+ plt.colorbar(orientation='horizontal')
457
+ plt.xlabel('k')
458
+ plt.ylabel('k')
459
+ plt.subplot(232)
460
+ plt.imshow(att[1][0][30, 3, :, :].detach().numpy())
461
+ plt.title('B1L0, t=30, Head=3')
462
+ plt.colorbar(orientation='horizontal')
463
+ plt.xlabel('k')
464
+ plt.ylabel('k')
465
+ plt.subplot(235)
466
+ plt.imshow(att[1][1][30, 3, :, :].detach().numpy())
467
+ plt.title('B1L1, t=30, Head=3')
468
+ plt.colorbar(orientation='horizontal')
469
+ plt.xlabel('k')
470
+ plt.ylabel('k')
471
+ plt.subplot(233)
472
+ plt.imshow(att[2][0][30, 3, :, :].detach().numpy())
473
+ plt.title('B2L0, t=30, Head=3')
474
+ plt.colorbar(orientation='horizontal')
475
+ plt.xlabel('k')
476
+ plt.ylabel('k')
477
+ plt.subplot(236)
478
+ plt.imshow(att[2][1][30, 3, :, :].detach().numpy())
479
+ plt.title('B2L1, t=30, Head=3')
480
+ plt.colorbar(orientation='horizontal')
481
+ plt.xlabel('k')
482
+ plt.ylabel('k')
483
+ plt.tight_layout()
484
+ plt.show()
485
+ plt.subplot(231)
486
+ plt.imshow(att[0][0][30, 5, :, :].detach().numpy())
487
+ plt.title('B0L0, t=30, Head=5')
488
+ plt.colorbar(orientation='horizontal')
489
+ plt.xlabel('k')
490
+ plt.ylabel('k')
491
+ plt.subplot(234)
492
+ plt.imshow(att[0][1][30, 5, :, :].detach().numpy())
493
+ plt.title('B0L1, t=30, Head=5')
494
+ plt.colorbar(orientation='horizontal')
495
+ plt.xlabel('k')
496
+ plt.ylabel('k')
497
+ plt.subplot(232)
498
+ plt.imshow(att[1][0][30, 5, :, :].detach().numpy())
499
+ plt.title('B1L0, t=30, Head=5')
500
+ plt.colorbar(orientation='horizontal')
501
+ plt.xlabel('k')
502
+ plt.ylabel('k')
503
+ plt.subplot(235)
504
+ plt.imshow(att[1][1][30, 5, :, :].detach().numpy())
505
+ plt.title('B1L1, t=30, Head=5')
506
+ plt.colorbar(orientation='horizontal')
507
+ plt.xlabel('k')
508
+ plt.ylabel('k')
509
+ plt.subplot(233)
510
+ plt.imshow(att[2][0][30, 5, :, :].detach().numpy())
511
+ plt.title('B2L0, t=30, Head=5')
512
+ plt.colorbar(orientation='horizontal')
513
+ plt.xlabel('k')
514
+ plt.ylabel('k')
515
+ plt.subplot(236)
516
+ plt.imshow(att[2][1][30, 5, :, :].detach().numpy())
517
+ plt.title('B2L1, t=30, Head=5')
518
+ plt.colorbar(orientation='horizontal')
519
+ plt.xlabel('k')
520
+ plt.ylabel('k')
521
+ plt.tight_layout()
522
+ plt.show()
523
+
524
+ # Temporal Self-attention: (K x H x T x T) How time t attends to time t?
525
+ plt.subplot(231)
526
+ plt.imshow(torch.sum(torch.sum(att[0][2], axis=1),
527
+ axis=0).detach().numpy(),
528
+ origin='upper')
529
+ plt.title('B0L2')
530
+ plt.xlabel('t')
531
+ plt.ylabel('t')
532
+ plt.subplot(234)
533
+ plt.imshow(torch.sum(torch.sum(att[0][3], axis=1),
534
+ axis=0).detach().numpy(),
535
+ origin='upper')
536
+ plt.title('B0L3')
537
+ plt.xlabel('t')
538
+ plt.ylabel('t')
539
+ plt.subplot(232)
540
+ plt.imshow(torch.sum(torch.sum(att[1][2], axis=1),
541
+ axis=0).detach().numpy(),
542
+ origin='upper')
543
+ plt.title('B1L2')
544
+ plt.xlabel('t')
545
+ plt.ylabel('t')
546
+ plt.subplot(235)
547
+ plt.imshow(torch.sum(torch.sum(att[1][3], axis=1),
548
+ axis=0).detach().numpy(),
549
+ origin='upper')
550
+ plt.title('B1L3')
551
+ plt.xlabel('t')
552
+ plt.ylabel('t')
553
+ plt.subplot(233)
554
+ plt.imshow(torch.sum(torch.sum(att[2][2], axis=1),
555
+ axis=0).detach().numpy(),
556
+ origin='upper')
557
+ plt.title('B2L2')
558
+ plt.xlabel('t')
559
+ plt.ylabel('t')
560
+ plt.subplot(236)
561
+ plt.imshow(torch.sum(torch.sum(att[2][3], axis=1),
562
+ axis=0).detach().numpy(),
563
+ origin='upper')
564
+ plt.title('B2L3')
565
+ plt.xlabel('t')
566
+ plt.ylabel('t')
567
+ plt.tight_layout()
568
+ plt.show()
569
+
570
+ # decoding
571
+ dec_input_ids = model.shift_right_fn(label)
572
+ dec_inputs_embeds = model.embed_tokens(dec_input_ids)
573
+ dec_output = model.decoder(inputs_embeds=dec_inputs_embeds,
574
+ encoder_hidden_states=enc_hs_proj,
575
+ output_attentions=True,
576
+ output_hidden_states=True,
577
+ return_dict=True)
578
+ dec_att, dec_catt = dec_output.attentions, dec_output.cross_attentions
579
+ dec_hs_all = dec_output.hidden_states
580
+
581
+ # dec att
582
+ plt.subplot(1, 2, 1)
583
+ plt.imshow(torch.sum(dec_att[0][0], axis=0).detach().numpy())
584
+ plt.title('decoder attention, layer0')
585
+ plt.xlabel('decoder time step')
586
+ plt.ylabel('decoder time step')
587
+ plt.subplot(1, 2, 2)
588
+ plt.imshow(torch.sum(dec_att[7][0], axis=0).detach().numpy())
589
+ plt.title('decoder attention, layer8')
590
+ plt.xlabel('decoder time step')
591
+ plt.show()
592
+ # dec catt
593
+ plt.imshow(np.rot90((torch.sum(dec_catt[7][0],
594
+ axis=0))[:1000, :].detach().numpy()),
595
+ origin='upper',
596
+ aspect='auto')
597
+ plt.colorbar()
598
+ plt.title('decoder cross att, layer8')
599
+ plt.xlabel('decoder time step')
600
+ plt.ylabel('encoder frame')
601
+ plt.show()
602
+ # dec catt by head with xxx
603
+ dec_att_z = z_normalize_tensors(shorten_att(dec_att))
604
+ plt.imshow(dec_att_z[0][0, 0, :, :].detach().numpy())
605
+ from bertviz import head_view
606
+ token = []
607
+ for i in label[0, :30]:
608
+ token.append(str(i))
609
+ head_view(dec_att_z, tokens)
610
+
611
+ # dec_hs
612
+ plt.subplot(1, 2, 1)
613
+ plt.imshow(dec_hs_all[0][0].detach().numpy(), origin='upper')
614
+ plt.colorbar(orientation='horizontal')
615
+ plt.title('decoder hidden state, layer1')
616
+ plt.xlabel('hidden dim')
617
+ plt.ylabel('time step')
618
+ plt.subplot(1, 2, 2)
619
+ plt.imshow(dec_hs_all[7][0].detach().numpy(), origin='upper')
620
+ plt.colorbar(orientation='horizontal')
621
+ plt.title('decoder hidden state, layer8')
622
+ plt.xlabel('hidden dim')
623
+ plt.show()
624
+
625
+ # lm head
626
+ logits = model.lm_head(dec_hs_all[0])
627
+ plt.imshow(logits[0][0:200, :].detach().numpy(), origin='upper')
628
+ plt.title('lm head softmax')
629
+ plt.xlabel('vocab dim')
630
+ plt.ylabel('time step')
631
+ plt.xlim([1000, 1350])
632
+ plt.show()
633
+ softmax = torch.nn.Softmax(dim=2)
634
+ logits_sm = softmax(logits)
635
+ plt.imshow(logits_sm[0][0:200, :].detach().numpy(), origin='upper')
636
+ plt.title('lm head softmax')
637
+ plt.xlabel('vocab dim')
638
+ plt.ylabel('time step')
639
+ plt.xlim([1000, 1350])
640
+ plt.show()
amt/src/extras/perceivertf_multi_inspect.py ADDED
@@ -0,0 +1,778 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torchaudio
5
+ from matplotlib.animation import FuncAnimation
6
+
7
+ def l2_normalize(matrix):
8
+ """
9
+ L2 Normalize the matrix along its rows.
10
+
11
+ Parameters:
12
+ matrix (numpy.ndarray): The input matrix.
13
+
14
+ Returns:
15
+ numpy.ndarray: The L2 normalized matrix.
16
+ """
17
+ l2_norms = np.linalg.norm(matrix, axis=1, keepdims=True)
18
+ normalized_matrix = matrix / l2_norms
19
+ return normalized_matrix
20
+
21
+
22
+ def z_normalize(matrix):
23
+ """
24
+ Z-normalize the matrix along its rows (mean=0 and std=1).
25
+ Z-normalization is also known as "standardization", and derives from z-score.
26
+ Z = (X - mean) / std
27
+ Z-nomarlized, each row has mean=0 and std=1.
28
+
29
+ Parameters:
30
+ matrix (numpy.ndarray): The input matrix.
31
+
32
+ Returns:
33
+ numpy.ndarray: The Z normalized matrix.
34
+ """
35
+ mean = np.mean(matrix, axis=1, keepdims=True)
36
+ std = np.std(matrix, axis=1, keepdims=True)
37
+ normalized_matrix = (matrix - mean) / std
38
+ return normalized_matrix
39
+
40
+
41
+ def l2_normalize_tensors(tensor_tuple):
42
+ """
43
+ Applies L2 normalization on the last two dimensions for each tensor in a tuple.
44
+
45
+ Parameters:
46
+ tensor_tuple (tuple of torch.Tensor): A tuple containing N tensors, each of shape (1, k, 30, 30).
47
+
48
+ Returns:
49
+ tuple of torch.Tensor: A tuple containing N L2-normalized tensors.
50
+ """
51
+ normalized_tensors = []
52
+ for tensor in tensor_tuple:
53
+ # Ensure the tensor is a floating-point type
54
+ tensor = tensor.float()
55
+
56
+ # Calculate L2 norm on the last two dimensions, keeping the dimensions using keepdim=True
57
+ l2_norm = torch.linalg.norm(tensor, dim=(-2, -1), keepdim=True)
58
+
59
+ # Apply L2 normalization
60
+ normalized_tensor = tensor / (
61
+ l2_norm + 1e-7) # Small value to avoid division by zero
62
+
63
+ normalized_tensors.append(normalized_tensor)
64
+
65
+ return tuple(normalized_tensors)
66
+
67
+
68
+ def z_normalize_tensors(tensor_tuple):
69
+ """
70
+ Applies Z-normalization on the last two dimensions for each tensor in a tuple.
71
+
72
+ Parameters:
73
+ tensor_tuple (tuple of torch.Tensor): A tuple containing N tensors, each of shape (1, k, 30, 30).
74
+
75
+ Returns:
76
+ tuple of torch.Tensor: A tuple containing N Z-normalized tensors.
77
+ """
78
+ normalized_tensors = []
79
+ for tensor in tensor_tuple:
80
+ # Ensure the tensor is a floating-point type
81
+ tensor = tensor.float()
82
+
83
+ # Calculate mean and std on the last two dimensions
84
+ mean = tensor.mean(dim=(-2, -1), keepdim=True)
85
+ std = tensor.std(dim=(-2, -1), keepdim=True)
86
+
87
+ # Apply Z-normalization
88
+ normalized_tensor = (tensor - mean) / (
89
+ std + 1e-7) # Small value to avoid division by zero
90
+
91
+ normalized_tensors.append(normalized_tensor)
92
+
93
+ return tuple(normalized_tensors)
94
+
95
+
96
+ def apply_temperature_to_attention_tensors(tensor_tuple, temperature=1.0):
97
+ """
98
+ Applies temperature scaling to the attention weights in each tensor in a tuple.
99
+
100
+ Parameters:
101
+ tensor_tuple (tuple of torch.Tensor): A tuple containing N tensors,
102
+ each of shape (1, k, 30, 30).
103
+ temperature (float): Temperature parameter to control the sharpness
104
+ of the attention weights. Default is 1.0.
105
+
106
+ Returns:
107
+ tuple of torch.Tensor: A tuple containing N tensors with scaled attention weights.
108
+ """
109
+ scaled_attention_tensors = []
110
+
111
+ for tensor in tensor_tuple:
112
+ # Ensure the tensor is a floating-point type
113
+ tensor = tensor.float()
114
+
115
+ # Flatten the last two dimensions
116
+ flattened_tensor = tensor.reshape(1, tensor.shape[1],
117
+ -1) # Modified line here
118
+
119
+ # Apply temperature scaling and softmax along the last dimension
120
+ scaled_attention = flattened_tensor / temperature
121
+ scaled_attention = F.softmax(scaled_attention, dim=-1)
122
+
123
+ # Reshape to original shape
124
+ scaled_attention = scaled_attention.view_as(tensor)
125
+
126
+ scaled_attention_tensors.append(scaled_attention)
127
+
128
+ return tuple(scaled_attention_tensors)
129
+
130
+
131
+ def shorten_att(tensor_tuple, length=30):
132
+ shortend_tensors = []
133
+ for tensor in tensor_tuple:
134
+ shortend_tensors.append(tensor[:, :, :length, :length])
135
+ return tuple(shortend_tensors)
136
+
137
+
138
+ def keep_top_k(matrix, k=6):
139
+ """
140
+ Keep only the top k values in each row, set the rest to 0.
141
+
142
+ Parameters:
143
+ matrix (numpy.ndarray): The input matrix.
144
+ k (int): The number of top values to keep in each row.
145
+
146
+ Returns:
147
+ numpy.ndarray: The transformed matrix.
148
+ """
149
+ topk_indices_per_row = np.argpartition(matrix, -k, axis=1)[:, -k:]
150
+ result_matrix = np.zeros_like(matrix)
151
+
152
+ for i in range(matrix.shape[0]):
153
+ result_matrix[i, topk_indices_per_row[i]] = matrix[
154
+ i, topk_indices_per_row[i]]
155
+ return result_matrix
156
+
157
+
158
+ def test_case_forward_enc_perceiver_tf_dec_multi_t5():
159
+ import torch
160
+ from model.ymt3 import YourMT3
161
+ from config.config import audio_cfg, model_cfg, shared_cfg
162
+ model_cfg["encoder_type"] = "perceiver-tf"
163
+
164
+ model_cfg["encoder"]["perceiver-tf"]["attention_to_channel"] = True
165
+ model_cfg["encoder"]["perceiver-tf"]["num_latents"] = 26
166
+
167
+ model_cfg["decoder_type"] = "multi-t5"
168
+
169
+ audio_cfg["codec"] = "spec"
170
+ audio_cfg["hop_length"] = 300
171
+ model = YourMT3(audio_cfg=audio_cfg, model_cfg=model_cfg)
172
+ model.eval()
173
+
174
+ # x = torch.randn(2, 1, 32767)
175
+ # labels = torch.randint(0, 400, (2, 1024), requires_grad=False)
176
+
177
+ # # forward
178
+ # output = model.forward(x, labels)
179
+
180
+ # # inference
181
+ # result = model.inference(x, None)
182
+
183
+ # display latents
184
+ checkpoint = torch.load(
185
+ "../logs/ymt3/ptf_mc13_256_all_cross_v6_xk5_amp0811_edr005_attend_c_full_plus_2psn_nl26_sb_b26r_800k/checkpoints/model.ckpt",
186
+ map_location="cpu")
187
+ state_dict = checkpoint['state_dict']
188
+ new_state_dict = {
189
+ k: v
190
+ for k, v in state_dict.items() if 'pitchshift' not in k
191
+ }
192
+ model.load_state_dict(new_state_dict, strict=False)
193
+
194
+ latents = model.encoder.latent_array.latents.detach().numpy()
195
+ import matplotlib.pyplot as plt
196
+ import numpy as np
197
+ from sklearn.metrics.pairwise import cosine_similarity
198
+ cos = cosine_similarity(latents)
199
+
200
+ from utils.data_modules import AMTDataModule
201
+ from einops import rearrange
202
+ # dm = AMTDataModule(data_preset_multi={"presets": ["slakh"]})
203
+ #dm.setup("test")
204
+ # dl = dm.test_dataloader()
205
+ # ds = list(dl.values())[0].dataset
206
+ # audio, notes, tokens, _ = ds.__getitem__(7)
207
+ # x = audio[[16], ::]
208
+ # label = tokens[[16], :]
209
+
210
+ # from utils.task_manager import TaskManager
211
+ # tm = TaskManager(task_name='mc13_256')
212
+ # dm = AMTDataModule(data_preset_multi={"presets": ["slakh"]},
213
+ # task_manager=tm,
214
+ # train_stem_iaug_prob=None,
215
+ # train_stem_xaug_policy=None)
216
+ # dm.setup('fit')
217
+ # dl = dm.train_dataloader()
218
+ # ds = dl.flattened[0].dataset
219
+ # audio,tokens, _, _ = ds.__getitem__(67)
220
+ # x = audio[[5], ::]
221
+ # label = tokens[[5], :]
222
+ # save audio
223
+ # torchaudio.save("singing.wav", x[0, :, :], 16000)
224
+
225
+ x, _ = torchaudio.load('piano.wav')#'test.wav')
226
+ x = x.unsqueeze(0)
227
+
228
+ # spectrogram
229
+ x_spec = model.spectrogram(x)
230
+ x_conv = model.pre_encoder(x_spec)
231
+ # Create a larger figure
232
+ plt.figure(
233
+ figsize=(15,
234
+ 10)) # Adjust these numbers as needed for width and height
235
+ plt.subplot(2, 4, 1)
236
+ plt.imshow(x_spec[0].detach().numpy().T, aspect='auto', origin='lower')
237
+ plt.title("spectrogram")
238
+ plt.xlabel('time step')
239
+ plt.ylabel('frequency bin')
240
+ plt.subplot(2, 4, 2)
241
+ plt.imshow(x_conv[0][:, :, 0].detach().numpy().T,
242
+ aspect='auto',
243
+ origin='lower')
244
+ plt.title("conv(spec), ch=0")
245
+ plt.xlabel('time step')
246
+ plt.ylabel('F')
247
+ plt.subplot(2, 4, 3)
248
+ plt.imshow(x_conv[0][:, :, 42].detach().numpy().T,
249
+ aspect='auto',
250
+ origin='lower')
251
+ plt.title("ch=42")
252
+ plt.xlabel('time step')
253
+ plt.ylabel('F')
254
+ plt.subplot(2, 4, 4)
255
+ plt.imshow(x_conv[0][:, :, 80].detach().numpy().T,
256
+ aspect='auto',
257
+ origin='lower')
258
+ plt.title("ch=80")
259
+ plt.xlabel('time step')
260
+ plt.ylabel('F')
261
+ plt.subplot(2, 4, 5)
262
+ plt.imshow(x_conv[0][:, :, 11].detach().numpy().T,
263
+ aspect='auto',
264
+ origin='lower')
265
+ plt.title("ch=11")
266
+ plt.xlabel('time step')
267
+ plt.ylabel('F')
268
+ plt.subplot(2, 4, 6)
269
+ plt.imshow(x_conv[0][:, :, 20].detach().numpy().T,
270
+ aspect='auto',
271
+ origin='lower')
272
+ plt.title("ch=20")
273
+ plt.xlabel('time step')
274
+ plt.ylabel('F')
275
+ plt.subplot(2, 4, 7)
276
+ plt.imshow(x_conv[0][:, :, 77].detach().numpy().T,
277
+ aspect='auto',
278
+ origin='lower')
279
+ plt.title("ch=77")
280
+ plt.xlabel('time step')
281
+ plt.ylabel('F')
282
+ plt.subplot(2, 4, 8)
283
+ plt.imshow(x_conv[0][:, :, 90].detach().numpy().T,
284
+ aspect='auto',
285
+ origin='lower')
286
+ plt.title("ch=90")
287
+ plt.xlabel('time step')
288
+ plt.ylabel('F')
289
+ plt.tight_layout()
290
+ plt.show()
291
+
292
+ # encoding
293
+ output = model.encoder(inputs_embeds=x_conv,
294
+ output_hidden_states=True,
295
+ output_attentions=True)
296
+ enc_hs_all, att, catt = output["hidden_states"], output[
297
+ "attentions"], output["cross_attentions"]
298
+ enc_hs_last = enc_hs_all[2]
299
+
300
+ # enc_hs: time-varying encoder hidden state
301
+ plt.subplot(2, 3, 1)
302
+ plt.imshow(enc_hs_all[0][0][:, :, 21].detach().numpy().T)
303
+ plt.title('ENC_HS B0, d21')
304
+ plt.colorbar(orientation='horizontal')
305
+ plt.ylabel('latent k')
306
+ plt.xlabel('t')
307
+ plt.subplot(2, 3, 4)
308
+ plt.imshow(enc_hs_all[0][0][:, :, 127].detach().numpy().T)
309
+ plt.colorbar(orientation='horizontal')
310
+ plt.title('B0, d127')
311
+ plt.ylabel('latent k')
312
+ plt.xlabel('t')
313
+ plt.subplot(2, 3, 2)
314
+ plt.imshow(enc_hs_all[1][0][:, :, 21].detach().numpy().T)
315
+ plt.colorbar(orientation='horizontal')
316
+ plt.title('B1, d21')
317
+ plt.ylabel('latent k')
318
+ plt.xlabel('t')
319
+ plt.subplot(2, 3, 5)
320
+ plt.imshow(enc_hs_all[1][0][:, :, 127].detach().numpy().T)
321
+ plt.colorbar(orientation='horizontal')
322
+ plt.title('B1, d127')
323
+ plt.ylabel('latent k')
324
+ plt.xlabel('t')
325
+ plt.subplot(2, 3, 3)
326
+ plt.imshow(enc_hs_all[2][0][:, :, 21].detach().numpy().T)
327
+ plt.colorbar(orientation='horizontal')
328
+ plt.title('B2, d21')
329
+ plt.ylabel('latent k')
330
+ plt.xlabel('t')
331
+ plt.subplot(2, 3, 6)
332
+ plt.imshow(enc_hs_all[2][0][:, :, 127].detach().numpy().T)
333
+ plt.colorbar(orientation='horizontal')
334
+ plt.title('B2, d127')
335
+ plt.ylabel('latent k')
336
+ plt.xlabel('t')
337
+ plt.tight_layout()
338
+ plt.show()
339
+
340
+ # enc_hs: time-varying encoder hidden state by k (block, 1, t, k, d)
341
+ # --> (t, d) for each k in last block
342
+ data = enc_hs_all[2][0].detach().numpy() # (T, K, D)
343
+ fig, axs = plt.subplots(
344
+ 5, 5, figsize=(10, 9)) # 25 subplots arranged in 5 rows and 5 columns
345
+ axs = axs.flatten(
346
+ ) # Flatten the 2D array of axes to 1D for easy iteration
347
+
348
+ for k in range(25): # Iterating through K indices from 0 to 24
349
+ axs[k].imshow(data[:, k, :].T,
350
+ cmap='viridis') # Transposing the matrix to swap T and D
351
+ axs[k].set_title(f'k={k}')
352
+ axs[k].set_xlabel('Time step')
353
+ axs[k].set_ylabel('Dim')
354
+
355
+ # Adjusting layout for better visibility
356
+ plt.tight_layout()
357
+ plt.show()
358
+
359
+ #!! Projected encoder hidden state for 13 channels, that is conditioning for decoder
360
+ enc_hs_proj = model.pre_decoder(enc_hs_last)
361
+ fig, axs = plt.subplots(1, 13, figsize=(26, 8)) # 13 subplots in a row
362
+ data = enc_hs_proj[0].detach().numpy()
363
+ for ch in range(13):
364
+ axs[ch].imshow(np.rot90(data[ch]), cmap='viridis') # Rotate 90 degrees
365
+ axs[ch].set_title(f'ch: {ch}')
366
+ axs[ch].set_xlabel('Time step')
367
+ axs[ch].set_ylabel('Dim')
368
+ plt.suptitle(
369
+ 'linear projection of encoder outputs by channel, which is conditioning for enc-dec cross attention',
370
+ y=0.1,
371
+ fontsize=12)
372
+ plt.tight_layout(rect=[0, 0.1, 1, 1])
373
+ plt.show()
374
+
375
+ plt.subplot(221)
376
+ plt.imshow(enc_hs_all[2][0][0, :, :].detach().numpy(), aspect='auto')
377
+ plt.title('enc_hs, t=0')
378
+ plt.ylabel('latent k')
379
+ plt.xlabel('d')
380
+ plt.subplot(222)
381
+ plt.imshow(enc_hs_all[2][0][10, :, :].detach().numpy(), aspect='auto')
382
+ plt.title('enc_hs, t=10')
383
+ plt.ylabel('latent k')
384
+ plt.xlabel('d')
385
+ plt.subplot(223)
386
+ plt.imshow(enc_hs_all[2][0][20, :, :].detach().numpy(), aspect='auto')
387
+ plt.title('enc_hs, t=20')
388
+ plt.ylabel('latent k')
389
+ plt.xlabel('d')
390
+ plt.subplot(224)
391
+ plt.imshow(enc_hs_all[2][0][30, :, :].detach().numpy(), aspect='auto')
392
+ plt.title('enc_hs, t=30')
393
+ plt.ylabel('latent k')
394
+ plt.xlabel('d')
395
+ plt.tight_layout()
396
+ plt.show()
397
+
398
+ # enc_hs correlation: which dim has most unique info?
399
+ plt.subplot(1, 3, 1)
400
+ a = rearrange(enc_hs_last, '1 t k d -> t (k d)').detach().numpy()
401
+ plt.imshow(cosine_similarity(a))
402
+ plt.title("enc hs, t x t cos_sim")
403
+ plt.subplot(1, 3, 2)
404
+ b = rearrange(enc_hs_last, '1 t k d -> k (t d)').detach().numpy()
405
+ plt.imshow(cosine_similarity(b))
406
+ plt.title("enc hs, k x k cos_sim")
407
+ plt.subplot(1, 3, 3)
408
+ c = rearrange(enc_hs_last, '1 t k d -> d (k t)').detach().numpy()
409
+ plt.imshow(cosine_similarity(c))
410
+ plt.title("cross att, d x d cos_sim")
411
+ plt.tight_layout()
412
+ plt.show()
413
+
414
+ #!! enc latent
415
+ plt.imshow(model.encoder.latent_array.latents.detach().numpy())
416
+ plt.title('latent array')
417
+ plt.xlabel('d')
418
+ plt.ylabel('latent k')
419
+ plt.show()
420
+
421
+ #!! enc Spectral Cross Attention: (T x head x K x D). How latent K attends to conv channel C?
422
+ plt.subplot(311)
423
+ plt.imshow(
424
+ torch.sum(torch.sum(catt[0][0], axis=0), axis=0).detach().numpy())
425
+ plt.title('block=0')
426
+ plt.ylabel('latent k')
427
+ plt.xlabel('conv channel')
428
+ plt.subplot(312)
429
+ plt.imshow(
430
+ torch.sum(torch.sum(catt[1][0], axis=0), axis=0).detach().numpy())
431
+ plt.title('block=1')
432
+ plt.ylabel('latent k')
433
+ plt.xlabel('conv channel')
434
+ plt.subplot(313)
435
+ plt.imshow(
436
+ torch.sum(torch.sum(catt[2][0], axis=0), axis=0).detach().numpy())
437
+ plt.title('block=2')
438
+ plt.ylabel('latent k')
439
+ plt.xlabel('conv channel')
440
+ # f'spectral cross attention. T-C-F Model',
441
+ # y=0,
442
+ # fontsize=12)
443
+ plt.tight_layout()
444
+ plt.show()
445
+
446
+ #!! Animation of SCA for varying time, head in last block
447
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 6)) # Adjusted figsize for better layout
448
+
449
+ # Function to update the plots for each frame in the animation
450
+ def update(t):
451
+ # Clear previous images
452
+ ax1.clear()
453
+ ax2.clear()
454
+
455
+ # Update subplot for h=3
456
+ ax1.imshow(catt[2][0][t, 3, :, :].detach().numpy())
457
+ ax1.set_title(f'block=2, t={t}, head=3')
458
+ ax1.set_ylabel('latent k'); ax1.set_xlabel('conv channel')
459
+
460
+ # Update subplot for h=5
461
+ ax2.imshow(catt[2][0][t, 5, :, :].detach().numpy())
462
+ ax2.set_title(f'block=2, t={t}, head=5')
463
+ ax2.set_ylabel('latent k'); ax2.set_xlabel('conv channel')
464
+
465
+ # Adjust layout
466
+ fig.tight_layout()
467
+
468
+ # Create the animation
469
+ anim = FuncAnimation(fig, update, frames=range(0, 110), interval=200)
470
+ anim.save('animation.gif', writer='pillow', fps=5)
471
+
472
+
473
+
474
+ fig, axs = plt.subplots(3, 1, figsize=(12, 18), gridspec_kw={'height_ratios': [1, 1, 0.5]}) # Adjusted for different subplot sizes
475
+
476
+ # Subplots for catt visualization (h=3 and h=5)
477
+ ax_catt3, ax_catt5, ax_att_row = axs
478
+
479
+ # Creating 8 subplots for att visualization within the third row
480
+ for i in range(8):
481
+ ax_att_row = fig.add_subplot(3, 8, 17 + i) # Adding subplots in the third row
482
+
483
+ # Update function for the combined animation
484
+ def combined_update_smaller_att(t):
485
+ # Update subplot for catt with h=3
486
+ ax_catt3.clear()
487
+ ax_catt3.imshow(catt[2][0][t, 3, :, :].detach().numpy())
488
+ ax_catt3.set_title(f'block=2, t={t}, head=3')
489
+ ax_catt3.set_ylabel('latent k'); ax_catt3.set_xlabel('conv channel')
490
+
491
+ # Update subplot for catt with h=5
492
+ ax_catt5.clear()
493
+ ax_catt5.imshow(catt[2][0][t, 5, :, :].detach().numpy())
494
+ ax_catt5.set_title(f'block=2, t={t}, head=5')
495
+ ax_catt5.set_ylabel('latent k'); ax_catt5.set_xlabel('conv channel')
496
+
497
+ # Update subplots for att (8 heads in one row)
498
+ for i in range(8):
499
+ ax = fig.add_subplot(3, 8, 17 + i)
500
+ ax.clear()
501
+ ax.imshow(att[0][1][t, i, :, :].detach().numpy(), cmap='viridis')
502
+ ax.set_title(f't={t}, head={i}')
503
+ ax.set_xlabel('k')
504
+ ax.set_ylabel('k')
505
+ ax.axis('square') # Make each subplot square-shaped
506
+
507
+ # Adjust layout
508
+ fig.tight_layout()
509
+ combined_anim_smaller_att = FuncAnimation(fig, combined_update_smaller_att, frames=range(0, 110), interval=200)
510
+ combined_anim_smaller_att.save('combined_animation_smaller_att.gif', writer='pillow', fps=5)
511
+
512
+
513
+
514
+
515
+
516
+ # enc Latent Self-attention: How latent K attends to K?
517
+ plt.subplot(231)
518
+ plt.imshow(torch.sum(torch.sum(att[0][0], axis=1),
519
+ axis=0).detach().numpy(),
520
+ origin='upper')
521
+ plt.title('B0L0')
522
+ plt.xlabel('latent k')
523
+ plt.ylabel('latent k')
524
+ plt.subplot(234)
525
+ plt.imshow(torch.sum(torch.sum(att[0][1], axis=1),
526
+ axis=0).detach().numpy(),
527
+ origin='upper')
528
+ plt.title('B0L1')
529
+ plt.xlabel('latent k')
530
+ plt.ylabel('latent k')
531
+ plt.subplot(232)
532
+ plt.imshow(torch.sum(torch.sum(att[1][0], axis=1),
533
+ axis=0).detach().numpy(),
534
+ origin='upper')
535
+ plt.title('B1L0')
536
+ plt.xlabel('latent k')
537
+ plt.ylabel('latent k')
538
+ plt.subplot(235)
539
+ plt.imshow(torch.sum(torch.sum(att[1][1], axis=1),
540
+ axis=0).detach().numpy(),
541
+ origin='upper')
542
+ plt.title('B1L1')
543
+ plt.xlabel('latent k')
544
+ plt.ylabel('latent k')
545
+ plt.subplot(233)
546
+ plt.imshow(torch.sum(torch.sum(att[2][0], axis=1),
547
+ axis=0).detach().numpy(),
548
+ origin='upper')
549
+ plt.title('B2L0')
550
+ plt.xlabel('latent k')
551
+ plt.ylabel('latent k')
552
+ plt.subplot(236)
553
+ plt.imshow(torch.sum(torch.sum(att[2][1], axis=1),
554
+ axis=0).detach().numpy(),
555
+ origin='upper')
556
+ plt.title('B2L1')
557
+ plt.xlabel('latent k')
558
+ plt.ylabel('latent k')
559
+ plt.tight_layout()
560
+ plt.show()
561
+ # Time varying, different head for latent self-attention
562
+ #!!! Display latent self-attention for each head
563
+ bl = 0 # first latent transformer block, last layer att
564
+ data = att[bl][1].detach().numpy()
565
+ time_steps = [30, 50, 100]
566
+ fig, axs = plt.subplots(
567
+ len(time_steps), 8,
568
+ figsize=(16, 6)) # Subplots for each time step and head
569
+ for i, t in enumerate(time_steps):
570
+ for head in range(8):
571
+ axs[i, head].imshow(data[t, head, :, :], cmap='viridis')
572
+ axs[i, head].set_title(f't={t}, head={head}')
573
+ axs[i, head].set_xlabel('k')
574
+ axs[i, head].set_ylabel('k')
575
+ plt.suptitle(
576
+ f'latent transformer block={bl}, last layer self-attention over time',
577
+ y=0,
578
+ fontsize=12)
579
+ plt.tight_layout()
580
+ plt.show()
581
+
582
+ bl = 1 # second latent transformer block, last layer att
583
+ data = att[bl][1].detach().numpy()
584
+ time_steps = [30, 50, 100]
585
+ fig, axs = plt.subplots(
586
+ len(time_steps), 8,
587
+ figsize=(16, 6)) # Subplots for each time step and head
588
+ for i, t in enumerate(time_steps):
589
+ for head in range(8):
590
+ axs[i, head].imshow(data[t, head, :, :], cmap='viridis')
591
+ axs[i, head].set_title(f't={t}, head={head}')
592
+ axs[i, head].set_xlabel('k')
593
+ axs[i, head].set_ylabel('k')
594
+ plt.suptitle(
595
+ f'latent transformer block={bl}, last layer self-attention over time',
596
+ y=0,
597
+ fontsize=12)
598
+ plt.tight_layout()
599
+ plt.show()
600
+
601
+ bl = 2 # last latent transformer block, last layer att
602
+ data = att[bl][1].detach().numpy()
603
+ time_steps = [30, 50, 100]
604
+ fig, axs = plt.subplots(
605
+ len(time_steps), 8,
606
+ figsize=(16, 6)) # Subplots for each time step and head
607
+ for i, t in enumerate(time_steps):
608
+ for head in range(8):
609
+ axs[i, head].imshow(data[t, head, :, :], cmap='viridis')
610
+ axs[i, head].set_title(f't={t}, head={head}')
611
+ axs[i, head].set_xlabel('k')
612
+ axs[i, head].set_ylabel('k')
613
+ plt.suptitle(
614
+ f'latent transformer block={bl}, last layer self-attention over time',
615
+ y=0,
616
+ fontsize=12)
617
+ plt.tight_layout()
618
+ plt.show()
619
+
620
+ # Temporal Self-attention: (K x H x T x T) How time t attends to time t?
621
+ plt.subplot(231)
622
+ plt.imshow(torch.sum(torch.sum(att[0][2], axis=1),
623
+ axis=0).detach().numpy(),
624
+ origin='upper')
625
+ plt.title('B0L2')
626
+ plt.xlabel('t')
627
+ plt.ylabel('t')
628
+ plt.subplot(234)
629
+ plt.imshow(torch.sum(torch.sum(att[0][3], axis=1),
630
+ axis=0).detach().numpy(),
631
+ origin='upper')
632
+ plt.title('B0L3')
633
+ plt.xlabel('t')
634
+ plt.ylabel('t')
635
+ plt.subplot(232)
636
+ plt.imshow(torch.sum(torch.sum(att[1][2], axis=1),
637
+ axis=0).detach().numpy(),
638
+ origin='upper')
639
+ plt.title('B1L2')
640
+ plt.xlabel('t')
641
+ plt.ylabel('t')
642
+ plt.subplot(235)
643
+ plt.imshow(torch.sum(torch.sum(att[1][3], axis=1),
644
+ axis=0).detach().numpy(),
645
+ origin='upper')
646
+ plt.title('B1L3')
647
+ plt.xlabel('t')
648
+ plt.ylabel('t')
649
+ plt.subplot(233)
650
+ plt.imshow(torch.sum(torch.sum(att[2][2], axis=1),
651
+ axis=0).detach().numpy(),
652
+ origin='upper')
653
+ plt.title('B2L2')
654
+ plt.xlabel('t')
655
+ plt.ylabel('t')
656
+ plt.subplot(236)
657
+ plt.imshow(torch.sum(torch.sum(att[2][3], axis=1),
658
+ axis=0).detach().numpy(),
659
+ origin='upper')
660
+ plt.title('B2L3')
661
+ plt.xlabel('t')
662
+ plt.ylabel('t')
663
+ plt.tight_layout()
664
+ plt.show()
665
+
666
+ # decoding
667
+ dec_input_ids = model.shift_right_fn(label)
668
+ dec_inputs_embeds = model.embed_tokens(dec_input_ids)
669
+ dec_output = model.decoder(inputs_embeds=dec_inputs_embeds,
670
+ encoder_hidden_states=enc_hs_proj,
671
+ output_attentions=True,
672
+ output_hidden_states=True,
673
+ return_dict=True)
674
+ dec_att, dec_catt = dec_output.attentions, dec_output.cross_attentions
675
+ dec_hs_all = dec_output.hidden_states
676
+ dec_last_hs = dec_output.last_hidden_state
677
+
678
+ # lm head
679
+ logits = model.lm_head(dec_last_hs)
680
+
681
+ # pred ids
682
+ pred_ids = torch.argmax(logits, dim=3)
683
+
684
+ # dec att
685
+ plt.subplot(1, 2, 1)
686
+ plt.imshow(torch.sum(dec_att[5][0], axis=0).detach().numpy())
687
+ plt.title('decoder attention, layer0')
688
+ plt.xlabel('decoder time step')
689
+ plt.ylabel('decoder time step')
690
+ plt.subplot(1, 2, 2)
691
+ plt.imshow(torch.sum(dec_att[7][0], axis=0).detach().numpy())
692
+ plt.title('decoder attention, final layer')
693
+ plt.xlabel('decoder step')
694
+ plt.show()
695
+
696
+
697
+ # dec catt
698
+ def remove_values_after_eos(catt_np, pred_ids, max_k):
699
+ # catt_np: (k, head, t, t)
700
+ # pred_ids: (1, k, t))
701
+ max_length = pred_ids.shape[-1]
702
+ seq_lengths = np.zeros((max_k), dtype=np.int32)
703
+ for k in range(max_k):
704
+ for t in range(max_length):
705
+ if pred_ids[0, k, t] == 1:
706
+ break
707
+ catt_np[k, :, t+1:, :] = 0
708
+ # catt_np[k, :, :, t+1:] = 0
709
+ seq_lengths[k] = t+1
710
+ return catt_np, seq_lengths
711
+
712
+ # data = dec_catt[1].detach().numpy() # last layer's cross attention
713
+ l = 4
714
+ data = dec_catt[l].detach().numpy()
715
+ data, seq_lengths = remove_values_after_eos(data, pred_ids, max_k=13)
716
+ seq_lengths[:]= 256
717
+
718
+ fig, axs = plt.subplots(13, 6, figsize=(21, 39)) # 13 rows (for k=0:12) and 7 columns (for head=0:6)
719
+ for k in range(13):
720
+ s = seq_lengths[k]
721
+ for head in range(6):
722
+ axs[k, head].imshow(data[k, head, :s, :].T, aspect='auto', cmap='viridis')
723
+ axs[k, head].set_title(f'Layer {l}, k={k}, head={head}')
724
+ axs[k, head].set_xlabel('Decoder step')
725
+ axs[k, head].set_ylabel('Encoder frame')
726
+ plt.tight_layout()
727
+ plt.show()
728
+
729
+
730
+ # # dec catt by head with xxx
731
+ # dec_att_z = z_normalize_tensors(shorten_att(dec_att))
732
+ # plt.imshow(dec_att_z[0][0, 0, :, :].detach().numpy())
733
+ # from bertviz import head_view
734
+ # token = []
735
+ # for i in label[0, :30]:
736
+ # token.append(str(i))
737
+ # head_view(dec_att_z, tokens)
738
+
739
+ # dec_hs
740
+ plt.subplot(1, 2, 1)
741
+ k=2
742
+ plt.imshow(dec_last_hs[0][k].detach().numpy(), origin='upper')
743
+ plt.colorbar(orientation='horizontal')
744
+ plt.title('decoder last hidden state, k=0')
745
+ plt.xlabel('hidden dim')
746
+ plt.ylabel('time step')
747
+ plt.subplot(1, 2, 2)
748
+ k=12
749
+ plt.imshow(dec_last_hs[0][k].detach().numpy(), origin='upper')
750
+ plt.colorbar(orientation='horizontal')
751
+ plt.title('decoder last hidden state, k=12')
752
+ plt.xlabel('hidden dim')
753
+ plt.show()
754
+
755
+ # lm head
756
+ logits = model.lm_head(dec_last_hs)
757
+ k=6
758
+ plt.imshow(logits[0][k][0:200, :].detach().numpy().T, origin='upper')
759
+ plt.title('lm head output')
760
+ plt.xlabel('vocab dim')
761
+ plt.ylabel('time step')
762
+ plt.show()
763
+ softmax = torch.nn.Softmax(dim=3)
764
+ logits_sm = softmax(logits) # B, K, T, V
765
+ k=6
766
+ plt.imshow(logits_sm[0][k][:255, :].detach().numpy().T, origin='upper')
767
+ plt.title('lm head softmax')
768
+ plt.xlabel('vocab dim')
769
+ plt.ylabel('time step')
770
+ # plt.xlim([1000, 1350])
771
+ plt.show()
772
+
773
+ k = 10
774
+ print(torch.argmax(logits, dim=3)[0,k,:])
775
+
776
+
777
+
778
+
amt/src/extras/pitch_shift_benchmark.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Test the speed of the augmentation """
2
+ import torch
3
+ import torchaudio
4
+
5
+ # Device
6
+ device = torch.device("cuda")
7
+ # device = torch.device("cpu")
8
+
9
+ # Music
10
+ # x, _ = torchaudio.load("music.wav")
11
+ # slice_length = 32767
12
+ # n_slices = 80
13
+ # slices = [x[0, i * slice_length:(i + 1) * slice_length] for i in range(n_slices)]
14
+ # x = torch.stack(slices) # (80, 32767)
15
+ # Sine wave
16
+ t = torch.arange(0, 2.0479, 1 / 16000) # 2.05 seconds at 16kHz
17
+ x = torch.sin(2 * torch.pi * 440 * t) * 0.5
18
+ x = x.reshape(1, 1, 32767).tile(80, 1, 1)
19
+ x = x.to(device)
20
+
21
+ ############################################################################################
22
+ # torch-audiomentation: https://github.com/asteroid-team/torch-audiomentation
23
+ #
24
+ # process time <CPU>: 1.18 s ± 5.35 ms
25
+ # process time <GPU>: 58 ms
26
+ # GPU memory usage: 3.8 GB per 1 semitone
27
+ ############################################################################################
28
+ import torch
29
+ from torch_audiomentations import Compose, PitchShift, Gain, PolarityInversion
30
+
31
+ apply_augmentation = Compose(transforms=[
32
+ # Gain(
33
+ # min_gain_in_db=-15.0,
34
+ # max_gain_in_db=5.0,
35
+ # p=0.5,
36
+ # ),
37
+ # PolarityInversion(p=0.5)
38
+ PitchShift(
39
+ min_transpose_semitones=0,
40
+ max_transpose_semitones=2.2,
41
+ mode="per_batch", #"per_example",
42
+ p=1.0,
43
+ p_mode="per_batch",
44
+ sample_rate=16000,
45
+ target_rate=16000)
46
+ ])
47
+ x_am = apply_augmentation(x, sample_rate=16000)
48
+
49
+ ############################################################################################
50
+ # torchaudio:
51
+ #
52
+ # process time <CPU>: 4.01 s ± 19.6 ms per loop
53
+ # process time <GPU>: 25.1 ms ± 161 µs per loop
54
+ # memory usage <GPU>: 1.2 (growth to 5.49) GB per 1 semitone
55
+ ############################################################################################
56
+ from torchaudio import transforms
57
+
58
+ ta_transform = transforms.PitchShift(16000, n_steps=2).to(device)
59
+ x_ta = ta_transform(x)
60
+
61
+ ############################################################################################
62
+ # YourMT3 pitch_shift_layer:
63
+ #
64
+ # process time <CPU>: 389ms ± 22ms, (stretch=143 ms, resampler=245 ms)
65
+ # process time <GPU>: 7.18 ms ± 17.3 µs (stretch=6.47 ms, resampler=0.71 ms)
66
+ # memory usage: 16 MB per 1 semitone (average)
67
+ ############################################################################################
68
+ from model.pitchshift_layer import PitchShiftLayer
69
+
70
+ ps_ymt3 = PitchShiftLayer(pshift_range=[2, 2], fs=16000, min_gcd=16, n_fft=2048).to(device)
71
+ x_ymt3 = ps_ymt3(x, 2)
72
+
73
+ ############################################################################################
74
+ # Plot 1: Comparison of Process Time and GPU Memory Usage for 3 Pitch Shifting Methods
75
+ ############################################################################################
76
+ import matplotlib.pyplot as plt
77
+
78
+ # Model names
79
+ models = ['torch-audiomentation', 'torchaudio', 'YourMT3:PitchShiftLayer']
80
+
81
+ # Process time (CPU) in seconds
82
+ cpu_time = [1.18, 4.01, 0.389]
83
+
84
+ # Process time (GPU) in milliseconds
85
+ gpu_time = [58, 25.1, 7.18]
86
+
87
+ # GPU memory usage in GB
88
+ gpu_memory = [3.8, 5.49, 0.016]
89
+
90
+ # Creating subplots
91
+ fig, axs = plt.subplots(1, 3, figsize=(15, 5))
92
+
93
+ # Creating bar charts
94
+ bar1 = axs[0].bar(models, cpu_time, color=['#FFB6C1', '#ADD8E6', '#98FB98'])
95
+ bar2 = axs[1].bar(models, gpu_time, color=['#FFB6C1', '#ADD8E6', '#98FB98'])
96
+ bar3 = axs[2].bar(models, gpu_memory, color=['#FFB6C1', '#ADD8E6', '#98FB98'])
97
+
98
+ # Adding labels and titles
99
+ axs[0].set_ylabel('Time (s)')
100
+ axs[0].set_title('Process Time (CPU) bsz=80')
101
+ axs[1].set_ylabel('Time (ms)')
102
+ axs[1].set_title('Process Time (GPU) bsz=80')
103
+ axs[2].set_ylabel('Memory (GB)')
104
+ axs[2].set_title('GPU Memory Usage per semitone')
105
+
106
+ # Adding grid for better readability of the plots
107
+ for ax in axs:
108
+ ax.grid(axis='y')
109
+ ax.set_yscale('log')
110
+ ax.set_xticklabels(models, rotation=45, ha="right")
111
+
112
+ # Adding text labels above the bars
113
+ for i, rect in enumerate(bar1):
114
+ axs[0].text(
115
+ rect.get_x() + rect.get_width() / 2,
116
+ rect.get_height(),
117
+ f'{cpu_time[i]:.2f} s',
118
+ ha='center',
119
+ va='bottom')
120
+ for i, rect in enumerate(bar2):
121
+ axs[1].text(
122
+ rect.get_x() + rect.get_width() / 2,
123
+ rect.get_height(),
124
+ f'{gpu_time[i]:.2f} ms',
125
+ ha='center',
126
+ va='bottom')
127
+ for i, rect in enumerate(bar3):
128
+ axs[2].text(
129
+ rect.get_x() + rect.get_width() / 2,
130
+ rect.get_height(),
131
+ f'{gpu_memory[i]:.3f} GB',
132
+ ha='center',
133
+ va='bottom')
134
+ plt.tight_layout()
135
+ plt.show()
136
+
137
+ ############################################################################################
138
+ # Plot 2: Stretch and Resampler Processing Time Contribution
139
+ ############################################################################################
140
+ # Data
141
+ processing_type = ['Stretch (Phase Vocoder)', 'Resampler (Conv1D)']
142
+ cpu_times = [143, 245] # [Stretch, Resampler] times for CPU in milliseconds
143
+ gpu_times = [6.47, 0.71] # [Stretch, Resampler] times for GPU in milliseconds
144
+
145
+ # Creating subplots
146
+ fig, axs = plt.subplots(1, 2, figsize=(12, 6))
147
+
148
+ # Plotting bar charts
149
+ axs[0].bar(processing_type, cpu_times, color=['#ADD8E6', '#98FB98'])
150
+ axs[1].bar(processing_type, gpu_times, color=['#ADD8E6', '#98FB98'])
151
+
152
+ # Adding labels and titles
153
+ axs[0].set_ylabel('Time (ms)')
154
+ axs[0].set_title('Contribution of CPU Processing Time: YMT3-PS (BSZ=80)')
155
+ axs[1].set_title('Contribution of GPU Processing Time: YMT3-PS (BSZ=80)')
156
+
157
+ # Adding grid for better readability of the plots
158
+ for ax in axs:
159
+ ax.grid(axis='y')
160
+ ax.set_yscale('log') # Log scale to better visualize the smaller values
161
+
162
+ # Adding values on top of the bars
163
+ for ax, times in zip(axs, [cpu_times, gpu_times]):
164
+ for idx, time in enumerate(times):
165
+ ax.text(idx, time, f"{time:.2f} ms", ha='center', va='bottom', fontsize=8)
166
+ plt.tight_layout()
167
+ plt.show()
amt/src/extras/remove_silence_musicnet_midi.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+
4
+ from utils.midi import midi2note
5
+ from utils.note2event import note2note_event
6
+ from utils.note_event_dataclasses import Note
7
+ from utils.note_event_dataclasses import NoteEvent
8
+ from utils.midi import note_event2midi
9
+
10
+ data_home = '../../data'
11
+ dataset_name = 'musicnet'
12
+ base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k')
13
+ mid_pattern = os.path.join(base_dir, '*_midi', '*.mid')
14
+ mid_files = glob.glob(mid_pattern, recursive=True)
15
+
16
+ for mid_file in mid_files:
17
+ notes, _ = midi2note(mid_file)
18
+ first_onset_time = notes[0].onset
19
+ fixed_notes = []
20
+ for note in notes:
21
+ fixed_notes.append(
22
+ Note(
23
+ is_drum=note.is_drum,
24
+ program=note.program,
25
+ onset=note.onset - first_onset_time,
26
+ offset=note.offset - first_onset_time,
27
+ pitch=note.pitch,
28
+ velocity=note.velocity))
29
+ assert len(notes) == len(fixed_notes)
30
+ fixed_note_events = note2note_event(fixed_notes, return_activity=False)
31
+ note_event2midi(fixed_note_events, mid_file)
32
+ print(f'Overwriting {mid_file}')
amt/src/extras/rotary_positional_embedding.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """rotary_positional_embedding.py - Rotary Positional Embedding
2
+
3
+ code from github.com/lucidrains/rotary-embedding-torch
4
+
5
+ MIT License
6
+ """
7
+
8
+ from math import pi, log
9
+ import torch
10
+ from torch import nn, einsum
11
+ from einops import rearrange, repeat
12
+
13
+
14
+ def exists(val):
15
+ return val is not None
16
+
17
+
18
+ def broadcat(tensors, dim=-1):
19
+ num_tensors = len(tensors)
20
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
21
+ assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
22
+ shape_len = list(shape_lens)[0]
23
+
24
+ dim = (dim + shape_len) if dim < 0 else dim
25
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
26
+
27
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
28
+ assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)
29
+ ]), 'invalid dimensions for broadcastable concatentation'
30
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
31
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
32
+ expanded_dims.insert(dim, (dim, dims[dim]))
33
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
34
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
35
+ return torch.cat(tensors, dim=dim)
36
+
37
+
38
+ # rotary embedding helper functions
39
+ def rotate_half(x):
40
+ x = rearrange(x, '... (d r) -> ... d r', r=2)
41
+ x1, x2 = x.unbind(dim=-1)
42
+ x = torch.stack((-x2, x1), dim=-1)
43
+ return rearrange(x, '... d r -> ... (d r)')
44
+
45
+
46
+ def apply_rotary_emb(freqs, t, start_index=0, scale=1.):
47
+ rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
48
+ freqs = freqs[-seq_len:, :]
49
+
50
+ freqs = freqs.to(t)
51
+ end_index = start_index + rot_dim
52
+ assert rot_dim <= t.shape[
53
+ -1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
54
+ t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
55
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
56
+ return torch.cat((t_left, t, t_right), dim=-1)
57
+
58
+
59
+ # learned rotation helpers
60
+ def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
61
+ if exists(freq_ranges):
62
+ rotations = einsum('..., f -> ... f', rotations, freq_ranges)
63
+ rotations = rearrange(rotations, '... r f -> ... (r f)')
64
+
65
+ rotations = repeat(rotations, '... n -> ... (n r)', r=2)
66
+ return apply_rotary_emb(rotations, t, start_index=start_index)
67
+
68
+
69
+ # classes
70
+ class RotaryEmbedding(nn.Module):
71
+
72
+ def __init__(self,
73
+ dim,
74
+ custom_freqs=None,
75
+ freqs_for='lang',
76
+ theta=10000,
77
+ max_freq=10,
78
+ num_freqs=1,
79
+ learned_freq=False,
80
+ use_xpos=False,
81
+ xpos_scale_base=512,
82
+ interpolate_factor=1.,
83
+ theta_rescale_factor=1.):
84
+ super().__init__()
85
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
86
+ # has some connection to NTK literature
87
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
88
+ theta *= theta_rescale_factor**(dim / (dim - 2))
89
+
90
+ if exists(custom_freqs):
91
+ freqs = custom_freqs
92
+ elif freqs_for == 'lang':
93
+ freqs = 1. / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
94
+ elif freqs_for == 'pixel':
95
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
96
+ elif freqs_for == 'constant':
97
+ freqs = torch.ones(num_freqs).float()
98
+ else:
99
+ raise ValueError(f'unknown modality {freqs_for}')
100
+
101
+ self.cache = dict()
102
+ self.cache_scale = dict()
103
+ self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)
104
+
105
+ # interpolation factors
106
+
107
+ assert interpolate_factor >= 1.
108
+ self.interpolate_factor = interpolate_factor
109
+
110
+ # xpos
111
+
112
+ self.use_xpos = use_xpos
113
+ if not use_xpos:
114
+ self.register_buffer('scale', None)
115
+ return
116
+
117
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
118
+ self.scale_base = xpos_scale_base
119
+ self.register_buffer('scale', scale)
120
+
121
+ def get_seq_pos(self, seq_len, device, dtype, offset=0):
122
+ return (torch.arange(seq_len, device=device, dtype=dtype) +
123
+ offset) / self.interpolate_factor
124
+
125
+ def rotate_queries_or_keys(self, t, seq_dim=-2, offset=0, freq_seq_len=None):
126
+ assert not self.use_xpos, 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings'
127
+
128
+ device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
129
+
130
+ if exists(freq_seq_len):
131
+ assert freq_seq_len >= seq_len
132
+ seq_len = freq_seq_len
133
+
134
+ freqs = self.forward(
135
+ lambda: self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset),
136
+ cache_key=f'freqs:{seq_len}|offset:{offset}')
137
+ return apply_rotary_emb(freqs, t)
138
+
139
+ def rotate_queries_with_cached_keys(self, q, k, seq_dim=-2):
140
+ q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
141
+ assert q_len <= k_len
142
+ q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, freq_seq_len=k_len)
143
+ k = self.rotate_queries_or_keys(k, seq_dim=seq_dim)
144
+ return q, k
145
+
146
+ def rotate_queries_and_keys(self, q, k, seq_dim=-2):
147
+ assert self.use_xpos
148
+ device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
149
+ seq = self.get_seq_pos(seq_len, dtype=dtype, device=device)
150
+ freqs = self.forward(lambda: seq, cache_key=f'freqs:{seq_len}')
151
+ scale = self.get_scale(lambda: seq, cache_key=f'scale:{seq_len}').to(dtype)
152
+ rotated_q = apply_rotary_emb(freqs, q, scale=scale)
153
+ rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1)
154
+ return rotated_q, rotated_k
155
+
156
+ def get_scale(self, t, cache_key=None):
157
+ assert self.use_xpos
158
+
159
+ if exists(cache_key) and cache_key in self.cache:
160
+ return self.cache[cache_key]
161
+
162
+ if callable(t):
163
+ t = t()
164
+
165
+ scale = 1.
166
+ if self.use_xpos:
167
+ power = (t - len(t) // 2) / self.scale_base
168
+ scale = self.scale**rearrange(power, 'n -> n 1')
169
+ scale = torch.cat((scale, scale), dim=-1)
170
+
171
+ if exists(cache_key):
172
+ self.cache[cache_key] = scale
173
+
174
+ return scale
175
+
176
+ def forward(self, t, cache_key=None):
177
+ if exists(cache_key) and cache_key in self.cache:
178
+ return self.cache[cache_key]
179
+
180
+ if callable(t):
181
+ t = t()
182
+
183
+ freqs = self.freqs
184
+
185
+ freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
186
+ freqs = repeat(freqs, '... n -> ... (n r)', r=2)
187
+
188
+ if exists(cache_key):
189
+ self.cache[cache_key] = freqs
190
+
191
+ return freqs
amt/src/extras/run_spleeter_mir1k.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ shopt -s globstar
3
+ for file in "$1"/**/*.wav; do
4
+ echo $file
5
+ output_dir="tmp"
6
+ spleeter separate -b 256k -B tensorflow -p spleeter:2stems -o $output_dir $file -f {instrument}.{codec}
7
+ sox --ignore-length tmp/accompaniment.wav -r 16000 -c 1 -b 16 tmp/accompaniment_16k.wav
8
+ sox --ignore-length tmp/vocals.wav -r 16000 -c 1 -b 16 tmp/vocals_16k.wav
9
+ acc_file="${file//.wav/_accompaniment.wav}"
10
+ voc_file="${file//.wav/_vocals.wav}"
11
+ mv -f "tmp/accompaniment_16k.wav" $acc_file
12
+ mv -f "tmp/vocals_16k.wav" $voc_file
13
+ echo $acc_file
14
+ echo $voc_file
15
+ rm -rf tmp
16
+ done
17
+ rm -rf pretrained_models
amt/src/extras/run_spleeter_mirst500.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ shopt -s globstar
3
+ for file in "$1"/**/*.wav; do
4
+ output_dir="${file%/*}"
5
+ input_file="$output_dir/converted_Mixture.wav"
6
+ spleeter separate -p spleeter:2stems -o $output_dir $input_file -f {instrument}.{codec}
7
+ ffmpeg -i "$output_dir/vocals.wav" -acodec pcm_s16le -ac 1 -ar 16000 -y "$output_dir/vocals_16k.wav"
8
+ ffmpeg -i "$output_dir/accompaniment.wav" -acodec pcm_s16le -ac 1 -ar 16000 -y "$output_dir/accompaniment_16k.wav"
9
+ rm "$output_dir/vocals.wav"
10
+ rm "$output_dir/accompaniment.wav"
11
+ mv "$output_dir/vocals_16k.wav" "$output_dir/vocals.wav"
12
+ mv "$output_dir/accompaniment_16k.wav" "$output_dir/accompaniment.wav"
13
+ done
amt/src/extras/run_spleeter_mirst500_cmedia.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ shopt -s globstar
3
+ for file in "$1"/**/*.wav; do
4
+ output_dir="${file%/*}"
5
+ input_file="$output_dir/converted_Mixture.wav"
6
+ spleeter separate -p spleeter:2stems -o $output_dir $input_file -f {instrument}.{codec}
7
+ ffmpeg -i "$output_dir/vocals.wav" -acodec pcm_s16le -ac 1 -ar 16000 -y "$output_dir/vocals_16k.wav"
8
+ ffmpeg -i "$output_dir/accompaniment.wav" -acodec pcm_s16le -ac 1 -ar 16000 -y "$output_dir/accompaniment_16k.wav"
9
+ rm "$output_dir/vocals.wav"
10
+ rm "$output_dir/accompaniment.wav"
11
+ mv "$output_dir/vocals_16k.wav" "$output_dir/vocals.wav"
12
+ mv "$output_dir/accompaniment_16k.wav" "$output_dir/accompaniment.wav"
13
+ done
amt/src/extras/swap_channel.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ a = np.arange(12).reshape(2, 3, 2) # (batch, channel, dim)
4
+ print(a)
5
+ array([[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]])
6
+
7
+ swap_mat = create_swap_channel_mat(input_shape, swap_channel=(1, 2))
8
+
9
+ # will swap channel 1 and 2 of batch 0 with channel 1 and 2 of batch 1
10
+ b = a @ swap_mat
11
+ print(b)
12
+ # expected output
13
+ array([[[0, 1], [8, 9], [10, 11]], [[6, 7], [2, 3], [4, 5]]])
14
+
15
+ import torch
16
+
17
+
18
+ def swap_channels_between_batches(a_tensor, swap_channels):
19
+ # Copy the tensor to avoid modifying the original tensor
20
+ result_tensor = a_tensor.clone()
21
+
22
+ # Unpack the channels to be swapped
23
+ ch1, ch2 = swap_channels
24
+
25
+ # Swap the specified channels between batches
26
+ result_tensor[0, ch1, :], result_tensor[1, ch1, :] = a_tensor[1, ch1, :].clone(), a_tensor[0, ch1, :].clone()
27
+ result_tensor[0, ch2, :], result_tensor[1, ch2, :] = a_tensor[1, ch2, :].clone(), a_tensor[0, ch2, :].clone()
28
+
29
+ return result_tensor
30
+
31
+
32
+ # Define a sample tensor 'a_tensor'
33
+ a_tensor = torch.tensor([[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]], dtype=torch.float32)
34
+
35
+ # Define channels to swap
36
+ swap_channels = (1, 2) # Channels to swap between batches
37
+
38
+ # Swap the channels between batches
39
+ swapped_tensor = swap_channels_between_batches(a_tensor, swap_channels)
40
+
41
+ # Print the original tensor and the tensor after swapping channels between batches
42
+ print("Original Tensor 'a_tensor':")
43
+ print(a_tensor)
44
+ print("\nTensor after swapping channels between batches:")
45
+ print(swapped_tensor)
46
+
47
+ #-------------------------------------------------
48
+
49
+ import torch
50
+ from einops import rearrange
51
+
52
+
53
+ def shift(arr, num, fill_value=np.nan):
54
+ result = np.empty_like(arr)
55
+ if num > 0:
56
+ result[:num] = fill_value
57
+ result[num:] = arr[:-num]
58
+ elif num < 0:
59
+ result[num:] = fill_value
60
+ result[:num] = arr[-num:]
61
+ else:
62
+ result[:] = arr
63
+ return result
64
+
65
+
66
+ def create_batch_swap_matrix(batch_size, channels, swap_channels):
67
+ swap_mat = np.eye(batch_size * channels)
68
+
69
+ for c in swap_channels:
70
+ idx1 = c # 첫 번째 배치의 교환할 채널 인덱스
71
+ idx2 = c + channels # 두 번째 배치의 교환할 채널 인덱스
72
+
73
+ swap_mat[idx1, idx1], swap_mat[idx2, idx2] = 0, 0 # 대각선 값을 0으로 설정
74
+ swap_mat[idx1, idx2], swap_mat[idx2, idx1] = 1, 1 # 해당 채널을 교환
75
+ return swap_mat
76
+
77
+
78
+ def create_batch_swap_matrix(batch_size, channels, swap_channels):
79
+ swap_mat = np.eye(batch_size * channels)
80
+
81
+ # 모든 채널에 대해 교환 수행
82
+ for c in swap_channels:
83
+ idx1 = np.arange(c, batch_size * channels, channels) # 현재 채널의 모든 배치 인덱스
84
+ idx2 = (idx1 + channels) % (batch_size * channels) # 순환을 위해 modulo 사용
85
+
86
+ swap_mat[idx1, idx1] = 0
87
+ swap_mat[idx2, idx2] = 0
88
+ swap_mat[idx1, idx2] = 1
89
+ swap_mat[idx2, idx1] = 1
90
+
91
+ return swap_mat
92
+
93
+
94
+ def swap_channels_between_batches(input_tensor, swap_matrix):
95
+ reshaped_tensor = rearrange(input_tensor, 'b c d -> (b c) d')
96
+ swapped_tensor = swap_matrix @ reshaped_tensor
97
+ return rearrange(swapped_tensor, '(b c) d -> b c d', b=input_tensor.shape[0])
98
+
99
+
100
+ # 예제 파라미터
101
+ batch_size = 2
102
+ channels = 3
103
+ # swap_info = {
104
+ # : [1, 2] # batch_index: [channel_indices]
105
+ # }
106
+ swap_channels = [1, 2] # 교환할 채널
107
+
108
+ # 예제 텐서 생성
109
+ input_tensor = torch.tensor([[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]], dtype=torch.float32)
110
+
111
+ # swap matrix 생성
112
+ swap_matrix = create_batch_swap_matrix(batch_size, channels, swap_channels)
113
+ swap_matrix = torch.Tensor(swap_matrix)
114
+
115
+ # 채널 교환 수행
116
+ swapped_tensor = swap_channels_between_batches(input_tensor, swap_matrix)
117
+
118
+ # 결과 출력
119
+ print("Original Tensor:")
120
+ print(input_tensor)
121
+ print("\nSwapped Tensor:")
122
+ print(swapped_tensor)
amt/src/extras/t5_dev.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import T5Config
3
+ from model.t5mod import T5ForConditionalGeneration
4
+
5
+ a = {
6
+ "architectures": ["T5ForConditionalGeneration"],
7
+ "d_ff": 1024, # size of the intermediate feed forward layer in each T5Block
8
+ "d_kv": 64, # d_kv has to be equal to d_model // num_heads.
9
+ # "d_model": 512, # encoder hiddnen size, defined by model_cfg
10
+ "decoder_start_token_id": 0,
11
+ "dense_act_fn": "gelu_new",
12
+ # "dropout_rate": 0.05, # can be overwritten by args in ymt3
13
+ "eos_token_id": 1,
14
+ "feed_forward_proj": "gated-gelu",
15
+ "initializer_factor": 1.0,
16
+ # "is_encoder_decoder": True,
17
+ "is_gated_act": True,
18
+ "layer_norm_epsilon": 1e-06,
19
+ "model_type": "t5",
20
+ # "num_decoder_layers": 8,
21
+ "num_heads": 6,
22
+ "num_layers": 8,
23
+ "output_past": True,
24
+ "pad_token_id": 0,
25
+ "relative_attention_num_buckets": 32,
26
+ "use_cache": True,
27
+ "vocab_size": 1391 # vocab_size is automatically set by the task manager...
28
+ }
29
+ cfg = T5Config(**a)
30
+ cfg.num_decoder_layers = 4
31
+ cfg.num_layers = 0
32
+
33
+ model = T5ForConditionalGeneration(cfg)
34
+ print(model)
35
+
36
+ x = torch.rand(((2, 256, 512)))
37
+ out = model.encoder.forward(inputs_embeds=x)
38
+
39
+ enc_hs = torch.rand((2, 256, 512))
40
+ labels = torch.randint(0, 1391, (2, 256))
41
+ pred = model(encoder_outputs=(enc_hs,), labels=labels) # important (enc_hs,) comma!
amt/src/extras/t5perceiver.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The YourMT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Please see the details in the LICENSE file.
10
+ """ Bare wrapper of HF PyTorch T5 and Perceiver with the following modifications:
11
+ - PerceiverTF encoder
12
+ - ResConv pre-encoder
13
+ - Projection layers for dynamic dimension matching
14
+ - Sinusoidal absolute positional embeddings
15
+ - Positional embeddings from Perceiver implementation
16
+ - Task conditioning on encoder and decoder by input tokens
17
+ """
18
+ import copy
19
+ import warnings
20
+ from typing import Optional, Tuple, Union
21
+
22
+ import torch
23
+ from torch import nn
24
+ from torch.nn import CrossEntropyLoss
25
+ from torch.utils.checkpoint import checkpoint
26
+
27
+ from transformers.utils import logging
28
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
29
+ from transformers.modeling_utils import PreTrainedModel
30
+ from transformers.models.t5.modeling_t5 import (T5LayerNorm, T5Block, PARALLELIZE_DOCSTRING, DEPARALLELIZE_DOCSTRING,
31
+ T5_START_DOCSTRING, T5_INPUTS_DOCSTRING, _CONFIG_FOR_DOC,
32
+ __HEAD_MASK_WARNING_MSG)
33
+ from transformers.modeling_outputs import (Seq2SeqLMOutput, BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions)
34
+ from transformers import T5Config #, T5PreTrainedModel
35
+ from model.ops import FixedSinusoidalPositionalEmbedding
36
+
37
+ # additional imports
38
+ from model.t5mod import T5Stack
39
+ from transformers.models.t5.modeling_t5 import (T5Model, T5ForConditionalGeneration, T5EncoderModel, T5DenseActDense,
40
+ T5DenseGatedActDense, T5Attention, load_tf_weights_in_t5,
41
+ is_torch_fx_proxy)
42
+
43
+ from transformers.utils import (DUMMY_INPUTS, DUMMY_MASK)
44
+
45
+ logger = logging.get_logger(__name__)
46
+
47
+
48
+ class T5PerceiverPreTrainedModel(PreTrainedModel):
49
+ """
50
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
51
+ models.
52
+ """
53
+
54
+ config_class = None
55
+ load_tf_weights = load_tf_weights_in_t5
56
+ base_model_prefix = "transformer"
57
+ is_parallelizable = True
58
+ supports_gradient_checkpointing = True
59
+ _no_split_modules = ["T5Block"]
60
+ _keep_in_fp32_modules = ["wo"]
61
+
62
+ @property
63
+ def dummy_inputs(self):
64
+ input_ids = torch.tensor(DUMMY_INPUTS)
65
+ input_mask = torch.tensor(DUMMY_MASK)
66
+ dummy_inputs = {
67
+ "decoder_input_ids": input_ids,
68
+ "input_ids": input_ids,
69
+ "decoder_attention_mask": input_mask,
70
+ }
71
+ return dummy_inputs
72
+
73
+ def _init_weights(self, module):
74
+ """Initialize the weights"""
75
+ factor = self.config.initializer_factor # Used for testing weights initialization
76
+ if isinstance(module, T5LayerNorm):
77
+ module.weight.data.fill_(factor * 1.0)
78
+ elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)):
79
+ # Mesh TensorFlow embeddings initialization
80
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
81
+ module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
82
+ if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
83
+ module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
84
+ elif isinstance(module, T5DenseActDense):
85
+ # Mesh TensorFlow FF initialization
86
+ # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
87
+ # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
88
+ module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model)**-0.5))
89
+ if hasattr(module.wi, "bias") and module.wi.bias is not None:
90
+ module.wi.bias.data.zero_()
91
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff)**-0.5))
92
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
93
+ module.wo.bias.data.zero_()
94
+ elif isinstance(module, T5DenseGatedActDense):
95
+ module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model)**-0.5))
96
+ if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
97
+ module.wi_0.bias.data.zero_()
98
+ module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model)**-0.5))
99
+ if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
100
+ module.wi_1.bias.data.zero_()
101
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff)**-0.5))
102
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
103
+ module.wo.bias.data.zero_()
104
+ elif isinstance(module, T5Attention):
105
+ # Mesh TensorFlow attention initialization to avoid scaling before softmax
106
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
107
+ d_model = self.config.d_model
108
+ key_value_proj_dim = self.config.d_kv
109
+ n_heads = self.config.num_heads
110
+ module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim)**-0.5))
111
+ module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
112
+ module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
113
+ module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim)**-0.5))
114
+ if module.has_relative_attention_bias:
115
+ module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model)**-0.5))
116
+
117
+ def _set_gradient_checkpointing(self, module, value=False):
118
+ if isinstance(module, (T5Attention, T5Stack)):
119
+ module.gradient_checkpointing = value
120
+
121
+ def _shift_right(self, input_ids):
122
+ decoder_start_token_id = self.config.decoder_start_token_id
123
+ pad_token_id = self.config.pad_token_id
124
+
125
+ assert decoder_start_token_id is not None, (
126
+ "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id."
127
+ " See T5 docs for more information")
128
+
129
+ # shift inputs to the right
130
+ if is_torch_fx_proxy(input_ids):
131
+ # Item assignment is not supported natively for proxies.
132
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
133
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
134
+ else:
135
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
136
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
137
+ shifted_input_ids[..., 0] = decoder_start_token_id
138
+
139
+ assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
140
+ # replace possible -100 values in labels by `pad_token_id`
141
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
142
+
143
+ return shifted_input_ids
144
+
145
+
146
+ class T5PerceiverForConditionalGeneration(T5PerceiverPreTrainedModel):
147
+ config_class = None
148
+ load_tf_weights = load_tf_weights_in_t5
149
+ base_model_prefix = "transformer"
150
+ is_parallelizable = True
151
+ supports_gradient_checkpointing = True
152
+ _no_split_modules = ["T5Block"]
153
+ _keep_in_fp32_modules = ["wo"]
154
+
155
+ @property
156
+ def dummy_inputs(self):
157
+ input_ids = torch.tensor(DUMMY_INPUTS)
158
+ input_mask = torch.tensor(DUMMY_MASK)
159
+ dummy_inputs = {
160
+ "decoder_input_ids": input_ids,
161
+ "input_ids": input_ids,
162
+ "decoder_attention_mask": input_mask,
163
+ }
164
+ return dummy_inputs
165
+
166
+ def __init__(
167
+ self,
168
+ model_cfg: dict,
169
+ # config: T5Config,
170
+ # use_fixed_absolute_pe: bool = True,
171
+ # num_max_positions: int = 1025
172
+ ):
173
+ super().__init__(config)
174
+ self.model_dim = config.d_model
175
+ """ mod: absolute position embedding """
176
+ self.use_fixed_absolute_pe = use_fixed_absolute_pe
177
+
178
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
179
+
180
+ encoder_config = copy.deepcopy(config)
181
+ encoder_config.is_decoder = False
182
+ encoder_config.use_cache = False
183
+ encoder_config.is_encoder_decoder = False
184
+ self.encoder = T5Stack(encoder_config,
185
+ self.shared,
186
+ use_fixed_absolute_pe=use_fixed_absolute_pe,
187
+ num_max_positions=num_max_positions)
188
+
189
+ decoder_config = copy.deepcopy(config)
190
+ decoder_config.is_decoder = True
191
+ decoder_config.is_encoder_decoder = False
192
+ decoder_config.num_layers = config.num_decoder_layers
193
+ self.decoder = T5Stack(decoder_config,
194
+ self.shared,
195
+ use_fixed_absolute_pe=use_fixed_absolute_pe,
196
+ num_max_positions=num_max_positions)
197
+
198
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
199
+
200
+ # Initialize weights and apply final processing
201
+ self.post_init()
202
+
203
+ # Model parallel
204
+ self.model_parallel = False
205
+ self.device_map = None
206
+
207
+ def get_input_embeddings(self):
208
+ return self.shared
209
+
210
+ def set_input_embeddings(self, new_embeddings):
211
+ self.shared = new_embeddings
212
+ self.encoder.set_input_embeddings(new_embeddings)
213
+ self.decoder.set_input_embeddings(new_embeddings)
214
+
215
+ def set_output_embeddings(self, new_embeddings):
216
+ self.lm_head = new_embeddings
217
+
218
+ def get_output_embeddings(self):
219
+ return self.lm_head
220
+
221
+ def get_encoder(self):
222
+ return self.encoder
223
+
224
+ def get_decoder(self):
225
+ return self.decoder
226
+
227
+ def forward(
228
+ self,
229
+ input_ids: Optional[torch.LongTensor] = None,
230
+ attention_mask: Optional[torch.FloatTensor] = None,
231
+ decoder_input_ids: Optional[torch.LongTensor] = None,
232
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
233
+ head_mask: Optional[torch.FloatTensor] = None,
234
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
235
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
236
+ encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
237
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
238
+ inputs_embeds: Optional[torch.FloatTensor] = None,
239
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
240
+ labels: Optional[torch.LongTensor] = None,
241
+ use_cache: Optional[bool] = None,
242
+ output_attentions: Optional[bool] = None,
243
+ output_hidden_states: Optional[bool] = None,
244
+ return_dict: Optional[bool] = None,
245
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
246
+ r"""
247
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
248
+ Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
249
+ config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
250
+ labels in `[0, ..., config.vocab_size]`
251
+
252
+ Returns:
253
+
254
+ Examples:
255
+
256
+ ```python
257
+ >>> from transformers import AutoTokenizer, T5ForConditionalGeneration
258
+
259
+ >>> tokenizer = AutoTokenizer.from_pretrained("t5-small")
260
+ >>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
261
+
262
+ >>> # training
263
+ >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
264
+ >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
265
+ >>> outputs = model(input_ids=input_ids, labels=labels)
266
+ >>> loss = outputs.loss
267
+ >>> logits = outputs.logits
268
+
269
+ >>> # inference
270
+ >>> input_ids = tokenizer(
271
+ ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
272
+ ... ).input_ids # Batch size 1
273
+ >>> outputs = model.generate(input_ids)
274
+ >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
275
+ >>> # studies have shown that owning a dog is good for you.
276
+ ```"""
277
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
278
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
279
+
280
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
281
+ if head_mask is not None and decoder_head_mask is None:
282
+ if self.config.num_layers == self.config.num_decoder_layers:
283
+ warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
284
+ decoder_head_mask = head_mask
285
+
286
+ # Encode if needed (training, first prediction pass)
287
+ if encoder_outputs is None:
288
+ # Convert encoder inputs in embeddings if needed
289
+ encoder_outputs = self.encoder(
290
+ input_ids=input_ids,
291
+ attention_mask=attention_mask,
292
+ inputs_embeds=inputs_embeds,
293
+ head_mask=head_mask,
294
+ output_attentions=output_attentions,
295
+ output_hidden_states=output_hidden_states,
296
+ return_dict=return_dict,
297
+ )
298
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
299
+ encoder_outputs = BaseModelOutput(
300
+ last_hidden_state=encoder_outputs[0],
301
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
302
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
303
+ )
304
+
305
+ hidden_states = encoder_outputs[0]
306
+
307
+ if self.model_parallel:
308
+ torch.cuda.set_device(self.decoder.first_device)
309
+
310
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
311
+ # get decoder inputs from shifting lm labels to the right
312
+ decoder_input_ids = self._shift_right(labels)
313
+
314
+ # Set device for model parallelism
315
+ if self.model_parallel:
316
+ torch.cuda.set_device(self.decoder.first_device)
317
+ hidden_states = hidden_states.to(self.decoder.first_device)
318
+ if decoder_input_ids is not None:
319
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
320
+ if attention_mask is not None:
321
+ attention_mask = attention_mask.to(self.decoder.first_device)
322
+ if decoder_attention_mask is not None:
323
+ decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
324
+
325
+ # Decode
326
+ decoder_outputs = self.decoder(
327
+ input_ids=decoder_input_ids,
328
+ attention_mask=decoder_attention_mask,
329
+ inputs_embeds=decoder_inputs_embeds,
330
+ past_key_values=past_key_values,
331
+ encoder_hidden_states=hidden_states,
332
+ encoder_attention_mask=attention_mask,
333
+ head_mask=decoder_head_mask,
334
+ cross_attn_head_mask=cross_attn_head_mask,
335
+ use_cache=use_cache,
336
+ output_attentions=output_attentions,
337
+ output_hidden_states=output_hidden_states,
338
+ return_dict=return_dict,
339
+ )
340
+
341
+ sequence_output = decoder_outputs[0]
342
+
343
+ # Set device for model parallelism
344
+ if self.model_parallel:
345
+ torch.cuda.set_device(self.encoder.first_device)
346
+ self.lm_head = self.lm_head.to(self.encoder.first_device)
347
+ sequence_output = sequence_output.to(self.lm_head.weight.device)
348
+
349
+ if self.config.tie_word_embeddings:
350
+ # Rescale output before projecting on vocab
351
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
352
+ sequence_output = sequence_output * (self.model_dim**-0.5)
353
+
354
+ lm_logits = self.lm_head(sequence_output)
355
+
356
+ loss = None
357
+ if labels is not None:
358
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
359
+ # move labels to correct device to enable PP
360
+ labels = labels.to(lm_logits.device)
361
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
362
+ # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
363
+
364
+ if not return_dict:
365
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
366
+ return ((loss,) + output) if loss is not None else output
367
+
368
+ return Seq2SeqLMOutput(
369
+ loss=loss,
370
+ logits=lm_logits,
371
+ past_key_values=decoder_outputs.past_key_values,
372
+ decoder_hidden_states=decoder_outputs.hidden_states,
373
+ decoder_attentions=decoder_outputs.attentions,
374
+ cross_attentions=decoder_outputs.cross_attentions,
375
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
376
+ encoder_hidden_states=encoder_outputs.hidden_states,
377
+ encoder_attentions=encoder_outputs.attentions,
378
+ )
379
+
380
+ def prepare_inputs_for_generation(
381
+ self,
382
+ input_ids,
383
+ past_key_values=None,
384
+ attention_mask=None,
385
+ head_mask=None,
386
+ decoder_head_mask=None,
387
+ cross_attn_head_mask=None,
388
+ use_cache=None,
389
+ encoder_outputs=None,
390
+ **kwargs,
391
+ ):
392
+ # cut decoder_input_ids if past is used
393
+ if past_key_values is not None:
394
+ input_ids = input_ids[:, -1:]
395
+
396
+ return {
397
+ "decoder_input_ids": input_ids,
398
+ "past_key_values": past_key_values,
399
+ "encoder_outputs": encoder_outputs,
400
+ "attention_mask": attention_mask,
401
+ "head_mask": head_mask,
402
+ "decoder_head_mask": decoder_head_mask,
403
+ "cross_attn_head_mask": cross_attn_head_mask,
404
+ "use_cache": use_cache,
405
+ }
406
+
407
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
408
+ return self._shift_right(labels)
409
+
410
+ def _reorder_cache(self, past_key_values, beam_idx):
411
+ # if decoder past is not included in output
412
+ # speedy decoding is disabled and no need to reorder
413
+ if past_key_values is None:
414
+ logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
415
+ return past_key_values
416
+
417
+ reordered_decoder_past = ()
418
+ for layer_past_states in past_key_values:
419
+ # get the correct batch idx from layer past batch dim
420
+ # batch dim of `past` is at 2nd position
421
+ reordered_layer_past_states = ()
422
+ for layer_past_state in layer_past_states:
423
+ # need to set correct `past` for each of the four key / value states
424
+ reordered_layer_past_states = reordered_layer_past_states + (layer_past_state.index_select(
425
+ 0, beam_idx.to(layer_past_state.device)),)
426
+
427
+ assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
428
+ assert len(reordered_layer_past_states) == len(layer_past_states)
429
+
430
+ reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
431
+ return reordered_decoder_past
432
+
433
+
434
+ from transformers import PreTrainedModel, PretrainedConfig
435
+ from transformers import AutoModel, AutoConfig
436
+
437
+
438
+ class MyConfig(T5Config, PerceiverConfig):
439
+ model_type = 'mymodel'
440
+
441
+ def __init__(self, important_param=42, **kwargs):
442
+ super().__init__(**kwargs)
443
+ self.important_param = important_param
amt/src/extras/unimax_sampler/README.md ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # UniMax Language Dataset Sampler with DDP support
2
+
3
+ This repository contains an unofficial implementation of the UNIMAX sampling algorithm using PyTorch. The UNIMAX algorithm ["UniMax: Fairer and more Effective Language Sampling for Large-Scale Multilingual Pretraining" by HW Chung et al. (ICLR 2023)](https://arxiv.org/abs/2304.09151) is used to generate a sampling distribution of languages based on their character counts, a total character budget, and a specified number of epochs per language. This can be useful for training language models on datasets with imbalanced language distribution.
4
+
5
+ ## Contents
6
+
7
+ 1. `unimax_sampler.py`: This Python file contains the `UnimaxSampler` class, a PyTorch `Sampler` that uses the UNIMAX algorithm.
8
+
9
+ 2. `test_unimax_sampler.py`: This Python file contains a unit test for the `UnimaxSampler` class to ensure its correct functionality.
10
+
11
+ ## Usage
12
+
13
+ ```python
14
+ from torch.utils.data import Dataset, DataLoader
15
+ from unimax_sampler import UnimaxSampler
16
+
17
+ # Define your parameters
18
+ language_character_counts = [100, 200, 300, 400, 500]
19
+ total_character_budget = 1000
20
+ num_epochs = 2
21
+
22
+ # Create the UnimaxSampler
23
+ unimax_sampler = UnimaxSampler(language_character_counts, total_character_budget, num_epochs)
24
+ ```
25
+
26
+ Then, use the sampler as the sampler argument when creating a DataLoader.
27
+
28
+ ```python
29
+ # Disable shuffle when using custom sampler...
30
+ data_loader = DataLoader(my_dataset, batch_size=2, shuffle=None, sampler=unimax_sampler)
31
+ ```
32
+
33
+ For DDP,
34
+ ```python
35
+ if torch.distributed.is_initialized():
36
+ sampler = DistributedUnimaxSampler(...)
37
+ else:
38
+ return unimax_sampler(...)
39
+ ```
40
+
41
+ ## Note
42
+ The initial version of this code was created by [Chat GPT-4](https://chat.openai.com/), based on the pseudocode provided in the [UNIMAX](https://arxiv.org/abs/2304.09151) paper. Subsequently, the code was manually revised for `PyTorch` Distributed Data Parallel ([DDP](https://pytorch.org/docs/stable/notes/ddp.html)) framework. The DistributedSamplerWrapper implementation is derived from an earlier version found in the [Catalyst](https://github.com/catalyst-team/catalyst) project.
43
+
44
+ ## License
45
+ This project is licensed under the MIT License.
amt/src/extras/unimax_sampler/demo.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.unimax_sampler.unimax_sampler import UnimaxSampler
2
+
3
+ language_character_counts = [100, 200, 300, 400, 500]
4
+ total_character_budget = 1000
5
+ num_epochs = 2
6
+
7
+ # Create the UnimaxSampler.
8
+ sampler = UnimaxSampler(language_character_counts, total_character_budget, num_epochs)
9
+
10
+ # Define the expected output. This will depend on your specific implementation of Unimax.
11
+ expected_output = torch.tensor([0.1, 0.2, 0.3, 0.2, 0.2])
12
+
13
+ # Use PyTorch's allclose function to compare the computed and expected outputs.
14
+ # The absolute tolerance parameter atol specifies the maximum difference allowed for the test to pass.
15
+ self.assertTrue(torch.allclose(sampler.p, expected_output, atol=1e-6))
amt/src/extras/unimax_sampler/unimax_sampler.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DistributedSampler
3
+ from torch.utils.data import Dataset, Sampler
4
+ from torch.utils.data import RandomSampler
5
+ from operator import itemgetter
6
+ from typing import List, Union, Iterator, Optional
7
+
8
+
9
+ class DatasetFromSampler(Dataset):
10
+ """Dataset to create indexes from `Sampler`. From catalyst library.
11
+
12
+ Args:
13
+ sampler: PyTorch sampler
14
+ """
15
+
16
+ def __init__(self, sampler: Sampler):
17
+ """Initialisation for DatasetFromSampler."""
18
+ self.sampler = sampler
19
+ self.sampler_list = None
20
+
21
+ def __getitem__(self, index: int):
22
+ """Gets element of the dataset.
23
+
24
+ Args:
25
+ index: index of the element in the dataset
26
+
27
+ Returns:
28
+ Single element by index
29
+ """
30
+ if self.sampler_list is None:
31
+ self.sampler_list = list(self.sampler)
32
+ return self.sampler_list[index]
33
+
34
+ def __len__(self) -> int:
35
+ """
36
+ Returns:
37
+ int: length of the dataset
38
+ """
39
+ return len(self.sampler)
40
+
41
+
42
+ class DistributedSamplerWrapper(DistributedSampler):
43
+ """
44
+ Wrapper over `Sampler` for distributed training.
45
+ Allows you to use any sampler in distributed mode.
46
+ From https://github.com/catalyst-team/catalyst/blob/master/catalyst/data/sampler.py
47
+
48
+ It is especially useful in conjunction with
49
+ `torch.nn.parallel.DistributedDataParallel`. In such case, each
50
+ process can pass a DistributedSamplerWrapper instance as a DataLoader
51
+ sampler, and load a subset of subsampled data of the original dataset
52
+ that is exclusive to it.
53
+
54
+ .. note::
55
+ Sampler is assumed to be of constant size.
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ sampler,
61
+ num_replicas: Optional[int] = None,
62
+ rank: Optional[int] = None,
63
+ shuffle: bool = True,
64
+ ):
65
+ """
66
+
67
+ Args:
68
+ sampler: Sampler used for subsampling
69
+ num_replicas (int, optional): Number of processes participating in
70
+ distributed training
71
+ rank (int, optional): Rank of the current process
72
+ within ``num_replicas``
73
+ shuffle (bool, optional): If true (default),
74
+ sampler will shuffle the indices
75
+ """
76
+ super(DistributedSamplerWrapper, self).__init__(
77
+ DatasetFromSampler(sampler),
78
+ num_replicas=num_replicas,
79
+ rank=rank,
80
+ shuffle=shuffle,
81
+ )
82
+ self.sampler = sampler
83
+
84
+ def __iter__(self) -> Iterator[int]:
85
+ """Iterate over sampler.
86
+
87
+ Returns:
88
+ python iterator
89
+ """
90
+ self.dataset = DatasetFromSampler(self.sampler)
91
+ indexes_of_indexes = super().__iter__()
92
+ subsampler_indexes = self.dataset
93
+ return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))
94
+
95
+
96
+ class UnimaxSampler(Sampler):
97
+ # Initialize the sampler with the character counts for each language,
98
+ # the total character budget, and the number of epochs per language.
99
+ def __init__(self, language_character_counts: List[int], total_character_budget: int,
100
+ num_epochs: int) -> None:
101
+ self.language_character_counts = torch.tensor(language_character_counts)
102
+ self.total_character_budget = total_character_budget
103
+ self.num_epochs = num_epochs
104
+ # Compute the sampling distribution p.
105
+ self.p = self._unimax()
106
+
107
+ # Define how to iterate over the data. We'll use PyTorch's multinomial
108
+ # function to generate indices according to the distribution p.
109
+ def __iter__(self) -> iter:
110
+ return iter(torch.multinomial(self.p, len(self.p), replacement=True).tolist())
111
+
112
+ # Define the length of the sampler as the number of languages.
113
+ def __len__(self) -> int:
114
+ return len(self.p)
115
+
116
+ # Implement the UNIMAX algorithm to compute the sampling distribution p.
117
+ def _unimax(self) -> torch.Tensor:
118
+ # Sort languages by character count.
119
+ L, indices = torch.sort(self.language_character_counts)
120
+ # Initialize the remaining budget to the total character budget.
121
+ B = float(self.total_character_budget)
122
+ i = 0
123
+ # Initialize the budget per language.
124
+ U = torch.zeros_like(L)
125
+ # For each language...
126
+ for idx in indices:
127
+ # Compute the remaining budget per-language.
128
+ bl = B / (len(L) - i)
129
+ cl = L[idx]
130
+ # If per-language budget exceeds N epochs of the language, use N epochs.
131
+ if bl > cl * self.num_epochs:
132
+ Ul = cl * self.num_epochs
133
+ # Otherwise use uniform per-language budget.
134
+ else:
135
+ Ul = bl
136
+ # Store the computed budget.
137
+ U[idx] = Ul
138
+ # Update the remaining budget.
139
+ B -= Ul
140
+ # Move to the next language.
141
+ i += 1
142
+ # Normalize the budget to create a distribution.
143
+ p = U / U.sum()
144
+ # Return the computed distribution.
145
+ return p
146
+
147
+
148
+ class DistributedUnimaxSampler(UnimaxSampler):
149
+
150
+ def __init__(self,
151
+ language_character_counts: List[int],
152
+ total_character_budget: int,
153
+ num_epochs: int,
154
+ num_replicas: Optional[int] = None,
155
+ rank: Optional[int] = None,
156
+ shuffle: bool = True) -> None:
157
+
158
+ super().__init__(language_character_counts, total_character_budget, num_epochs)
159
+ self.distributed_sampler = DistributedSamplerWrapper(self, num_replicas, rank, shuffle)
160
+
161
+ def __iter__(self):
162
+ return iter(self.distributed_sampler)
163
+
164
+ def __len__(self):
165
+ return len(self.distributed_sampler)
166
+
167
+ def set_epoch(self, epoch):
168
+ self.distributed_sampler.set_epoch(epoch)
amt/src/install_dataset.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The YourMT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Please see the details in the LICENSE file.
10
+ """ install_dataset.py """
11
+ import os
12
+ import argparse
13
+ import mirdata
14
+ from typing import Optional, Tuple, Union
15
+ from utils.preprocess.generate_dataset_stats import generate_dataset_stats_for_all_datasets, update_dataset_stats_for_new_dataset
16
+ from utils.mirdata_dev.datasets import slakh16k
17
+ from utils.preprocess.preprocess_slakh import preprocess_slakh16k, add_program_and_is_drum_info_to_file_list
18
+ from utils.preprocess.preprocess_musicnet import preprocess_musicnet16k
19
+ from utils.preprocess.preprocess_maps import preprocess_maps16k
20
+ from utils.preprocess.preprocess_maestro import preprocess_maestro16k
21
+ from utils.preprocess.preprocess_guitarset import preprocess_guitarset16k, create_filelist_by_style_guitarset16k
22
+ from utils.preprocess.preprocess_enstdrums import preprocess_enstdrums16k, create_filelist_dtm_random_enstdrums16k
23
+ from utils.preprocess.preprocess_mir_st500 import preprocess_mir_st500_16k
24
+ from utils.preprocess.preprocess_cmedia import preprocess_cmedia_16k
25
+ from utils.preprocess.preprocess_rwc_pop_full import preprocess_rwc_pop_full16k
26
+ from utils.preprocess.preprocess_rwc_pop import preprocess_rwc_pop16k
27
+ from utils.preprocess.preprocess_egmd import preprocess_egmd16k
28
+ from utils.preprocess.preprocess_mir1k import preprocess_mir1k_16k
29
+ from utils.preprocess.preprocess_urmp import preprocess_urmp16k
30
+ from utils.preprocess.preprocess_idmt_smt_bass import preprocess_idmt_smt_bass_16k
31
+ from utils.preprocess.preprocess_geerdes import preprocess_geerdes16k
32
+ from utils.utils import download_and_extract #, download_and_extract_zenodo_restricted
33
+
34
+ # zenodo_token = "eyJhbGciOiJIUzUxMiIsImlhdCI6MTcxMDE1MDYzNywiZXhwIjoxNzEyNzA3MTk5fQ.eyJpZCI6ImRmODA5NzZlLTBjM2QtNDk5NS05YjM0LWFiNGM4NzJhMmZhMSIsImRhdGEiOnt9LCJyYW5kb20iOiIwMzY5ZDcxZjc2NTMyN2UyYmVmN2ExYjJkMmMyYTRhNSJ9.0aHnNC-7ivWQO6l8twjLR0NDH4boC0uOolAAmogVt7XRi2PHU5MEKBQoK7-wgDdnmWEIqEIvoLO6p8KTnsY9dg"
35
+
36
+
37
+ def install_slakh(data_home=os.PathLike, no_down=False) -> None:
38
+ if not no_down:
39
+ ds = slakh16k.Dataset(data_home, version='2100-yourmt3-16k')
40
+ ds.download(partial_download=['2100-yourmt3-16k', 'index'])
41
+ del (ds)
42
+ preprocess_slakh16k(data_home, delete_source_files=False, fix_bass_octave=True)
43
+ add_program_and_is_drum_info_to_file_list(data_home)
44
+
45
+
46
+ def install_musicnet(data_home=os.PathLike, no_down=False) -> None:
47
+ if not no_down:
48
+ url = "https://zenodo.org/record/7811639/files/musicnet_yourmt3_16k.tar.gz?download=1"
49
+ checksum = "a2da7c169e26d452a4e8b9bef498b3d7"
50
+ download_and_extract(data_home, url, remove_tar_file=True, check_sum=checksum)
51
+ preprocess_musicnet16k(data_home, dataset_name='musicnet')
52
+
53
+
54
+ def install_maps(data_home=os.PathLike, no_down=False, sanity_check=False) -> None:
55
+ if not no_down:
56
+ url = "https://zenodo.org/record/7812075/files/maps_yourmt3_16k.tar.gz?download=1"
57
+ checksum = "6b070d162c931cd5e69c16ef2398a649"
58
+ download_and_extract(data_home, url, remove_tar_file=True, check_sum=checksum)
59
+ preprocess_maps16k(data_home, dataset_name='maps', ignore_pedal=False, sanity_check=sanity_check)
60
+
61
+
62
+ def install_maestro(data_home=os.PathLike, no_down=False, sanity_check=False) -> None:
63
+ if not no_down:
64
+ url = "https://zenodo.org/record/7852176/files/maestro_yourmt3_16k.tar.gz?download=1"
65
+ checksum = "c17c6a188d936e5ff3870ef27144d397"
66
+ download_and_extract(data_home, url, remove_tar_file=True, check_sum=checksum)
67
+ preprocess_maestro16k(data_home, dataset_name='maestro', ignore_pedal=False, sanity_check=sanity_check)
68
+
69
+
70
+ def install_guitarset(data_home=os.PathLike, no_down=False) -> None:
71
+ if not no_down:
72
+ url = "https://zenodo.org/record/7831843/files/guitarset_yourmt3_16k.tar.gz?download=1"
73
+ checksum = "e3cfe0cc9394d91d9c290ce888821360"
74
+ download_and_extract(data_home, url, remove_tar_file=True, check_sum=checksum)
75
+ preprocess_guitarset16k(data_home, dataset_name='guitarset')
76
+ create_filelist_by_style_guitarset16k(data_home, dataset_name='guitarset')
77
+
78
+
79
+ def install_enstdrums(data_home, no_down=False) -> None:
80
+ if not no_down:
81
+ url = "https://zenodo.org/record/7831843/files/enstdrums_yourmt3_16k.tar.gz?download=1"
82
+ checksum = "7e28c2a923e4f4162b3d83877cedb5eb"
83
+ download_and_extract(data_home, url, remove_tar_file=True, check_sum=checksum)
84
+ preprocess_enstdrums16k(data_home, dataset_name='enstdrums')
85
+ create_filelist_dtm_random_enstdrums16k(data_home, dataset_name='enstdrums')
86
+
87
+
88
+ def install_egmd(data_home, no_down=False) -> None:
89
+ if not no_down:
90
+ url = "https://zenodo.org/record/7831072/files/egmc_yourmt3_16k.tar.gz?download=1"
91
+ checksum = "4f615157ea4c52a64c6c9dcf68bf2bde"
92
+ download_and_extract(data_home, url, remove_tar_file=True, check_sum=checksum)
93
+ preprocess_egmd16k(data_home, dataset_name='egmd')
94
+
95
+
96
+ def install_mirst500(data_home, zenodo_token, no_down=False, sanity_check=True, apply_correction=False) -> None:
97
+ """ Update Oct 2023: MIR-ST500 with FULL audio files"""
98
+ if not no_down:
99
+ url = "https://zenodo.org/records/10016397/files/mir_st500_yourmt3_16k.tar.gz?download=1"
100
+ checksum = "98eb52eb2456ce4034e21750f309da13"
101
+ download_and_extract(data_home, url, check_sum=checksum, zenodo_token=zenodo_token)
102
+ preprocess_mir_st500_16k(data_home, dataset_name='mir_st500', sanity_check=sanity_check)
103
+
104
+
105
+ def install_cmedia(data_home, zenodo_token, no_down=False, sanity_check=True) -> None:
106
+ if not no_down:
107
+ url = "https://zenodo.org/records/10016397/files/cmedia_yourmt3_16k.tar.gz?download=1"
108
+ checksum = "e6cca23577ba7588e9ed9711a398f7cf"
109
+ download_and_extract(data_home, url, check_sum=checksum, zenodo_token=zenodo_token)
110
+ preprocess_cmedia_16k(data_home, dataset_name='cmedia', sanity_check=sanity_check, apply_correction=True)
111
+
112
+
113
+ def install_rwc_pop(data_home, zenodo_token, no_down=False) -> None:
114
+ if not no_down:
115
+ url = "https://zenodo.org/records/10016397/files/rwc_pop_yourmt3_16k.tar.gz?download=1"
116
+ checksum = "ad459f9fa1b6b87676b2fb37c0ba5dfc"
117
+ download_and_extract(data_home, url, check_sum=checksum, zenodo_token=zenodo_token)
118
+ preprocess_rwc_pop16k(data_home, dataset_name='rwc_pop') # bass transcriptions
119
+ preprocess_rwc_pop_full16k(data_home, dataset_name='rwc_pop') # full transcriptions
120
+
121
+
122
+ def install_mir1k(data_home, no_down=False) -> None:
123
+ if not no_down:
124
+ url = "https://zenodo.org/record/7955481/files/mir1k_yourmt3_16k.tar.gz?download=1"
125
+ checksum = "4cbac56a4e971432ca807efd5cb76d67"
126
+ download_and_extract(data_home, url, remove_tar_file=True, check_sum=checksum)
127
+ # preprocess_mir1k_16k(data_home, dataset_name='mir1k')
128
+
129
+
130
+ def install_urmp(data_home, no_down=False) -> None:
131
+ if not no_down:
132
+ url = "https://zenodo.org/record/8021437/files/urmp_yourmt3_16k.tar.gz?download=1"
133
+ checksum = "4f539c71678a77ba34f6dfca41072102"
134
+ download_and_extract(data_home, url, remove_tar_file=True, check_sum=checksum)
135
+ preprocess_urmp16k(data_home, dataset_name='urmp')
136
+
137
+
138
+ def install_idmt_smt_bass(data_home, no_down=False) -> None:
139
+ if not no_down:
140
+ url = "https://zenodo.org/records/10009959/files/idmt_smt_bass_yourmt3_16k.tar.gz?download=1"
141
+ checksum = "0c95f91926a1e95b1f5d075c05b7eb76"
142
+ download_and_extract(data_home, url, remove_tar_file=True, check_sum=checksum)
143
+ preprocess_idmt_smt_bass_16k(data_home, dataset_name='idmt_smt_bass', sanity_check=True,
144
+ edit_audio=False) # the donwloaded audio has already been edited
145
+
146
+
147
+ def install_random_nsynth(data_home, no_down=False) -> None:
148
+ return
149
+
150
+
151
+ def install_geerdes(data_home) -> None:
152
+ try:
153
+ preprocess_geerdes16k(data_home, dataset_name='geerdes', sanity_check=False)
154
+ except Exception as e:
155
+ print(e)
156
+ print("Geerdes dataset is not available for download. Please contact the dataset provider.")
157
+
158
+
159
+ def regenerate_dataset_stats(data_home) -> None:
160
+ generate_dataset_stats_for_all_datasets(data_home)
161
+
162
+
163
+ def get_cached_zenodo_token() -> str:
164
+ # check if cached token exists
165
+ if not os.path.exists('.cached_zenodo_token'):
166
+ raise Exception("Cached Zenodo token not found. Please enter your Zenodo token.")
167
+ # read cached token
168
+ with open('.cached_zenodo_token', 'r') as f:
169
+ zenodo_token = f.read().strip()
170
+ print(f"Using cached Zenodo token: {zenodo_token}")
171
+ return zenodo_token
172
+
173
+
174
+ def cache_zenodo_token(zenodo_token: str) -> None:
175
+ with open('.cached_zenodo_token', 'w') as f:
176
+ f.write(zenodo_token)
177
+ print("Your Zenodo token is cached.")
178
+
179
+
180
+ def option_prompt(data_home: os.PathLike, no_download: bool = False) -> None:
181
+ print("Select the dataset(s) to install (enter comma-separated numbers):")
182
+ print("1. Slakh")
183
+ print("2. MusicNet")
184
+ print("3. MAPS")
185
+ print("4. Maestro")
186
+ print("5. GuitarSet")
187
+ print("6. ENST-drums")
188
+ print("7. EGMD")
189
+ print("8. MIR-ST500 ** Restricted Access **")
190
+ print("9. CMedia ** Restricted Access **")
191
+ print("10. RWC-Pop (Bass and Full) ** Restricted Access **")
192
+ print("11. MIR-1K (NOT SUPPORTED)")
193
+ print("12. URMP")
194
+ print("13. IDMT-SMT-Bass")
195
+ print("14. Random-NSynth")
196
+ print("15. Geerdes")
197
+ print("16. Regenerate Dataset Stats (experimental)")
198
+ print("17. Request Token for ** Restricted Access **")
199
+ print("18. Exit")
200
+
201
+ choice = input("Enter your choices (multiple choices with comma): ")
202
+ choices = [c.strip() for c in choice.split(',')]
203
+
204
+ if "18" in choices:
205
+ print("Exiting.")
206
+ else:
207
+ # ask for Zenodo token
208
+ for c in choices:
209
+ if int(c) in [8, 9, 10]:
210
+ if no_download is True:
211
+ zenodo_token = None
212
+ else:
213
+ zenodo_token = input("Enter Zenodo token, or press enter to use the cached token:")
214
+ if zenodo_token == "":
215
+ zenodo_token = get_cached_zenodo_token()
216
+ else:
217
+ cache_zenodo_token(zenodo_token)
218
+ break
219
+
220
+ if "1" in choices:
221
+ install_slakh(data_home, no_down=no_download)
222
+ if "2" in choices:
223
+ install_musicnet(data_home, no_down=no_download)
224
+ if "3" in choices:
225
+ install_maps(data_home, no_down=no_download)
226
+ if "4" in choices:
227
+ install_maestro(data_home, no_down=no_download)
228
+ if "5" in choices:
229
+ install_guitarset(data_home, no_down=no_download)
230
+ if "6" in choices:
231
+ install_enstdrums(data_home, no_down=no_download)
232
+ if "7" in choices:
233
+ install_egmd(data_home, no_down=no_download)
234
+ if "8" in choices:
235
+ install_mirst500(data_home, zenodo_token, no_down=no_download)
236
+ if "9" in choices:
237
+ install_cmedia(data_home, zenodo_token, no_down=no_download)
238
+ if "10" in choices:
239
+ install_rwc_pop(data_home, zenodo_token, no_down=no_download)
240
+ if "11" in choices:
241
+ install_mir1k(data_home, no_down=no_download)
242
+ if "12" in choices:
243
+ install_urmp(data_home, no_down=no_download)
244
+ if "13" in choices:
245
+ install_idmt_smt_bass(data_home, no_down=no_download)
246
+ if "14" in choices:
247
+ install_random_nsynth(data_home, no_down=no_download)
248
+ if "15" in choices:
249
+ install_geerdes(data_home) # not available for download
250
+ if "16" in choices:
251
+ regenerate_dataset_stats(data_home, no_down=no_download)
252
+ if "17" in choices:
253
+ print("\nPlease visit https://zenodo.org/records/10016397 to request a Zenodo token.")
254
+ print("Upon submitting your request, you will receive an email with a link labeled 'Access the record'.")
255
+ print("Copy the token that follows 'token=' in that link.")
256
+ if not any(int(c) in range(16) for c in choices):
257
+ print("Invalid choice(s). Please enter valid numbers separated by commas.")
258
+
259
+
260
+ if __name__ == "__main__":
261
+
262
+ parser = argparse.ArgumentParser(description='Dataset installer script.')
263
+ # data home dir
264
+ parser.add_argument(
265
+ 'data_home',
266
+ type=str,
267
+ nargs='?',
268
+ default=None,
269
+ help='Path to data home directory. If None, use the default path defined in src/config/config.py')
270
+ # `no_download` option
271
+ parser.add_argument('--nodown',
272
+ '-nd',
273
+ action='store_true',
274
+ help='Flag to control downloading. If set, no downloading will occur.')
275
+ args = parser.parse_args()
276
+
277
+ if args.data_home is None:
278
+ from config.config import shared_cfg
279
+ data_home = shared_cfg["PATH"]["data_home"]
280
+ else:
281
+ data_home = args.data_home
282
+ os.makedirs(data_home, exist_ok=True)
283
+ no_download = args.nodown
284
+
285
+ option_prompt(data_home, no_download)
amt/src/model/RoPE/RoPE.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """rotary_embedding.py - Rotary Embedding based on https://github.com/lucidrains/rotary-embedding-torch"""
2
+ from typing import Literal, Union, Optional
3
+ from math import pi, log
4
+ from einops import rearrange, repeat
5
+
6
+ import torch
7
+ from torch.nn import Module, ModuleList
8
+ from torch.cuda.amp import autocast
9
+ from torch import nn, einsum, broadcast_tensors, Tensor
10
+
11
+
12
+ # helper functions
13
+ def exists(val):
14
+ return val is not None
15
+
16
+
17
+ def default(val, d):
18
+ return val if exists(val) else d
19
+
20
+
21
+ # broadcat, as tortoise-tts was using it
22
+ def broadcat(tensors, dim=-1):
23
+ broadcasted_tensors = broadcast_tensors(*tensors)
24
+ return torch.cat(broadcasted_tensors, dim=dim)
25
+
26
+
27
+ # rotary embedding helper functions
28
+ def rotate_half(x):
29
+ x = rearrange(x, '... (d r) -> ... d r', r=2)
30
+ x1, x2 = x.unbind(dim=-1)
31
+ x = torch.stack((-x2, x1), dim=-1)
32
+ return rearrange(x, '... d r -> ... (d r)')
33
+
34
+
35
+ @autocast(enabled=False)
36
+ def apply_rotary_emb(freqs, t, start_index=0, scale=1., seq_dim=-2):
37
+ """Applies rotary embedding for pixels."""
38
+ if t.ndim == 3:
39
+ seq_len = t.shape[seq_dim]
40
+ freqs = freqs[-seq_len:].to(t)
41
+
42
+ rot_dim = freqs.shape[-1]
43
+ end_index = start_index + rot_dim
44
+
45
+ assert rot_dim <= t.shape[
46
+ -1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
47
+
48
+ t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
49
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
50
+ return torch.cat((t_left, t, t_right), dim=-1)
51
+
52
+
53
+ # learned rotation helpers
54
+ def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
55
+ if exists(freq_ranges):
56
+ rotations = einsum('..., f -> ... f', rotations, freq_ranges)
57
+ rotations = rearrange(rotations, '... r f -> ... (r f)')
58
+
59
+ rotations = repeat(rotations, '... n -> ... (n r)', r=2)
60
+ return apply_rotary_emb(rotations, t, start_index=start_index)
61
+
62
+
63
+ # classes
64
+ class RotaryEmbedding(Module):
65
+
66
+ def __init__(self,
67
+ dim,
68
+ custom_freqs: Optional[Tensor] = None,
69
+ freqs_for: Union[Literal['lang'], Literal['pixel'], Literal['constant']] = 'lang',
70
+ theta=10000,
71
+ max_freq=10,
72
+ num_freqs=1,
73
+ learned_freq=False,
74
+ use_xpos=False,
75
+ xpos_scale_base=512,
76
+ interpolate_factor=1.,
77
+ theta_rescale_factor=1.,
78
+ seq_before_head_dim=False,
79
+ cache_if_possible=True):
80
+ super().__init__()
81
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
82
+ # has some connection to NTK literature
83
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
84
+
85
+ theta *= theta_rescale_factor**(dim / (dim - 2))
86
+
87
+ self.freqs_for = freqs_for
88
+
89
+ if exists(custom_freqs):
90
+ freqs = custom_freqs
91
+ elif freqs_for == 'lang':
92
+ freqs = 1. / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
93
+ elif freqs_for == 'pixel':
94
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
95
+ elif freqs_for == 'constant':
96
+ freqs = torch.ones(num_freqs).float()
97
+
98
+ self.cache_if_possible = cache_if_possible
99
+
100
+ self.tmp_store('cached_freqs', None)
101
+ self.tmp_store('cached_scales', None)
102
+
103
+ self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)
104
+
105
+ self.learned_freq = learned_freq
106
+
107
+ # dummy for device
108
+
109
+ self.tmp_store('dummy', torch.tensor(0))
110
+
111
+ # default sequence dimension
112
+
113
+ self.seq_before_head_dim = seq_before_head_dim
114
+ self.default_seq_dim = -3 if seq_before_head_dim else -2
115
+
116
+ # interpolation factors
117
+
118
+ assert interpolate_factor >= 1.
119
+ self.interpolate_factor = interpolate_factor
120
+
121
+ # xpos
122
+
123
+ self.use_xpos = use_xpos
124
+ if not use_xpos:
125
+ self.tmp_store('scale', None)
126
+ return
127
+
128
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
129
+ self.scale_base = xpos_scale_base
130
+ self.tmp_store('scale', scale)
131
+
132
+ @property
133
+ def device(self):
134
+ return self.dummy.device
135
+
136
+ def tmp_store(self, key, value):
137
+ self.register_buffer(key, value, persistent=False)
138
+
139
+ def get_seq_pos(self, seq_len, device, dtype, offset=0):
140
+ return (torch.arange(seq_len, device=device, dtype=dtype) + offset) / self.interpolate_factor
141
+
142
+ def rotate_queries_or_keys(self, t, seq_dim=None, offset=0, freq_seq_len=None):
143
+ seq_dim = default(seq_dim, self.default_seq_dim)
144
+
145
+ assert not self.use_xpos, 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings'
146
+
147
+ device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
148
+
149
+ if exists(freq_seq_len):
150
+ assert freq_seq_len >= seq_len
151
+ seq_len = freq_seq_len
152
+
153
+ freqs = self.forward(self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset),
154
+ seq_len=seq_len,
155
+ offset=offset)
156
+
157
+ if seq_dim == -3:
158
+ freqs = rearrange(freqs, 'n d -> n 1 d')
159
+
160
+ return apply_rotary_emb(freqs, t, seq_dim=seq_dim)
161
+
162
+ def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0):
163
+ seq_dim = default(seq_dim, self.default_seq_dim)
164
+
165
+ q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
166
+ assert q_len <= k_len
167
+ rotated_q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, freq_seq_len=k_len)
168
+ rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim)
169
+
170
+ rotated_q = rotated_q.type(q.dtype)
171
+ rotated_k = rotated_k.type(k.dtype)
172
+
173
+ return rotated_q, rotated_k
174
+
175
+ def rotate_queries_and_keys(self, q, k, seq_dim=None):
176
+ seq_dim = default(seq_dim, self.default_seq_dim)
177
+
178
+ assert self.use_xpos
179
+ device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
180
+
181
+ seq = self.get_seq_pos(seq_len, dtype=dtype, device=device)
182
+
183
+ freqs = self.forward(seq, seq_len=seq_len)
184
+ scale = self.get_scale(seq, seq_len=seq_len).to(dtype)
185
+
186
+ if seq_dim == -3:
187
+ freqs = rearrange(freqs, 'n d -> n 1 d')
188
+ scale = rearrange(scale, 'n d -> n 1 d')
189
+
190
+ rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim)
191
+ rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim)
192
+
193
+ rotated_q = rotated_q.type(q.dtype)
194
+ rotated_k = rotated_k.type(k.dtype)
195
+
196
+ return rotated_q, rotated_k
197
+
198
+ def get_scale(self, t: Tensor, seq_len: Optional[int] = None, offset=0):
199
+ assert self.use_xpos
200
+
201
+ should_cache = (self.cache_if_possible and exists(seq_len))
202
+
203
+ if (
204
+ should_cache and \
205
+ exists(self.cached_scales) and \
206
+ (seq_len + offset) <= self.cached_scales.shape[0]
207
+ ):
208
+ return self.cached_scales[offset:(offset + seq_len)]
209
+
210
+ scale = 1.
211
+ if self.use_xpos:
212
+ power = (t - len(t) // 2) / self.scale_base
213
+ scale = self.scale**rearrange(power, 'n -> n 1')
214
+ scale = torch.cat((scale, scale), dim=-1)
215
+
216
+ if should_cache:
217
+ self.tmp_store('cached_scales', scale)
218
+
219
+ return scale
220
+
221
+ def get_axial_freqs(self, *dims):
222
+ Colon = slice(None)
223
+ all_freqs = []
224
+
225
+ for ind, dim in enumerate(dims):
226
+ if self.freqs_for == 'pixel':
227
+ pos = torch.linspace(-1, 1, steps=dim, device=self.device)
228
+ else:
229
+ pos = torch.arange(dim, device=self.device)
230
+
231
+ freqs = self.forward(pos, seq_len=dim)
232
+
233
+ all_axis = [None] * len(dims)
234
+ all_axis[ind] = Colon
235
+
236
+ new_axis_slice = (Ellipsis, *all_axis, Colon)
237
+ all_freqs.append(freqs[new_axis_slice])
238
+
239
+ all_freqs = broadcast_tensors(*all_freqs)
240
+ return torch.cat(all_freqs, dim=-1)
241
+
242
+ @autocast(enabled=False)
243
+ def forward(self, t: Tensor, seq_len=None, offset=0):
244
+ should_cache = (
245
+ self.cache_if_possible and \
246
+ not self.learned_freq and \
247
+ exists(seq_len) and \
248
+ self.freqs_for != 'pixel'
249
+ )
250
+
251
+ if (
252
+ should_cache and \
253
+ exists(self.cached_freqs) and \
254
+ (offset + seq_len) <= self.cached_freqs.shape[0]
255
+ ):
256
+ return self.cached_freqs[offset:(offset + seq_len)].detach()
257
+
258
+ freqs = self.freqs
259
+
260
+ freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
261
+ freqs = repeat(freqs, '... n -> ... (n r)', r=2)
262
+
263
+ if should_cache:
264
+ self.tmp_store('cached_freqs', freqs.detach())
265
+
266
+ return freqs
267
+
268
+ # custom method for applying rotary embeddings
269
+ @torch.compiler.disable
270
+ def apply_rotary_custom(self, t: torch.Tensor):
271
+ """Apply rotary embeddings to queries and keys, if k is None, only q is rotated.
272
+ Depending on the freqs type, the rotation will be different."""
273
+ if self.freqs_for == 'lang':
274
+ return self.rotate_queries_or_keys(t, seq_dim=-2)
275
+ elif self.freqs_for == 'pixel':
276
+ return apply_rotary_emb(self.get_axial_freqs(t.shape[-2]), t)
277
+ else:
278
+ raise ValueError(f"freqs_for must be 'lang' or 'pixel', but got {self.freqs_for}")
279
+
280
+
281
+ def test_rotary_embedding_lang():
282
+ d = 32 # d by head
283
+ q = torch.ones(1, 4, 110, 32) # (B, H, T, D) for multi-head attention
284
+ rdim = d // 2 # will do a partial rotation on half, or d
285
+
286
+ rotary = RotaryEmbedding(dim=rdim, freqs_for="lang")
287
+ q = rotary.rotate_queries_or_keys(q, seq_dim=-2)
288
+
289
+ # visualize
290
+ import matplotlib.pyplot as plt
291
+ plt.imshow(q[0, 0, :, :].numpy().T, origin='lower')
292
+
293
+
294
+ def test_rotary_embedding_pixel():
295
+ d = 32 # d by head
296
+ q = torch.ones(1, 4, 128, 32) # (B*T, H, F, C/H) for multi-head attention
297
+ rdim = d // 2 # will do a partial rotation on half
298
+
299
+ rotary = RotaryEmbedding(dim=rdim, freqs_for="pixel", max_freq=10)
300
+ freqs = rotary.get_axial_freqs(128)
301
+
302
+ q = apply_rotary_emb(freqs, q) # also k, if needed
303
+
304
+ # visualize
305
+ import matplotlib.pyplot as plt
306
+ plt.imshow(q[0, 0, :, :].numpy().T, origin='lower')
amt/src/model/conformer_helper.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The YourMT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Please see the details in the LICENSE file.
10
+ import math
11
+ from typing import Optional, Union
12
+
13
+ from torch import nn
14
+ from transformers.configuration_utils import PretrainedConfig
15
+ from transformers.modeling_utils import PreTrainedModel
16
+
17
+
18
+ class ConformerYMT3Config(PretrainedConfig):
19
+ r"""
20
+ This is the configuration class to store the configuration of a [`ConformerYMT3Encoder`]. It is used to
21
+ instantiate an ConformerYMT3Encoder according to the specified arguments, defining the model architecture.
22
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the Wav2Vec2Conformer
23
+ [facebook/wav2vec2-conformer-rel-pos-large](https://huggingface.co/facebook/wav2vec2-conformer-rel-pos-large)
24
+ architecture.
25
+
26
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
27
+ documentation from [`PretrainedConfig`] for more information.
28
+
29
+
30
+ Args:
31
+ d_model (`int`, *optional*, defaults to 512):
32
+ Dimensionality of the encoder layers and the pooler layer.
33
+ num_layers (`int`, *optional*, defaults to 12):
34
+ Number of hidden layers in the Transformer encoder.
35
+ num_heads (`int`, *optional*, defaults to 12):
36
+ Number of attention heads for each attention layer in the Transformer encoder.
37
+ intermediate_size (`int`, *optional*, defaults to 2048):
38
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
39
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
40
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
41
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
42
+ dropout_rate (`float`, *optional*, defaults to 0.05):
43
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
44
+ layerdrop (`float`, *optional*, defaults to 0.1):
45
+ The LayerDrop probability. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) for more
46
+ details.
47
+ initializer_range (`float`, *optional*, defaults to 0.02):
48
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
49
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
50
+ The epsilon used by the layer normalization layers.
51
+ conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
52
+ A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
53
+ feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
54
+ conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
55
+ A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
56
+ of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
57
+ conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
58
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
59
+ length of *conv_kernel* defines the number of convolutional layers and has to match the length of
60
+ *conv_dim*.
61
+ conv_bias (`bool`, *optional*, defaults to `False`):
62
+ Whether the 1D convolutional layers have a bias.
63
+ output_hidden_size (`int`, *optional*):
64
+ Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant
65
+ if `add_adapter is True`.
66
+ position_encoding_type (`str`, *optional*, defaults to `"relative"`):
67
+ Can be specified to `relative` or `rotary` for relative or rotary position embeddings respectively. If left
68
+ `None` no relative position embedding is applied.
69
+ rotary_embedding_base (`int`, *optional*, defaults to 10000):
70
+ If `"rotary"` position embeddings are used, defines the size of the embedding base.
71
+ num_max_positions (`int`, *optional*, defaults to 5000):
72
+ if `"relative"` position embeddings are used, defines the maximum source input positions.
73
+ conv_depthwise_kernel_size (`int`, defaults to 31):
74
+ Kernel size of convolutional depthwise 1D layer in Conformer blocks.
75
+
76
+ Example:
77
+
78
+ ```python
79
+ >>> from transformers import ConformerYMT3Config, ConformerYMT3Encoder
80
+
81
+ >>> # Initializing a ConformerYMT3Encoder configuration
82
+ >>> configuration = ConformerYMT3Config()
83
+
84
+ >>> # Initializing a model (with random weights) from the facebook/wav2vec2-conformer-rel-pos-large style configuration
85
+ >>> model = ConformerYMT3Encoder(configuration)
86
+
87
+ >>> # Accessing the model configuration
88
+ >>> configuration = model.config
89
+ ```"""
90
+ model_type = "conformer-ymt3"
91
+
92
+ def __init__(
93
+ self,
94
+ d_model=512, # 768
95
+ num_layers=8, # ConformerYMT3Encoder
96
+ num_heads=8, # ConformerYMT3SelfAttention
97
+ intermediate_size=2048, # 3072,# used in intermediate_dense of ConformerYMT3FeedForward
98
+ hidden_act="gelu", # used in intermediate_act_fn of ConformerYMT3FeedForward
99
+ dropout_rate=0.1,
100
+ layerdrop=0.1,
101
+ initializer_range=0.02,
102
+ layer_norm_eps=1e-5,
103
+ conv_dim=(512, 512, 512, 512, 512, 512, 512),
104
+ conv_stride=(5, 2, 2, 2, 2, 2, 2),
105
+ conv_kernel=(10, 3, 3, 3, 3, 3, 3),
106
+ conv_bias=False,
107
+ position_encoding_type="rotary",
108
+ rotary_embedding_base=10000,
109
+ num_max_positions=1024,
110
+ conv_depthwise_kernel_size=31,
111
+ **kwargs,
112
+ ):
113
+ super().__init__(**kwargs)
114
+ self.d_model = d_model
115
+ self.conv_dim = list(conv_dim)
116
+ self.conv_stride = list(conv_stride)
117
+ self.conv_kernel = list(conv_kernel)
118
+ self.conv_bias = conv_bias
119
+ self.num_layers = num_layers
120
+ self.intermediate_size = intermediate_size
121
+ self.hidden_act = hidden_act
122
+ self.num_heads = num_heads
123
+ self.dropout_rate = dropout_rate
124
+
125
+ self.layerdrop = layerdrop
126
+ self.layer_norm_eps = layer_norm_eps
127
+ self.initializer_range = initializer_range
128
+ self.num_max_positions = num_max_positions
129
+ self.position_encoding_type = position_encoding_type
130
+ self.rotary_embedding_base = rotary_embedding_base
131
+
132
+ # Conformer-block related
133
+ self.conv_depthwise_kernel_size = conv_depthwise_kernel_size
134
+
135
+
136
+ class ConformerYMT3PreTrainedModel(PreTrainedModel):
137
+ """
138
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
139
+ models.
140
+ """
141
+
142
+ config_class = ConformerYMT3Config
143
+ base_model_prefix = "wav2vec2_conformer"
144
+ main_input_name = "input_values"
145
+ supports_gradient_checkpointing = True
146
+
147
+ def _init_weights(self, module):
148
+ """Initialize the weights"""
149
+ if module.__class__.__name__ == "ConformerYMT3SelfAttention":
150
+ if hasattr(module, "pos_bias_u"):
151
+ nn.init.xavier_uniform_(module.pos_bias_u)
152
+ if hasattr(module, "pos_bias_v"):
153
+ nn.init.xavier_uniform_(module.pos_bias_v)
154
+ elif isinstance(module, nn.Linear):
155
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
156
+ if module.bias is not None:
157
+ module.bias.data.zero_()
158
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
159
+ module.bias.data.zero_()
160
+ module.weight.data.fill_(1.0)
161
+ elif isinstance(module, nn.Conv1d):
162
+ nn.init.kaiming_normal_(module.weight)
163
+ if module.bias is not None:
164
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
165
+ nn.init.uniform_(module.bias, a=-k, b=k)
166
+
167
+ def _set_gradient_checkpointing(self, module, value=False):
168
+ if module.__class__.__name__ == "ConformerYMT3Encoder":
169
+ module.gradient_checkpointing = value
amt/src/model/conformer_mod.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The YourMT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Please see the details in the LICENSE file.
10
+ from typing import Tuple, Literal, Any, Optional
11
+ import math
12
+
13
+ import torch
14
+ from torch import nn
15
+ from transformers.activations import ACT2FN
16
+ from transformers.modeling_outputs import BaseModelOutput
17
+
18
+ from model.conformer_helper import ConformerYMT3Config, ConformerYMT3PreTrainedModel
19
+ from model.positional_encoding import (Wav2Vec2ConformerRelPositionalEmbedding,
20
+ Wav2Vec2ConformerRotaryPositionalEmbedding)
21
+
22
+
23
+ class ConformerYMT3FeedForward(nn.Module):
24
+
25
+ def __init__(self, config):
26
+ super().__init__()
27
+ self.intermediate_dropout = nn.Dropout(config.dropout_rate)
28
+
29
+ self.intermediate_dense = nn.Linear(config.d_model, config.intermediate_size)
30
+ if isinstance(config.hidden_act, str):
31
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
32
+ else:
33
+ self.intermediate_act_fn = config.hidden_act
34
+
35
+ self.output_dense = nn.Linear(config.intermediate_size, config.d_model)
36
+ self.output_dropout = nn.Dropout(config.dropout_rate)
37
+
38
+ def forward(self, hidden_states):
39
+ hidden_states = self.intermediate_dense(hidden_states)
40
+ hidden_states = self.intermediate_act_fn(hidden_states)
41
+ hidden_states = self.intermediate_dropout(hidden_states)
42
+
43
+ hidden_states = self.output_dense(hidden_states)
44
+ hidden_states = self.output_dropout(hidden_states)
45
+ return hidden_states
46
+
47
+
48
+ class ConformerYMT3ConvolutionModule(nn.Module):
49
+ """Convolution block used in the conformer block"""
50
+
51
+ def __init__(self, config):
52
+ super().__init__()
53
+ if (config.conv_depthwise_kernel_size - 1) % 2 == 1:
54
+ raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding")
55
+ self.layer_norm = nn.LayerNorm(config.d_model)
56
+ self.pointwise_conv1 = torch.nn.Conv1d(
57
+ config.d_model,
58
+ 2 * config.d_model,
59
+ kernel_size=1,
60
+ stride=1,
61
+ padding=0,
62
+ bias=False,
63
+ )
64
+ self.glu = torch.nn.GLU(dim=1)
65
+ self.depthwise_conv = torch.nn.Conv1d(
66
+ config.d_model,
67
+ config.d_model,
68
+ config.conv_depthwise_kernel_size,
69
+ stride=1,
70
+ padding=(config.conv_depthwise_kernel_size - 1) // 2,
71
+ groups=config.d_model,
72
+ bias=False,
73
+ )
74
+ self.batch_norm = torch.nn.BatchNorm1d(config.d_model)
75
+ self.activation = ACT2FN[config.hidden_act]
76
+ self.pointwise_conv2 = torch.nn.Conv1d(
77
+ config.d_model,
78
+ config.d_model,
79
+ kernel_size=1,
80
+ stride=1,
81
+ padding=0,
82
+ bias=False,
83
+ )
84
+ self.dropout = torch.nn.Dropout(config.dropout_rate)
85
+
86
+ def forward(self, hidden_states):
87
+ hidden_states = self.layer_norm(hidden_states)
88
+ # exchange the temporal dimension and the feature dimension
89
+ hidden_states = hidden_states.transpose(1, 2)
90
+
91
+ # GLU mechanism
92
+ # => (batch, 2*channel, dim)
93
+ hidden_states = self.pointwise_conv1(hidden_states)
94
+ # => (batch, channel, dim)
95
+ hidden_states = self.glu(hidden_states)
96
+
97
+ # 1D Depthwise Conv
98
+ hidden_states = self.depthwise_conv(hidden_states)
99
+ hidden_states = self.batch_norm(hidden_states)
100
+ hidden_states = self.activation(hidden_states)
101
+
102
+ hidden_states = self.pointwise_conv2(hidden_states)
103
+ hidden_states = self.dropout(hidden_states)
104
+ hidden_states = hidden_states.transpose(1, 2)
105
+ return hidden_states
106
+
107
+
108
+ class ConformerYMT3SelfAttention(nn.Module):
109
+ """Construct a ConformerSelfAttention object.
110
+ Can be enhanced with rotary or relative position embeddings.
111
+ """
112
+
113
+ def __init__(self, config):
114
+ super().__init__()
115
+
116
+ self.head_size = config.d_model // config.num_heads
117
+ self.num_heads = config.num_heads
118
+ self.position_encoding_type = config.position_encoding_type
119
+
120
+ self.linear_q = nn.Linear(config.d_model, config.d_model)
121
+ self.linear_k = nn.Linear(config.d_model, config.d_model)
122
+ self.linear_v = nn.Linear(config.d_model, config.d_model)
123
+ self.linear_out = nn.Linear(config.d_model, config.d_model)
124
+
125
+ self.dropout = nn.Dropout(p=config.dropout_rate)
126
+
127
+ if self.position_encoding_type == "relative":
128
+ # linear transformation for positional encoding
129
+ self.linear_pos = nn.Linear(config.d_model, config.d_model, bias=False)
130
+ # these two learnable bias are used in matrix c and matrix d
131
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
132
+ self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
133
+ self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
134
+
135
+ def forward(
136
+ self,
137
+ hidden_states: torch.Tensor,
138
+ attention_mask: Optional[torch.Tensor] = None,
139
+ relative_position_embeddings: Optional[torch.Tensor] = None,
140
+ output_attentions: bool = False,
141
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
142
+ # self-attention mechanism
143
+ batch_size, sequence_length, d_model = hidden_states.size()
144
+
145
+ # make sure query/key states can be != value states
146
+ query_key_states = hidden_states
147
+ value_states = hidden_states
148
+
149
+ if self.position_encoding_type == "rotary":
150
+ if relative_position_embeddings is None:
151
+ raise ValueError(
152
+ "`relative_position_embeddings` has to be defined when `self.position_encoding_type == 'rotary'")
153
+ query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings)
154
+
155
+ # project query_key_states and value_states
156
+ query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
157
+ key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
158
+ value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size)
159
+
160
+ # => (batch, head, time1, d_k)
161
+ query = query.transpose(1, 2)
162
+ key = key.transpose(1, 2)
163
+ value = value.transpose(1, 2)
164
+
165
+ if self.position_encoding_type == "relative":
166
+ if relative_position_embeddings is None:
167
+ raise ValueError("`relative_position_embeddings` has to be defined when `self.position_encoding_type =="
168
+ " 'relative'")
169
+ # apply relative_position_embeddings to qk scores
170
+ # as proposed in Transformer_XL: https://arxiv.org/abs/1901.02860
171
+ scores = self._apply_relative_embeddings(query=query,
172
+ key=key,
173
+ relative_position_embeddings=relative_position_embeddings)
174
+ else:
175
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_size)
176
+
177
+ # apply attention_mask if necessary
178
+ if attention_mask is not None:
179
+ scores = scores + attention_mask
180
+
181
+ # => (batch, head, time1, time2)
182
+ probs = torch.softmax(scores, dim=-1)
183
+ probs = self.dropout(probs)
184
+
185
+ # => (batch, head, time1, d_k)
186
+ hidden_states = torch.matmul(probs, value)
187
+
188
+ # => (batch, time1, d_model)
189
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size)
190
+ hidden_states = self.linear_out(hidden_states)
191
+
192
+ return hidden_states, probs
193
+
194
+ def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings):
195
+ batch_size, sequence_length, d_model = hidden_states.size()
196
+ hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size)
197
+
198
+ cos = relative_position_embeddings[0, :sequence_length, ...]
199
+ sin = relative_position_embeddings[1, :sequence_length, ...]
200
+
201
+ # rotate hidden_states with rotary embeddings
202
+ hidden_states = hidden_states.transpose(0, 1)
203
+ rotated_states_begin = hidden_states[..., :self.head_size // 2]
204
+ rotated_states_end = hidden_states[..., self.head_size // 2:]
205
+ rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1)
206
+ hidden_states = (hidden_states * cos) + (rotated_states * sin)
207
+ hidden_states = hidden_states.transpose(0, 1)
208
+
209
+ hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size)
210
+
211
+ return hidden_states
212
+
213
+ def _apply_relative_embeddings(self, query, key, relative_position_embeddings):
214
+ # 1. project positional embeddings
215
+ # => (batch, head, 2*time1-1, d_k)
216
+ proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings)
217
+ proj_relative_position_embeddings = proj_relative_position_embeddings.view(relative_position_embeddings.size(0),
218
+ -1, self.num_heads, self.head_size)
219
+ proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2)
220
+ proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3)
221
+
222
+ # 2. Add bias to query
223
+ # => (batch, head, time1, d_k)
224
+ query = query.transpose(1, 2)
225
+ q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
226
+ q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
227
+
228
+ # 3. attention score: first compute matrix a and matrix c
229
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
230
+ # => (batch, head, time1, time2)
231
+ scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
232
+
233
+ # 4. then compute matrix b and matrix d
234
+ # => (batch, head, time1, 2*time1-1)
235
+ scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings)
236
+
237
+ # 5. shift matrix b and matrix d
238
+ zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype)
239
+ scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1)
240
+ scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2])
241
+ scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape)
242
+ scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd)
243
+ scores_bd = scores_bd[:, :, :, :scores_bd.size(-1) // 2 + 1]
244
+
245
+ # 6. sum matrices
246
+ # => (batch, head, time1, time2)
247
+ scores = (scores_ac + scores_bd) / math.sqrt(self.head_size)
248
+
249
+ return scores
250
+
251
+
252
+ class ConformerYMT3EncoderLayer(nn.Module):
253
+ """Conformer block based on https://arxiv.org/abs/2005.08100."""
254
+
255
+ def __init__(self, config):
256
+ super().__init__()
257
+ embed_dim = config.d_model
258
+ dropout = config.dropout_rate
259
+
260
+ # Feed-forward 1
261
+ self.ffn1_layer_norm = nn.LayerNorm(embed_dim)
262
+ self.ffn1 = ConformerYMT3FeedForward(config)
263
+
264
+ # Self-Attention
265
+ self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
266
+ self.self_attn_dropout = torch.nn.Dropout(dropout)
267
+ self.self_attn = ConformerYMT3SelfAttention(config)
268
+
269
+ # Conformer Convolution
270
+ self.conv_module = ConformerYMT3ConvolutionModule(config)
271
+
272
+ # Feed-forward 2
273
+ self.ffn2_layer_norm = nn.LayerNorm(embed_dim)
274
+ self.ffn2 = ConformerYMT3FeedForward(config)
275
+ self.final_layer_norm = nn.LayerNorm(embed_dim)
276
+
277
+ def forward(
278
+ self,
279
+ hidden_states,
280
+ attention_mask: Optional[torch.Tensor] = None,
281
+ relative_position_embeddings: Optional[torch.Tensor] = None,
282
+ output_attentions: bool = False,
283
+ ):
284
+ hidden_states = hidden_states
285
+
286
+ # 1. Feed-Forward 1 layer
287
+ residual = hidden_states
288
+ hidden_states = self.ffn1_layer_norm(hidden_states)
289
+ hidden_states = self.ffn1(hidden_states)
290
+ hidden_states = hidden_states * 0.5 + residual
291
+ residual = hidden_states
292
+
293
+ # 2. Self-Attention layer
294
+ hidden_states = self.self_attn_layer_norm(hidden_states)
295
+ hidden_states, attn_weigts = self.self_attn(
296
+ hidden_states=hidden_states,
297
+ attention_mask=attention_mask,
298
+ relative_position_embeddings=relative_position_embeddings,
299
+ output_attentions=output_attentions,
300
+ )
301
+ hidden_states = self.self_attn_dropout(hidden_states)
302
+ hidden_states = hidden_states + residual
303
+
304
+ # 3. Convolutional Layer
305
+ residual = hidden_states
306
+ hidden_states = self.conv_module(hidden_states)
307
+ hidden_states = residual + hidden_states
308
+
309
+ # 4. Feed-Forward 2 Layer
310
+ residual = hidden_states
311
+ hidden_states = self.ffn2_layer_norm(hidden_states)
312
+ hidden_states = self.ffn2(hidden_states)
313
+ hidden_states = hidden_states * 0.5 + residual
314
+ hidden_states = self.final_layer_norm(hidden_states)
315
+
316
+ return hidden_states, attn_weigts
317
+
318
+
319
+ class ConformerYMT3Encoder(nn.Module):
320
+
321
+ def __init__(self, config):
322
+ super().__init__()
323
+ self.config = config
324
+
325
+ if config.position_encoding_type == "relative":
326
+ self.embed_positions = Wav2Vec2ConformerRelPositionalEmbedding(config)
327
+ elif config.position_encoding_type == "rotary":
328
+ self.embed_positions = Wav2Vec2ConformerRotaryPositionalEmbedding(config)
329
+ else:
330
+ self.embed_positions = None
331
+
332
+ # self.pos_conv_embed = Wav2Vec2ConformerPositionalConvEmbedding(config)
333
+ self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
334
+ self.dropout = nn.Dropout(config.dropout_rate)
335
+ self.layers = nn.ModuleList([ConformerYMT3EncoderLayer(config) for _ in range(config.num_layers)])
336
+ self.gradient_checkpointing = False
337
+
338
+ def forward(
339
+ self,
340
+ inputs_embeds: torch.FloatTensor, # (B, T, D)
341
+ attention_mask: Optional[torch.FloatTensor] = None,
342
+ output_attentions: Optional[bool] = False,
343
+ output_hidden_states: Optional[bool] = False,
344
+ return_dict: Optional[bool] = True,
345
+ ):
346
+ if output_attentions is None:
347
+ output_attentions = self.config.output_attentions
348
+ if output_hidden_states is None:
349
+ output_hidden_states = self.config.output_hidden_states
350
+ if return_dict is None:
351
+ return_dict = self.config.use_return_dict
352
+ all_hidden_states = () if output_hidden_states else None
353
+ all_self_attentions = () if output_attentions else None
354
+
355
+ # inputs_embeds as hidden_states
356
+ hidden_states = inputs_embeds
357
+
358
+ if attention_mask is not None:
359
+ # make sure padded tokens output 0
360
+ hidden_states[~attention_mask] = 0.0
361
+
362
+ # extend attention_mask
363
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
364
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
365
+ attention_mask = attention_mask.expand(attention_mask.shape[0], 1, attention_mask.shape[-1],
366
+ attention_mask.shape[-1])
367
+
368
+ hidden_states = self.dropout(hidden_states)
369
+
370
+ if self.embed_positions is not None:
371
+ relative_position_embeddings = self.embed_positions(hidden_states)
372
+ else:
373
+ relative_position_embeddings = None
374
+
375
+ for i, layer in enumerate(self.layers):
376
+ if output_hidden_states:
377
+ all_hidden_states = all_hidden_states + (hidden_states,)
378
+
379
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
380
+ dropout_probability = torch.rand([])
381
+
382
+ skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
383
+ if not skip_the_layer:
384
+ # under deepspeed zero3 all gpus must run in sync
385
+ if self.gradient_checkpointing and self.training:
386
+ # create gradient checkpointing function
387
+ def create_custom_forward(module):
388
+
389
+ def custom_forward(*inputs):
390
+ return module(*inputs, output_attentions)
391
+
392
+ return custom_forward
393
+
394
+ layer_outputs = torch.utils.checkpoint.checkpoint(
395
+ create_custom_forward(layer),
396
+ hidden_states,
397
+ attention_mask,
398
+ relative_position_embeddings,
399
+ )
400
+ else:
401
+ layer_outputs = layer(
402
+ hidden_states,
403
+ attention_mask=attention_mask,
404
+ relative_position_embeddings=relative_position_embeddings,
405
+ output_attentions=output_attentions,
406
+ )
407
+ hidden_states = layer_outputs[0]
408
+
409
+ if skip_the_layer:
410
+ layer_outputs = (None, None)
411
+
412
+ if output_attentions:
413
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
414
+
415
+ hidden_states = self.layer_norm(hidden_states)
416
+ if output_hidden_states:
417
+ all_hidden_states = all_hidden_states + (hidden_states,)
418
+
419
+ if not return_dict:
420
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
421
+ return BaseModelOutput(
422
+ last_hidden_state=hidden_states,
423
+ hidden_states=all_hidden_states,
424
+ attentions=all_self_attentions,
425
+ )
426
+
427
+
428
+ def test():
429
+ import torch
430
+ from model.conformer_mod import ConformerYMT3Encoder
431
+ from model.conformer_helper import ConformerYMT3Config
432
+ from model.ops import count_parameters
433
+ config = ConformerYMT3Config()
434
+ encoder = ConformerYMT3Encoder(config)
435
+ encoder.eval()
436
+ # num params: 48,468,992 w/ intermediate_size=2048
437
+ # num params: 23,278,592 w/ intermediate_size=512
438
+ x = torch.randn(2, 256, 512) # (B, T, D)
439
+ enc_hs = encoder.forward(inputs_embeds=x)['last_hidden_state'] # (B, T, D)
amt/src/model/conv_block.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The YourMT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Please see the details in the LICENSE file.
10
+ from typing import Literal
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from einops import rearrange
15
+
16
+
17
+ def init_layer(layer: nn.Module) -> None:
18
+ """Initialize a Linear or Convolutional layer."""
19
+ nn.init.xavier_uniform_(layer.weight)
20
+ if hasattr(layer, "bias") and layer.bias is not None:
21
+ layer.bias.data.zero_()
22
+
23
+
24
+ def init_bn(bn: nn.Module) -> None:
25
+ """Initialize a Batchnorm layer."""
26
+ bn.bias.data.zero_()
27
+ bn.weight.data.fill_(1.0)
28
+ bn.running_mean.data.zero_()
29
+ bn.running_var.data.fill_(1.0)
30
+
31
+
32
+ def act(x: torch.Tensor, activation: str) -> torch.Tensor:
33
+ """Activation function."""
34
+ funcs = {"relu": F.relu_, "leaky_relu": lambda x: F.leaky_relu_(x, 0.01), "swish": lambda x: x * torch.sigmoid(x)}
35
+ return funcs.get(activation, lambda x: Exception("Incorrect activation!"))(x)
36
+
37
+
38
+ class Res2DAVPBlock(nn.Module):
39
+
40
+ def __init__(self, in_channels, out_channels, kernel_size, avp_kernel_size, activation):
41
+ """Convolutional residual block modified fromr bytedance/music_source_separation."""
42
+ super().__init__()
43
+
44
+ padding = kernel_size[0] // 2, kernel_size[1] // 2
45
+
46
+ self.activation = activation
47
+ self.bn1, self.bn2 = nn.BatchNorm2d(out_channels), nn.BatchNorm2d(out_channels)
48
+
49
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=False)
50
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding, bias=False)
51
+
52
+ self.is_shortcut = in_channels != out_channels
53
+ if self.is_shortcut:
54
+ self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
55
+
56
+ self.avp = nn.AvgPool2d(avp_kernel_size)
57
+ self.init_weights()
58
+
59
+ def init_weights(self):
60
+ for m in [self.conv1, self.conv2] + ([self.shortcut] if self.is_shortcut else []):
61
+ init_layer(m)
62
+ for m in [self.bn1, self.bn2]:
63
+ init_bn(m)
64
+
65
+ def forward(self, x):
66
+ origin = x
67
+ x = act(self.bn1(self.conv1(x)), self.activation)
68
+ x = self.bn2(self.conv2(x))
69
+ x += self.shortcut(origin) if self.is_shortcut else origin
70
+ x = act(x, self.activation)
71
+ return self.avp(x)
72
+
73
+
74
+ class PreEncoderBlockRes3B(nn.Module):
75
+
76
+ def __init__(self, in_channels, out_channels, kernel_size=(3, 3), avp_kernerl_size=(1, 2), activation='relu'):
77
+ """Pre-Encoder with 3 Res2DAVPBlocks."""
78
+ super().__init__()
79
+
80
+ self.blocks = nn.ModuleList([
81
+ Res2DAVPBlock(in_channels if i == 0 else out_channels, out_channels, kernel_size, avp_kernerl_size,
82
+ activation) for i in range(3)
83
+ ])
84
+
85
+ def forward(self, x): # (B, T, F)
86
+ x = rearrange(x, 'b t f -> b 1 t f')
87
+ for block in self.blocks:
88
+ x = block(x)
89
+ return rearrange(x, 'b c t f -> b t f c')
90
+
91
+
92
+ def test_res3b():
93
+ # mel-spec input
94
+ x = torch.randn(2, 256, 512) # (B, T, F)
95
+ pre = PreEncoderBlockRes3B(in_channels=1, out_channels=128)
96
+ x = pre(x) # (2, 256, 64, 128): B T,F,C
97
+
98
+ x = torch.randn(2, 110, 1024) # (B, T, F)
99
+ pre = PreEncoderBlockRes3B(in_channels=1, out_channels=128)
100
+ x = pre(x) # (2, 110, 128, 128): B,T,F,C
101
+
102
+
103
+ # ====================================================================================================================
104
+ # PreEncoderBlockHFTT: hFT-Transformer-like Pre-encoder
105
+ # ====================================================================================================================
106
+ class PreEncoderBlockHFTT(nn.Module):
107
+
108
+ def __init__(self, margin_pre=15, margin_post=16) -> None:
109
+ """Pre-Encoder with hFT-Transformer-like convolutions."""
110
+ super().__init__()
111
+
112
+ self.margin_pre, self.margin_post = margin_pre, margin_post
113
+ self.conv = nn.Conv2d(1, 4, kernel_size=(1, 5), padding='same', padding_mode='zeros')
114
+ self.emb_freq = nn.Linear(128, 128)
115
+
116
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
117
+ # x: (B, T, F)
118
+ x = rearrange(x, 'b t f -> b 1 f t') # (B, 1, F, T) or (2, 1, 128, 110)
119
+ x = F.pad(x, (self.margin_pre, self.margin_post), value=1e-7) # (B, 1, F, T+margin) or (2,1,128,141)
120
+ x = self.conv(x) # (B, C, F, T+margin) or (2, 4, 128, 141)
121
+ x = x.unfold(dimension=3, size=32, step=1) # (B, c1, T, F, c2) or (2, 4, 128, 110, 32)
122
+ x = rearrange(x, 'b c1 f t c2 -> b t f (c1 c2)') # (B, T, F, C) or (2, 110, 128, 128)
123
+ return self.emb_freq(x) # (B, T, F, C) or (2, 110, 128, 128)
124
+
125
+
126
+ def test_hftt():
127
+ # from model.spectrogram import get_spectrogram_layer_from_audio_cfg
128
+ # from config.config import audio_cfg as default_audio_cfg
129
+ # audio_cfg = default_audio_cfg
130
+ # audio_cfg['codec'] = 'melspec'
131
+ # audio_cfg['hop_length'] = 300
132
+ # audio_cfg['n_mels'] = 128
133
+ # x = torch.randn(2, 1, 32767)
134
+ # mspec, _ = get_spectrogram_layer_from_audio_cfg(audio_cfg)
135
+ # x = mspec(x)
136
+ x = torch.randn(2, 110, 128) # (B, T, F)
137
+ pre_enc_hftt = PreEncoderBlockHFTT()
138
+ y = pre_enc_hftt(x) # (2, 110, 128, 128): B, T, F, C
139
+
140
+
141
+ # ====================================================================================================================
142
+ # PreEncoderBlockRes3BHFTT: hFT-Transformer-like Pre-encoder with Res2DAVPBlock and spec input
143
+ # ====================================================================================================================
144
+ class PreEncoderBlockRes3BHFTT(nn.Module):
145
+
146
+ def __init__(self, margin_pre: int = 15, margin_post: int = 16) -> None:
147
+ """Pre-Encoder with hFT-Transformer-like convolutions.
148
+
149
+ Args:
150
+ margin_pre (int): padding before the input
151
+ margin_post (int): padding after the input
152
+ stack_dim (Literal['c', 'f']): stack dimension. channel or frequency
153
+
154
+ """
155
+ super().__init__()
156
+ self.margin_pre, self.margin_post = margin_pre, margin_post
157
+ self.res3b = PreEncoderBlockRes3B(in_channels=1, out_channels=4)
158
+ self.emb_freq = nn.Linear(128, 128)
159
+
160
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
161
+ # x: (B, T, F) or (2, 110, 1024), input spectrogram
162
+ x = rearrange(x, 'b t f -> b f t') # (2, 1024, 110): B,F,T
163
+ x = F.pad(x, (self.margin_pre, self.margin_post), value=1e-7) # (2, 1024, 141): B,F,T+margin
164
+ x = rearrange(x, 'b f t -> b t f') # (2, 141, 1024): B,T+margin,F
165
+ x = self.res3b(x) # (2, 141, 128, 4): B,T+margin,F,C
166
+ x = x.unfold(dimension=1, size=32, step=1) # (B, T, F, C1, C2) or (2, 110, 128, 4, 32)
167
+ x = rearrange(x, 'b t f c1 c2 -> b t f (c1 c2)') # (B, T, F, C) or (2, 110, 128, 128)
168
+ return self.emb_freq(x) # (B, T, F, C) or (2, 110, 128, 128)
169
+
170
+
171
+ def test_res3b_hftt():
172
+ # from model.spectrogram import get_spectrogram_layer_from_audio_cfg
173
+ # from config.config import audio_cfg as default_audio_cfg
174
+ # audio_cfg = default_audio_cfg
175
+ # audio_cfg['codec'] = 'spec'
176
+ # audio_cfg['hop_length'] = 300
177
+ # x = torch.randn(2, 1, 32767)
178
+ # spec, _ = get_spectrogram_layer_from_audio_cfg(audio_cfg)
179
+ # x = spec(x) # (2, 110, 1024): B,T,F
180
+ x = torch.randn(2, 110, 1024) # (B, T, F)
181
+ pre_enc_res3b_hftt = PreEncoderBlockRes3BHFTT()
182
+ y = pre_enc_res3b_hftt(x) # (2, 110, 128, 128): B, T, F, C
183
+
184
+
185
+ # # ====================================================================================================================
186
+ # # PreEncoderBlockConv1D: Pre-encoder without activation, with Melspec input
187
+ # # ====================================================================================================================
188
+ # class PreEncoderBlockConv1D(nn.Module):
189
+
190
+ # def __init__(self,
191
+ # in_channels,
192
+ # out_channels,
193
+ # kernel_size=3) -> None:
194
+ # """Pre-Encoder with 1D convolution."""
195
+ # super().__init__()
196
+ # self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=1)
197
+ # self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, stride=1)
198
+
199
+ # def forward(self, x: torch.Tensor) -> torch.Tensor:
200
+ # # x: (B, T, F) or (2, 128, 256), input melspec
201
+ # x = rearrange(x, 'b t f -> b f t') # (2, 256, 128): B,F,T
202
+ # x = self.conv1(x) # (2, 128, 128): B,F,T
203
+ # return rearrange(x, 'b f t -> b t f') # (2, 110, 128): B,T,F
204
+
205
+ # def test_conv1d():
206
+ # # from model.spectrogram import get_spectrogram_layer_from_audio_cfg
207
+ # # from config.config import audio_cfg as default_audio_cfg
208
+ # # audio_cfg = default_audio_cfg
209
+ # # audio_cfg['codec'] = 'melspec'
210
+ # # audio_cfg['hop_length'] = 256
211
+ # # audio_cfg['n_mels'] = 512
212
+ # # x = torch.randn(2, 1, 32767)
213
+ # # mspec, _ = get_spectrogram_layer_from_audio_cfg(audio_cfg)
214
+ # # x = mspec(x)
215
+ # x = torch.randn(2, 128, 128) # (B, T, F)
216
+ # pre_enc_conv1d = PreEncoderBlockConv1D(in_channels=1, out_channels=128)
217
+ # y = pre_enc_conv1d(x) # (2, 110, 128, 128): B, T, F, C
amt/src/model/ff_layer.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The YourMT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Please see the details in the LICENSE file.
10
+ """ff_layer.py
11
+
12
+ This module contains the implementation of the feedforward layers.
13
+
14
+ Supported ff_layer_type:
15
+ 'mlp': Multi-Layer Perceptron
16
+ 'gmlp': Gated Multi-Layer Perceptron, simplified version of Mixtral Expert with num_experts=1 and top_k=1.
17
+ This is not the spatial gating MLP (https://arxiv.org/abs/2105.08050).
18
+ 'moe': Mixtral of Experts, modified from the original source code:
19
+ https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/mixtral/modeling_mixtral.py
20
+
21
+ Usage:
22
+ from model.ff_layer import get_ff_layer
23
+
24
+ config = PerceiverTFConfig() # or any type of PretrainedConfig()
25
+ config.ff_layer_type = 'moe' # or 'mlp'
26
+ config.moe_num_experts = 4
27
+ config.moe_topk = 2
28
+ config.hidden_act = 'gelu' # or any type of activation function, e.g., 'silu'
29
+
30
+ ff_layer = get_ff_layer(config, input_size, widening_factor)
31
+
32
+ What ff_layer returns:
33
+ - It returns (hidden_states, router_logits) for MoE and (hidden_states, None) for MLP.
34
+ - router_logits has the shape of (batch_size * sequence_length, n_experts) for MoE.
35
+
36
+
37
+ """
38
+ from typing import Any, Tuple
39
+ import torch
40
+ import torch.nn as nn
41
+ import torch.nn.functional as F
42
+ from transformers.configuration_utils import PretrainedConfig
43
+ from transformers.activations import ACT2FN
44
+ from model.ops import get_layer_norm
45
+ from model.ops import optional_compiler_disable, optional_compiler_dynamic
46
+
47
+
48
+ class MixtralBlockSparseTop2MLP(nn.Module):
49
+ """
50
+ The Gated Multilayer Perceptron (GMLP) used in Mixtral of Experts (MoE).
51
+
52
+ """
53
+
54
+ def __init__(self, config: PretrainedConfig, input_size: int, widening_factor: int):
55
+ super().__init__()
56
+ self.hidden_dim = input_size
57
+ self.ffn_dim = int(input_size * widening_factor)
58
+
59
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
60
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
61
+ self.gate = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
62
+ self.act_fn = ACT2FN[config.hidden_act]
63
+
64
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
65
+ current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.gate(hidden_states)
66
+ current_hidden_states = self.w2(current_hidden_states)
67
+ return current_hidden_states
68
+
69
+
70
+ class MixtralSparseMoeBlock(nn.Module):
71
+ """
72
+ This implementation is
73
+ strictly equivalent to standard MoE with full capacity (no
74
+ dropped tokens). It's faster since it formulates MoE operations
75
+ in terms of block-sparse operations to accomodate imbalanced
76
+ assignments of tokens to experts, whereas standard MoE either
77
+ (1) drop tokens at the cost of reduced performance or (2) set
78
+ capacity factor to number of experts and thus waste computation
79
+ and memory on padding.
80
+ """
81
+
82
+ def __init__(self, config, input_size: int, widening_factor: int):
83
+ super().__init__()
84
+ self.hidden_dim = input_size
85
+ self.widening_factor = widening_factor
86
+ self.num_experts = config.moe_num_experts
87
+ self.top_k = config.moe_topk
88
+
89
+ # gating
90
+ self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
91
+ self.experts = nn.ModuleList(
92
+ [MixtralBlockSparseTop2MLP(config, self.hidden_dim, self.widening_factor) for _ in range(self.num_experts)])
93
+
94
+ @optional_compiler_disable
95
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
96
+ """ """
97
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
98
+ hidden_states = hidden_states.view(-1, hidden_dim)
99
+ # router_logits: (batch * sequence_length, n_experts)
100
+ router_logits = self.gate(hidden_states)
101
+
102
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
103
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
104
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
105
+ # we cast back to the input dtype
106
+ routing_weights = routing_weights.to(hidden_states.dtype)
107
+
108
+ final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim),
109
+ dtype=hidden_states.dtype,
110
+ device=hidden_states.device)
111
+
112
+ # One hot encode the selected experts to create an expert mask
113
+ # this will be used to easily index which expert is going to be sollicitated
114
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
115
+
116
+ # Loop over all available experts in the model and perform the computation on each expert
117
+ for expert_idx in range(self.num_experts):
118
+ expert_layer = self.experts[expert_idx]
119
+ idx, top_x = torch.where(expert_mask[expert_idx])
120
+
121
+ if top_x.shape[0] == 0:
122
+ continue
123
+
124
+ # in torch it is faster to index using lists than torch tensors
125
+ top_x_list = top_x.tolist()
126
+ idx_list = idx.tolist()
127
+
128
+ # Index the correct hidden states and compute the expert hidden state for
129
+ # the current expert. We need to make sure to multiply the output hidden
130
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
131
+ current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
132
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
133
+
134
+ # However `index_add_` only support torch tensors for indexing so we'll use
135
+ # the `top_x` tensor here.
136
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
137
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
138
+ return final_hidden_states, router_logits
139
+
140
+
141
+ class MLP(nn.Module):
142
+ """A Standard Transformer-style dense module to follow attention."""
143
+
144
+ def __init__(self, config: PretrainedConfig, input_size: int, widening_factor: int):
145
+ super().__init__()
146
+ self.dense1 = nn.Linear(input_size, widening_factor * input_size)
147
+ self.dense2 = nn.Linear(widening_factor * input_size, input_size)
148
+
149
+ if isinstance(config.hidden_act, str):
150
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
151
+ else:
152
+ self.intermediate_act_fn = config.hidden_act
153
+
154
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, Any]:
155
+ hidden_states = self.dense1(hidden_states)
156
+ hidden_states = self.intermediate_act_fn(hidden_states)
157
+ hidden_states = self.dense2(hidden_states)
158
+ return hidden_states, None
159
+
160
+
161
+ class SimpleGMLP(nn.Module):
162
+ """A Simple Gated Multilayer Perceptron (aka. 'gmlp'), without the spatial gating mechanism.
163
+
164
+ Note that this is not the spatial gating MLP (https://arxiv.org/abs/2105.08050).
165
+ - A simplified MLP w/ gating mechanism adapted from Mixtral Expert, as when
166
+ the number of experts and top_k are both set to 1.)
167
+ - Added a dropout layer.
168
+ - This was also used in T5 v1.1.
169
+ """
170
+
171
+ def __init__(self, config: PretrainedConfig, input_size: int, widening_factor: int):
172
+ super().__init__()
173
+ self.hidden_dim = input_size
174
+ self.ffn_dim = int(input_size * widening_factor)
175
+
176
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
177
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
178
+ self.gate = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
179
+ self.act_fn = ACT2FN[config.hidden_act]
180
+ self.dropout1 = nn.Dropout(config.dropout_rate)
181
+ self.dropout2 = nn.Dropout(config.dropout_rate)
182
+
183
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
184
+ current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.gate(hidden_states)
185
+ current_hidden_states = self.dropout1(current_hidden_states)
186
+ current_hidden_states = self.w2(current_hidden_states)
187
+ current_hidden_states = self.dropout2(
188
+ current_hidden_states) # Residual connection is applied outside of this module.
189
+ return current_hidden_states, None
190
+
191
+
192
+ def get_ff_layer(config: PretrainedConfig, input_size: int, widening_factor: int):
193
+ if config.ff_layer_type == 'moe':
194
+ assert hasattr(config, 'moe_num_experts') and hasattr(config, 'moe_topk') and hasattr(config, 'hidden_act')
195
+ return MixtralSparseMoeBlock(config, input_size, widening_factor)
196
+ elif config.ff_layer_type == 'mlp':
197
+ assert hasattr(config, 'hidden_act')
198
+ return MLP(config, input_size, widening_factor)
199
+ elif config.ff_layer_type == 'gmlp':
200
+ assert hasattr(config, 'hidden_act')
201
+ return SimpleGMLP(config, input_size, widening_factor)
202
+ else:
203
+ raise ValueError(
204
+ f"Unsupported ff_layer_type: {config.ff_layer_type}. Supported types are 'moe', 'mlp' and 'gmlp'.")
205
+
206
+
207
+ def test_get_ff_layer():
208
+ from model.ff_layer import get_ff_layer
209
+ from model.perceiver_helper import PerceiverTFConfig
210
+ input_size = 32
211
+ widening_factor = 1
212
+
213
+ # Test for MoE
214
+ config = PerceiverTFConfig() # or any type of PretrainedConfig()
215
+ config.ff_layer_type = 'moe'
216
+ config.moe_num_experts = 4
217
+ config.moe_topk = 2
218
+ config.hidden_act = 'silu'
219
+
220
+ ff_layer = get_ff_layer(config, input_size, widening_factor)
221
+ x = torch.rand(2, 8, input_size)
222
+ hidden_states, router_logits = ff_layer(x)
223
+ print(hidden_states.shape, router_logits.shape) # (2, 8, 32), (2*8, 4)
224
+
225
+ # Test for MLP
226
+ config.ff_layer_type = 'mlp'
227
+ config.hidden_act = 'gelu'
228
+
229
+ ff_layer = get_ff_layer(config, input_size, widening_factor)
230
+ hidden_states, _ = ff_layer(x)
231
+ print(hidden_states.shape) # (2, 8, 32)
232
+
233
+ # Test for (simple)gMLP
234
+ config.ff_layer_type = 'gmlp'
235
+ config.hidden_act = 'silu'
236
+ ff_layer = get_ff_layer(config, input_size, widening_factor)
237
+ hidden_states, _ = ff_layer(x)
238
+ print(hidden_states.shape) # (2, 8, 32)
amt/src/model/init_train.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The YourMT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Please see the details in the LICENSE file.
10
+ """init_train.py"""
11
+ from typing import Tuple, Literal, Any
12
+ from copy import deepcopy
13
+ import os
14
+ import argparse
15
+ import pytorch_lightning as pl
16
+ from pytorch_lightning.loggers import WandbLogger
17
+ from pytorch_lightning.callbacks import ModelCheckpoint
18
+ from pytorch_lightning.callbacks import LearningRateMonitor
19
+ from pytorch_lightning.utilities import rank_zero_only
20
+ from config.config import shared_cfg as default_shared_cfg
21
+ from config.config import audio_cfg as default_audio_cfg
22
+ from config.config import model_cfg as default_model_cfg
23
+ from config.config import DEEPSPEED_CFG
24
+
25
+
26
+ def initialize_trainer(args: argparse.Namespace,
27
+ stage: Literal['train', 'test'] = 'train') -> Tuple[pl.Trainer, WandbLogger, dict]:
28
+ """Initialize trainer and logger"""
29
+ shared_cfg = deepcopy(default_shared_cfg)
30
+
31
+ # create save dir
32
+ os.makedirs(shared_cfg["WANDB"]["save_dir"], exist_ok=True)
33
+
34
+ # collecting specific checkpoint from exp_id with extension (@xxx where xxx is checkpoint name)
35
+ if "@" in args.exp_id:
36
+ args.exp_id, checkpoint_name = args.exp_id.split("@")
37
+ else:
38
+ checkpoint_name = "last.ckpt"
39
+
40
+ # checkpoint dir
41
+ lightning_dir = os.path.join(shared_cfg["WANDB"]["save_dir"], args.project, args.exp_id)
42
+
43
+ # create logger
44
+ if args.wandb_mode is not None:
45
+ shared_cfg["WANDB"]["mode"] = str(args.wandb_mode)
46
+ if shared_cfg["WANDB"].get("cache_dir", None) is not None:
47
+ os.environ["WANDB_CACHE_DIR"] = shared_cfg["WANDB"].get("cache_dir")
48
+ del shared_cfg["WANDB"]["cache_dir"] # remove cache_dir from shared_cfg
49
+ wandb_logger = WandbLogger(log_model="all",
50
+ project=args.project,
51
+ id=args.exp_id,
52
+ allow_val_change=True,
53
+ **shared_cfg['WANDB'])
54
+
55
+ # check if any checkpoint exists
56
+ last_ckpt_path = os.path.join(lightning_dir, "checkpoints", checkpoint_name)
57
+ if os.path.exists(os.path.join(last_ckpt_path)):
58
+ print(f'Resuming from {last_ckpt_path}')
59
+ elif stage == 'train':
60
+ print(f'No checkpoint found in {last_ckpt_path}. Starting from scratch')
61
+ last_ckpt_path = None
62
+ else:
63
+ raise ValueError(f'No checkpoint found in {last_ckpt_path}. Quit...')
64
+
65
+ # add info
66
+ dir_info = dict(lightning_dir=lightning_dir, last_ckpt_path=last_ckpt_path)
67
+
68
+ # define checkpoint callback
69
+ checkpoint_callback = ModelCheckpoint(**shared_cfg["CHECKPOINT"],)
70
+
71
+ # define lr scheduler monitor callback
72
+ lr_monitor = LearningRateMonitor(logging_interval='step')
73
+
74
+ # deepspeed strategy
75
+ if args.strategy == 'deepspeed':
76
+ strategy = pl.strategies.DeepSpeedStrategy(config=DEEPSPEED_CFG)
77
+
78
+ # validation interval
79
+ if stage == 'train' and args.val_interval is not None:
80
+ shared_cfg["TRAINER"]["check_val_every_n_epoch"] = None
81
+ shared_cfg["TRAINER"]["val_check_interval"] = int(args.val_interval)
82
+
83
+ # define trainer
84
+ sync_batchnorm = False
85
+ if stage == 'train':
86
+ # train batch size
87
+ if args.train_batch_size is not None:
88
+ train_sub_bsz = int(args.train_batch_size[0])
89
+ train_local_bsz = int(args.train_batch_size[1])
90
+ if train_local_bsz % train_sub_bsz == 0:
91
+ shared_cfg["BSZ"]["train_sub"] = train_sub_bsz
92
+ shared_cfg["BSZ"]["train_local"] = train_local_bsz
93
+ else:
94
+ raise ValueError(
95
+ f'Local batch size {train_local_bsz} must be divisible by sub batch size {train_sub_bsz}')
96
+
97
+ # ddp strategy
98
+ if args.strategy == 'ddp':
99
+ args.strategy = 'ddp_find_unused_parameters_true' # fix for conformer or pitchshifter having unused parameter issue
100
+
101
+ # sync-batchnorm
102
+ if args.sync_batchnorm is True:
103
+ sync_batchnorm = True
104
+
105
+ train_params = dict(**shared_cfg["TRAINER"],
106
+ devices=args.num_gpus if args.num_gpus == 'auto' else int(args.num_gpus),
107
+ num_nodes=int(args.num_nodes),
108
+ strategy=strategy if args.strategy == 'deepspeed' else args.strategy,
109
+ precision=args.precision,
110
+ max_epochs=args.max_epochs if stage == 'train' else None,
111
+ max_steps=args.max_steps if stage == 'train' else -1,
112
+ logger=wandb_logger,
113
+ callbacks=[checkpoint_callback, lr_monitor],
114
+ sync_batchnorm=sync_batchnorm)
115
+ trainer = pl.trainer.trainer.Trainer(**train_params)
116
+
117
+ # Update wandb logger (for DDP)
118
+ if trainer.global_rank == 0:
119
+ wandb_logger.experiment.config.update(args, allow_val_change=True)
120
+
121
+ return trainer, wandb_logger, dir_info, shared_cfg
122
+
123
+
124
+ def update_config(args, shared_cfg, stage: Literal['train', 'test'] = 'train'):
125
+ """Update audio/model/shared configurations with args"""
126
+ audio_cfg = default_audio_cfg
127
+ model_cfg = default_model_cfg
128
+
129
+ # Only update config when training
130
+ if stage == 'train':
131
+ # Augmentation parameters
132
+ if args.random_amp_range is not None:
133
+ shared_cfg["AUGMENTATION"]["train_random_amp_range"] = list(
134
+ (float(args.random_amp_range[0]), float(args.random_amp_range[1])))
135
+ if args.stem_iaug_prob is not None:
136
+ shared_cfg["AUGMENTATION"]["train_stem_iaug_prob"] = float(args.stem_iaug_prob)
137
+
138
+ if args.xaug_max_k is not None:
139
+ shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["max_k"] = int(args.xaug_max_k)
140
+ if args.xaug_tau is not None:
141
+ shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["tau"] = float(args.xaug_tau)
142
+ if args.xaug_alpha is not None:
143
+ shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["alpha"] = float(args.xaug_alpha)
144
+ if args.xaug_no_instr_overlap is not None:
145
+ shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["no_instr_overlap"] = bool(args.xaug_no_instr_overlap)
146
+ if args.xaug_no_drum_overlap is not None:
147
+ shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["no_drum_overlap"] = bool(args.xaug_no_drum_overlap)
148
+ if args.uhat_intra_stem_augment is not None:
149
+ shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["uhat_intra_stem_augment"] = bool(
150
+ args.uhat_intra_stem_augment)
151
+
152
+ if args.pitch_shift_range is not None:
153
+ if args.pitch_shift_range in [["0", "0"], [0, 0]]:
154
+ shared_cfg["AUGMENTATION"]["train_pitch_shift_range"] = None
155
+ else:
156
+ shared_cfg["AUGMENTATION"]["train_pitch_shift_range"] = list(
157
+ (int(args.pitch_shift_range[0]), int(args.pitch_shift_range[1])))
158
+
159
+ train_stem_iaug_prob = shared_cfg["AUGMENTATION"]["train_stem_iaug_prob"]
160
+ random_amp_range = shared_cfg["AUGMENTATION"]["train_random_amp_range"]
161
+ train_stem_xaug_policy = shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]
162
+ print(f'Random amp range: {random_amp_range}\n' +
163
+ f'Intra-stem augmentation probability: {train_stem_iaug_prob}\n' +
164
+ f'Stem augmentation policy: {train_stem_xaug_policy}\n' +
165
+ f'Pitch shift range: {shared_cfg["AUGMENTATION"]["train_pitch_shift_range"]}\n')
166
+
167
+ # Update audio config
168
+ if args.audio_codec != None:
169
+ assert args.audio_codec in ['spec', 'melspec']
170
+ audio_cfg["codec"] = str(args.audio_codec)
171
+ if args.hop_length != None:
172
+ audio_cfg["hop_length"] = int(args.hop_length)
173
+ if args.n_mels != None:
174
+ audio_cfg["n_mels"] = int(args.n_mels)
175
+ if args.input_frames != None:
176
+ audio_cfg["input_frames"] = int(args.input_frames)
177
+
178
+ # Update shared config
179
+ if shared_cfg["TOKENIZER"]["max_shift_steps"] == "auto":
180
+ shift_steps_ms = shared_cfg["TOKENIZER"]["shift_step_ms"]
181
+ input_frames = audio_cfg["input_frames"]
182
+ fs = audio_cfg["sample_rate"]
183
+ max_shift_steps = (input_frames / fs) // (shift_steps_ms / 1000) + 2 # 206 by default
184
+ shared_cfg["TOKENIZER"]["max_shift_steps"] = int(max_shift_steps)
185
+
186
+ # Update model config
187
+ if args.encoder_type != None:
188
+ model_cfg["encoder_type"] = str(args.encoder_type)
189
+ if args.decoder_type != None:
190
+ model_cfg["decoder_type"] = str(args.decoder_type)
191
+ if args.pre_encoder_type != "default":
192
+ model_cfg["pre_encoder_type"] = str(args.pre_encoder_type)
193
+ if args.pre_decoder_type != 'default':
194
+ model_cfg["pre_decoder_type"] = str(args.pre_decoder_type)
195
+ if args.conv_out_channels != None:
196
+ model_cfg["conv_out_channels"] = int(args.conv_out_channels)
197
+ assert isinstance(args.task_cond_decoder, bool) and isinstance(args.task_cond_encoder, bool)
198
+ model_cfg["use_task_conditional_encoder"] = args.task_cond_encoder
199
+ model_cfg["use_task_conditional_decoder"] = args.task_cond_decoder
200
+
201
+ if args.encoder_position_encoding_type != 'default':
202
+ if args.encoder_position_encoding_type in ['None', 'none', '0']:
203
+ model_cfg["encoder"][model_cfg["encoder_type"]]["position_encoding_type"] = None
204
+ elif args.encoder_position_encoding_type in [
205
+ 'sinusoidal', 'rope', 'trainable', 'alibi', 'alibit', 'tkd', 'td', 'tk', 'kdt'
206
+ ]:
207
+ model_cfg["encoder"][model_cfg["encoder_type"]]["position_encoding_type"] = str(
208
+ args.encoder_position_encoding_type)
209
+ else:
210
+ raise ValueError(f'Encoder PE type {args.encoder_position_encoding_type} not supported')
211
+ if args.decoder_position_encoding_type != 'default':
212
+ if args.decoder_position_encoding_type in ['None', 'none', '0']:
213
+ raise ValueError('Decoder PE type cannot be None')
214
+ elif args.decoder_position_encoding_type in ['sinusoidal', 'trainable']:
215
+ model_cfg["decoder"][model_cfg["decoder_type"]]["position_encoding_type"] = str(
216
+ args.decoder_position_encoding_type)
217
+ else:
218
+ raise ValueError(f'Decoder PE {args.decoder_position_encoding_type} not supported')
219
+
220
+ if args.tie_word_embedding is not None:
221
+ model_cfg["tie_word_embedding"] = bool(args.tie_word_embedding)
222
+
223
+ if args.d_feat != None:
224
+ model_cfg["d_feat"] = int(args.d_feat)
225
+ if args.d_latent != None:
226
+ model_cfg['encoder']['perceiver-tf']["d_latent"] = int(args.d_latent)
227
+ if args.num_latents != None:
228
+ model_cfg['encoder']['perceiver-tf']['num_latents'] = int(args.num_latents)
229
+ if args.perceiver_tf_d_model != None:
230
+ model_cfg['encoder']['perceiver-tf']['d_model'] = int(args.perceiver_tf_d_model)
231
+ if args.num_perceiver_tf_blocks != None:
232
+ model_cfg["encoder"]["perceiver-tf"]["num_blocks"] = int(args.num_perceiver_tf_blocks)
233
+ if args.num_perceiver_tf_local_transformers_per_block != None:
234
+ model_cfg["encoder"]["perceiver-tf"]["num_local_transformers_per_block"] = int(
235
+ args.num_perceiver_tf_local_transformers_per_block)
236
+ if args.num_perceiver_tf_temporal_transformers_per_block != None:
237
+ model_cfg["encoder"]["perceiver-tf"]["num_temporal_transformers_per_block"] = int(
238
+ args.num_perceiver_tf_temporal_transformers_per_block)
239
+ if args.attention_to_channel != None:
240
+ model_cfg["encoder"]["perceiver-tf"]["attention_to_channel"] = bool(args.attention_to_channel)
241
+ if args.sca_use_query_residual != None:
242
+ model_cfg["encoder"]["perceiver-tf"]["sca_use_query_residual"] = bool(args.sca_use_query_residual)
243
+ if args.layer_norm_type != None:
244
+ model_cfg["encoder"]["perceiver-tf"]["layer_norm"] = str(args.layer_norm_type)
245
+ if args.ff_layer_type != None:
246
+ model_cfg["encoder"]["perceiver-tf"]["ff_layer_type"] = str(args.ff_layer_type)
247
+ if args.ff_widening_factor != None:
248
+ model_cfg["encoder"]["perceiver-tf"]["ff_widening_factor"] = int(args.ff_widening_factor)
249
+ if args.moe_num_experts != None:
250
+ model_cfg["encoder"]["perceiver-tf"]["moe_num_experts"] = int(args.moe_num_experts)
251
+ if args.moe_topk != None:
252
+ model_cfg["encoder"]["perceiver-tf"]["moe_topk"] = int(args.moe_topk)
253
+ if args.hidden_act != None:
254
+ model_cfg["encoder"]["perceiver-tf"]["hidden_act"] = str(args.hidden_act)
255
+ if args.rotary_type != None:
256
+ assert len(
257
+ args.rotary_type
258
+ ) == 3, "rotary_type must be a 3-letter string (e.g. 'ppl': 'pixel' for SCA, 'pixel' for latent, 'lang' for temporal transformer)"
259
+ model_cfg["encoder"]["perceiver-tf"]["rotary_type_sca"] = str(args.rotary_type)[0]
260
+ model_cfg["encoder"]["perceiver-tf"]["rotary_type_latent"] = str(args.rotary_type)[1]
261
+ model_cfg["encoder"]["perceiver-tf"]["rotary_type_temporal"] = str(args.rotary_type)[2]
262
+ if args.rope_apply_to_keys != None:
263
+ model_cfg["encoder"]["perceiver-tf"]["rope_apply_to_keys"] = bool(args.rope_apply_to_keys)
264
+ if args.rope_partial_pe != None:
265
+ model_cfg["encoder"]["perceiver-tf"]["rope_partial_pe"] = bool(args.rope_partial_pe)
266
+
267
+ if args.decoder_ff_layer_type != None:
268
+ model_cfg["decoder"][model_cfg["decoder_type"]]["ff_layer_type"] = str(args.decoder_ff_layer_type)
269
+ if args.decoder_ff_widening_factor != None:
270
+ model_cfg["decoder"][model_cfg["decoder_type"]]["ff_widening_factor"] = int(args.decoder_ff_widening_factor)
271
+
272
+ if args.event_length != None:
273
+ model_cfg["event_length"] = int(args.event_length)
274
+
275
+ if stage == 'train':
276
+ if args.encoder_dropout_rate != None:
277
+ model_cfg["encoder"][model_cfg["encoder_type"]]["dropout_rate"] = float(args.encoder_dropout_rate)
278
+ if args.decoder_dropout_rate != None:
279
+ model_cfg["decoder"][model_cfg["decoder_type"]]["dropout_rate"] = float(args.decoder_dropout_rate)
280
+
281
+ return shared_cfg, audio_cfg, model_cfg # return updated configs
amt/src/model/lm_head.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The YourMT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Please see the details in the LICENSE file.
10
+ """lm_head.py"""
11
+ import torch
12
+ from torch import nn
13
+ from typing import Optional, Dict
14
+
15
+
16
+ class LMHead(nn.Module):
17
+ """Language Model Head with tied weights."""
18
+
19
+ def __init__(self, decoder_config: Dict, init_factor: float = 1.0, tie_word_embeddings: bool = True):
20
+
21
+ super().__init__()
22
+ self.d_model = decoder_config["d_model"]
23
+ self.init_factor = init_factor
24
+ self.tie_word_embeddings = tie_word_embeddings
25
+
26
+ self.lm_head = nn.Linear(decoder_config["d_model"], decoder_config["vocab_size"], bias=False)
27
+ self._init_weights()
28
+
29
+ def _init_weights(self):
30
+ if self.tie_word_embeddings is False:
31
+ self.lm_head.weight.data.normal_(mean=0.0, std=self.init_factor * 1.0)
32
+
33
+ def forward(self, decoder_hs: torch.FloatTensor) -> torch.FloatTensor:
34
+ if self.tie_word_embeddings is True:
35
+ # Rescale output before projecting on vocab
36
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
37
+ decoder_hs = decoder_hs * (self.d_model**-0.5)
38
+
39
+ lm_logits = self.lm_head(decoder_hs)
40
+ return lm_logits
amt/src/model/lr_scheduler.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The YourMT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Please see the details in the LICENSE file.
10
+ """lr_schedule.py"""
11
+ import torch
12
+ from typing import Dict, Optional
13
+
14
+
15
+ def get_lr_scheduler(optimizer: torch.optim.Optimizer, scheduler_name: str, base_lr: float, scheduler_cfg: Dict):
16
+
17
+ if scheduler_name.lower() == 'cosine':
18
+ from torch.optim.lr_scheduler import (
19
+ SequentialLR,
20
+ LinearLR,
21
+ CosineAnnealingLR,
22
+ )
23
+
24
+ scheduler1 = LinearLR(
25
+ optimizer,
26
+ start_factor=0.5,
27
+ end_factor=1,
28
+ total_iters=scheduler_cfg["warmup_steps"],
29
+ last_epoch=-1,
30
+ )
31
+
32
+ scheduler2 = CosineAnnealingLR(
33
+ optimizer,
34
+ T_max=scheduler_cfg["total_steps"] - scheduler_cfg["warmup_steps"],
35
+ eta_min=scheduler_cfg["final_cosine"],
36
+ )
37
+
38
+ lr_scheduler = SequentialLR(optimizer,
39
+ schedulers=[scheduler1, scheduler2],
40
+ milestones=[scheduler_cfg["warmup_steps"]])
41
+ elif scheduler_name.lower() == 'legacy':
42
+ import math
43
+ from torch.optim.lr_scheduler import (
44
+ SequentialLR,
45
+ LinearLR,
46
+ LambdaLR,
47
+ )
48
+
49
+ msg = "You are using T5 legacy LR Schedule, it's independent from the optim.base_lr"
50
+ print(msg)
51
+
52
+ num_steps_optimizer1 = math.ceil(scheduler_cfg["total_steps"] * 0.9)
53
+ iters_left_for_optimizer2 = scheduler_cfg["total_steps"] - num_steps_optimizer1
54
+
55
+ scheduler1 = LambdaLR(optimizer, lambda step: min(base_lr, 1.0 / math.sqrt(step)) / base_lr
56
+ if step else base_lr / base_lr)
57
+
58
+ scheduler2 = LinearLR(optimizer,
59
+ start_factor=(min(base_lr, 1.0 / math.sqrt(num_steps_optimizer1)) / base_lr),
60
+ end_factor=0,
61
+ total_iters=iters_left_for_optimizer2,
62
+ last_epoch=-1)
63
+
64
+ lr_scheduler = SequentialLR(
65
+ optimizer,
66
+ schedulers=[scheduler1, scheduler2],
67
+ milestones=[num_steps_optimizer1],
68
+ )
69
+ elif scheduler_name.lower() == 'constant':
70
+ from transformers import get_scheduler
71
+ lr_scheduler = get_scheduler(
72
+ name=scheduler_name.lower(),
73
+ optimizer=optimizer,
74
+ )
75
+ else:
76
+ raise NotImplementedError
77
+
78
+ return lr_scheduler
79
+
80
+
81
+ def extra_stats(args, model, optimizer):
82
+ stats = {}
83
+
84
+ if args.logging.weights_l2:
85
+ weights_l2 = sum(p.detach().norm(2).item()**2 for p in model.parameters())**0.5
86
+ stats['weights_l2'] = weights_l2
87
+
88
+ cur_lr = optimizer.param_groups[0]['lr']
89
+ stats['lr'] = cur_lr
90
+
91
+ return stats
amt/src/model/ops.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The YourMT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Please see the details in the LICENSE file.
10
+ """ op.py """
11
+ import math
12
+ from packaging.version import parse as VersionParse
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from einops import rearrange
18
+ from transformers.models.t5.modeling_t5 import T5LayerNorm as RMSNorm
19
+
20
+
21
+ def get_layer_norm(dim: int, layer_norm_type: str = "layer_norm", layer_norm_eps: float = 1e-5):
22
+ """Get layer normalization layer.
23
+ Args:
24
+ dim (int): Feature dimension
25
+ layer_norm_type (str): "layer_norm" or "rms_norm"
26
+ layer_norm_eps (float): Epsilon value for numerical stability
27
+
28
+ Returns:
29
+ nn.Module: Layer normalization layer
30
+ """
31
+ if layer_norm_type == "rms_norm":
32
+ # T5LayerNorm is equivalent to RMSNorm. https://arxiv.org/abs/1910.07467
33
+ return RMSNorm(hidden_size=dim, eps=layer_norm_eps)
34
+ else:
35
+ return nn.LayerNorm(normalized_shape=dim, eps=layer_norm_eps)
36
+
37
+
38
+ def check_all_elements_equal(x: torch.Tensor) -> bool:
39
+ return x.eq(x[0]).all().item()
40
+
41
+
42
+ def minmax_normalize(x: torch.Tensor, eps: float = 0.008) -> torch.FloatTensor:
43
+ """Min-max normalization:
44
+
45
+ x_norm = (x - x_min) / (x_max - x_min + eps)
46
+
47
+ Args:
48
+ x (torch.Tensor): (B, T, F)
49
+ Returns:
50
+ torch.Tensor: (B, T, F) with output range of [0, 1]
51
+ """
52
+ x_max = rearrange(x, "b t f -> b (t f)").max(1, keepdim=True)[0]
53
+ x_min = rearrange(x, "b t f -> b (f t)").min(1, keepdim=True)[0]
54
+ x_max = x_max[:, None, :] # (B,1,1)
55
+ x_min = x_min[:, None, :] # (B,1,1)
56
+ return (x - x_min) / (x_max - x_min + eps)
57
+
58
+
59
+ def count_parameters(model):
60
+ num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
61
+ num_params = sum(p.numel() for p in model.parameters())
62
+ return num_trainable_params, num_params
63
+
64
+
65
+ def adjust_b_to_gcd(a, b, min_gcd=16):
66
+ """
67
+ Adjust the value of b to ensure the GCD(a, b) is at least min_gcd with minimum change to b.
68
+
69
+ Parameters:
70
+ - a (int): A positive integer
71
+ - b (int): A positive integer
72
+ - min_gcd (int): The minimum desired GCD
73
+
74
+ Returns:
75
+ - int: The adjusted value of b
76
+ """
77
+ current_gcd = math.gcd(a, b)
78
+
79
+ # If current GCD is already greater than or equal to min_gcd, return b as it is.
80
+ if current_gcd >= min_gcd:
81
+ return b
82
+
83
+ # If a is less than min_gcd, then it's impossible to get a GCD of at least min_gcd.
84
+ if a < min_gcd:
85
+ raise ValueError("a must be at least as large as min_gcd.")
86
+
87
+ # Adjust b by trying increments and decrements, preferring the smallest absolute change.
88
+ adjusted_b_up = b
89
+ adjusted_b_down = b
90
+
91
+ while True:
92
+ adjusted_b_up += 1
93
+ adjusted_b_down -= 1
94
+
95
+ if math.gcd(a, adjusted_b_up) >= min_gcd:
96
+ return adjusted_b_up
97
+ elif math.gcd(a, adjusted_b_down) >= min_gcd:
98
+ return adjusted_b_down
99
+
100
+
101
+ def optional_compiler_disable(func):
102
+ if VersionParse(torch.__version__) >= VersionParse("2.1"):
103
+ # If the version is 2.1 or higher, apply the torch.compiler.disable decorator.
104
+ return torch.compiler.disable(func)
105
+ else:
106
+ # If the version is below 2.1, return the original function.
107
+ return func
108
+
109
+
110
+ def optional_compiler_dynamic(func):
111
+ return torch.compile(func, dynamic=True)
amt/src/model/optimizers.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ optimizers.py
2
+
3
+ Code based on nanoT5 project:
4
+ https://github.com/PiotrNawrot/nanoT5/blob/main/nanoT5/utils/copied_utils.py
5
+
6
+ + D-adapt Adam from https://github.com/facebookresearch/dadaptation
7
+ """
8
+ import importlib
9
+ import math
10
+ import torch
11
+
12
+ from typing import Iterable, Tuple
13
+ from torch import nn
14
+ from torch.optim import Optimizer
15
+ from transformers import Adafactor
16
+ from torch.optim import AdamW
17
+
18
+
19
+ class AdamWScale(Optimizer):
20
+ """
21
+ This AdamW implementation is copied from Huggingface.
22
+ We modified it with Adagrad scaling by rms of a weight tensor
23
+
24
+ Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
25
+ Regularization](https://arxiv.org/abs/1711.05101).
26
+
27
+ Parameters:
28
+ params (`Iterable[nn.parameter.Parameter]`):
29
+ Iterable of parameters to optimize or dictionaries defining parameter groups.
30
+ lr (`float`, *optional*, defaults to 1e-3):
31
+ The learning rate to use.
32
+ betas (`Tuple[float,float]`, *optional*, defaults to (0.9, 0.999)):
33
+ Adam's betas parameters (b1, b2).
34
+ eps (`float`, *optional*, defaults to 1e-6):
35
+ Adam's epsilon for numerical stability.
36
+ weight_decay (`float`, *optional*, defaults to 0):
37
+ Decoupled weight decay to apply.
38
+ correct_bias (`bool`, *optional*, defaults to `True`):
39
+ Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
40
+ no_deprecation_warning (`bool`, *optional*, defaults to `False`):
41
+ A flag used to disable the deprecation warning (set to `True` to disable the warning).
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ params: Iterable[nn.parameter.Parameter],
47
+ lr: float = 1e-3,
48
+ betas: Tuple[float, float] = (0.9, 0.999),
49
+ eps: float = 1e-6,
50
+ weight_decay: float = 0.0,
51
+ correct_bias: bool = True,
52
+ ):
53
+ if lr < 0.0:
54
+ raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
55
+ if not 0.0 <= betas[0] < 1.0:
56
+ raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
57
+ if not 0.0 <= betas[1] < 1.0:
58
+ raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
59
+ if not 0.0 <= eps:
60
+ raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
61
+ defaults = dict(
62
+ lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
63
+ super().__init__(params, defaults)
64
+
65
+ @staticmethod
66
+ def _rms(tensor):
67
+ return tensor.norm(2) / (tensor.numel()**0.5)
68
+
69
+ def step(self, closure=None):
70
+ """
71
+ Performs a single optimization step.
72
+
73
+ Arguments:
74
+ closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
75
+ """
76
+ loss = None
77
+ if closure is not None:
78
+ loss = closure()
79
+
80
+ for group in self.param_groups:
81
+ for p in group["params"]:
82
+ if p.grad is None:
83
+ continue
84
+ grad = p.grad.data
85
+ if grad.is_sparse:
86
+ raise RuntimeError(
87
+ "Adam does not support sparse gradients, please consider SparseAdam instead"
88
+ )
89
+
90
+ state = self.state[p]
91
+ beta1, beta2 = group["betas"]
92
+
93
+ # State initialization
94
+ if len(state) == 0:
95
+ state["step"] = 0
96
+ # Exponential moving average of gradient values
97
+ state["exp_avg"] = torch.zeros_like(p.data)
98
+ # Exponential moving average of squared gradient values
99
+ state["exp_avg_sq"] = torch.zeros_like(p.data)
100
+
101
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
102
+
103
+ state["step"] += 1
104
+
105
+ # Decay the first and second moment running average coefficient
106
+ # In-place operations to update the averages at the same time
107
+ exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
108
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
109
+ denom = exp_avg_sq.sqrt().add_(group["eps"])
110
+
111
+ step_size = group["lr"]
112
+ if group["correct_bias"]: # No bias correction for Bert
113
+ bias_correction1 = 1.0 - beta1**state["step"]
114
+ bias_correction2 = 1.0 - beta2**state["step"]
115
+ step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
116
+
117
+ # /Adapt Step from Adagrad
118
+ step_size = step_size * max(1e-3, self._rms(p.data))
119
+ # /Adapt Step from Adagrad
120
+
121
+ p.data.addcdiv_(exp_avg, denom, value=-step_size)
122
+
123
+ # Just adding the square of the weights to the loss function is *not*
124
+ # the correct way of using L2 regularization/weight decay with Adam,
125
+ # since that will interact with the m and v parameters in strange ways.
126
+ #
127
+ # Instead we want to decay the weights in a manner that doesn't interact
128
+ # with the m/v parameters. This is equivalent to adding the square
129
+ # of the weights to the loss with plain (non-momentum) SGD.
130
+ # Add weight decay at the end (fixed version)
131
+ if group["weight_decay"] > 0.0:
132
+ p.data.add_(p.data, alpha=(-group["lr"] * group["weight_decay"]))
133
+
134
+ return loss
135
+
136
+
137
+ # def get_optimizer(models_dict: nn.ModuleDict,
138
+ # optimizer_name: str,
139
+ # base_lr: float,
140
+ # weight_decay: float = 0.):
141
+
142
+ # no_decay = [
143
+ # "bias", "LayerNorm", "layernorm", "layer_norm", "ln", "BatchNorm", "bn", "batch_norm",
144
+ # "batchnorm"
145
+ # ]
146
+
147
+
148
+ # optimizer_grouped_parameters = []
149
+ # for name, current_model in models_dict.items():
150
+ # if current_model is None:
151
+ # continue
152
+ # optimizer_grouped_parameters += [
153
+ # {
154
+ # "params": [
155
+ # p for n, p in current_model.named_parameters()
156
+ # if not any(nd in n for nd in no_decay)
157
+ # ],
158
+ # "weight_decay": weight_decay,
159
+ # },
160
+ # {
161
+ # "params": [
162
+ # p for n, p in current_model.named_parameters()
163
+ # if any(nd in n for nd in no_decay)
164
+ # ],
165
+ # "weight_decay": 0.0,
166
+ # },
167
+ # ]
168
+ def get_optimizer(models_dict: nn.ModuleDict,
169
+ optimizer_name: str,
170
+ base_lr: float,
171
+ weight_decay: float = 0.):
172
+
173
+ no_decay = [
174
+ "bias", "LayerNorm", "layernorm", "layer_norm", "ln", "BatchNorm", "bn", "batch_norm",
175
+ "batchnorm"
176
+ ]
177
+ optimizer_grouped_parameters = []
178
+ for n, p in models_dict:
179
+ # drop pitch shifter
180
+ if 'pshifters' in n:
181
+ continue
182
+ # no decay
183
+ if n in no_decay:
184
+ optimizer_grouped_parameters.append({"params": [p], "weight_decay": 0.0})
185
+ else:
186
+ optimizer_grouped_parameters.append({"params": [p], "weight_decay": weight_decay})
187
+
188
+ if optimizer_name.lower() == 'adamw':
189
+ base_lr = 1e-03 if base_lr == None else float(base_lr)
190
+ opt = AdamW(optimizer_grouped_parameters, lr=base_lr)
191
+ elif optimizer_name.lower() == 'adafactor':
192
+ if base_lr == None:
193
+ opt = Adafactor(
194
+ optimizer_grouped_parameters,
195
+ lr=None,
196
+ scale_parameter=True,
197
+ relative_step=True,
198
+ warmup_init=True)
199
+ else:
200
+ opt = Adafactor(optimizer_grouped_parameters, lr=base_lr, relative_step=False)
201
+ elif optimizer_name.lower() == 'adamwscale':
202
+ base_lr = 1e-02 if base_lr == None else float(base_lr)
203
+ opt = AdamWScale(
204
+ optimizer_grouped_parameters,
205
+ lr=base_lr,
206
+ )
207
+ elif optimizer_name.lower() == 'cpuadam':
208
+ dspd = importlib.import_module('deepspeed')
209
+ base_lr = 1e-03 if base_lr == None else float(base_lr)
210
+ opt = dspd.ops.adam.cpu_adam.DeepSpeedCPUAdam(optimizer_grouped_parameters, lr=base_lr)
211
+ elif optimizer_name.lower() == 'dadaptadam':
212
+ dadaptation = importlib.import_module('dadaptation')
213
+ base_lr = 1.0 if base_lr == None else float(base_lr)
214
+ opt = dadaptation.DAdaptAdam(optimizer_grouped_parameters, lr=base_lr)
215
+ else:
216
+ raise NotImplementedError(optimizer_name)
217
+
218
+ return opt, base_lr
amt/src/model/perceiver_helper.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The YourMT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Please see the details in the LICENSE file.
10
+ from dataclasses import dataclass
11
+ from typing import Optional, Tuple
12
+ import torch
13
+ from torch import nn
14
+ from transformers.utils import ModelOutput
15
+ from transformers.configuration_utils import PretrainedConfig
16
+ from transformers.modeling_utils import PreTrainedModel
17
+ # from transformers.models.perceiver.modeling_perceiver import (PerceiverAbstractPositionEncoding,
18
+ # PerceiverTrainablePositionEncoding,
19
+ # PerceiverFourierPositionEncoding)
20
+
21
+
22
+ class PerceiverTFConfig(PretrainedConfig):
23
+ r"""
24
+ This is the configuration class to store the configuration of a [`PerceiverTF`]. It is used to instantiate an
25
+ Perceiver model according to the specified arguments, defining the model architecture. Instantiating a
26
+ configuration with the defaults will yield a similar configuration to that of the Perceiver
27
+ [deepmind/language-perceiver](https://huggingface.co/deepmind/language-perceiver) architecture.
28
+
29
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
30
+ documentation from [`PretrainedConfig`] for more information.
31
+
32
+ Args:
33
+ num_latents (`int`, *optional*, defaults to 256):
34
+ The number of latents.
35
+ d_latents (`int`, *optional*, defaults to 1280):
36
+ Dimension of the latent embeddings.
37
+ d_model (`int`, *optional*, defaults to 768):
38
+ Dimension of the inputs. Should only be provided in case [*PerceiverTextPreprocessor*] is used or no
39
+ preprocessor is provided.
40
+ kv_dim (`int`, *optional*, defaults to 128):
41
+ num_blocks (`int`, *optional*, defaults to 1):
42
+ Number of blocks in the Transformer encoder.
43
+ num_self_attention_heads (`int`, *optional*, defaults to 8):
44
+ Number of attention heads for each self-attention layer in the Transformer encoder.
45
+ num_cross_attention_heads (`int`, *optional*, defaults to 8):
46
+ Number of attention heads for each cross-attention layer in the Transformer encoder.
47
+ num_local_transformers_per_block (`int`, *optional*, defaults to 2):
48
+ Number of local Transformer layers per Transformer block in the Transformer encoder.
49
+ num_temporal_transformers_per_block (`int`, *optional*, defaults to 2):
50
+ Number of temporal Transformer layers per Transformer block in the Transformer encoder.
51
+ shared_parallel_temporal_transformers (`bool`, *optional*, defaults to `False`):
52
+ Whether to share the parameters across the K parallel temporal Transformers in each block.
53
+ qk_channels (`int`, *optional*):
54
+ Dimension to project the queries + keys before applying attention in the cross-attention and self-attention
55
+ layers of the encoder. Will default to preserving the dimension of the queries if not specified.
56
+ v_channels (`int`, *optional*):
57
+ Dimension to project the values before applying attention in the cross-attention and self-attention layers
58
+ of the encoder. Will default to preserving the dimension of the queries if not specified.
59
+ ** DEPRECATED ** cross_attention_shape_for_attention (`str`, *optional*, defaults to `'kv'`):
60
+ Dimension to use when downsampling the queries and keys in the cross-attention layer of the encoder.
61
+ ** DEPRECATED ** self_attention_widening_factor (`int`, *optional*, defaults to 1):
62
+ Dimension of the feed-forward layer in the cross-attention layer of the Transformer encoder.
63
+ cross_attention_widening_factor (`int`, *optional*, defaults to 1):
64
+ Dimension of the feed-forward layer in the self-attention layers of the Transformer encoder.
65
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
66
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
67
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
68
+ dropout_rate (`float`, *optional*, defaults to 0.1):
69
+ The dropout ratio for the attention probabilities.
70
+ initializer_range (`float`, *optional*, defaults to 0.02):
71
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
72
+ layer_norm_type (`str`, *optional*, defaults to `'layer_norm'`):
73
+ The type of layer normalization to use. Can be one of {'layer_norm', 'rms_norm'}.
74
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
75
+ The epsilon used by the layer normalization layers.
76
+ sca_use_query_residual (`bool`, *optional*, defaults to `True`):
77
+ Whether to add a query residual in the spectral cross attention (SCA) layer of the encoder.
78
+ use_query_residual (`float`, *optional*, defaults to `True`):
79
+ Whether to add a query residual in the cross-attention layer of the encoder.
80
+ position_encoding_type (`str`, *optional*, defaults to `'trainable'`):
81
+ Type of position encoding to use. Can be one of {'trainable', 'alibi', 'alibit', 'rope', None}.
82
+ num_max_positions (`int`, *optional*, defaults to 331):
83
+ Maximum number of positions to use for the position encoding.
84
+ vocab_size (`int`, *optional*, defaults to 262):
85
+ Vocabulary size for the masked language modeling model.
86
+ attention_to_channel (`bool`, defaults to `False`):
87
+ Whether SCA should attend to the channel dimension. If False, attention to frequency bin dimension.
88
+ ff_layer_type (`str`, *optional*, defaults to `'mlp'`):
89
+ Type of feed-forward layer to use. Can be one of {'mlp', 'moe'}.
90
+ ff_widening_factor (`int`, *optional*, defaults to 1):
91
+ Widening factor for the feed-forward layers in the MLP/MoE.
92
+ moe_num_experts (`int`, *optional*, defaults to 4):
93
+ Number of experts to use in the mixture of experts (MoE) feed-forward layer.
94
+ Only used if `ff_layer_type` is set to `'moe'`.
95
+ moe_topk (`int`, *optional*, defaults to 2):
96
+ Number of top experts to use in the mixture of experts (MoE) feed-forward layer.
97
+ Only used if `ff_layer_type` is set to `'moe'`.
98
+ rope_type_sca (`str`, *optional*, defaults to `pixel`): Can be one of {'l'|lang', 'p'|'pixel', None}.
99
+ RoPE index type for SCA. Only used if `position_encoding_type` is set to `rope`.
100
+ rope_type_latent (`str`, *optional*, defaults to `pixel`): Can be one of {'l'|'lang', 'p'|'pixel', None}.
101
+ RoPE index type for Latent Transformer. Only used if `position_encoding_type` is set to `'rope'`.
102
+ rope_type_temporal (`str`, *optional*, defaults to `lang`): Can be one of {'l'|'lang', 'p'|'pixel', None}.
103
+ RoPE index type for Temporal Transformer. Only used if `position_encoding_type` is set to `'rope'`.
104
+ rope_apply_to_keys (`bool`, *optional*, defaults to `False`): Whether to apply RoPE to the keys in the
105
+ self/cross-attention layers. Only used if `position_encoding_type` is set to `'rope'`.
106
+ rope_partial_pe (`bool`, *optional*, defaults to `False`): Whether to use partial RoPE in the self/cross-attention.
107
+ Only used if `position_encoding_type` is set to `'rope'`.
108
+ rope_trainable (`bool`, *optional*, defaults to `False`): Whether to make the RoPE trainable. Only used if
109
+
110
+ Example:
111
+
112
+ ```python
113
+ >>> from model.perceiver_mod import PerceiverTFEncodel, PerceiverTFConfig
114
+
115
+ >>> # Initializing a Perceiver deepmind/language-perceiver style configuration
116
+ >>> configuration = PerceiverTFConfig()
117
+
118
+ >>> # Initializing a model from the deepmind/language-perceiver style configuration
119
+ >>> model = PerceiverTFEncoder(configuration)
120
+
121
+ >>> # Accessing the model configuration
122
+ >>> configuration = model.config
123
+ ```"""
124
+ model_type = "perceivertf"
125
+
126
+ def __init__(
127
+ self,
128
+ num_latents=24,
129
+ d_latents=128,
130
+ d_model=128,
131
+ kv_dim=128,
132
+ num_blocks=3,
133
+ num_self_attention_heads=8,
134
+ num_cross_attention_heads=8,
135
+ num_local_transformers_per_block=2,
136
+ num_temporal_transformers_per_block=2,
137
+ qk_channels=128,
138
+ v_channels=128,
139
+ cross_attention_shape_for_attention="q",
140
+ # self_attention_widening_factor=1, ** DEPRECATED **
141
+ # cross_attention_widening_factor=1, ** DEPRECATED **
142
+ hidden_act="gelu",
143
+ dropout_rate=0.1,
144
+ initializer_range=0.02,
145
+ layer_norm_type="layer_norm",
146
+ layer_norm_eps=1e-5,
147
+ sca_use_query_residual=True,
148
+ use_query_residual=True,
149
+ position_encoding_type="trainable",
150
+ num_max_positions=330,
151
+ vocab_size=1391,
152
+ attention_to_channel=False,
153
+ ff_layer_type="mlp",
154
+ ff_widening_factor=1,
155
+ moe_num_experts=4,
156
+ moe_topk=2,
157
+ rope_type_sca="pixel",
158
+ rope_type_latent="pixel",
159
+ rope_type_temporal="lang",
160
+ rope_apply_to_keys=False,
161
+ rope_partial_pe=False,
162
+ rope_trainable=False,
163
+ **kwargs,
164
+ ):
165
+ super().__init__(**kwargs)
166
+
167
+ self.num_latents = num_latents
168
+ self.d_latents = d_latents
169
+ self.d_model = d_model
170
+ self.kv_dim = kv_dim
171
+ self.qk_channels = qk_channels
172
+ self.v_channels = v_channels
173
+
174
+ self.num_blocks = num_blocks
175
+ self.num_self_attention_heads = num_self_attention_heads
176
+ self.num_cross_attention_heads = num_cross_attention_heads
177
+ self.num_local_transformers_per_block = num_local_transformers_per_block
178
+ self.num_temporal_transformers_per_block = num_temporal_transformers_per_block
179
+ self.sca_use_query_residual = sca_use_query_residual
180
+ self.use_query_residual = use_query_residual
181
+ self.position_encoding_type = position_encoding_type
182
+ self.num_max_positions = num_max_positions
183
+ # self.self_attention_widening_factor = self_attention_widening_factor
184
+ # self.cross_attention_widening_factor = cross_attention_widening_factor
185
+ self.cross_attention_shape_for_attention = cross_attention_shape_for_attention
186
+ self.attention_to_channel = attention_to_channel
187
+ self.ff_layer_type = ff_layer_type
188
+ self.ff_widening_factor = ff_widening_factor
189
+ self.moe_num_experts = moe_num_experts
190
+ self.moe_topk = moe_topk
191
+ self.rope_type_sca = rope_type_sca
192
+ self.rope_type_latent = rope_type_latent
193
+ self.rope_type_temporal = rope_type_temporal
194
+ self.rope_apply_to_keys = rope_apply_to_keys
195
+ self.rope_partial_pe = rope_partial_pe
196
+ self.rope_trainable = rope_trainable
197
+
198
+ self.hidden_act = hidden_act
199
+ self.dropout_rate = dropout_rate
200
+ self.initializer_range = initializer_range
201
+ self.layer_norm_type = layer_norm_type
202
+ self.layer_norm_eps = layer_norm_eps
203
+
204
+ # masked language modeling attributes
205
+ self.vocab_size = vocab_size
206
+
207
+
208
+ class PerceiverTFPreTrainedModel(PreTrainedModel):
209
+ """
210
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
211
+ models.
212
+ """
213
+
214
+ config_class = PerceiverTFConfig
215
+ base_model_prefix = "perceivertf"
216
+ main_input_name = "inputs"
217
+
218
+ def _init_weights(self, module):
219
+ """Initialize the weights"""
220
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
221
+ # Slightly different from the TF version which uses truncated_normal for initialization
222
+ # cf https://github.com/pytorch/pytorch/pull/5617
223
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
224
+ if module.bias is not None:
225
+ module.bias.data.zero_()
226
+ elif hasattr(module, "latents"):
227
+ module.latents.data.normal_(mean=0.0, std=self.config.initializer_range)
228
+ elif hasattr(module, "_pos_emb") and isinstance(module._pos_emb, nn.Parameter):
229
+ # initialize PerceiverTFTrainablePE
230
+ module._pos_emb.data.normal_(mean=0.0, std=self.config.initializer_range)
231
+ elif hasattr(module, "_pos_emb_temporal"):
232
+ # initialize PerceiverTFTrainablePE
233
+ module._pos_emb_temporal.data.normal_(mean=0.0, std=self.config.initializer_range)
234
+ elif hasattr(module, "slopes") and isinstance(module.slopes, nn.Parameter):
235
+ # initialize AlibiPositionalBias
236
+ module.reset_parameters()
237
+ elif isinstance(module, nn.ParameterDict):
238
+ for modality in module.keys():
239
+ module[modality].data.normal_(mean=0.0, std=self.config.initializer_range)
240
+ elif isinstance(module, nn.Embedding):
241
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
242
+ if module.padding_idx is not None:
243
+ module.weight.data[module.padding_idx].zero_()
244
+ elif isinstance(module, nn.LayerNorm):
245
+ module.bias.data.zero_()
246
+ module.weight.data.fill_(1.0)
247
+ # elif hasattr(module, "position_embeddings") and isinstance(
248
+ # module, PerceiverTrainablePositionEncoding):
249
+ # module.position_embeddings.data.normal_(mean=0.0, std=self.config.initializer_range)
250
+
251
+
252
+ # Replace the 'ModelOutputWithCrossAttentions' with 'MoEModelOutputWithCrossAttentions' for MoE
253
+ @dataclass
254
+ class MoEModelOutputWithCrossAttentions(ModelOutput):
255
+ """
256
+ Base class for model's outputs, with potential hidden states and attentions.
257
+ Plus, router_probs for Mixture of Experts models.
258
+
259
+ Args:
260
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
261
+ Sequence of hidden-states at the output of the last layer of the model.
262
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
263
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
264
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
265
+
266
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
267
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
268
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
269
+ sequence_length)`.
270
+
271
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
272
+ heads.
273
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
274
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
275
+ sequence_length)`.
276
+
277
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
278
+ weighted average in the cross-attention heads.
279
+ router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
280
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
281
+
282
+ Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary
283
+ loss and the z_loss for Mixture of Experts models.
284
+ """
285
+
286
+ last_hidden_state: torch.FloatTensor = None
287
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
288
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
289
+ cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
290
+ router_logits: Optional[Tuple[torch.FloatTensor]] = None