leeyunjai commited on
Commit
7b6efae
1 Parent(s): 06af277

Create backbone.py

Browse files
Files changed (1) hide show
  1. backbone.py +115 -0
backbone.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ from collections import OrderedDict
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import torchvision
7
+ from torch import nn
8
+ from torchvision.models._utils import IntermediateLayerGetter
9
+ from typing import Dict, List
10
+
11
+ from utils import NestedTensor, is_main_process
12
+
13
+ from position_encoding import build_position_encoding
14
+
15
+
16
+ class FrozenBatchNorm2d(torch.nn.Module):
17
+ """
18
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
19
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
20
+ without which any other models than torchvision.models.resnet[18,34,50,101]
21
+ produce nans.
22
+ """
23
+
24
+ def __init__(self, n):
25
+ super(FrozenBatchNorm2d, self).__init__()
26
+ self.register_buffer("weight", torch.ones(n))
27
+ self.register_buffer("bias", torch.zeros(n))
28
+ self.register_buffer("running_mean", torch.zeros(n))
29
+ self.register_buffer("running_var", torch.ones(n))
30
+
31
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
32
+ missing_keys, unexpected_keys, error_msgs):
33
+ num_batches_tracked_key = prefix + 'num_batches_tracked'
34
+ if num_batches_tracked_key in state_dict:
35
+ del state_dict[num_batches_tracked_key]
36
+
37
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
38
+ state_dict, prefix, local_metadata, strict,
39
+ missing_keys, unexpected_keys, error_msgs)
40
+
41
+ def forward(self, x):
42
+ # move reshapes to the beginning
43
+ # to make it fuser-friendly
44
+ w = self.weight.reshape(1, -1, 1, 1)
45
+ b = self.bias.reshape(1, -1, 1, 1)
46
+ rv = self.running_var.reshape(1, -1, 1, 1)
47
+ rm = self.running_mean.reshape(1, -1, 1, 1)
48
+ eps = 1e-5
49
+ scale = w * (rv + eps).rsqrt()
50
+ bias = b - rm * scale
51
+ return x * scale + bias
52
+
53
+
54
+ class BackboneBase(nn.Module):
55
+
56
+ def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
57
+ super().__init__()
58
+ for name, parameter in backbone.named_parameters():
59
+ if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
60
+ parameter.requires_grad_(False)
61
+ if return_interm_layers:
62
+ return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
63
+ else:
64
+ return_layers = {'layer4': "0"}
65
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
66
+ self.num_channels = num_channels
67
+
68
+ def forward(self, tensor_list: NestedTensor):
69
+ xs = self.body(tensor_list.tensors)
70
+ out: Dict[str, NestedTensor] = {}
71
+ for name, x in xs.items():
72
+ m = tensor_list.mask
73
+ assert m is not None
74
+ mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
75
+ out[name] = NestedTensor(x, mask)
76
+ return out
77
+
78
+
79
+ class Backbone(BackboneBase):
80
+ """ResNet backbone with frozen BatchNorm."""
81
+ def __init__(self, name: str,
82
+ train_backbone: bool,
83
+ return_interm_layers: bool,
84
+ dilation: bool):
85
+ backbone = getattr(torchvision.models, name)(
86
+ replace_stride_with_dilation=[False, False, dilation],
87
+ pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
88
+ num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
89
+ super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
90
+
91
+
92
+ class Joiner(nn.Sequential):
93
+ def __init__(self, backbone, position_embedding):
94
+ super().__init__(backbone, position_embedding)
95
+
96
+ def forward(self, tensor_list: NestedTensor):
97
+ xs = self[0](tensor_list)
98
+ out: List[NestedTensor] = []
99
+ pos = []
100
+ for name, x in xs.items():
101
+ out.append(x)
102
+ # position encoding
103
+ pos.append(self[1](x).to(x.tensors.dtype))
104
+
105
+ return out, pos
106
+
107
+
108
+ def build_backbone(config):
109
+ position_embedding = build_position_encoding(config)
110
+ train_backbone = config.lr_backbone > 0
111
+ return_interm_layers = False
112
+ backbone = Backbone(config.backbone, train_backbone, return_interm_layers, config.dilation)
113
+ model = Joiner(backbone, position_embedding)
114
+ model.num_channels = backbone.num_channels
115
+ return model