Create backbone.py
Browse files- backbone.py +115 -0
backbone.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
from collections import OrderedDict
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torchvision
|
7 |
+
from torch import nn
|
8 |
+
from torchvision.models._utils import IntermediateLayerGetter
|
9 |
+
from typing import Dict, List
|
10 |
+
|
11 |
+
from utils import NestedTensor, is_main_process
|
12 |
+
|
13 |
+
from position_encoding import build_position_encoding
|
14 |
+
|
15 |
+
|
16 |
+
class FrozenBatchNorm2d(torch.nn.Module):
|
17 |
+
"""
|
18 |
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
19 |
+
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
20 |
+
without which any other models than torchvision.models.resnet[18,34,50,101]
|
21 |
+
produce nans.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, n):
|
25 |
+
super(FrozenBatchNorm2d, self).__init__()
|
26 |
+
self.register_buffer("weight", torch.ones(n))
|
27 |
+
self.register_buffer("bias", torch.zeros(n))
|
28 |
+
self.register_buffer("running_mean", torch.zeros(n))
|
29 |
+
self.register_buffer("running_var", torch.ones(n))
|
30 |
+
|
31 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
32 |
+
missing_keys, unexpected_keys, error_msgs):
|
33 |
+
num_batches_tracked_key = prefix + 'num_batches_tracked'
|
34 |
+
if num_batches_tracked_key in state_dict:
|
35 |
+
del state_dict[num_batches_tracked_key]
|
36 |
+
|
37 |
+
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
38 |
+
state_dict, prefix, local_metadata, strict,
|
39 |
+
missing_keys, unexpected_keys, error_msgs)
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
# move reshapes to the beginning
|
43 |
+
# to make it fuser-friendly
|
44 |
+
w = self.weight.reshape(1, -1, 1, 1)
|
45 |
+
b = self.bias.reshape(1, -1, 1, 1)
|
46 |
+
rv = self.running_var.reshape(1, -1, 1, 1)
|
47 |
+
rm = self.running_mean.reshape(1, -1, 1, 1)
|
48 |
+
eps = 1e-5
|
49 |
+
scale = w * (rv + eps).rsqrt()
|
50 |
+
bias = b - rm * scale
|
51 |
+
return x * scale + bias
|
52 |
+
|
53 |
+
|
54 |
+
class BackboneBase(nn.Module):
|
55 |
+
|
56 |
+
def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
|
57 |
+
super().__init__()
|
58 |
+
for name, parameter in backbone.named_parameters():
|
59 |
+
if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
60 |
+
parameter.requires_grad_(False)
|
61 |
+
if return_interm_layers:
|
62 |
+
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
63 |
+
else:
|
64 |
+
return_layers = {'layer4': "0"}
|
65 |
+
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
66 |
+
self.num_channels = num_channels
|
67 |
+
|
68 |
+
def forward(self, tensor_list: NestedTensor):
|
69 |
+
xs = self.body(tensor_list.tensors)
|
70 |
+
out: Dict[str, NestedTensor] = {}
|
71 |
+
for name, x in xs.items():
|
72 |
+
m = tensor_list.mask
|
73 |
+
assert m is not None
|
74 |
+
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
75 |
+
out[name] = NestedTensor(x, mask)
|
76 |
+
return out
|
77 |
+
|
78 |
+
|
79 |
+
class Backbone(BackboneBase):
|
80 |
+
"""ResNet backbone with frozen BatchNorm."""
|
81 |
+
def __init__(self, name: str,
|
82 |
+
train_backbone: bool,
|
83 |
+
return_interm_layers: bool,
|
84 |
+
dilation: bool):
|
85 |
+
backbone = getattr(torchvision.models, name)(
|
86 |
+
replace_stride_with_dilation=[False, False, dilation],
|
87 |
+
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
|
88 |
+
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
|
89 |
+
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
90 |
+
|
91 |
+
|
92 |
+
class Joiner(nn.Sequential):
|
93 |
+
def __init__(self, backbone, position_embedding):
|
94 |
+
super().__init__(backbone, position_embedding)
|
95 |
+
|
96 |
+
def forward(self, tensor_list: NestedTensor):
|
97 |
+
xs = self[0](tensor_list)
|
98 |
+
out: List[NestedTensor] = []
|
99 |
+
pos = []
|
100 |
+
for name, x in xs.items():
|
101 |
+
out.append(x)
|
102 |
+
# position encoding
|
103 |
+
pos.append(self[1](x).to(x.tensors.dtype))
|
104 |
+
|
105 |
+
return out, pos
|
106 |
+
|
107 |
+
|
108 |
+
def build_backbone(config):
|
109 |
+
position_embedding = build_position_encoding(config)
|
110 |
+
train_backbone = config.lr_backbone > 0
|
111 |
+
return_interm_layers = False
|
112 |
+
backbone = Backbone(config.backbone, train_backbone, return_interm_layers, config.dilation)
|
113 |
+
model = Joiner(backbone, position_embedding)
|
114 |
+
model.num_channels = backbone.num_channels
|
115 |
+
return model
|