Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from models.core.custom_hooks.shuffle_hooks import ShufflePairedSamplesHook | |
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel | |
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook, | |
build_optimizer) | |
from mmpose.core import DistEvalHook, EvalHook, Fp16OptimizerHook | |
from mmpose.datasets import build_dataloader | |
from mmpose.utils import get_root_logger | |
def train_model(model, | |
dataset, | |
val_dataset, | |
cfg, | |
distributed=False, | |
validate=False, | |
timestamp=None, | |
meta=None): | |
"""Train model entry function. | |
Args: | |
model (nn.Module): The model to be trained. | |
dataset (Dataset): Train dataset. | |
cfg (dict): The config dict for training. | |
distributed (bool): Whether to use distributed training. | |
Default: False. | |
validate (bool): Whether to do evaluation. Default: False. | |
timestamp (str | None): Local time for runner. Default: None. | |
meta (dict | None): Meta dict to record some important information. | |
Default: None | |
""" | |
logger = get_root_logger(cfg.log_level) | |
# prepare data loaders | |
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] | |
dataloader_setting = dict( | |
samples_per_gpu=cfg.data.get('samples_per_gpu', {}), | |
workers_per_gpu=cfg.data.get('workers_per_gpu', {}), | |
# cfg.gpus will be ignored if distributed | |
num_gpus=len(cfg.gpu_ids), | |
dist=distributed, | |
seed=cfg.seed, | |
pin_memory=False, | |
) | |
dataloader_setting = dict(dataloader_setting, | |
**cfg.data.get('train_dataloader', {})) | |
data_loaders = [ | |
build_dataloader(ds, **dataloader_setting) for ds in dataset | |
] | |
# put model on gpus | |
if distributed: | |
find_unused_parameters = cfg.get('find_unused_parameters', | |
False) # NOTE: True has been modified to False for faster training. | |
# Sets the `find_unused_parameters` parameter in | |
# torch.nn.parallel.DistributedDataParallel | |
model = MMDistributedDataParallel( | |
model.cuda(), | |
device_ids=[torch.cuda.current_device()], | |
broadcast_buffers=False, | |
find_unused_parameters=find_unused_parameters) | |
else: | |
model = MMDataParallel( | |
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) | |
# build runner | |
optimizer = build_optimizer(model, cfg.optimizer) | |
runner = EpochBasedRunner( | |
model, | |
optimizer=optimizer, | |
work_dir=cfg.work_dir, | |
logger=logger, | |
meta=meta) | |
# an ugly workaround to make .log and .log.json filenames the same | |
runner.timestamp = timestamp | |
# fp16 setting | |
fp16_cfg = cfg.get('fp16', None) | |
if fp16_cfg is not None: | |
optimizer_config = Fp16OptimizerHook( | |
**cfg.optimizer_config, **fp16_cfg, distributed=distributed) | |
elif distributed and 'type' not in cfg.optimizer_config: | |
optimizer_config = OptimizerHook(**cfg.optimizer_config) | |
else: | |
optimizer_config = cfg.optimizer_config | |
# register hooks | |
runner.register_training_hooks(cfg.lr_config, optimizer_config, | |
cfg.checkpoint_config, cfg.log_config, | |
cfg.get('momentum_config', None)) | |
if distributed: | |
runner.register_hook(DistSamplerSeedHook()) | |
shuffle_cfg = cfg.get('shuffle_cfg', None) | |
if shuffle_cfg is not None: | |
for data_loader in data_loaders: | |
runner.register_hook(ShufflePairedSamplesHook(data_loader, **shuffle_cfg)) | |
# register eval hooks | |
if validate: | |
eval_cfg = cfg.get('evaluation', {}) | |
eval_cfg['res_folder'] = os.path.join(cfg.work_dir, eval_cfg['res_folder']) | |
dataloader_setting = dict( | |
# samples_per_gpu=cfg.data.get('samples_per_gpu', {}), | |
samples_per_gpu=1, | |
workers_per_gpu=cfg.data.get('workers_per_gpu', {}), | |
# cfg.gpus will be ignored if distributed | |
num_gpus=len(cfg.gpu_ids), | |
dist=distributed, | |
shuffle=False, | |
pin_memory=False, | |
) | |
dataloader_setting = dict(dataloader_setting, | |
**cfg.data.get('val_dataloader', {})) | |
val_dataloader = build_dataloader(val_dataset, **dataloader_setting) | |
eval_hook = DistEvalHook if distributed else EvalHook | |
runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) | |
if cfg.resume_from: | |
runner.resume(cfg.resume_from) | |
elif cfg.load_from: | |
runner.load_checkpoint(cfg.load_from) | |
runner.run(data_loaders, cfg.workflow, cfg.total_epochs) | |