# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import List, Optional import torch try: import mmpretrain from mmpretrain.models.classifiers import ImageClassifier except ImportError: mmpretrain = None ImageClassifier = object from mmdet.registry import MODELS from mmdet.structures import ReIDDataSample @MODELS.register_module() class BaseReID(ImageClassifier): """Base model for re-identification.""" def __init__(self, *args, **kwargs): if mmpretrain is None: raise RuntimeError('Please run "pip install openmim" and ' 'run "mim install mmpretrain" to ' 'install mmpretrain first.') super().__init__(*args, **kwargs) def forward(self, inputs: torch.Tensor, data_samples: Optional[List[ReIDDataSample]] = None, mode: str = 'tensor'): """The unified entry for a forward process in both training and test. The method should accept three modes: "tensor", "predict" and "loss": - "tensor": Forward the whole network and return tensor or tuple of tensor without any post-processing, same as a common nn.Module. - "predict": Forward and return the predictions, which are fully processed to a list of :obj:`ReIDDataSample`. - "loss": Forward and return a dict of losses according to the given inputs and data samples. Note that this method doesn't handle neither back propagation nor optimizer updating, which are done in the :meth:`train_step`. Args: inputs (torch.Tensor): The input tensor with shape (N, C, H, W) or (N, T, C, H, W). data_samples (List[ReIDDataSample], optional): The annotation data of every sample. It's required if ``mode="loss"``. Defaults to None. mode (str): Return what kind of value. Defaults to 'tensor'. Returns: The return type depends on ``mode``. - If ``mode="tensor"``, return a tensor or a tuple of tensor. - If ``mode="predict"``, return a list of :obj:`ReIDDataSample`. - If ``mode="loss"``, return a dict of tensor. """ if len(inputs.size()) == 5: assert inputs.size(0) == 1 inputs = inputs[0] return super().forward(inputs, data_samples, mode)