Spaces:
Paused
Paused
import logging | |
import warnings | |
from abc import ABCMeta, abstractmethod | |
from collections import OrderedDict | |
import annotator.uniformer.mmcv as mmcv | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
import torch.nn as nn | |
from annotator.uniformer.mmcv.runner import auto_fp16 | |
class BaseSegmentor(nn.Module): | |
"""Base class for segmentors.""" | |
__metaclass__ = ABCMeta | |
def __init__(self): | |
super(BaseSegmentor, self).__init__() | |
self.fp16_enabled = False | |
def with_neck(self): | |
"""bool: whether the segmentor has neck""" | |
return hasattr(self, 'neck') and self.neck is not None | |
def with_auxiliary_head(self): | |
"""bool: whether the segmentor has auxiliary head""" | |
return hasattr(self, | |
'auxiliary_head') and self.auxiliary_head is not None | |
def with_decode_head(self): | |
"""bool: whether the segmentor has decode head""" | |
return hasattr(self, 'decode_head') and self.decode_head is not None | |
def extract_feat(self, imgs): | |
"""Placeholder for extract features from images.""" | |
pass | |
def encode_decode(self, img, img_metas): | |
"""Placeholder for encode images with backbone and decode into a | |
semantic segmentation map of the same size as input.""" | |
pass | |
def forward_train(self, imgs, img_metas, **kwargs): | |
"""Placeholder for Forward function for training.""" | |
pass | |
def simple_test(self, img, img_meta, **kwargs): | |
"""Placeholder for single image test.""" | |
pass | |
def aug_test(self, imgs, img_metas, **kwargs): | |
"""Placeholder for augmentation test.""" | |
pass | |
def init_weights(self, pretrained=None): | |
"""Initialize the weights in segmentor. | |
Args: | |
pretrained (str, optional): Path to pre-trained weights. | |
Defaults to None. | |
""" | |
if pretrained is not None: | |
logger = logging.getLogger() | |
logger.info(f'load model from: {pretrained}') | |
def forward_test(self, imgs, img_metas, **kwargs): | |
""" | |
Args: | |
imgs (List[Tensor]): the outer list indicates test-time | |
augmentations and inner Tensor should have a shape NxCxHxW, | |
which contains all images in the batch. | |
img_metas (List[List[dict]]): the outer list indicates test-time | |
augs (multiscale, flip, etc.) and the inner list indicates | |
images in a batch. | |
""" | |
for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]: | |
if not isinstance(var, list): | |
raise TypeError(f'{name} must be a list, but got ' | |
f'{type(var)}') | |
num_augs = len(imgs) | |
if num_augs != len(img_metas): | |
raise ValueError(f'num of augmentations ({len(imgs)}) != ' | |
f'num of image meta ({len(img_metas)})') | |
# all images in the same aug batch all of the same ori_shape and pad | |
# shape | |
for img_meta in img_metas: | |
ori_shapes = [_['ori_shape'] for _ in img_meta] | |
assert all(shape == ori_shapes[0] for shape in ori_shapes) | |
img_shapes = [_['img_shape'] for _ in img_meta] | |
assert all(shape == img_shapes[0] for shape in img_shapes) | |
pad_shapes = [_['pad_shape'] for _ in img_meta] | |
assert all(shape == pad_shapes[0] for shape in pad_shapes) | |
if num_augs == 1: | |
return self.simple_test(imgs[0], img_metas[0], **kwargs) | |
else: | |
return self.aug_test(imgs, img_metas, **kwargs) | |
def forward(self, img, img_metas, return_loss=True, **kwargs): | |
"""Calls either :func:`forward_train` or :func:`forward_test` depending | |
on whether ``return_loss`` is ``True``. | |
Note this setting will change the expected inputs. When | |
``return_loss=True``, img and img_meta are single-nested (i.e. Tensor | |
and List[dict]), and when ``resturn_loss=False``, img and img_meta | |
should be double nested (i.e. List[Tensor], List[List[dict]]), with | |
the outer list indicating test time augmentations. | |
""" | |
if return_loss: | |
return self.forward_train(img, img_metas, **kwargs) | |
else: | |
return self.forward_test(img, img_metas, **kwargs) | |
def train_step(self, data_batch, optimizer, **kwargs): | |
"""The iteration step during training. | |
This method defines an iteration step during training, except for the | |
back propagation and optimizer updating, which are done in an optimizer | |
hook. Note that in some complicated cases or models, the whole process | |
including back propagation and optimizer updating is also defined in | |
this method, such as GAN. | |
Args: | |
data (dict): The output of dataloader. | |
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of | |
runner is passed to ``train_step()``. This argument is unused | |
and reserved. | |
Returns: | |
dict: It should contain at least 3 keys: ``loss``, ``log_vars``, | |
``num_samples``. | |
``loss`` is a tensor for back propagation, which can be a | |
weighted sum of multiple losses. | |
``log_vars`` contains all the variables to be sent to the | |
logger. | |
``num_samples`` indicates the batch size (when the model is | |
DDP, it means the batch size on each GPU), which is used for | |
averaging the logs. | |
""" | |
losses = self(**data_batch) | |
loss, log_vars = self._parse_losses(losses) | |
outputs = dict( | |
loss=loss, | |
log_vars=log_vars, | |
num_samples=len(data_batch['img_metas'])) | |
return outputs | |
def val_step(self, data_batch, **kwargs): | |
"""The iteration step during validation. | |
This method shares the same signature as :func:`train_step`, but used | |
during val epochs. Note that the evaluation after training epochs is | |
not implemented with this method, but an evaluation hook. | |
""" | |
output = self(**data_batch, **kwargs) | |
return output | |
def _parse_losses(losses): | |
"""Parse the raw outputs (losses) of the network. | |
Args: | |
losses (dict): Raw output of the network, which usually contain | |
losses and other necessary information. | |
Returns: | |
tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor | |
which may be a weighted sum of all losses, log_vars contains | |
all the variables to be sent to the logger. | |
""" | |
log_vars = OrderedDict() | |
for loss_name, loss_value in losses.items(): | |
if isinstance(loss_value, torch.Tensor): | |
log_vars[loss_name] = loss_value.mean() | |
elif isinstance(loss_value, list): | |
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) | |
else: | |
raise TypeError( | |
f'{loss_name} is not a tensor or list of tensors') | |
loss = sum(_value for _key, _value in log_vars.items() | |
if 'loss' in _key) | |
log_vars['loss'] = loss | |
for loss_name, loss_value in log_vars.items(): | |
# reduce loss when distributed training | |
if dist.is_available() and dist.is_initialized(): | |
loss_value = loss_value.data.clone() | |
dist.all_reduce(loss_value.div_(dist.get_world_size())) | |
log_vars[loss_name] = loss_value.item() | |
return loss, log_vars | |
def show_result(self, | |
img, | |
result, | |
palette=None, | |
win_name='', | |
show=False, | |
wait_time=0, | |
out_file=None, | |
opacity=0.5): | |
"""Draw `result` over `img`. | |
Args: | |
img (str or Tensor): The image to be displayed. | |
result (Tensor): The semantic segmentation results to draw over | |
`img`. | |
palette (list[list[int]]] | np.ndarray | None): The palette of | |
segmentation map. If None is given, random palette will be | |
generated. Default: None | |
win_name (str): The window name. | |
wait_time (int): Value of waitKey param. | |
Default: 0. | |
show (bool): Whether to show the image. | |
Default: False. | |
out_file (str or None): The filename to write the image. | |
Default: None. | |
opacity(float): Opacity of painted segmentation map. | |
Default 0.5. | |
Must be in (0, 1] range. | |
Returns: | |
img (Tensor): Only if not `show` or `out_file` | |
""" | |
img = mmcv.imread(img) | |
img = img.copy() | |
seg = result[0] | |
if palette is None: | |
if self.PALETTE is None: | |
palette = np.random.randint( | |
0, 255, size=(len(self.CLASSES), 3)) | |
else: | |
palette = self.PALETTE | |
palette = np.array(palette) | |
assert palette.shape[0] == len(self.CLASSES) | |
assert palette.shape[1] == 3 | |
assert len(palette.shape) == 2 | |
assert 0 < opacity <= 1.0 | |
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) | |
for label, color in enumerate(palette): | |
color_seg[seg == label, :] = color | |
# convert to BGR | |
color_seg = color_seg[..., ::-1] | |
img = img * (1 - opacity) + color_seg * opacity | |
img = img.astype(np.uint8) | |
# if out_file specified, do not show image in window | |
if out_file is not None: | |
show = False | |
if show: | |
mmcv.imshow(img, win_name, wait_time) | |
if out_file is not None: | |
mmcv.imwrite(img, out_file) | |
if not (show or out_file): | |
warnings.warn('show==False and out_file is not specified, only ' | |
'result image will be returned') | |
return img | |