vict0rsch's picture
initial commit from `vict0rsch/climateGAN`
ce190ee
raw
history blame
8.79 kB
"""
https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/backbone/resnet.py
"""
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from climategan.deeplab.mobilenet_v3 import SeparableConv2d
from climategan.utils import find_target_size
class _DeepLabHead(nn.Module):
def __init__(
self, nclass, c1_channels=256, c4_channels=2048, norm_layer=nn.BatchNorm2d
):
super().__init__()
last_channels = c4_channels
# self.c1_block = _ConvBNReLU(c1_channels, 48, 1, norm_layer=norm_layer)
# last_channels += 48
self.block = nn.Sequential(
SeparableConv2d(
last_channels, 256, 3, norm_layer=norm_layer, relu_first=False
),
SeparableConv2d(256, 256, 3, norm_layer=norm_layer, relu_first=False),
nn.Conv2d(256, nclass, 1),
)
def forward(self, x, c1=None):
return self.block(x)
class ConvBNReLU(nn.Module):
"""
https://github.com/CoinCheung/DeepLab-v3-plus-cityscapes/blob/master/models/deeplabv3plus.py
"""
def __init__(
self, in_chan, out_chan, ks=3, stride=1, padding=1, dilation=1, *args, **kwargs
):
super().__init__()
self.conv = nn.Conv2d(
in_chan,
out_chan,
kernel_size=ks,
stride=stride,
padding=padding,
dilation=dilation,
bias=True,
)
self.bn = nn.BatchNorm2d(out_chan)
self.init_weight()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if ly.bias is not None:
nn.init.constant_(ly.bias, 0)
class ASPPv3Plus(nn.Module):
"""
https://github.com/CoinCheung/DeepLab-v3-plus-cityscapes/blob/master/models/deeplabv3plus.py
"""
def __init__(self, backbone, no_init):
super().__init__()
if backbone == "mobilenet":
in_chan = 320
else:
in_chan = 2048
self.with_gp = False
self.conv1 = ConvBNReLU(in_chan, 256, ks=1, dilation=1, padding=0)
self.conv2 = ConvBNReLU(in_chan, 256, ks=3, dilation=6, padding=6)
self.conv3 = ConvBNReLU(in_chan, 256, ks=3, dilation=12, padding=12)
self.conv4 = ConvBNReLU(in_chan, 256, ks=3, dilation=18, padding=18)
if self.with_gp:
self.avg = nn.AdaptiveAvgPool2d((1, 1))
self.conv1x1 = ConvBNReLU(in_chan, 256, ks=1)
self.conv_out = ConvBNReLU(256 * 5, 256, ks=1)
else:
self.conv_out = ConvBNReLU(256 * 4, 256, ks=1)
if not no_init:
self.init_weight()
def forward(self, x):
H, W = x.size()[2:]
feat1 = self.conv1(x)
feat2 = self.conv2(x)
feat3 = self.conv3(x)
feat4 = self.conv4(x)
if self.with_gp:
avg = self.avg(x)
feat5 = self.conv1x1(avg)
feat5 = F.interpolate(feat5, (H, W), mode="bilinear", align_corners=True)
feat = torch.cat([feat1, feat2, feat3, feat4, feat5], 1)
else:
feat = torch.cat([feat1, feat2, feat3, feat4], 1)
feat = self.conv_out(feat)
return feat
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if ly.bias is not None:
nn.init.constant_(ly.bias, 0)
class Decoder(nn.Module):
"""
https://github.com/CoinCheung/DeepLab-v3-plus-cityscapes/blob/master/models/deeplabv3plus.py
"""
def __init__(self, n_classes):
super(Decoder, self).__init__()
self.conv_low = ConvBNReLU(256, 48, ks=1, padding=0)
self.conv_cat = nn.Sequential(
ConvBNReLU(304, 256, ks=3, padding=1),
ConvBNReLU(256, 256, ks=3, padding=1),
)
self.conv_out = nn.Conv2d(256, n_classes, kernel_size=1, bias=False)
def forward(self, feat_low, feat_aspp):
H, W = feat_low.size()[2:]
feat_low = self.conv_low(feat_low)
feat_aspp_up = F.interpolate(
feat_aspp, (H, W), mode="bilinear", align_corners=True
)
feat_cat = torch.cat([feat_low, feat_aspp_up], dim=1)
feat_out = self.conv_cat(feat_cat)
logits = self.conv_out(feat_out)
return logits
"""
https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/deeplab.py
"""
class DeepLabV3Decoder(nn.Module):
def __init__(
self,
opts,
no_init=False,
freeze_bn=False,
):
super().__init__()
num_classes = opts.gen.s.output_dim
self.backbone = opts.gen.deeplabv3.backbone
self.use_dada = ("d" in opts.tasks) and opts.gen.s.use_dada
if self.backbone == "resnet":
self.aspp = ASPPv3Plus(self.backbone, no_init)
self.decoder = Decoder(num_classes)
self.freeze_bn = freeze_bn
else:
self.head = _DeepLabHead(num_classes, c4_channels=320)
self._target_size = find_target_size(opts, "s")
print(
" - {}: setting target size to {}".format(
self.__class__.__name__, self._target_size
)
)
if not no_init:
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)
self.load_pretrained(opts)
def load_pretrained(self, opts):
assert opts.gen.deeplabv3.backbone in {"resnet", "mobilenet"}
assert Path(opts.gen.deeplabv3.pretrained_model.resnet).exists()
if opts.gen.deeplabv3.backbone == "resnet":
std = torch.load(opts.gen.deeplabv3.pretrained_model.resnet)
self.aspp.load_state_dict(
{
k.replace("aspp.", ""): v
for k, v in std.items()
if k.startswith("aspp.")
}
)
self.decoder.load_state_dict(
{
k.replace("decoder.", ""): v
for k, v in std.items()
if k.startswith("decoder.")
and not (len(v.shape) > 0 and v.shape[0] == 19)
},
strict=False,
)
print(
"- Loaded pre-trained DeepLabv3+ (Resnet) Decoder & ASPP as Seg Decoder"
)
else:
std = torch.load(opts.gen.deeplabv3.pretrained_model.mobilenet)
self.load_state_dict(
{
k: v
for k, v in std.items()
if k.startswith("head.")
and not (len(v.shape) > 0 and v.shape[0] == 19)
},
strict=False,
)
print(
" - Loaded pre-trained DeepLabv3+ (MobileNetV2) Head as Seg Decoder"
)
def set_target_size(self, size):
"""
Set final interpolation's target size
Args:
size (int, list, tuple): target size (h, w). If int, target will be (i, i)
"""
if isinstance(size, (list, tuple)):
self._target_size = size[:2]
else:
self._target_size = (size, size)
def forward(self, z, z_depth=None):
assert isinstance(z, (tuple, list))
if self._target_size is None:
error = "self._target_size should be set with self.set_target_size()"
error += "to interpolate logits to the target seg map's size"
raise ValueError(error)
z_high, z_low = z
if z_depth is not None and self.use_dada:
z_high = z_high * z_depth
if self.backbone == "resnet":
z_high = self.aspp(z_high)
s = self.decoder(z_high, z_low)
else:
s = self.head(z_high)
s = F.interpolate(
s, size=self._target_size, mode="bilinear", align_corners=True
)
return s
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()