File size: 1,537 Bytes
3d6b478
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from __gin__ import dynamic_registration
from gamadhani import src
from gamadhani.src import dataset
from gamadhani.src import model_transformer
from gamadhani.src import task_functions
from gamadhani.utils import utils
import torch.optim

MODEL_DIM = 512
EMB_DIM = 512
NUM_TOKENS = 7928
NUM_QUANTIZERS = 1
DROPOUT_RATE = 0.3
NUM_HEADS = 8
SEQ_LEN = 1200
HEAD_DIM = 32
NUM_LAYERS = 8
LR = 1e-3

model_transformer.XTransformerPrior:
    num_tokens = %NUM_TOKENS
    seq_len = %SEQ_LEN
    model_dim = %MODEL_DIM
    emb_dim = %EMB_DIM
    head_dim = %HEAD_DIM
    num_layers = %NUM_LAYERS
    num_heads = %NUM_HEADS
    dropout_rate = %DROPOUT_RATE
    

src.dataset.Task:
    read_fn = @src.task_functions.pitch_read_downsample
    invert_fn = @src.task_functions.invert_pitch_read_downsample
    kwargs = {"seq_len": %SEQ_LEN, 
                    "decoder_key": "pitch",
                    "min_norm_pitch": -4915,
                    "time_downsample": 2,
                    "pitch_downsample": 10,
                    "base_tonic": 440.}

src.dataset.SequenceDataset:
    task = @dataset.Task()
    apply_transform = False

model_transformer.XTransformerPrior.configure_optimizers:
    optimizer_cls = @torch.optim.AdamW
    scheduler_cls = @utils.build_warmed_exponential_lr_scheduler

utils.build_warmed_exponential_lr_scheduler:
    start_factor = .01
    peak_iteration = 10000
    cycle_length = 394600
    eta_min = 0.1
    eta_max = %LR

utils.set_seed:
    seed = 2023

torch.optim.AdamW:
    lr = %LR
    betas = (.9, .98)