Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,858 Bytes
28c256d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from mmcv.transforms import Compose
from mmengine.hooks import Hook
from mmdet.registry import HOOKS
@HOOKS.register_module()
class PipelineSwitchHook(Hook):
"""Switch data pipeline at switch_epoch.
Args:
switch_epoch (int): switch pipeline at this epoch.
switch_pipeline (list[dict]): the pipeline to switch to.
"""
def __init__(self, switch_epoch, switch_pipeline):
self.switch_epoch = switch_epoch
self.switch_pipeline = switch_pipeline
self._restart_dataloader = False
self._has_switched = False
def before_train_epoch(self, runner):
"""switch pipeline."""
epoch = runner.epoch
train_loader = runner.train_dataloader
if epoch >= self.switch_epoch and not self._has_switched:
runner.logger.info('Switch pipeline now!')
# The dataset pipeline cannot be updated when persistent_workers
# is True, so we need to force the dataloader's multi-process
# restart. This is a very hacky approach.
train_loader.dataset.pipeline = Compose(self.switch_pipeline)
if hasattr(train_loader, 'persistent_workers'
) and train_loader.persistent_workers is True:
train_loader._DataLoader__initialized = False
train_loader._iterator = None
self._restart_dataloader = True
self._has_switched = True
else:
# Once the restart is complete, we need to restore
# the initialization flag.
if self._restart_dataloader:
train_loader._DataLoader__initialized = True
|