Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from abc import abstractmethod | |
from collections import OrderedDict | |
from typing import Dict, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
from mmengine.optim import OptimWrapper | |
from mmengine.registry import MODELS | |
from mmengine.utils import is_list_of | |
from ..base_module import BaseModule | |
from .data_preprocessor import BaseDataPreprocessor | |
class BaseModel(BaseModule): | |
"""Base class for all algorithmic models. | |
BaseModel implements the basic functions of the algorithmic model, such as | |
weights initialize, batch inputs preprocess(see more information in | |
:class:`BaseDataPreprocessor`), parse losses, and update model parameters. | |
Subclasses inherit from BaseModel only need to implement the forward | |
method, which implements the logic to calculate loss and predictions, | |
then can be trained in the runner. | |
Examples: | |
>>> @MODELS.register_module() | |
>>> class ToyModel(BaseModel): | |
>>> | |
>>> def __init__(self): | |
>>> super().__init__() | |
>>> self.backbone = nn.Sequential() | |
>>> self.backbone.add_module('conv1', nn.Conv2d(3, 6, 5)) | |
>>> self.backbone.add_module('pool', nn.MaxPool2d(2, 2)) | |
>>> self.backbone.add_module('conv2', nn.Conv2d(6, 16, 5)) | |
>>> self.backbone.add_module('fc1', nn.Linear(16 * 5 * 5, 120)) | |
>>> self.backbone.add_module('fc2', nn.Linear(120, 84)) | |
>>> self.backbone.add_module('fc3', nn.Linear(84, 10)) | |
>>> | |
>>> self.criterion = nn.CrossEntropyLoss() | |
>>> | |
>>> def forward(self, batch_inputs, data_samples, mode='tensor'): | |
>>> data_samples = torch.stack(data_samples) | |
>>> if mode == 'tensor': | |
>>> return self.backbone(batch_inputs) | |
>>> elif mode == 'predict': | |
>>> feats = self.backbone(batch_inputs) | |
>>> predictions = torch.argmax(feats, 1) | |
>>> return predictions | |
>>> elif mode == 'loss': | |
>>> feats = self.backbone(batch_inputs) | |
>>> loss = self.criterion(feats, data_samples) | |
>>> return dict(loss=loss) | |
Args: | |
data_preprocessor (dict, optional): The pre-process config of | |
:class:`BaseDataPreprocessor`. | |
init_cfg (dict, optional): The weight initialized config for | |
:class:`BaseModule`. | |
Attributes: | |
data_preprocessor (:obj:`BaseDataPreprocessor`): Used for | |
pre-processing data sampled by dataloader to the format accepted by | |
:meth:`forward`. | |
init_cfg (dict, optional): Initialization config dict. | |
""" | |
def __init__(self, | |
data_preprocessor: Optional[Union[dict, nn.Module]] = None, | |
init_cfg: Optional[dict] = None): | |
super().__init__(init_cfg) | |
if data_preprocessor is None: | |
data_preprocessor = dict(type='BaseDataPreprocessor') | |
if isinstance(data_preprocessor, nn.Module): | |
self.data_preprocessor = data_preprocessor | |
elif isinstance(data_preprocessor, dict): | |
self.data_preprocessor = MODELS.build(data_preprocessor) | |
else: | |
raise TypeError('data_preprocessor should be a `dict` or ' | |
f'`nn.Module` instance, but got ' | |
f'{type(data_preprocessor)}') | |
def train_step(self, data: Union[dict, tuple, list], | |
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: | |
"""Implements the default model training process including | |
preprocessing, model forward propagation, loss calculation, | |
optimization, and back-propagation. | |
During non-distributed training. If subclasses do not override the | |
:meth:`train_step`, :class:`EpochBasedTrainLoop` or | |
:class:`IterBasedTrainLoop` will call this method to update model | |
parameters. The default parameter update process is as follows: | |
1. Calls ``self.data_processor(data, training=False)`` to collect | |
batch_inputs and corresponding data_samples(labels). | |
2. Calls ``self(batch_inputs, data_samples, mode='loss')`` to get raw | |
loss | |
3. Calls ``self.parse_losses`` to get ``parsed_losses`` tensor used to | |
backward and dict of loss tensor used to log messages. | |
4. Calls ``optim_wrapper.update_params(loss)`` to update model. | |
Args: | |
data (dict or tuple or list): Data sampled from dataset. | |
optim_wrapper (OptimWrapper): OptimWrapper instance | |
used to update model parameters. | |
Returns: | |
Dict[str, torch.Tensor]: A ``dict`` of tensor for logging. | |
""" | |
# Enable automatic mixed precision training context. | |
with optim_wrapper.optim_context(self): | |
data = self.data_preprocessor(data, True) | |
losses = self._run_forward(data, mode='loss') # type: ignore | |
parsed_losses, log_vars = self.parse_losses(losses) # type: ignore | |
optim_wrapper.update_params(parsed_losses) | |
return log_vars | |
def val_step(self, data: Union[tuple, dict, list]) -> list: | |
"""Gets the predictions of given data. | |
Calls ``self.data_preprocessor(data, False)`` and | |
``self(inputs, data_sample, mode='predict')`` in order. Return the | |
predictions which will be passed to evaluator. | |
Args: | |
data (dict or tuple or list): Data sampled from dataset. | |
Returns: | |
list: The predictions of given data. | |
""" | |
data = self.data_preprocessor(data, False) | |
return self._run_forward(data, mode='predict') # type: ignore | |
def test_step(self, data: Union[dict, tuple, list]) -> list: | |
"""``BaseModel`` implements ``test_step`` the same as ``val_step``. | |
Args: | |
data (dict or tuple or list): Data sampled from dataset. | |
Returns: | |
list: The predictions of given data. | |
""" | |
data = self.data_preprocessor(data, False) | |
return self._run_forward(data, mode='predict') # type: ignore | |
def parse_losses( | |
self, losses: Dict[str, torch.Tensor] | |
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |
"""Parses the raw outputs (losses) of the network. | |
Args: | |
losses (dict): Raw output of the network, which usually contain | |
losses and other necessary information. | |
Returns: | |
tuple[Tensor, dict]: There are two elements. The first is the | |
loss tensor passed to optim_wrapper which may be a weighted sum | |
of all losses, and the second is log_vars which will be sent to | |
the logger. | |
""" | |
log_vars = [] | |
for loss_name, loss_value in losses.items(): | |
if isinstance(loss_value, torch.Tensor): | |
log_vars.append([loss_name, loss_value.mean()]) | |
elif is_list_of(loss_value, torch.Tensor): | |
log_vars.append( | |
[loss_name, | |
sum(_loss.mean() for _loss in loss_value)]) | |
else: | |
raise TypeError( | |
f'{loss_name} is not a tensor or list of tensors') | |
loss = sum(value for key, value in log_vars if 'loss' in key) | |
log_vars.insert(0, ['loss', loss]) | |
log_vars = OrderedDict(log_vars) # type: ignore | |
return loss, log_vars # type: ignore | |
def to(self, *args, **kwargs) -> nn.Module: | |
"""Overrides this method to call :meth:`BaseDataPreprocessor.to` | |
additionally. | |
Returns: | |
nn.Module: The model itself. | |
""" | |
# Since Torch has not officially merged | |
# the npu-related fields, using the _parse_to function | |
# directly will cause the NPU to not be found. | |
# Here, the input parameters are processed to avoid errors. | |
if args and isinstance(args[0], str) and 'npu' in args[0]: | |
import torch_npu | |
args = tuple([ | |
list(args)[0].replace( | |
'npu', torch_npu.npu.native_device if hasattr( | |
torch_npu.npu, 'native_device') else 'privateuseone') | |
]) | |
if kwargs and 'npu' in str(kwargs.get('device', '')): | |
import torch_npu | |
kwargs['device'] = kwargs['device'].replace( | |
'npu', torch_npu.npu.native_device if hasattr( | |
torch_npu.npu, 'native_device') else 'privateuseone') | |
device = torch._C._nn._parse_to(*args, **kwargs)[0] | |
if device is not None: | |
self._set_device(torch.device(device)) | |
return super().to(*args, **kwargs) | |
def cuda( | |
self, | |
device: Optional[Union[int, str, torch.device]] = None, | |
) -> nn.Module: | |
"""Overrides this method to call :meth:`BaseDataPreprocessor.cuda` | |
additionally. | |
Returns: | |
nn.Module: The model itself. | |
""" | |
if device is None or isinstance(device, int): | |
device = torch.device('cuda', index=device) | |
self._set_device(torch.device(device)) | |
return super().cuda(device) | |
def mlu( | |
self, | |
device: Union[int, str, torch.device, None] = None, | |
) -> nn.Module: | |
"""Overrides this method to call :meth:`BaseDataPreprocessor.mlu` | |
additionally. | |
Returns: | |
nn.Module: The model itself. | |
""" | |
device = torch.device('mlu', torch.mlu.current_device()) | |
self._set_device(device) | |
return super().mlu() | |
def npu( | |
self, | |
device: Union[int, str, torch.device, None] = None, | |
) -> nn.Module: | |
"""Overrides this method to call :meth:`BaseDataPreprocessor.npu` | |
additionally. | |
Returns: | |
nn.Module: The model itself. | |
Note: | |
This generation of NPU(Ascend910) does not support | |
the use of multiple cards in a single process, | |
so the index here needs to be consistent with the default device | |
""" | |
device = torch.npu.current_device() | |
self._set_device(device) | |
return super().npu() | |
def cpu(self, *args, **kwargs) -> nn.Module: | |
"""Overrides this method to call :meth:`BaseDataPreprocessor.cpu` | |
additionally. | |
Returns: | |
nn.Module: The model itself. | |
""" | |
self._set_device(torch.device('cpu')) | |
return super().cpu() | |
def _set_device(self, device: torch.device) -> None: | |
"""Recursively set device for `BaseDataPreprocessor` instance. | |
Args: | |
device (torch.device): the desired device of the parameters and | |
buffers in this module. | |
""" | |
def apply_fn(module): | |
if not isinstance(module, BaseDataPreprocessor): | |
return | |
if device is not None: | |
module._device = device | |
self.apply(apply_fn) | |
def forward(self, | |
inputs: torch.Tensor, | |
data_samples: Optional[list] = None, | |
mode: str = 'tensor') -> Union[Dict[str, torch.Tensor], list]: | |
"""Returns losses or predictions of training, validation, testing, and | |
simple inference process. | |
``forward`` method of BaseModel is an abstract method, its subclasses | |
must implement this method. | |
Accepts ``batch_inputs`` and ``data_sample`` processed by | |
:attr:`data_preprocessor`, and returns results according to mode | |
arguments. | |
During non-distributed training, validation, and testing process, | |
``forward`` will be called by ``BaseModel.train_step``, | |
``BaseModel.val_step`` and ``BaseModel.test_step`` directly. | |
During distributed data parallel training process, | |
``MMSeparateDistributedDataParallel.train_step`` will first call | |
``DistributedDataParallel.forward`` to enable automatic | |
gradient synchronization, and then call ``forward`` to get training | |
loss. | |
Args: | |
inputs (torch.Tensor): batch input tensor collated by | |
:attr:`data_preprocessor`. | |
data_samples (list, optional): | |
data samples collated by :attr:`data_preprocessor`. | |
mode (str): mode should be one of ``loss``, ``predict`` and | |
``tensor`` | |
- ``loss``: Called by ``train_step`` and return loss ``dict`` | |
used for logging | |
- ``predict``: Called by ``val_step`` and ``test_step`` | |
and return list of results used for computing metric. | |
- ``tensor``: Called by custom use to get ``Tensor`` type | |
results. | |
Returns: | |
dict or list: | |
- If ``mode == loss``, return a ``dict`` of loss tensor used | |
for backward and logging. | |
- If ``mode == predict``, return a ``list`` of inference | |
results. | |
- If ``mode == tensor``, return a tensor or ``tuple`` of tensor | |
or ``dict`` of tensor for custom use. | |
""" | |
def _run_forward(self, data: Union[dict, tuple, list], | |
mode: str) -> Union[Dict[str, torch.Tensor], list]: | |
"""Unpacks data for :meth:`forward` | |
Args: | |
data (dict or tuple or list): Data sampled from dataset. | |
mode (str): Mode of forward. | |
Returns: | |
dict or list: Results of training or testing mode. | |
""" | |
if isinstance(data, dict): | |
results = self(**data, mode=mode) | |
elif isinstance(data, (list, tuple)): | |
results = self(*data, mode=mode) | |
else: | |
raise TypeError('Output of `data_preprocessor` should be ' | |
f'list, tuple or dict, but got {type(data)}') | |
return results | |