Spaces:
Runtime error
Runtime error
import argparse | |
import logging | |
import os | |
import pathlib | |
from functools import partial | |
from typing import List, NoReturn | |
import pytorch_lightning as pl | |
from pytorch_lightning.plugins import DDPPlugin | |
from bytesep.callbacks import get_callbacks | |
from bytesep.data.augmentors import Augmentor | |
from bytesep.data.batch_data_preprocessors import ( | |
get_batch_data_preprocessor_class, | |
) | |
from bytesep.data.data_modules import DataModule, Dataset | |
from bytesep.data.samplers import SegmentSampler | |
from bytesep.losses import get_loss_function | |
from bytesep.models.lightning_modules import ( | |
LitSourceSeparation, | |
get_model_class, | |
) | |
from bytesep.optimizers.lr_schedulers import get_lr_lambda | |
from bytesep.utils import ( | |
create_logging, | |
get_pitch_shift_factor, | |
read_yaml, | |
check_configs_gramma, | |
) | |
def get_dirs( | |
workspace: str, task_name: str, filename: str, config_yaml: str, gpus: int | |
) -> List[str]: | |
r"""Get directories. | |
Args: | |
workspace: str | |
task_name, str, e.g., 'musdb18' | |
filenmae: str | |
config_yaml: str | |
gpus: int, e.g., 0 for cpu and 8 for training with 8 gpu cards | |
Returns: | |
checkpoints_dir: str | |
logs_dir: str | |
logger: pl.loggers.TensorBoardLogger | |
statistics_path: str | |
""" | |
# save checkpoints dir | |
checkpoints_dir = os.path.join( | |
workspace, | |
"checkpoints", | |
task_name, | |
filename, | |
"config={},gpus={}".format(pathlib.Path(config_yaml).stem, gpus), | |
) | |
os.makedirs(checkpoints_dir, exist_ok=True) | |
# logs dir | |
logs_dir = os.path.join( | |
workspace, | |
"logs", | |
task_name, | |
filename, | |
"config={},gpus={}".format(pathlib.Path(config_yaml).stem, gpus), | |
) | |
os.makedirs(logs_dir, exist_ok=True) | |
# loggings | |
create_logging(logs_dir, filemode='w') | |
logging.info(args) | |
# tensorboard logs dir | |
tb_logs_dir = os.path.join(workspace, "tensorboard_logs") | |
os.makedirs(tb_logs_dir, exist_ok=True) | |
experiment_name = os.path.join(task_name, filename, pathlib.Path(config_yaml).stem) | |
logger = pl.loggers.TensorBoardLogger(save_dir=tb_logs_dir, name=experiment_name) | |
# statistics path | |
statistics_path = os.path.join( | |
workspace, | |
"statistics", | |
task_name, | |
filename, | |
"config={},gpus={}".format(pathlib.Path(config_yaml).stem, gpus), | |
"statistics.pkl", | |
) | |
os.makedirs(os.path.dirname(statistics_path), exist_ok=True) | |
return checkpoints_dir, logs_dir, logger, statistics_path | |
def _get_data_module( | |
workspace: str, config_yaml: str, num_workers: int, distributed: bool | |
) -> DataModule: | |
r"""Create data_module. Mini-batch data can be obtained by: | |
code-block:: python | |
data_module.setup() | |
for batch_data_dict in data_module.train_dataloader(): | |
print(batch_data_dict.keys()) | |
break | |
Args: | |
workspace: str | |
config_yaml: str | |
num_workers: int, e.g., 0 for non-parallel and 8 for using cpu cores | |
for preparing data in parallel | |
distributed: bool | |
Returns: | |
data_module: DataModule | |
""" | |
configs = read_yaml(config_yaml) | |
input_source_types = configs['train']['input_source_types'] | |
indexes_path = os.path.join(workspace, configs['train']['indexes_dict']) | |
sample_rate = configs['train']['sample_rate'] | |
segment_seconds = configs['train']['segment_seconds'] | |
mixaudio_dict = configs['train']['augmentations']['mixaudio'] | |
augmentations = configs['train']['augmentations'] | |
max_pitch_shift = max( | |
[ | |
augmentations['pitch_shift'][source_type] | |
for source_type in input_source_types | |
] | |
) | |
batch_size = configs['train']['batch_size'] | |
steps_per_epoch = configs['train']['steps_per_epoch'] | |
segment_samples = int(segment_seconds * sample_rate) | |
ex_segment_samples = int(segment_samples * get_pitch_shift_factor(max_pitch_shift)) | |
# sampler | |
train_sampler = SegmentSampler( | |
indexes_path=indexes_path, | |
segment_samples=ex_segment_samples, | |
mixaudio_dict=mixaudio_dict, | |
batch_size=batch_size, | |
steps_per_epoch=steps_per_epoch, | |
) | |
# augmentor | |
augmentor = Augmentor(augmentations=augmentations) | |
# dataset | |
train_dataset = Dataset(augmentor, segment_samples) | |
# data module | |
data_module = DataModule( | |
train_sampler=train_sampler, | |
train_dataset=train_dataset, | |
num_workers=num_workers, | |
distributed=distributed, | |
) | |
return data_module | |
def train(args) -> NoReturn: | |
r"""Train & evaluate and save checkpoints. | |
Args: | |
workspace: str, directory of workspace | |
gpus: int | |
config_yaml: str, path of config file for training | |
""" | |
# arugments & parameters | |
workspace = args.workspace | |
gpus = args.gpus | |
config_yaml = args.config_yaml | |
filename = args.filename | |
num_workers = 8 | |
distributed = True if gpus > 1 else False | |
evaluate_device = "cuda" if gpus > 0 else "cpu" | |
# Read config file. | |
configs = read_yaml(config_yaml) | |
check_configs_gramma(configs) | |
task_name = configs['task_name'] | |
target_source_types = configs['train']['target_source_types'] | |
target_sources_num = len(target_source_types) | |
channels = configs['train']['channels'] | |
batch_data_preprocessor_type = configs['train']['batch_data_preprocessor'] | |
model_type = configs['train']['model_type'] | |
loss_type = configs['train']['loss_type'] | |
optimizer_type = configs['train']['optimizer_type'] | |
learning_rate = float(configs['train']['learning_rate']) | |
precision = configs['train']['precision'] | |
early_stop_steps = configs['train']['early_stop_steps'] | |
warm_up_steps = configs['train']['warm_up_steps'] | |
reduce_lr_steps = configs['train']['reduce_lr_steps'] | |
# paths | |
checkpoints_dir, logs_dir, logger, statistics_path = get_dirs( | |
workspace, task_name, filename, config_yaml, gpus | |
) | |
# training data module | |
data_module = _get_data_module( | |
workspace=workspace, | |
config_yaml=config_yaml, | |
num_workers=num_workers, | |
distributed=distributed, | |
) | |
# batch data preprocessor | |
BatchDataPreprocessor = get_batch_data_preprocessor_class( | |
batch_data_preprocessor_type=batch_data_preprocessor_type | |
) | |
batch_data_preprocessor = BatchDataPreprocessor( | |
target_source_types=target_source_types | |
) | |
# model | |
Model = get_model_class(model_type=model_type) | |
model = Model(input_channels=channels, target_sources_num=target_sources_num) | |
# loss function | |
loss_function = get_loss_function(loss_type=loss_type) | |
# callbacks | |
callbacks = get_callbacks( | |
task_name=task_name, | |
config_yaml=config_yaml, | |
workspace=workspace, | |
checkpoints_dir=checkpoints_dir, | |
statistics_path=statistics_path, | |
logger=logger, | |
model=model, | |
evaluate_device=evaluate_device, | |
) | |
# callbacks = [] | |
# learning rate reduce function | |
lr_lambda = partial( | |
get_lr_lambda, warm_up_steps=warm_up_steps, reduce_lr_steps=reduce_lr_steps | |
) | |
# pytorch-lightning model | |
pl_model = LitSourceSeparation( | |
batch_data_preprocessor=batch_data_preprocessor, | |
model=model, | |
optimizer_type=optimizer_type, | |
loss_function=loss_function, | |
learning_rate=learning_rate, | |
lr_lambda=lr_lambda, | |
) | |
# trainer | |
trainer = pl.Trainer( | |
checkpoint_callback=False, | |
gpus=gpus, | |
callbacks=callbacks, | |
max_steps=early_stop_steps, | |
accelerator="ddp", | |
sync_batchnorm=True, | |
precision=precision, | |
replace_sampler_ddp=False, | |
plugins=[DDPPlugin(find_unused_parameters=True)], | |
profiler='simple', | |
) | |
# Fit, evaluate, and save checkpoints. | |
trainer.fit(pl_model, data_module) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="") | |
subparsers = parser.add_subparsers(dest="mode") | |
parser_train = subparsers.add_parser("train") | |
parser_train.add_argument( | |
"--workspace", type=str, required=True, help="Directory of workspace." | |
) | |
parser_train.add_argument("--gpus", type=int, required=True) | |
parser_train.add_argument( | |
"--config_yaml", | |
type=str, | |
required=True, | |
help="Path of config file for training.", | |
) | |
args = parser.parse_args() | |
args.filename = pathlib.Path(__file__).stem | |
if args.mode == "train": | |
train(args) | |
else: | |
raise Exception("Error argument!") | |