|
import timm |
|
import torch.nn as nn |
|
|
|
|
|
class TimmUniversalEncoder(nn.Module): |
|
def __init__(self, name, pretrained=True, in_channels=3, depth=5, output_stride=32): |
|
super().__init__() |
|
kwargs = dict( |
|
in_chans=in_channels, |
|
features_only=True, |
|
output_stride=output_stride, |
|
pretrained=pretrained, |
|
out_indices=tuple(range(depth)), |
|
) |
|
|
|
|
|
if output_stride == 32: |
|
kwargs.pop("output_stride") |
|
|
|
self.model = timm.create_model(name, **kwargs) |
|
|
|
self._in_channels = in_channels |
|
self._out_channels = [in_channels,] + self.model.feature_info.channels() |
|
self._depth = depth |
|
self._output_stride = output_stride |
|
|
|
def forward(self, x): |
|
features = self.model(x) |
|
features = [x,] + features |
|
return features |
|
|
|
@property |
|
def out_channels(self): |
|
return self._out_channels |
|
|
|
@property |
|
def output_stride(self): |
|
return min(self._output_stride, 2 ** self._depth) |
|
|