Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from typing import List, Optional | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class ImageEncoder(nn.Module): | |
def __init__( | |
self, | |
trunk: nn.Module, | |
neck: nn.Module, | |
scalp: int = 0, | |
): | |
super().__init__() | |
self.trunk = trunk | |
self.neck = neck | |
self.scalp = scalp | |
assert ( | |
self.trunk.channel_list == self.neck.backbone_channel_list | |
), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" | |
def forward(self, sample: torch.Tensor): | |
# Forward through backbone | |
features, pos = self.neck(self.trunk(sample)) | |
if self.scalp > 0: | |
# Discard the lowest resolution features | |
features, pos = features[: -self.scalp], pos[: -self.scalp] | |
src = features[-1] | |
output = { | |
"vision_features": src, | |
"vision_pos_enc": pos, | |
"backbone_fpn": features, | |
} | |
return output | |
class FpnNeck(nn.Module): | |
""" | |
A modified variant of Feature Pyramid Network (FPN) neck | |
(we remove output conv and also do bicubic interpolation similar to ViT | |
pos embed interpolation) | |
""" | |
def __init__( | |
self, | |
position_encoding: nn.Module, | |
d_model: int, | |
backbone_channel_list: List[int], | |
kernel_size: int = 1, | |
stride: int = 1, | |
padding: int = 0, | |
fpn_interp_model: str = "bilinear", | |
fuse_type: str = "sum", | |
fpn_top_down_levels: Optional[List[int]] = None, | |
): | |
"""Initialize the neck | |
:param trunk: the backbone | |
:param position_encoding: the positional encoding to use | |
:param d_model: the dimension of the model | |
:param neck_norm: the normalization to use | |
""" | |
super().__init__() | |
self.position_encoding = position_encoding | |
self.convs = nn.ModuleList() | |
self.backbone_channel_list = backbone_channel_list | |
for dim in backbone_channel_list: | |
current = nn.Sequential() | |
current.add_module( | |
"conv", | |
nn.Conv2d( | |
in_channels=dim, | |
out_channels=d_model, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
), | |
) | |
self.convs.append(current) | |
self.fpn_interp_model = fpn_interp_model | |
assert fuse_type in ["sum", "avg"] | |
self.fuse_type = fuse_type | |
# levels to have top-down features in its outputs | |
# e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 | |
# have top-down propagation, while outputs of level 0 and level 1 have only | |
# lateral features from the same backbone level. | |
if fpn_top_down_levels is None: | |
# default is to have top-down features on all levels | |
fpn_top_down_levels = range(len(self.convs)) | |
self.fpn_top_down_levels = list(fpn_top_down_levels) | |
def forward(self, xs: List[torch.Tensor]): | |
out = [None] * len(self.convs) | |
pos = [None] * len(self.convs) | |
assert len(xs) == len(self.convs) | |
# fpn forward pass | |
# see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py | |
prev_features = None | |
# forward in top-down order (from low to high resolution) | |
n = len(self.convs) - 1 | |
for i in range(n, -1, -1): | |
x = xs[i] | |
lateral_features = self.convs[n - i](x) | |
if i in self.fpn_top_down_levels and prev_features is not None: | |
top_down_features = F.interpolate( | |
prev_features.to(dtype=torch.float32), | |
scale_factor=2.0, | |
mode=self.fpn_interp_model, | |
align_corners=( | |
None if self.fpn_interp_model == "nearest" else False | |
), | |
antialias=False, | |
) | |
prev_features = lateral_features + top_down_features | |
if self.fuse_type == "avg": | |
prev_features /= 2 | |
else: | |
prev_features = lateral_features | |
x_out = prev_features | |
out[i] = x_out | |
pos[i] = self.position_encoding(x_out).to(x_out.dtype) | |
return out, pos | |