|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import os |
|
from pathlib import Path |
|
|
|
|
|
def get_parser(): |
|
parser = argparse.ArgumentParser("demucs", description="Train and evaluate Demucs.") |
|
default_raw = None |
|
default_musdb = None |
|
if 'DEMUCS_RAW' in os.environ: |
|
default_raw = Path(os.environ['DEMUCS_RAW']) |
|
if 'DEMUCS_MUSDB' in os.environ: |
|
default_musdb = Path(os.environ['DEMUCS_MUSDB']) |
|
parser.add_argument( |
|
"--raw", |
|
type=Path, |
|
default=default_raw, |
|
help="Path to raw audio, can be faster, see python3 -m demucs.raw to extract.") |
|
parser.add_argument("--no_raw", action="store_const", const=None, dest="raw") |
|
parser.add_argument("-m", |
|
"--musdb", |
|
type=Path, |
|
default=default_musdb, |
|
help="Path to musdb root") |
|
parser.add_argument("--is_wav", action="store_true", |
|
help="Indicate that the MusDB dataset is in wav format (i.e. MusDB-HQ).") |
|
parser.add_argument("--metadata", type=Path, default=Path("metadata/"), |
|
help="Folder where metadata information is stored.") |
|
parser.add_argument("--wav", type=Path, |
|
help="Path to a wav dataset. This should contain a 'train' and a 'valid' " |
|
"subfolder.") |
|
parser.add_argument("--samplerate", type=int, default=44100) |
|
parser.add_argument("--audio_channels", type=int, default=2) |
|
parser.add_argument("--samples", |
|
default=44100 * 10, |
|
type=int, |
|
help="number of samples to feed in") |
|
parser.add_argument("--data_stride", |
|
default=44100, |
|
type=int, |
|
help="Stride for chunks, shorter = longer epochs") |
|
parser.add_argument("-w", "--workers", default=10, type=int, help="Loader workers") |
|
parser.add_argument("--eval_workers", default=2, type=int, help="Final evaluation workers") |
|
parser.add_argument("-d", |
|
"--device", |
|
help="Device to train on, default is cuda if available else cpu") |
|
parser.add_argument("--eval_cpu", action="store_true", help="Eval on test will be run on cpu.") |
|
parser.add_argument("--dummy", help="Dummy parameter, useful to create a new checkpoint file") |
|
parser.add_argument("--test", help="Just run the test pipeline + one validation. " |
|
"This should be a filename relative to the models/ folder.") |
|
parser.add_argument("--test_pretrained", help="Just run the test pipeline + one validation, " |
|
"on a pretrained model. ") |
|
|
|
parser.add_argument("--rank", default=0, type=int) |
|
parser.add_argument("--world_size", default=1, type=int) |
|
parser.add_argument("--master") |
|
|
|
parser.add_argument("--checkpoints", |
|
type=Path, |
|
default=Path("checkpoints"), |
|
help="Folder where to store checkpoints etc") |
|
parser.add_argument("--evals", |
|
type=Path, |
|
default=Path("evals"), |
|
help="Folder where to store evals and waveforms") |
|
parser.add_argument("--save", |
|
action="store_true", |
|
help="Save estimated for the test set waveforms") |
|
parser.add_argument("--logs", |
|
type=Path, |
|
default=Path("logs"), |
|
help="Folder where to store logs") |
|
parser.add_argument("--models", |
|
type=Path, |
|
default=Path("models"), |
|
help="Folder where to store trained models") |
|
parser.add_argument("-R", |
|
"--restart", |
|
action='store_true', |
|
help='Restart training, ignoring previous run') |
|
|
|
parser.add_argument("--seed", type=int, default=42) |
|
parser.add_argument("-e", "--epochs", type=int, default=180, help="Number of epochs") |
|
parser.add_argument("-r", |
|
"--repeat", |
|
type=int, |
|
default=2, |
|
help="Repeat the train set, longer epochs") |
|
parser.add_argument("-b", "--batch_size", type=int, default=64) |
|
parser.add_argument("--lr", type=float, default=3e-4) |
|
parser.add_argument("--mse", action="store_true", help="Use MSE instead of L1") |
|
parser.add_argument("--init", help="Initialize from a pre-trained model.") |
|
|
|
|
|
parser.add_argument("--no_augment", |
|
action="store_false", |
|
dest="augment", |
|
default=True, |
|
help="No basic data augmentation.") |
|
parser.add_argument("--repitch", type=float, default=0.2, |
|
help="Probability to do tempo/pitch change") |
|
parser.add_argument("--max_tempo", type=float, default=12, |
|
help="Maximum relative tempo change in %% when using repitch.") |
|
|
|
parser.add_argument("--remix_group_size", |
|
type=int, |
|
default=4, |
|
help="Shuffle sources using group of this size. Useful to somewhat " |
|
"replicate multi-gpu training " |
|
"on less GPUs.") |
|
parser.add_argument("--shifts", |
|
type=int, |
|
default=10, |
|
help="Number of random shifts used for the shift trick.") |
|
parser.add_argument("--overlap", |
|
type=float, |
|
default=0.25, |
|
help="Overlap when --split_valid is passed.") |
|
|
|
|
|
parser.add_argument("--growth", |
|
type=float, |
|
default=2., |
|
help="Number of channels between two layers will increase by this factor") |
|
parser.add_argument("--depth", |
|
type=int, |
|
default=6, |
|
help="Number of layers for the encoder and decoder") |
|
parser.add_argument("--lstm_layers", type=int, default=2, help="Number of layers for the LSTM") |
|
parser.add_argument("--channels", |
|
type=int, |
|
default=64, |
|
help="Number of channels for the first encoder layer") |
|
parser.add_argument("--kernel_size", |
|
type=int, |
|
default=8, |
|
help="Kernel size for the (transposed) convolutions") |
|
parser.add_argument("--conv_stride", |
|
type=int, |
|
default=4, |
|
help="Stride for the (transposed) convolutions") |
|
parser.add_argument("--context", |
|
type=int, |
|
default=3, |
|
help="Context size for the decoder convolutions " |
|
"before the transposed convolutions") |
|
parser.add_argument("--rescale", |
|
type=float, |
|
default=0.1, |
|
help="Initial weight rescale reference") |
|
parser.add_argument("--no_resample", action="store_false", |
|
default=True, dest="resample", |
|
help="No Resampling of the input/output x2") |
|
parser.add_argument("--no_glu", |
|
action="store_false", |
|
default=True, |
|
dest="glu", |
|
help="Replace all GLUs by ReLUs") |
|
parser.add_argument("--no_rewrite", |
|
action="store_false", |
|
default=True, |
|
dest="rewrite", |
|
help="No 1x1 rewrite convolutions") |
|
parser.add_argument("--normalize", action="store_true") |
|
parser.add_argument("--no_norm_wav", action="store_false", dest='norm_wav', default=True) |
|
|
|
|
|
parser.add_argument("--tasnet", action="store_true") |
|
parser.add_argument("--split_valid", |
|
action="store_true", |
|
help="Predict chunks by chunks for valid and test. Required for tasnet") |
|
parser.add_argument("--X", type=int, default=8) |
|
|
|
|
|
parser.add_argument("--show", |
|
action="store_true", |
|
help="Show model architecture, size and exit") |
|
parser.add_argument("--save_model", action="store_true", |
|
help="Skip traning, just save final model " |
|
"for the current checkpoint value.") |
|
parser.add_argument("--save_state", |
|
help="Skip training, just save state " |
|
"for the current checkpoint value. You should " |
|
"provide a model name as argument.") |
|
|
|
|
|
parser.add_argument("--q-min-size", type=float, default=1, |
|
help="Only quantize layers over this size (in MB)") |
|
parser.add_argument( |
|
"--qat", type=int, help="If provided, use QAT training with that many bits.") |
|
|
|
parser.add_argument("--diffq", type=float, default=0) |
|
parser.add_argument( |
|
"--ms-target", type=float, default=162, |
|
help="Model size target in MB, when using DiffQ. Best model will be kept " |
|
"only if it is smaller than this target.") |
|
|
|
return parser |
|
|
|
|
|
def get_name(parser, args): |
|
""" |
|
Return the name of an experiment given the args. Some parameters are ignored, |
|
for instance --workers, as they do not impact the final result. |
|
""" |
|
ignore_args = set([ |
|
"checkpoints", |
|
"deterministic", |
|
"eval", |
|
"evals", |
|
"eval_cpu", |
|
"eval_workers", |
|
"logs", |
|
"master", |
|
"rank", |
|
"restart", |
|
"save", |
|
"save_model", |
|
"save_state", |
|
"show", |
|
"workers", |
|
"world_size", |
|
]) |
|
parts = [] |
|
name_args = dict(args.__dict__) |
|
for name, value in name_args.items(): |
|
if name in ignore_args: |
|
continue |
|
if value != parser.get_default(name): |
|
if isinstance(value, Path): |
|
parts.append(f"{name}={value.name}") |
|
else: |
|
parts.append(f"{name}={value}") |
|
if parts: |
|
name = " ".join(parts) |
|
else: |
|
name = "default" |
|
return name |
|
|