Spaces:
Runtime error
Runtime error
from pathlib import Path | |
import torch | |
import torch.nn as nn | |
from climategan.deeplab.deeplab_v2 import DeepLabV2Decoder | |
from climategan.deeplab.deeplab_v3 import DeepLabV3Decoder | |
from climategan.deeplab.mobilenet_v3 import MobileNetV2 | |
from climategan.deeplab.resnet101_v3 import ResNet101 | |
from climategan.deeplab.resnetmulti_v2 import ResNetMulti | |
def create_encoder(opts, no_init=False, verbose=0): | |
if opts.gen.encoder.architecture == "deeplabv2": | |
if verbose > 0: | |
print(" - Add Deeplabv2 Encoder") | |
return DeeplabV2Encoder(opts, no_init, verbose) | |
elif opts.gen.encoder.architecture == "deeplabv3": | |
if verbose > 0: | |
backone = opts.gen.deeplabv3.backbone | |
print(" - Add Deeplabv3 ({}) Encoder".format(backone)) | |
return build_v3_backbone(opts, no_init) | |
else: | |
raise NotImplementedError( | |
"Unknown encoder: {}".format(opts.gen.encoder.architecture) | |
) | |
def create_segmentation_decoder(opts, no_init=False, verbose=0): | |
if opts.gen.s.architecture == "deeplabv2": | |
if verbose > 0: | |
print(" - Add DeepLabV2Decoder") | |
return DeepLabV2Decoder(opts) | |
elif opts.gen.s.architecture == "deeplabv3": | |
if verbose > 0: | |
print(" - Add DeepLabV3Decoder") | |
return DeepLabV3Decoder(opts, no_init) | |
else: | |
raise NotImplementedError( | |
"Unknown Segmentation architecture: {}".format(opts.gen.s.architecture) | |
) | |
def build_v3_backbone(opts, no_init, verbose=0): | |
backbone = opts.gen.deeplabv3.backbone | |
output_stride = opts.gen.deeplabv3.output_stride | |
if backbone == "resnet": | |
resnet = ResNet101( | |
output_stride=output_stride, | |
BatchNorm=nn.BatchNorm2d, | |
verbose=verbose, | |
no_init=no_init, | |
) | |
if not no_init: | |
if opts.gen.deeplabv3.backbone == "resnet": | |
assert Path(opts.gen.deeplabv3.pretrained_model.resnet).exists() | |
std = torch.load(opts.gen.deeplabv3.pretrained_model.resnet) | |
resnet.load_state_dict( | |
{ | |
k.replace("backbone.", ""): v | |
for k, v in std.items() | |
if k.startswith("backbone.") | |
} | |
) | |
print( | |
" - Loaded pre-trained DeepLabv3+ Resnet101 Backbone as Encoder" | |
) | |
return resnet | |
elif opts.gen.deeplabv3.backbone == "mobilenet": | |
assert Path(opts.gen.deeplabv3.pretrained_model.mobilenet).exists() | |
mobilenet = MobileNetV2( | |
no_init=no_init, | |
pretrained_path=opts.gen.deeplabv3.pretrained_model.mobilenet, | |
) | |
print(" - Loaded pre-trained DeepLabv3+ MobileNetV2 Backbone as Encoder") | |
return mobilenet | |
else: | |
raise NotImplementedError("Unknown backbone in " + str(opts.gen.deeplabv3)) | |
class DeeplabV2Encoder(nn.Module): | |
def __init__(self, opts, no_init=False, verbose=0): | |
"""Deeplab architecture encoder""" | |
super().__init__() | |
self.model = ResNetMulti(opts.gen.deeplabv2.nblocks, opts.gen.encoder.n_res) | |
if opts.gen.deeplabv2.use_pretrained and not no_init: | |
saved_state_dict = torch.load(opts.gen.deeplabv2.pretrained_model) | |
new_params = self.model.state_dict().copy() | |
for i in saved_state_dict: | |
i_parts = i.split(".") | |
if not i_parts[1] in ["layer5", "resblock"]: | |
new_params[".".join(i_parts[1:])] = saved_state_dict[i] | |
self.model.load_state_dict(new_params) | |
if verbose > 0: | |
print(" - Loaded pretrained weights") | |
def forward(self, x): | |
return self.model(x) | |