unpairedelectron07
commited on
Upload 7 files
Browse files- audiocraft/grids/musicgen/_explorers.py +93 -0
- audiocraft/grids/musicgen/musicgen_base_32khz.py +43 -0
- audiocraft/grids/musicgen/musicgen_base_cached_32khz.py +67 -0
- audiocraft/grids/musicgen/musicgen_clapemb_32khz.py +32 -0
- audiocraft/grids/musicgen/musicgen_melody_32khz.py +65 -0
- audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py +99 -0
- audiocraft/grids/musicgen/musicgen_stereo_finetune_32khz.py +57 -0
audiocraft/grids/musicgen/_explorers.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
import treetable as tt
|
10 |
+
|
11 |
+
from .._base_explorers import BaseExplorer
|
12 |
+
|
13 |
+
|
14 |
+
class LMExplorer(BaseExplorer):
|
15 |
+
eval_metrics: tp.List[str] = []
|
16 |
+
|
17 |
+
def stages(self) -> tp.List[str]:
|
18 |
+
return ['train', 'valid']
|
19 |
+
|
20 |
+
def get_grid_metrics(self):
|
21 |
+
"""Return the metrics that should be displayed in the tracking table."""
|
22 |
+
return [
|
23 |
+
tt.group(
|
24 |
+
'train',
|
25 |
+
[
|
26 |
+
tt.leaf('epoch'),
|
27 |
+
tt.leaf('duration', '.1f'), # duration in minutes
|
28 |
+
tt.leaf('ping'),
|
29 |
+
tt.leaf('ce', '.4f'), # cross entropy
|
30 |
+
tt.leaf("ppl", '.3f'), # perplexity
|
31 |
+
],
|
32 |
+
align='>',
|
33 |
+
),
|
34 |
+
tt.group(
|
35 |
+
'valid',
|
36 |
+
[
|
37 |
+
tt.leaf('ce', '.4f'),
|
38 |
+
tt.leaf('ppl', '.3f'),
|
39 |
+
tt.leaf('best_ppl', '.3f'),
|
40 |
+
],
|
41 |
+
align='>',
|
42 |
+
),
|
43 |
+
]
|
44 |
+
|
45 |
+
def process_sheep(self, sheep, history):
|
46 |
+
parts = super().process_sheep(sheep, history)
|
47 |
+
|
48 |
+
track_by = {'ppl': 'lower'} # values should be in ['lower', 'higher']
|
49 |
+
best_metrics = {k: (1 if v == 'lower' else -1) * float('inf') for k, v in track_by.items()}
|
50 |
+
|
51 |
+
def comparator(mode, a, b):
|
52 |
+
return a < b if mode == 'lower' else a > b
|
53 |
+
|
54 |
+
for metrics in history:
|
55 |
+
for key, sub in metrics.items():
|
56 |
+
for metric in track_by:
|
57 |
+
# for the validation set, keep track of best metrics (ppl in this example)
|
58 |
+
# this is so we can conveniently compare metrics between runs in the grid
|
59 |
+
if key == 'valid' and metric in sub and comparator(
|
60 |
+
track_by[metric], sub[metric], best_metrics[metric]
|
61 |
+
):
|
62 |
+
best_metrics[metric] = sub[metric]
|
63 |
+
|
64 |
+
if 'valid' in parts:
|
65 |
+
parts['valid'].update({f'best_{k}': v for k, v in best_metrics.items()})
|
66 |
+
return parts
|
67 |
+
|
68 |
+
|
69 |
+
class GenerationEvalExplorer(BaseExplorer):
|
70 |
+
eval_metrics: tp.List[str] = []
|
71 |
+
|
72 |
+
def stages(self) -> tp.List[str]:
|
73 |
+
return ['evaluate']
|
74 |
+
|
75 |
+
def get_grid_metrics(self):
|
76 |
+
"""Return the metrics that should be displayed in the tracking table."""
|
77 |
+
return [
|
78 |
+
tt.group(
|
79 |
+
'evaluate',
|
80 |
+
[
|
81 |
+
tt.leaf('epoch', '.3f'),
|
82 |
+
tt.leaf('duration', '.1f'),
|
83 |
+
tt.leaf('ping'),
|
84 |
+
tt.leaf('ce', '.4f'),
|
85 |
+
tt.leaf('ppl', '.3f'),
|
86 |
+
tt.leaf('fad', '.3f'),
|
87 |
+
tt.leaf('kld', '.3f'),
|
88 |
+
tt.leaf('text_consistency', '.3f'),
|
89 |
+
tt.leaf('chroma_cosine', '.3f'),
|
90 |
+
],
|
91 |
+
align='>',
|
92 |
+
),
|
93 |
+
]
|
audiocraft/grids/musicgen/musicgen_base_32khz.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from ._explorers import LMExplorer
|
8 |
+
from ...environment import AudioCraftEnvironment
|
9 |
+
|
10 |
+
|
11 |
+
@LMExplorer
|
12 |
+
def explorer(launcher):
|
13 |
+
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
14 |
+
launcher.slurm_(gpus=32, partition=partitions)
|
15 |
+
launcher.bind_(solver='musicgen/musicgen_base_32khz')
|
16 |
+
# replace this by the desired music dataset
|
17 |
+
launcher.bind_(dset='internal/music_400k_32khz')
|
18 |
+
|
19 |
+
fsdp = {'autocast': False, 'fsdp.use': True}
|
20 |
+
medium = {'model/lm/model_scale': 'medium'}
|
21 |
+
large = {'model/lm/model_scale': 'large'}
|
22 |
+
|
23 |
+
cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
|
24 |
+
wd_low = {'conditioners.description.t5.word_dropout': 0.2}
|
25 |
+
|
26 |
+
adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
|
27 |
+
|
28 |
+
launcher.bind_(fsdp)
|
29 |
+
|
30 |
+
launcher.slurm_(gpus=32).bind_(label='32gpus')
|
31 |
+
with launcher.job_array():
|
32 |
+
sub = launcher.bind()
|
33 |
+
sub()
|
34 |
+
|
35 |
+
launcher.slurm_(gpus=64).bind_(label='64gpus')
|
36 |
+
with launcher.job_array():
|
37 |
+
sub = launcher.bind()
|
38 |
+
sub(medium, adam)
|
39 |
+
|
40 |
+
launcher.slurm_(gpus=96).bind_(label='96gpus')
|
41 |
+
with launcher.job_array():
|
42 |
+
sub = launcher.bind()
|
43 |
+
sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
|
audiocraft/grids/musicgen/musicgen_base_cached_32khz.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from ._explorers import LMExplorer
|
8 |
+
from ...environment import AudioCraftEnvironment
|
9 |
+
|
10 |
+
|
11 |
+
@LMExplorer
|
12 |
+
def explorer(launcher):
|
13 |
+
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
14 |
+
launcher.slurm_(gpus=32, partition=partitions)
|
15 |
+
launcher.bind_(solver='musicgen/musicgen_base_32khz')
|
16 |
+
# replace this by the desired music dataset
|
17 |
+
launcher.bind_(dset='internal/music_400k_32khz')
|
18 |
+
|
19 |
+
fsdp = {'autocast': False, 'fsdp.use': True}
|
20 |
+
medium = {'model/lm/model_scale': 'medium'}
|
21 |
+
large = {'model/lm/model_scale': 'large'}
|
22 |
+
|
23 |
+
cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
|
24 |
+
wd_low = {'conditioners.description.t5.word_dropout': 0.2}
|
25 |
+
|
26 |
+
adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
|
27 |
+
|
28 |
+
# BEGINNING OF CACHE WRITING JOBS.
|
29 |
+
cache_write = {
|
30 |
+
'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k',
|
31 |
+
'cache.write': True,
|
32 |
+
'generate.every': 500,
|
33 |
+
'evaluate.every': 500,
|
34 |
+
'logging.log_updates': 50,
|
35 |
+
}
|
36 |
+
|
37 |
+
cache_sub = launcher.bind({'model/lm/model_scale': 'xsmall', 'conditioner': 'none'})
|
38 |
+
cache_sub.bind_({'deadlock.use': True})
|
39 |
+
cache_sub.slurm_(gpus=8)
|
40 |
+
with launcher.job_array():
|
41 |
+
num_shards = 10 # total number of jobs running in parallel.
|
42 |
+
for shard in range(0, num_shards):
|
43 |
+
launcher(cache_write, {'cache.write_num_shards': num_shards, 'cache.write_shard': shard})
|
44 |
+
|
45 |
+
# REMOVE THE FOLLOWING RETURN STATEMENT ONCE THE ABOVE JOBS ARE DONE,
|
46 |
+
# OR SUFFICIENTLY AHEAD.
|
47 |
+
return
|
48 |
+
|
49 |
+
cache = {
|
50 |
+
'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k',
|
51 |
+
}
|
52 |
+
launcher.bind_(fsdp, cache)
|
53 |
+
|
54 |
+
launcher.slurm_(gpus=32).bind_(label='32gpus')
|
55 |
+
with launcher.job_array():
|
56 |
+
sub = launcher.bind()
|
57 |
+
sub()
|
58 |
+
|
59 |
+
launcher.slurm_(gpus=64).bind_(label='64gpus')
|
60 |
+
with launcher.job_array():
|
61 |
+
sub = launcher.bind()
|
62 |
+
sub(medium, adam)
|
63 |
+
|
64 |
+
launcher.slurm_(gpus=96).bind_(label='96gpus')
|
65 |
+
with launcher.job_array():
|
66 |
+
sub = launcher.bind()
|
67 |
+
sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
|
audiocraft/grids/musicgen/musicgen_clapemb_32khz.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from ._explorers import LMExplorer
|
8 |
+
from ...environment import AudioCraftEnvironment
|
9 |
+
|
10 |
+
|
11 |
+
@LMExplorer
|
12 |
+
def explorer(launcher):
|
13 |
+
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
14 |
+
launcher.slurm_(gpus=32, partition=partitions)
|
15 |
+
launcher.bind_(solver='musicgen/musicgen_base_32khz')
|
16 |
+
# replace this by the desired music dataset
|
17 |
+
launcher.bind_(dset='internal/music_400k_32khz')
|
18 |
+
launcher.bind_(conditioner='clapemb2music')
|
19 |
+
|
20 |
+
fsdp = {'autocast': False, 'fsdp.use': True}
|
21 |
+
cache_path = {'conditioners.description.clap.cache_path':
|
22 |
+
'/fsx-audio-craft-llm/jadecopet/experiments/audiocraft/caches/clap_embed_music'}
|
23 |
+
text_wav_training_opt = {'conditioners.description.clap.text_p': 0.5}
|
24 |
+
|
25 |
+
launcher.bind_(fsdp)
|
26 |
+
|
27 |
+
launcher.slurm_(gpus=32).bind_(label='32gpus')
|
28 |
+
with launcher.job_array():
|
29 |
+
launcher()
|
30 |
+
launcher(text_wav_training_opt)
|
31 |
+
launcher(cache_path)
|
32 |
+
launcher(cache_path, text_wav_training_opt)
|
audiocraft/grids/musicgen/musicgen_melody_32khz.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from ._explorers import LMExplorer
|
8 |
+
from ...environment import AudioCraftEnvironment
|
9 |
+
|
10 |
+
|
11 |
+
@LMExplorer
|
12 |
+
def explorer(launcher):
|
13 |
+
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
14 |
+
launcher.slurm_(gpus=32, partition=partitions)
|
15 |
+
launcher.bind_(solver='musicgen/musicgen_melody_32khz')
|
16 |
+
# replace this by the desired music dataset
|
17 |
+
launcher.bind_(dset='internal/music_400k_32khz')
|
18 |
+
|
19 |
+
fsdp = {'autocast': False, 'fsdp.use': True}
|
20 |
+
medium = {'model/lm/model_scale': 'medium'}
|
21 |
+
large = {'model/lm/model_scale': 'large'}
|
22 |
+
|
23 |
+
cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
|
24 |
+
wd_low = {'conditioners.description.t5.word_dropout': 0.2}
|
25 |
+
|
26 |
+
adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
|
27 |
+
|
28 |
+
cache_path = {'conditioners.self_wav.chroma_stem.cache_path':
|
29 |
+
'/fsx-audio-craft-llm/jadecopet/experiments/audiocraft/caches/chroma_stem'}
|
30 |
+
|
31 |
+
# CACHE GENERATION JOBS
|
32 |
+
n_cache_gen_jobs = 4
|
33 |
+
gen_sub = launcher.slurm(gpus=1)
|
34 |
+
gen_sub.bind_(
|
35 |
+
cache_path, {
|
36 |
+
# the cache is always computed over the whole file, so duration doesn't matter here.
|
37 |
+
'dataset.segment_duration': 2.,
|
38 |
+
'dataset.batch_size': 8,
|
39 |
+
'dataset.train.permutation_on_files': True, # try to not repeat files.
|
40 |
+
'optim.epochs': 10,
|
41 |
+
'model/lm/model_scale': 'xsmall',
|
42 |
+
|
43 |
+
})
|
44 |
+
with gen_sub.job_array():
|
45 |
+
for gen_job in range(n_cache_gen_jobs):
|
46 |
+
gen_sub({'dataset.train.shuffle_seed': gen_job})
|
47 |
+
|
48 |
+
# ACTUAL TRAINING JOBS.
|
49 |
+
launcher.bind_(fsdp)
|
50 |
+
|
51 |
+
launcher.slurm_(gpus=32).bind_(label='32gpus')
|
52 |
+
with launcher.job_array():
|
53 |
+
sub = launcher.bind()
|
54 |
+
sub()
|
55 |
+
sub(cache_path)
|
56 |
+
|
57 |
+
launcher.slurm_(gpus=64).bind_(label='64gpus')
|
58 |
+
with launcher.job_array():
|
59 |
+
sub = launcher.bind()
|
60 |
+
sub(medium, adam)
|
61 |
+
|
62 |
+
launcher.slurm_(gpus=96).bind_(label='96gpus')
|
63 |
+
with launcher.job_array():
|
64 |
+
sub = launcher.bind()
|
65 |
+
sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
|
audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Evaluation with objective metrics for the pretrained MusicGen models.
|
9 |
+
This grid takes signature from the training grid and runs evaluation-only stage.
|
10 |
+
|
11 |
+
When running the grid for the first time, please use:
|
12 |
+
REGEN=1 dora grid musicgen.musicgen_pretrained_32khz_eval
|
13 |
+
and re-use the REGEN=1 option when the grid is changed to force regenerating it.
|
14 |
+
|
15 |
+
Note that you need the proper metrics external libraries setup to use all
|
16 |
+
the objective metrics activated in this grid. Refer to the README for more information.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import os
|
20 |
+
|
21 |
+
from ._explorers import GenerationEvalExplorer
|
22 |
+
from ...environment import AudioCraftEnvironment
|
23 |
+
from ... import train
|
24 |
+
|
25 |
+
|
26 |
+
def eval(launcher, batch_size: int = 32, eval_melody: bool = False):
|
27 |
+
opts = {
|
28 |
+
'dset': 'audio/musiccaps_32khz',
|
29 |
+
'solver/musicgen/evaluation': 'objective_eval',
|
30 |
+
'execute_only': 'evaluate',
|
31 |
+
'+dataset.evaluate.batch_size': batch_size,
|
32 |
+
'+metrics.fad.tf.batch_size': 16,
|
33 |
+
}
|
34 |
+
# chroma-specific evaluation
|
35 |
+
chroma_opts = {
|
36 |
+
'dset': 'internal/music_400k_32khz',
|
37 |
+
'dataset.evaluate.segment_duration': 30,
|
38 |
+
'dataset.evaluate.num_samples': 1000,
|
39 |
+
'evaluate.metrics.chroma_cosine': True,
|
40 |
+
'evaluate.metrics.fad': False,
|
41 |
+
'evaluate.metrics.kld': False,
|
42 |
+
'evaluate.metrics.text_consistency': False,
|
43 |
+
}
|
44 |
+
# binary for FAD computation: replace this path with your own path
|
45 |
+
metrics_opts = {
|
46 |
+
'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research'
|
47 |
+
}
|
48 |
+
opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.}
|
49 |
+
opt2 = {'transformer_lm.two_step_cfg': True}
|
50 |
+
|
51 |
+
sub = launcher.bind(opts)
|
52 |
+
sub.bind_(metrics_opts)
|
53 |
+
|
54 |
+
# base objective metrics
|
55 |
+
sub(opt1, opt2)
|
56 |
+
|
57 |
+
if eval_melody:
|
58 |
+
# chroma-specific metrics
|
59 |
+
sub(opt1, opt2, chroma_opts)
|
60 |
+
|
61 |
+
|
62 |
+
@GenerationEvalExplorer
|
63 |
+
def explorer(launcher):
|
64 |
+
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
65 |
+
launcher.slurm_(gpus=4, partition=partitions)
|
66 |
+
|
67 |
+
if 'REGEN' not in os.environ:
|
68 |
+
folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1]
|
69 |
+
with launcher.job_array():
|
70 |
+
for sig in folder.iterdir():
|
71 |
+
if not sig.is_symlink():
|
72 |
+
continue
|
73 |
+
xp = train.main.get_xp_from_sig(sig.name)
|
74 |
+
launcher(xp.argv)
|
75 |
+
return
|
76 |
+
|
77 |
+
with launcher.job_array():
|
78 |
+
musicgen_base = launcher.bind(solver="musicgen/musicgen_base_32khz")
|
79 |
+
musicgen_base.bind_({'autocast': False, 'fsdp.use': True})
|
80 |
+
|
81 |
+
# base musicgen models
|
82 |
+
musicgen_base_small = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-small'})
|
83 |
+
eval(musicgen_base_small, batch_size=128)
|
84 |
+
|
85 |
+
musicgen_base_medium = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-medium'})
|
86 |
+
musicgen_base_medium.bind_({'model/lm/model_scale': 'medium'})
|
87 |
+
eval(musicgen_base_medium, batch_size=128)
|
88 |
+
|
89 |
+
musicgen_base_large = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-large'})
|
90 |
+
musicgen_base_large.bind_({'model/lm/model_scale': 'large'})
|
91 |
+
eval(musicgen_base_large, batch_size=128)
|
92 |
+
|
93 |
+
# melody musicgen model
|
94 |
+
musicgen_melody = launcher.bind(solver="musicgen/musicgen_melody_32khz")
|
95 |
+
musicgen_melody.bind_({'autocast': False, 'fsdp.use': True})
|
96 |
+
|
97 |
+
musicgen_melody_medium = musicgen_melody.bind({'continue_from': '//pretrained/facebook/musicgen-melody'})
|
98 |
+
musicgen_melody_medium.bind_({'model/lm/model_scale': 'medium'})
|
99 |
+
eval(musicgen_melody_medium, batch_size=128, eval_melody=True)
|
audiocraft/grids/musicgen/musicgen_stereo_finetune_32khz.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from pathlib import Path
|
8 |
+
from ._explorers import LMExplorer
|
9 |
+
from ...environment import AudioCraftEnvironment
|
10 |
+
|
11 |
+
|
12 |
+
@LMExplorer
|
13 |
+
def explorer(launcher):
|
14 |
+
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
15 |
+
launcher.slurm_(gpus=32, partition=partitions)
|
16 |
+
launcher.bind_(solver='musicgen/musicgen_base_32khz')
|
17 |
+
# replace this by the desired music dataset, which needs to be stereo
|
18 |
+
launcher.bind_(dset='audio/example')
|
19 |
+
|
20 |
+
fsdp = {'autocast': False, 'fsdp.use': True}
|
21 |
+
medium = {'model/lm/model_scale': 'medium'}
|
22 |
+
large = {'model/lm/model_scale': 'large'}
|
23 |
+
|
24 |
+
cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
|
25 |
+
wd_low = {'conditioners.description.t5.word_dropout': 0.2}
|
26 |
+
|
27 |
+
adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
|
28 |
+
|
29 |
+
stereo = {
|
30 |
+
'codebooks_pattern.delay.delays': [0, 0, 1, 1, 2, 2, 3, 3],
|
31 |
+
'transformer_lm.n_q': 8,
|
32 |
+
'interleave_stereo_codebooks.use': True,
|
33 |
+
'channels': 2,
|
34 |
+
}
|
35 |
+
|
36 |
+
# You must follow the instructions in docs/MUSICGEN.md about the creation
|
37 |
+
# of the proper fine tuning checkpoints. We will assume they are stored under
|
38 |
+
# ~/checkpoints/{mode_name}.
|
39 |
+
|
40 |
+
checkpoints = Path.home() / 'checkpoints'
|
41 |
+
|
42 |
+
launcher.bind_(fsdp, stereo, {'optim.epochs': 100})
|
43 |
+
|
44 |
+
launcher.slurm_(gpus=32).bind_(label='32gpus')
|
45 |
+
with launcher.job_array():
|
46 |
+
sub = launcher.bind({'continue_from': str(checkpoints / 'stereo_finetune_musicgen-small.th')})
|
47 |
+
sub()
|
48 |
+
|
49 |
+
launcher.slurm_(gpus=64).bind_(label='64gpus')
|
50 |
+
with launcher.job_array():
|
51 |
+
sub = launcher.bind({'continue_from': str(checkpoints / 'stereo_finetune_musicgen-medium.th')})
|
52 |
+
sub(medium, adam)
|
53 |
+
|
54 |
+
launcher.slurm_(gpus=96).bind_(label='96gpus')
|
55 |
+
with launcher.job_array():
|
56 |
+
sub = launcher.bind({'continue_from': str(checkpoints / 'stereo_finetune_musicgen-large.th')})
|
57 |
+
sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
|