Spaces:
Runtime error
Runtime error
""" | |
https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/backbone/resnet.py | |
""" | |
from pathlib import Path | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from climategan.deeplab.mobilenet_v3 import SeparableConv2d | |
from climategan.utils import find_target_size | |
class _DeepLabHead(nn.Module): | |
def __init__( | |
self, nclass, c1_channels=256, c4_channels=2048, norm_layer=nn.BatchNorm2d | |
): | |
super().__init__() | |
last_channels = c4_channels | |
# self.c1_block = _ConvBNReLU(c1_channels, 48, 1, norm_layer=norm_layer) | |
# last_channels += 48 | |
self.block = nn.Sequential( | |
SeparableConv2d( | |
last_channels, 256, 3, norm_layer=norm_layer, relu_first=False | |
), | |
SeparableConv2d(256, 256, 3, norm_layer=norm_layer, relu_first=False), | |
nn.Conv2d(256, nclass, 1), | |
) | |
def forward(self, x, c1=None): | |
return self.block(x) | |
class ConvBNReLU(nn.Module): | |
""" | |
https://github.com/CoinCheung/DeepLab-v3-plus-cityscapes/blob/master/models/deeplabv3plus.py | |
""" | |
def __init__( | |
self, in_chan, out_chan, ks=3, stride=1, padding=1, dilation=1, *args, **kwargs | |
): | |
super().__init__() | |
self.conv = nn.Conv2d( | |
in_chan, | |
out_chan, | |
kernel_size=ks, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
bias=True, | |
) | |
self.bn = nn.BatchNorm2d(out_chan) | |
self.init_weight() | |
def forward(self, x): | |
x = self.conv(x) | |
x = self.bn(x) | |
return x | |
def init_weight(self): | |
for ly in self.children(): | |
if isinstance(ly, nn.Conv2d): | |
nn.init.kaiming_normal_(ly.weight, a=1) | |
if ly.bias is not None: | |
nn.init.constant_(ly.bias, 0) | |
class ASPPv3Plus(nn.Module): | |
""" | |
https://github.com/CoinCheung/DeepLab-v3-plus-cityscapes/blob/master/models/deeplabv3plus.py | |
""" | |
def __init__(self, backbone, no_init): | |
super().__init__() | |
if backbone == "mobilenet": | |
in_chan = 320 | |
else: | |
in_chan = 2048 | |
self.with_gp = False | |
self.conv1 = ConvBNReLU(in_chan, 256, ks=1, dilation=1, padding=0) | |
self.conv2 = ConvBNReLU(in_chan, 256, ks=3, dilation=6, padding=6) | |
self.conv3 = ConvBNReLU(in_chan, 256, ks=3, dilation=12, padding=12) | |
self.conv4 = ConvBNReLU(in_chan, 256, ks=3, dilation=18, padding=18) | |
if self.with_gp: | |
self.avg = nn.AdaptiveAvgPool2d((1, 1)) | |
self.conv1x1 = ConvBNReLU(in_chan, 256, ks=1) | |
self.conv_out = ConvBNReLU(256 * 5, 256, ks=1) | |
else: | |
self.conv_out = ConvBNReLU(256 * 4, 256, ks=1) | |
if not no_init: | |
self.init_weight() | |
def forward(self, x): | |
H, W = x.size()[2:] | |
feat1 = self.conv1(x) | |
feat2 = self.conv2(x) | |
feat3 = self.conv3(x) | |
feat4 = self.conv4(x) | |
if self.with_gp: | |
avg = self.avg(x) | |
feat5 = self.conv1x1(avg) | |
feat5 = F.interpolate(feat5, (H, W), mode="bilinear", align_corners=True) | |
feat = torch.cat([feat1, feat2, feat3, feat4, feat5], 1) | |
else: | |
feat = torch.cat([feat1, feat2, feat3, feat4], 1) | |
feat = self.conv_out(feat) | |
return feat | |
def init_weight(self): | |
for ly in self.children(): | |
if isinstance(ly, nn.Conv2d): | |
nn.init.kaiming_normal_(ly.weight, a=1) | |
if ly.bias is not None: | |
nn.init.constant_(ly.bias, 0) | |
class Decoder(nn.Module): | |
""" | |
https://github.com/CoinCheung/DeepLab-v3-plus-cityscapes/blob/master/models/deeplabv3plus.py | |
""" | |
def __init__(self, n_classes): | |
super(Decoder, self).__init__() | |
self.conv_low = ConvBNReLU(256, 48, ks=1, padding=0) | |
self.conv_cat = nn.Sequential( | |
ConvBNReLU(304, 256, ks=3, padding=1), | |
ConvBNReLU(256, 256, ks=3, padding=1), | |
) | |
self.conv_out = nn.Conv2d(256, n_classes, kernel_size=1, bias=False) | |
def forward(self, feat_low, feat_aspp): | |
H, W = feat_low.size()[2:] | |
feat_low = self.conv_low(feat_low) | |
feat_aspp_up = F.interpolate( | |
feat_aspp, (H, W), mode="bilinear", align_corners=True | |
) | |
feat_cat = torch.cat([feat_low, feat_aspp_up], dim=1) | |
feat_out = self.conv_cat(feat_cat) | |
logits = self.conv_out(feat_out) | |
return logits | |
""" | |
https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/deeplab.py | |
""" | |
class DeepLabV3Decoder(nn.Module): | |
def __init__( | |
self, | |
opts, | |
no_init=False, | |
freeze_bn=False, | |
): | |
super().__init__() | |
num_classes = opts.gen.s.output_dim | |
self.backbone = opts.gen.deeplabv3.backbone | |
self.use_dada = ("d" in opts.tasks) and opts.gen.s.use_dada | |
if self.backbone == "resnet": | |
self.aspp = ASPPv3Plus(self.backbone, no_init) | |
self.decoder = Decoder(num_classes) | |
self.freeze_bn = freeze_bn | |
else: | |
self.head = _DeepLabHead(num_classes, c4_channels=320) | |
self._target_size = find_target_size(opts, "s") | |
print( | |
" - {}: setting target size to {}".format( | |
self.__class__.__name__, self._target_size | |
) | |
) | |
if not no_init: | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
nn.init.kaiming_normal_(m.weight, mode="fan_out") | |
if m.bias is not None: | |
nn.init.zeros_(m.bias) | |
elif isinstance(m, nn.BatchNorm2d): | |
nn.init.ones_(m.weight) | |
nn.init.zeros_(m.bias) | |
elif isinstance(m, nn.Linear): | |
nn.init.normal_(m.weight, 0, 0.01) | |
nn.init.zeros_(m.bias) | |
self.load_pretrained(opts) | |
def load_pretrained(self, opts): | |
assert opts.gen.deeplabv3.backbone in {"resnet", "mobilenet"} | |
assert Path(opts.gen.deeplabv3.pretrained_model.resnet).exists() | |
if opts.gen.deeplabv3.backbone == "resnet": | |
std = torch.load(opts.gen.deeplabv3.pretrained_model.resnet) | |
self.aspp.load_state_dict( | |
{ | |
k.replace("aspp.", ""): v | |
for k, v in std.items() | |
if k.startswith("aspp.") | |
} | |
) | |
self.decoder.load_state_dict( | |
{ | |
k.replace("decoder.", ""): v | |
for k, v in std.items() | |
if k.startswith("decoder.") | |
and not (len(v.shape) > 0 and v.shape[0] == 19) | |
}, | |
strict=False, | |
) | |
print( | |
"- Loaded pre-trained DeepLabv3+ (Resnet) Decoder & ASPP as Seg Decoder" | |
) | |
else: | |
std = torch.load(opts.gen.deeplabv3.pretrained_model.mobilenet) | |
self.load_state_dict( | |
{ | |
k: v | |
for k, v in std.items() | |
if k.startswith("head.") | |
and not (len(v.shape) > 0 and v.shape[0] == 19) | |
}, | |
strict=False, | |
) | |
print( | |
" - Loaded pre-trained DeepLabv3+ (MobileNetV2) Head as Seg Decoder" | |
) | |
def set_target_size(self, size): | |
""" | |
Set final interpolation's target size | |
Args: | |
size (int, list, tuple): target size (h, w). If int, target will be (i, i) | |
""" | |
if isinstance(size, (list, tuple)): | |
self._target_size = size[:2] | |
else: | |
self._target_size = (size, size) | |
def forward(self, z, z_depth=None): | |
assert isinstance(z, (tuple, list)) | |
if self._target_size is None: | |
error = "self._target_size should be set with self.set_target_size()" | |
error += "to interpolate logits to the target seg map's size" | |
raise ValueError(error) | |
z_high, z_low = z | |
if z_depth is not None and self.use_dada: | |
z_high = z_high * z_depth | |
if self.backbone == "resnet": | |
z_high = self.aspp(z_high) | |
s = self.decoder(z_high, z_low) | |
else: | |
s = self.head(z_high) | |
s = F.interpolate( | |
s, size=self._target_size, mode="bilinear", align_corners=True | |
) | |
return s | |
def freeze_bn(self): | |
for m in self.modules(): | |
if isinstance(m, nn.BatchNorm2d): | |
m.eval() | |