T2I-Adapter / seger.py
Adapter's picture
add seg
a056b0b
raw
history blame
9.25 kB
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
from basicsr.utils import img2tensor, tensor2img
_BATCH_NORM = nn.BatchNorm2d
_BOTTLENECK_EXPANSION = 4
import blobfile as bf
def _list_image_files_recursively(data_dir):
results = []
for entry in sorted(bf.listdir(data_dir)):
full_path = bf.join(data_dir, entry)
ext = entry.split(".")[-1]
if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
results.append(full_path)
elif bf.isdir(full_path):
results.extend(_list_image_files_recursively(full_path))
return results
def uint82bin(n, count=8):
"""returns the binary of integer n, count refers to amount of bits"""
return ''.join([str((n >> y) & 1) for y in range(count - 1, -1, -1)])
def labelcolormap(N):
if N == 35: # cityscape
cmap = np.array([(0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (111, 74, 0), (81, 0, 81),
(128, 64, 128), (244, 35, 232), (250, 170, 160), (230, 150, 140), (70, 70, 70), (102, 102, 156), (190, 153, 153),
(180, 165, 180), (150, 100, 100), (150, 120, 90), (153, 153, 153), (153, 153, 153), (250, 170, 30), (220, 220, 0),
(107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60), (255, 0, 0), (0, 0, 142), (0, 0, 70),
(0, 60, 100), (0, 0, 90), (0, 0, 110), (0, 80, 100), (0, 0, 230), (119, 11, 32), (0, 0, 142)],
dtype=np.uint8)
else:
cmap = np.zeros((N, 3), dtype=np.uint8)
for i in range(N):
r, g, b = 0, 0, 0
id = i + 1 # let's give 0 a color
for j in range(7):
str_id = uint82bin(id)
r = r ^ (np.uint8(str_id[-1]) << (7 - j))
g = g ^ (np.uint8(str_id[-2]) << (7 - j))
b = b ^ (np.uint8(str_id[-3]) << (7 - j))
id = id >> 3
cmap[i, 0] = r
cmap[i, 1] = g
cmap[i, 2] = b
return cmap
class Colorize(object):
def __init__(self, n=182):
self.cmap = labelcolormap(n)
def __call__(self, gray_image):
size = gray_image.shape
color_image = np.zeros((3, size[0], size[1]))
for label in range(0, len(self.cmap)):
mask = (label == gray_image )
color_image[0][mask] = self.cmap[label][0]
color_image[1][mask] = self.cmap[label][1]
color_image[2][mask] = self.cmap[label][2]
return color_image
class _ConvBnReLU(nn.Sequential):
"""
Cascade of 2D convolution, batch norm, and ReLU.
"""
BATCH_NORM = _BATCH_NORM
def __init__(
self, in_ch, out_ch, kernel_size, stride, padding, dilation, relu=True
):
super(_ConvBnReLU, self).__init__()
self.add_module(
"conv",
nn.Conv2d(
in_ch, out_ch, kernel_size, stride, padding, dilation, bias=False
),
)
self.add_module("bn", _BATCH_NORM(out_ch, eps=1e-5, momentum=1 - 0.999))
if relu:
self.add_module("relu", nn.ReLU())
class _Bottleneck(nn.Module):
"""
Bottleneck block of MSRA ResNet.
"""
def __init__(self, in_ch, out_ch, stride, dilation, downsample):
super(_Bottleneck, self).__init__()
mid_ch = out_ch // _BOTTLENECK_EXPANSION
self.reduce = _ConvBnReLU(in_ch, mid_ch, 1, stride, 0, 1, True)
self.conv3x3 = _ConvBnReLU(mid_ch, mid_ch, 3, 1, dilation, dilation, True)
self.increase = _ConvBnReLU(mid_ch, out_ch, 1, 1, 0, 1, False)
self.shortcut = (
_ConvBnReLU(in_ch, out_ch, 1, stride, 0, 1, False)
if downsample
else nn.Identity()
)
def forward(self, x):
h = self.reduce(x)
h = self.conv3x3(h)
h = self.increase(h)
h += self.shortcut(x)
return F.relu(h)
class _ResLayer(nn.Sequential):
"""
Residual layer with multi grids
"""
def __init__(self, n_layers, in_ch, out_ch, stride, dilation, multi_grids=None):
super(_ResLayer, self).__init__()
if multi_grids is None:
multi_grids = [1 for _ in range(n_layers)]
else:
assert n_layers == len(multi_grids)
# Downsampling is only in the first block
for i in range(n_layers):
self.add_module(
"block{}".format(i + 1),
_Bottleneck(
in_ch=(in_ch if i == 0 else out_ch),
out_ch=out_ch,
stride=(stride if i == 0 else 1),
dilation=dilation * multi_grids[i],
downsample=(True if i == 0 else False),
),
)
class _Stem(nn.Sequential):
"""
The 1st conv layer.
Note that the max pooling is different from both MSRA and FAIR ResNet.
"""
def __init__(self, out_ch):
super(_Stem, self).__init__()
self.add_module("conv1", _ConvBnReLU(3, out_ch, 7, 2, 3, 1))
self.add_module("pool", nn.MaxPool2d(3, 2, 1, ceil_mode=True))
class _ASPP(nn.Module):
"""
Atrous spatial pyramid pooling (ASPP)
"""
def __init__(self, in_ch, out_ch, rates):
super(_ASPP, self).__init__()
for i, rate in enumerate(rates):
self.add_module(
"c{}".format(i),
nn.Conv2d(in_ch, out_ch, 3, 1, padding=rate, dilation=rate, bias=True),
)
for m in self.children():
nn.init.normal_(m.weight, mean=0, std=0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x):
return sum([stage(x) for stage in self.children()])
class MSC(nn.Module):
"""
Multi-scale inputs
"""
def __init__(self, base, scales=None):
super(MSC, self).__init__()
self.base = base
if scales:
self.scales = scales
else:
self.scales = [0.5, 0.75]
def forward(self, x):
# Original
logits = self.base(x)
_, _, H, W = logits.shape
interp = lambda l: F.interpolate(
l, size=(H, W), mode="bilinear", align_corners=False
)
# Scaled
logits_pyramid = []
for p in self.scales:
h = F.interpolate(x, scale_factor=p, mode="bilinear", align_corners=False)
logits_pyramid.append(self.base(h))
# Pixel-wise max
logits_all = [logits] + [interp(l) for l in logits_pyramid]
logits_max = torch.max(torch.stack(logits_all), dim=0)[0]
return logits_max
class DeepLabV2(nn.Sequential):
"""
DeepLab v2: Dilated ResNet + ASPP
Output stride is fixed at 8
"""
def __init__(self, n_classes=182, n_blocks=[3, 4, 23, 3], atrous_rates=[6, 12, 18, 24]):
super(DeepLabV2, self).__init__()
ch = [64 * 2 ** p for p in range(6)]
self.add_module("layer1", _Stem(ch[0]))
self.add_module("layer2", _ResLayer(n_blocks[0], ch[0], ch[2], 1, 1))
self.add_module("layer3", _ResLayer(n_blocks[1], ch[2], ch[3], 2, 1))
self.add_module("layer4", _ResLayer(n_blocks[2], ch[3], ch[4], 1, 2))
self.add_module("layer5", _ResLayer(n_blocks[3], ch[4], ch[5], 1, 4))
self.add_module("aspp", _ASPP(ch[5], n_classes, atrous_rates))
def freeze_bn(self):
for m in self.modules():
if isinstance(m, _ConvBnReLU.BATCH_NORM):
m.eval()
def preprocessing(image, device):
# Resize
scale = 640 / max(image.shape[:2])
image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
raw_image = image.astype(np.uint8)
# Subtract mean values
image = image.astype(np.float32)
image -= np.array(
[
float(104.008),
float(116.669),
float(122.675),
]
)
# Convert to torch.Tensor and add "batch" axis
image = torch.from_numpy(image.transpose(2, 0, 1)).float().unsqueeze(0)
image = image.to(device)
return image, raw_image
# Model setup
def seger():
model = MSC(
base=DeepLabV2(
n_classes=182, n_blocks=[3, 4, 23, 3], atrous_rates=[6, 12, 18, 24]
),
scales=[0.5, 0.75],
)
state_dict = torch.load('models/deeplabv2_resnet101_msc-cocostuff164k-100000.pth')
model.load_state_dict(state_dict) # to skip ASPP
return model
if __name__ == '__main__':
device = 'cuda'
model = seger()
model.to(device)
model.eval()
with torch.no_grad():
im = cv2.imread('/group/30042/chongmou/ft_local/Diffusion/baselines/SPADE/datasets/coco_stuff/val_img/000000000785.jpg', cv2.IMREAD_COLOR)
im, raw_im = preprocessing(im, 'cuda')
_, _, H, W = im.shape
# Image -> Probability map
logits = model(im)
logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False)
probs = F.softmax(logits, dim=1)[0]
probs = probs.cpu().data.numpy()
labelmap = np.argmax(probs, axis=0)
print(labelmap.shape, np.max(labelmap), np.min(labelmap))
cv2.imwrite('mask.png', labelmap)