akhaliq3
spaces demo
5019931
raw
history blame
8.66 kB
import argparse
import logging
import os
import pathlib
from functools import partial
from typing import List, NoReturn
import pytorch_lightning as pl
from pytorch_lightning.plugins import DDPPlugin
from bytesep.callbacks import get_callbacks
from bytesep.data.augmentors import Augmentor
from bytesep.data.batch_data_preprocessors import (
get_batch_data_preprocessor_class,
)
from bytesep.data.data_modules import DataModule, Dataset
from bytesep.data.samplers import SegmentSampler
from bytesep.losses import get_loss_function
from bytesep.models.lightning_modules import (
LitSourceSeparation,
get_model_class,
)
from bytesep.optimizers.lr_schedulers import get_lr_lambda
from bytesep.utils import (
create_logging,
get_pitch_shift_factor,
read_yaml,
check_configs_gramma,
)
def get_dirs(
workspace: str, task_name: str, filename: str, config_yaml: str, gpus: int
) -> List[str]:
r"""Get directories.
Args:
workspace: str
task_name, str, e.g., 'musdb18'
filenmae: str
config_yaml: str
gpus: int, e.g., 0 for cpu and 8 for training with 8 gpu cards
Returns:
checkpoints_dir: str
logs_dir: str
logger: pl.loggers.TensorBoardLogger
statistics_path: str
"""
# save checkpoints dir
checkpoints_dir = os.path.join(
workspace,
"checkpoints",
task_name,
filename,
"config={},gpus={}".format(pathlib.Path(config_yaml).stem, gpus),
)
os.makedirs(checkpoints_dir, exist_ok=True)
# logs dir
logs_dir = os.path.join(
workspace,
"logs",
task_name,
filename,
"config={},gpus={}".format(pathlib.Path(config_yaml).stem, gpus),
)
os.makedirs(logs_dir, exist_ok=True)
# loggings
create_logging(logs_dir, filemode='w')
logging.info(args)
# tensorboard logs dir
tb_logs_dir = os.path.join(workspace, "tensorboard_logs")
os.makedirs(tb_logs_dir, exist_ok=True)
experiment_name = os.path.join(task_name, filename, pathlib.Path(config_yaml).stem)
logger = pl.loggers.TensorBoardLogger(save_dir=tb_logs_dir, name=experiment_name)
# statistics path
statistics_path = os.path.join(
workspace,
"statistics",
task_name,
filename,
"config={},gpus={}".format(pathlib.Path(config_yaml).stem, gpus),
"statistics.pkl",
)
os.makedirs(os.path.dirname(statistics_path), exist_ok=True)
return checkpoints_dir, logs_dir, logger, statistics_path
def _get_data_module(
workspace: str, config_yaml: str, num_workers: int, distributed: bool
) -> DataModule:
r"""Create data_module. Mini-batch data can be obtained by:
code-block:: python
data_module.setup()
for batch_data_dict in data_module.train_dataloader():
print(batch_data_dict.keys())
break
Args:
workspace: str
config_yaml: str
num_workers: int, e.g., 0 for non-parallel and 8 for using cpu cores
for preparing data in parallel
distributed: bool
Returns:
data_module: DataModule
"""
configs = read_yaml(config_yaml)
input_source_types = configs['train']['input_source_types']
indexes_path = os.path.join(workspace, configs['train']['indexes_dict'])
sample_rate = configs['train']['sample_rate']
segment_seconds = configs['train']['segment_seconds']
mixaudio_dict = configs['train']['augmentations']['mixaudio']
augmentations = configs['train']['augmentations']
max_pitch_shift = max(
[
augmentations['pitch_shift'][source_type]
for source_type in input_source_types
]
)
batch_size = configs['train']['batch_size']
steps_per_epoch = configs['train']['steps_per_epoch']
segment_samples = int(segment_seconds * sample_rate)
ex_segment_samples = int(segment_samples * get_pitch_shift_factor(max_pitch_shift))
# sampler
train_sampler = SegmentSampler(
indexes_path=indexes_path,
segment_samples=ex_segment_samples,
mixaudio_dict=mixaudio_dict,
batch_size=batch_size,
steps_per_epoch=steps_per_epoch,
)
# augmentor
augmentor = Augmentor(augmentations=augmentations)
# dataset
train_dataset = Dataset(augmentor, segment_samples)
# data module
data_module = DataModule(
train_sampler=train_sampler,
train_dataset=train_dataset,
num_workers=num_workers,
distributed=distributed,
)
return data_module
def train(args) -> NoReturn:
r"""Train & evaluate and save checkpoints.
Args:
workspace: str, directory of workspace
gpus: int
config_yaml: str, path of config file for training
"""
# arugments & parameters
workspace = args.workspace
gpus = args.gpus
config_yaml = args.config_yaml
filename = args.filename
num_workers = 8
distributed = True if gpus > 1 else False
evaluate_device = "cuda" if gpus > 0 else "cpu"
# Read config file.
configs = read_yaml(config_yaml)
check_configs_gramma(configs)
task_name = configs['task_name']
target_source_types = configs['train']['target_source_types']
target_sources_num = len(target_source_types)
channels = configs['train']['channels']
batch_data_preprocessor_type = configs['train']['batch_data_preprocessor']
model_type = configs['train']['model_type']
loss_type = configs['train']['loss_type']
optimizer_type = configs['train']['optimizer_type']
learning_rate = float(configs['train']['learning_rate'])
precision = configs['train']['precision']
early_stop_steps = configs['train']['early_stop_steps']
warm_up_steps = configs['train']['warm_up_steps']
reduce_lr_steps = configs['train']['reduce_lr_steps']
# paths
checkpoints_dir, logs_dir, logger, statistics_path = get_dirs(
workspace, task_name, filename, config_yaml, gpus
)
# training data module
data_module = _get_data_module(
workspace=workspace,
config_yaml=config_yaml,
num_workers=num_workers,
distributed=distributed,
)
# batch data preprocessor
BatchDataPreprocessor = get_batch_data_preprocessor_class(
batch_data_preprocessor_type=batch_data_preprocessor_type
)
batch_data_preprocessor = BatchDataPreprocessor(
target_source_types=target_source_types
)
# model
Model = get_model_class(model_type=model_type)
model = Model(input_channels=channels, target_sources_num=target_sources_num)
# loss function
loss_function = get_loss_function(loss_type=loss_type)
# callbacks
callbacks = get_callbacks(
task_name=task_name,
config_yaml=config_yaml,
workspace=workspace,
checkpoints_dir=checkpoints_dir,
statistics_path=statistics_path,
logger=logger,
model=model,
evaluate_device=evaluate_device,
)
# callbacks = []
# learning rate reduce function
lr_lambda = partial(
get_lr_lambda, warm_up_steps=warm_up_steps, reduce_lr_steps=reduce_lr_steps
)
# pytorch-lightning model
pl_model = LitSourceSeparation(
batch_data_preprocessor=batch_data_preprocessor,
model=model,
optimizer_type=optimizer_type,
loss_function=loss_function,
learning_rate=learning_rate,
lr_lambda=lr_lambda,
)
# trainer
trainer = pl.Trainer(
checkpoint_callback=False,
gpus=gpus,
callbacks=callbacks,
max_steps=early_stop_steps,
accelerator="ddp",
sync_batchnorm=True,
precision=precision,
replace_sampler_ddp=False,
plugins=[DDPPlugin(find_unused_parameters=True)],
profiler='simple',
)
# Fit, evaluate, and save checkpoints.
trainer.fit(pl_model, data_module)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="")
subparsers = parser.add_subparsers(dest="mode")
parser_train = subparsers.add_parser("train")
parser_train.add_argument(
"--workspace", type=str, required=True, help="Directory of workspace."
)
parser_train.add_argument("--gpus", type=int, required=True)
parser_train.add_argument(
"--config_yaml",
type=str,
required=True,
help="Path of config file for training.",
)
args = parser.parse_args()
args.filename = pathlib.Path(__file__).stem
if args.mode == "train":
train(args)
else:
raise Exception("Error argument!")