import torch import torch.nn as nn from typing import List from collections import OrderedDict from . import _utils as utils class EncoderMixin: """Add encoder functionality such as: - output channels specification of feature tensors (produced by encoder) - patching first convolution for arbitrary input channels """ _output_stride = 32 @property def out_channels(self): """Return channels dimensions for each tensor of forward output of encoder""" return self._out_channels[: self._depth + 1] @property def output_stride(self): return min(self._output_stride, 2 ** self._depth) def set_in_channels(self, in_channels, pretrained=True): """Change first convolution channels""" if in_channels == 3: return self._in_channels = in_channels if self._out_channels[0] == 3: self._out_channels = tuple([in_channels] + list(self._out_channels)[1:]) utils.patch_first_conv( model=self, new_in_channels=in_channels, pretrained=pretrained ) def get_stages(self): """Override it in your implementation""" raise NotImplementedError def make_dilated(self, output_stride): if output_stride == 16: stage_list = [ 5, ] dilation_list = [ 2, ] elif output_stride == 8: stage_list = [4, 5] dilation_list = [2, 4] else: raise ValueError( "Output stride should be 16 or 8, got {}.".format(output_stride) ) self._output_stride = output_stride stages = self.get_stages() for stage_indx, dilation_rate in zip(stage_list, dilation_list): utils.replace_strides_with_dilation( module=stages[stage_indx], dilation_rate=dilation_rate, )