Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import cv2 | |
from basicsr.utils import img2tensor, tensor2img | |
_BATCH_NORM = nn.BatchNorm2d | |
_BOTTLENECK_EXPANSION = 4 | |
import blobfile as bf | |
def _list_image_files_recursively(data_dir): | |
results = [] | |
for entry in sorted(bf.listdir(data_dir)): | |
full_path = bf.join(data_dir, entry) | |
ext = entry.split(".")[-1] | |
if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]: | |
results.append(full_path) | |
elif bf.isdir(full_path): | |
results.extend(_list_image_files_recursively(full_path)) | |
return results | |
def uint82bin(n, count=8): | |
"""returns the binary of integer n, count refers to amount of bits""" | |
return ''.join([str((n >> y) & 1) for y in range(count - 1, -1, -1)]) | |
def labelcolormap(N): | |
if N == 35: # cityscape | |
cmap = np.array([(0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (111, 74, 0), (81, 0, 81), | |
(128, 64, 128), (244, 35, 232), (250, 170, 160), (230, 150, 140), (70, 70, 70), (102, 102, 156), (190, 153, 153), | |
(180, 165, 180), (150, 100, 100), (150, 120, 90), (153, 153, 153), (153, 153, 153), (250, 170, 30), (220, 220, 0), | |
(107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60), (255, 0, 0), (0, 0, 142), (0, 0, 70), | |
(0, 60, 100), (0, 0, 90), (0, 0, 110), (0, 80, 100), (0, 0, 230), (119, 11, 32), (0, 0, 142)], | |
dtype=np.uint8) | |
else: | |
cmap = np.zeros((N, 3), dtype=np.uint8) | |
for i in range(N): | |
r, g, b = 0, 0, 0 | |
id = i + 1 # let's give 0 a color | |
for j in range(7): | |
str_id = uint82bin(id) | |
r = r ^ (np.uint8(str_id[-1]) << (7 - j)) | |
g = g ^ (np.uint8(str_id[-2]) << (7 - j)) | |
b = b ^ (np.uint8(str_id[-3]) << (7 - j)) | |
id = id >> 3 | |
cmap[i, 0] = r | |
cmap[i, 1] = g | |
cmap[i, 2] = b | |
return cmap | |
class Colorize(object): | |
def __init__(self, n=182): | |
self.cmap = labelcolormap(n) | |
def __call__(self, gray_image): | |
size = gray_image.shape | |
color_image = np.zeros((3, size[0], size[1])) | |
for label in range(0, len(self.cmap)): | |
mask = (label == gray_image ) | |
color_image[0][mask] = self.cmap[label][0] | |
color_image[1][mask] = self.cmap[label][1] | |
color_image[2][mask] = self.cmap[label][2] | |
return color_image | |
class _ConvBnReLU(nn.Sequential): | |
""" | |
Cascade of 2D convolution, batch norm, and ReLU. | |
""" | |
BATCH_NORM = _BATCH_NORM | |
def __init__( | |
self, in_ch, out_ch, kernel_size, stride, padding, dilation, relu=True | |
): | |
super(_ConvBnReLU, self).__init__() | |
self.add_module( | |
"conv", | |
nn.Conv2d( | |
in_ch, out_ch, kernel_size, stride, padding, dilation, bias=False | |
), | |
) | |
self.add_module("bn", _BATCH_NORM(out_ch, eps=1e-5, momentum=1 - 0.999)) | |
if relu: | |
self.add_module("relu", nn.ReLU()) | |
class _Bottleneck(nn.Module): | |
""" | |
Bottleneck block of MSRA ResNet. | |
""" | |
def __init__(self, in_ch, out_ch, stride, dilation, downsample): | |
super(_Bottleneck, self).__init__() | |
mid_ch = out_ch // _BOTTLENECK_EXPANSION | |
self.reduce = _ConvBnReLU(in_ch, mid_ch, 1, stride, 0, 1, True) | |
self.conv3x3 = _ConvBnReLU(mid_ch, mid_ch, 3, 1, dilation, dilation, True) | |
self.increase = _ConvBnReLU(mid_ch, out_ch, 1, 1, 0, 1, False) | |
self.shortcut = ( | |
_ConvBnReLU(in_ch, out_ch, 1, stride, 0, 1, False) | |
if downsample | |
else nn.Identity() | |
) | |
def forward(self, x): | |
h = self.reduce(x) | |
h = self.conv3x3(h) | |
h = self.increase(h) | |
h += self.shortcut(x) | |
return F.relu(h) | |
class _ResLayer(nn.Sequential): | |
""" | |
Residual layer with multi grids | |
""" | |
def __init__(self, n_layers, in_ch, out_ch, stride, dilation, multi_grids=None): | |
super(_ResLayer, self).__init__() | |
if multi_grids is None: | |
multi_grids = [1 for _ in range(n_layers)] | |
else: | |
assert n_layers == len(multi_grids) | |
# Downsampling is only in the first block | |
for i in range(n_layers): | |
self.add_module( | |
"block{}".format(i + 1), | |
_Bottleneck( | |
in_ch=(in_ch if i == 0 else out_ch), | |
out_ch=out_ch, | |
stride=(stride if i == 0 else 1), | |
dilation=dilation * multi_grids[i], | |
downsample=(True if i == 0 else False), | |
), | |
) | |
class _Stem(nn.Sequential): | |
""" | |
The 1st conv layer. | |
Note that the max pooling is different from both MSRA and FAIR ResNet. | |
""" | |
def __init__(self, out_ch): | |
super(_Stem, self).__init__() | |
self.add_module("conv1", _ConvBnReLU(3, out_ch, 7, 2, 3, 1)) | |
self.add_module("pool", nn.MaxPool2d(3, 2, 1, ceil_mode=True)) | |
class _ASPP(nn.Module): | |
""" | |
Atrous spatial pyramid pooling (ASPP) | |
""" | |
def __init__(self, in_ch, out_ch, rates): | |
super(_ASPP, self).__init__() | |
for i, rate in enumerate(rates): | |
self.add_module( | |
"c{}".format(i), | |
nn.Conv2d(in_ch, out_ch, 3, 1, padding=rate, dilation=rate, bias=True), | |
) | |
for m in self.children(): | |
nn.init.normal_(m.weight, mean=0, std=0.01) | |
nn.init.constant_(m.bias, 0) | |
def forward(self, x): | |
return sum([stage(x) for stage in self.children()]) | |
class MSC(nn.Module): | |
""" | |
Multi-scale inputs | |
""" | |
def __init__(self, base, scales=None): | |
super(MSC, self).__init__() | |
self.base = base | |
if scales: | |
self.scales = scales | |
else: | |
self.scales = [0.5, 0.75] | |
def forward(self, x): | |
# Original | |
logits = self.base(x) | |
_, _, H, W = logits.shape | |
interp = lambda l: F.interpolate( | |
l, size=(H, W), mode="bilinear", align_corners=False | |
) | |
# Scaled | |
logits_pyramid = [] | |
for p in self.scales: | |
h = F.interpolate(x, scale_factor=p, mode="bilinear", align_corners=False) | |
logits_pyramid.append(self.base(h)) | |
# Pixel-wise max | |
logits_all = [logits] + [interp(l) for l in logits_pyramid] | |
logits_max = torch.max(torch.stack(logits_all), dim=0)[0] | |
return logits_max | |
class DeepLabV2(nn.Sequential): | |
""" | |
DeepLab v2: Dilated ResNet + ASPP | |
Output stride is fixed at 8 | |
""" | |
def __init__(self, n_classes=182, n_blocks=[3, 4, 23, 3], atrous_rates=[6, 12, 18, 24]): | |
super(DeepLabV2, self).__init__() | |
ch = [64 * 2 ** p for p in range(6)] | |
self.add_module("layer1", _Stem(ch[0])) | |
self.add_module("layer2", _ResLayer(n_blocks[0], ch[0], ch[2], 1, 1)) | |
self.add_module("layer3", _ResLayer(n_blocks[1], ch[2], ch[3], 2, 1)) | |
self.add_module("layer4", _ResLayer(n_blocks[2], ch[3], ch[4], 1, 2)) | |
self.add_module("layer5", _ResLayer(n_blocks[3], ch[4], ch[5], 1, 4)) | |
self.add_module("aspp", _ASPP(ch[5], n_classes, atrous_rates)) | |
def freeze_bn(self): | |
for m in self.modules(): | |
if isinstance(m, _ConvBnReLU.BATCH_NORM): | |
m.eval() | |
def preprocessing(image, device): | |
# Resize | |
scale = 640 / max(image.shape[:2]) | |
image = cv2.resize(image, dsize=None, fx=scale, fy=scale) | |
raw_image = image.astype(np.uint8) | |
# Subtract mean values | |
image = image.astype(np.float32) | |
image -= np.array( | |
[ | |
float(104.008), | |
float(116.669), | |
float(122.675), | |
] | |
) | |
# Convert to torch.Tensor and add "batch" axis | |
image = torch.from_numpy(image.transpose(2, 0, 1)).float().unsqueeze(0) | |
image = image.to(device) | |
return image, raw_image | |
# Model setup | |
def seger(): | |
model = MSC( | |
base=DeepLabV2( | |
n_classes=182, n_blocks=[3, 4, 23, 3], atrous_rates=[6, 12, 18, 24] | |
), | |
scales=[0.5, 0.75], | |
) | |
state_dict = torch.load('models/deeplabv2_resnet101_msc-cocostuff164k-100000.pth') | |
model.load_state_dict(state_dict) # to skip ASPP | |
return model | |
if __name__ == '__main__': | |
device = 'cuda' | |
model = seger() | |
model.to(device) | |
model.eval() | |
with torch.no_grad(): | |
im = cv2.imread('/group/30042/chongmou/ft_local/Diffusion/baselines/SPADE/datasets/coco_stuff/val_img/000000000785.jpg', cv2.IMREAD_COLOR) | |
im, raw_im = preprocessing(im, 'cuda') | |
_, _, H, W = im.shape | |
# Image -> Probability map | |
logits = model(im) | |
logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False) | |
probs = F.softmax(logits, dim=1)[0] | |
probs = probs.cpu().data.numpy() | |
labelmap = np.argmax(probs, axis=0) | |
print(labelmap.shape, np.max(labelmap), np.min(labelmap)) | |
cv2.imwrite('mask.png', labelmap) | |