# Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, Optional, Union from mmengine.runner import IterBasedTrainLoop from torch.utils.data import DataLoader class TrainLoop(IterBasedTrainLoop): def __init__(self, runner, dataloader: Union[DataLoader, Dict], max_iters: Optional[int] = None, max_epochs: Union[int, float] = None, **kwargs) -> None: if max_iters is None and max_epochs is None: raise RuntimeError('Please specify the `max_iters` or ' '`max_epochs` in `train_cfg`.') elif max_iters is not None and max_epochs is not None: raise RuntimeError('Only one of `max_iters` or `max_epochs` can ' 'exist in `train_cfg`.') else: if max_iters is not None: iters = int(max_iters) assert iters == max_iters, ('`max_iters` should be a integer ' f'number, but get {max_iters}') elif max_epochs is not None: if isinstance(dataloader, dict): diff_rank_seed = runner._randomness_cfg.get( 'diff_rank_seed', False) dataloader = runner.build_dataloader( dataloader, seed=runner.seed, diff_rank_seed=diff_rank_seed) iters = max_epochs * len(dataloader) else: raise NotImplementedError super().__init__( runner=runner, dataloader=dataloader, max_iters=iters, **kwargs)