from abc import ABCMeta, abstractmethod from .decode_head import BaseDecodeHead class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta): """Base class for cascade decode head used in :class:`CascadeEncoderDecoder.""" def __init__(self, *args, **kwargs): super(BaseCascadeDecodeHead, self).__init__(*args, **kwargs) @abstractmethod def forward(self, inputs, prev_output): """Placeholder of forward function.""" pass def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg, train_cfg): """Forward function for training. Args: inputs (list[Tensor]): List of multi-level img features. prev_output (Tensor): The output of previous decode head. img_metas (list[dict]): List of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmseg/datasets/pipelines/formatting.py:Collect`. gt_semantic_seg (Tensor): Semantic segmentation masks used if the architecture supports semantic segmentation task. train_cfg (dict): The training config. Returns: dict[str, Tensor]: a dictionary of loss components """ seg_logits = self.forward(inputs, prev_output) losses = self.losses(seg_logits, gt_semantic_seg) return losses def forward_test(self, inputs, prev_output, img_metas, test_cfg): """Forward function for testing. Args: inputs (list[Tensor]): List of multi-level img features. prev_output (Tensor): The output of previous decode head. img_metas (list[dict]): List of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmseg/datasets/pipelines/formatting.py:Collect`. test_cfg (dict): The testing config. Returns: Tensor: Output segmentation map. """ return self.forward(inputs, prev_output)