|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
|
|
|
|
|
|
|
|
|
|
class SBR(nn.Module): |
|
def __init__(self, in_ch): |
|
super(SBR, self).__init__() |
|
self.conv1x3 = nn.Sequential( |
|
nn.Conv2d(in_ch, in_ch, kernel_size=(1, 3), stride=1, padding=(0, 1)), |
|
nn.BatchNorm2d(in_ch), |
|
nn.ReLU(True) |
|
) |
|
self.conv3x1 = nn.Sequential( |
|
nn.Conv2d(in_ch, in_ch, kernel_size=(3, 1), stride=1, padding=(1, 0)), |
|
nn.BatchNorm2d(in_ch), |
|
nn.ReLU(True) |
|
) |
|
|
|
def forward(self, x): |
|
out = self.conv3x1(self.conv1x3(x)) |
|
return out + x |
|
|
|
|
|
|
|
class c_stage123(nn.Module): |
|
def __init__(self, in_chans, out_chans): |
|
super().__init__() |
|
self.stage123 = nn.Sequential( |
|
nn.Conv2d(in_channels=in_chans, out_channels=out_chans, kernel_size=3, stride=2, padding=1), |
|
nn.BatchNorm2d(out_chans), |
|
nn.ReLU(), |
|
nn.Conv2d(in_channels=out_chans, out_channels=out_chans, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(out_chans), |
|
nn.ReLU(), |
|
) |
|
self.conv1x1_123 = nn.Conv2d(in_channels=in_chans, out_channels=out_chans, kernel_size=1) |
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) |
|
|
|
def forward(self, x): |
|
stage123 = self.stage123(x) |
|
max = self.maxpool(x) |
|
max = self.conv1x1_123(max) |
|
stage123 = stage123 + max |
|
return stage123 |
|
|
|
|
|
|
|
class c_stage45(nn.Module): |
|
def __init__(self, in_chans, out_chans): |
|
super().__init__() |
|
self.stage45 = nn.Sequential( |
|
nn.Conv2d(in_channels=in_chans, out_channels=out_chans, kernel_size=3, stride=2, padding=1), |
|
nn.BatchNorm2d(out_chans), |
|
nn.ReLU(), |
|
nn.Conv2d(in_channels=out_chans, out_channels=out_chans, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(out_chans), |
|
nn.ReLU(), |
|
nn.Conv2d(in_channels=out_chans, out_channels=out_chans, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(out_chans), |
|
nn.ReLU(), |
|
) |
|
self.conv1x1_45 = nn.Conv2d(in_channels=in_chans, out_channels=out_chans, kernel_size=1) |
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) |
|
|
|
def forward(self, x): |
|
stage45 = self.stage45(x) |
|
max = self.maxpool(x) |
|
max = self.conv1x1_45(max) |
|
stage45 = stage45 + max |
|
return stage45 |
|
|
|
|
|
class Identity(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, x): |
|
return x |
|
|
|
|
|
|
|
class DepthwiseConv2d(nn.Module): |
|
def __init__(self, in_chans, out_chans, kernel_size=1, stride=1, padding=0, dilation=1): |
|
super().__init__() |
|
|
|
self.depthwise = nn.Conv2d( |
|
in_channels=in_chans, |
|
out_channels=in_chans, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
groups=in_chans |
|
) |
|
|
|
self.bn = nn.BatchNorm2d(num_features=in_chans) |
|
|
|
|
|
self.pointwise = nn.Conv2d( |
|
in_channels=in_chans, |
|
out_channels=out_chans, |
|
kernel_size=1 |
|
) |
|
|
|
def forward(self, x): |
|
x = self.depthwise(x) |
|
x = self.bn(x) |
|
x = self.pointwise(x) |
|
return x |
|
|
|
|
|
|
|
class Residual(nn.Module): |
|
def __init__(self, fn): |
|
super().__init__() |
|
self.fn = fn |
|
|
|
def forward(self, input, **kwargs): |
|
x = self.fn(input, **kwargs) |
|
return (x + input) |
|
|
|
|
|
|
|
class PreNorm(nn.Module): |
|
def __init__(self, dim, fn): |
|
super().__init__() |
|
self.norm = nn.LayerNorm(dim) |
|
self.fn = fn |
|
|
|
def forward(self, input, **kwargs): |
|
return self.fn(self.norm(input), **kwargs) |
|
|
|
|
|
|
|
class FeedForward(nn.Module): |
|
def __init__(self, dim, hidden_dim, dropout=0.): |
|
super().__init__() |
|
self.net = nn.Sequential( |
|
nn.Linear(in_features=dim, out_features=hidden_dim), |
|
nn.GELU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(in_features=hidden_dim, out_features=dim), |
|
nn.Dropout(dropout) |
|
) |
|
|
|
def forward(self, input): |
|
return self.net(input) |
|
|
|
|
|
class ConvAttnetion(nn.Module): |
|
''' |
|
using the Depth_Separable_Wise Conv2d to produce the q, k, v instead of using Linear Project in ViT |
|
''' |
|
|
|
def __init__(self, dim, img_size, heads=8, dim_head=64, kernel_size=3, q_stride=1, k_stride=1, v_stride=1, |
|
dropout=0., last_stage=False): |
|
super().__init__() |
|
self.last_stage = last_stage |
|
self.img_size = img_size |
|
inner_dim = dim_head * heads |
|
project_out = not (heads == 1 and dim_head == dim) |
|
|
|
self.heads = heads |
|
self.scale = dim_head ** (-0.5) |
|
|
|
pad = (kernel_size - q_stride) // 2 |
|
|
|
self.to_q = DepthwiseConv2d(in_chans=dim, out_chans=inner_dim, kernel_size=kernel_size, stride=q_stride, |
|
padding=pad) |
|
self.to_k = DepthwiseConv2d(in_chans=dim, out_chans=inner_dim, kernel_size=kernel_size, stride=k_stride, |
|
padding=pad) |
|
self.to_v = DepthwiseConv2d(in_chans=dim, out_chans=inner_dim, kernel_size=kernel_size, stride=v_stride, |
|
padding=pad) |
|
|
|
self.to_out = nn.Sequential( |
|
nn.Linear( |
|
in_features=inner_dim, |
|
out_features=dim |
|
), |
|
nn.Dropout(dropout) |
|
) if project_out else Identity() |
|
|
|
def forward(self, x): |
|
b, n, c, h = *x.shape, self.heads |
|
|
|
|
|
|
|
|
|
|
|
if self.last_stage: |
|
cls_token = x[:, 0] |
|
|
|
|
|
x = x[:, 1:] |
|
|
|
cls_token = rearrange(torch.unsqueeze(cls_token, dim=1), 'b n (h d) -> b h n d', h=h) |
|
|
|
|
|
x = rearrange(x, 'b (l w) n -> b n l w', l=self.img_size, w=self.img_size) |
|
|
|
|
|
q = self.to_q(x) |
|
|
|
|
|
q = rearrange(q, 'b (h d) l w -> b h (l w) d', h=h) |
|
|
|
|
|
|
|
|
|
k = self.to_k(x) |
|
k = rearrange(k, 'b (h d) l w -> b h (l w) d', h=h) |
|
|
|
|
|
v = self.to_v(x) |
|
|
|
|
|
v = rearrange(v, 'b (h d) l w -> b h (l w) d', h=h) |
|
|
|
|
|
|
|
|
|
if self.last_stage: |
|
|
|
|
|
q = torch.cat([cls_token, q], dim=2) |
|
|
|
|
|
v = torch.cat([cls_token, v], dim=2) |
|
k = torch.cat([cls_token, k], dim=2) |
|
|
|
|
|
|
|
|
|
|
|
k = k.permute(0, 1, 3, 2) |
|
|
|
|
|
attention = (q.matmul(k)) |
|
|
|
|
|
attention = attention * self.scale |
|
|
|
|
|
|
|
attention = F.softmax(attention, dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
out = (attention.matmul(v)).permute(0, 2, 1, 3).reshape(b, n, |
|
c) |
|
|
|
|
|
out = self.to_out(out) |
|
return out |
|
|
|
|
|
|
|
class Rearrange(nn.Module): |
|
def __init__(self, string, h, w): |
|
super().__init__() |
|
self.string = string |
|
self.h = h |
|
self.w = w |
|
|
|
def forward(self, input): |
|
|
|
if self.string == 'b c h w -> b (h w) c': |
|
N, C, H, W = input.shape |
|
|
|
x = torch.reshape(input, shape=(N, -1, self.h * self.w)).permute(0, 2, 1) |
|
|
|
|
|
if self.string == 'b (h w) c -> b c h w': |
|
N, _, C = input.shape |
|
|
|
x = torch.reshape(input, shape=(N, self.h, self.w, -1)).permute(0, 3, 1, 2) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
class Transformer(nn.Module): |
|
def __init__(self, dim, img_size, depth, heads, dim_head, mlp_dim, dropout=0., last_stage=False): |
|
super().__init__() |
|
self.layers = nn.ModuleList([ |
|
nn.ModuleList([ |
|
PreNorm(dim=dim, fn=ConvAttnetion(dim, img_size, heads=heads, dim_head=dim_head, dropout=dropout, |
|
last_stage=last_stage)), |
|
PreNorm(dim=dim, fn=FeedForward(dim=dim, hidden_dim=mlp_dim, dropout=dropout)) |
|
]) for _ in range(depth) |
|
]) |
|
|
|
def forward(self, x): |
|
for attn, ff in self.layers: |
|
x = x + attn(x) |
|
x = x + ff(x) |
|
return x |
|
|
|
|
|
class DBNet(nn.Module): |
|
def __init__(self, img_size, in_channels, num_classes, dim=64, kernels=[7, 3, 3, 3], strides=[4, 2, 2, 2], |
|
heads=[1, 3, 6, 6], |
|
depth=[1, 2, 10, 10], pool='cls', dropout=0., emb_dropout=0., scale_dim=4, ): |
|
super().__init__() |
|
|
|
assert pool in ['cls', 'mean'], f'pool type must be either cls or mean pooling' |
|
self.pool = pool |
|
self.dim = dim |
|
|
|
|
|
|
|
self.stage1_conv_embed = nn.Sequential( |
|
nn.Conv2d( |
|
in_channels=in_channels, |
|
out_channels=dim, |
|
kernel_size=kernels[0], |
|
stride=strides[0], |
|
padding=2 |
|
), |
|
Rearrange('b c h w -> b (h w) c', h=img_size // 4, w=img_size // 4), |
|
nn.LayerNorm(dim) |
|
) |
|
|
|
self.stage1_transformer = nn.Sequential( |
|
Transformer( |
|
dim=dim, |
|
img_size=img_size // 4, |
|
depth=depth[0], |
|
heads=heads[0], |
|
dim_head=self.dim, |
|
mlp_dim=dim * scale_dim, |
|
dropout=dropout, |
|
|
|
), |
|
Rearrange('b (h w) c -> b c h w', h=img_size // 4, w=img_size // 4) |
|
) |
|
|
|
|
|
|
|
in_channels = dim |
|
scale = heads[1] // heads[0] |
|
dim = scale * dim |
|
|
|
self.stage2_conv_embed = nn.Sequential( |
|
nn.Conv2d( |
|
in_channels=in_channels, |
|
out_channels=dim, |
|
kernel_size=kernels[1], |
|
stride=strides[1], |
|
padding=1 |
|
), |
|
Rearrange('b c h w -> b (h w) c', h=img_size // 8, w=img_size // 8), |
|
nn.LayerNorm(dim) |
|
) |
|
|
|
self.stage2_transformer = nn.Sequential( |
|
Transformer( |
|
dim=dim, |
|
img_size=img_size // 8, |
|
depth=depth[1], |
|
heads=heads[1], |
|
dim_head=self.dim, |
|
mlp_dim=dim * scale_dim, |
|
dropout=dropout |
|
), |
|
Rearrange('b (h w) c -> b c h w', h=img_size // 8, w=img_size // 8) |
|
) |
|
|
|
|
|
in_channels = dim |
|
scale = heads[2] // heads[1] |
|
dim = scale * dim |
|
|
|
self.stage3_conv_embed = nn.Sequential( |
|
nn.Conv2d( |
|
in_channels=in_channels, |
|
out_channels=dim, |
|
kernel_size=kernels[2], |
|
stride=strides[2], |
|
padding=1 |
|
), |
|
Rearrange('b c h w -> b (h w) c', h=img_size // 16, w=img_size // 16), |
|
nn.LayerNorm(dim) |
|
) |
|
|
|
self.stage3_transformer = nn.Sequential( |
|
Transformer( |
|
dim=dim, |
|
img_size=img_size // 16, |
|
depth=depth[2], |
|
heads=heads[2], |
|
dim_head=self.dim, |
|
mlp_dim=dim * scale_dim, |
|
dropout=dropout |
|
), |
|
Rearrange('b (h w) c -> b c h w', h=img_size // 16, w=img_size // 16) |
|
) |
|
|
|
|
|
in_channels = dim |
|
scale = heads[3] // heads[2] |
|
dim = scale * dim |
|
|
|
self.stage4_conv_embed = nn.Sequential( |
|
nn.Conv2d( |
|
in_channels=in_channels, |
|
out_channels=dim, |
|
kernel_size=kernels[3], |
|
stride=strides[3], |
|
padding=1 |
|
), |
|
Rearrange('b c h w -> b (h w) c', h=img_size // 32, w=img_size // 32), |
|
nn.LayerNorm(dim) |
|
) |
|
|
|
self.stage4_transformer = nn.Sequential( |
|
Transformer( |
|
dim=dim, img_size=img_size // 32, |
|
depth=depth[3], |
|
heads=heads[3], |
|
dim_head=self.dim, |
|
mlp_dim=dim * scale_dim, |
|
dropout=dropout, |
|
), |
|
Rearrange('b (h w) c -> b c h w', h=img_size // 32, w=img_size // 32) |
|
) |
|
|
|
|
|
self.c_stage1 = c_stage123(in_chans=3, out_chans=64) |
|
self.c_stage2 = c_stage123(in_chans=64, out_chans=128) |
|
self.c_stage3 = c_stage123(in_chans=128, out_chans=384) |
|
self.c_stage4 = c_stage45(in_chans=384, out_chans=512) |
|
self.c_stage5 = c_stage45(in_chans=512, out_chans=1024) |
|
self.c_max = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) |
|
self.up_conv1 = nn.Conv2d(in_channels=192, out_channels=128, kernel_size=1) |
|
self.up_conv2 = nn.Conv2d(in_channels=384, out_channels=512, kernel_size=1) |
|
|
|
|
|
self.CTmerge1 = nn.Sequential( |
|
nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
) |
|
self.CTmerge2 = nn.Sequential( |
|
nn.Conv2d(in_channels=320, out_channels=128, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
) |
|
self.CTmerge3 = nn.Sequential( |
|
nn.Conv2d(in_channels=768, out_channels=512, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(512), |
|
nn.ReLU(), |
|
nn.Conv2d(in_channels=512, out_channels=384, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(384), |
|
nn.ReLU(), |
|
nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(384), |
|
nn.ReLU(), |
|
) |
|
|
|
self.CTmerge4 = nn.Sequential( |
|
nn.Conv2d(in_channels=896, out_channels=640, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(640), |
|
nn.ReLU(), |
|
nn.Conv2d(in_channels=640, out_channels=512, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(512), |
|
nn.ReLU(), |
|
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(512), |
|
nn.ReLU(), |
|
) |
|
|
|
|
|
self.decoder4 = nn.Sequential( |
|
DepthwiseConv2d( |
|
in_chans=1408, |
|
out_chans=1024, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1 |
|
), |
|
DepthwiseConv2d( |
|
in_chans=1024, |
|
out_chans=512, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1 |
|
), |
|
nn.GELU() |
|
) |
|
self.decoder3 = nn.Sequential( |
|
DepthwiseConv2d( |
|
in_chans=896, |
|
out_chans=512, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1 |
|
), |
|
DepthwiseConv2d( |
|
in_chans=512, |
|
out_chans=384, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1 |
|
), |
|
nn.GELU() |
|
) |
|
|
|
self.decoder2 = nn.Sequential( |
|
DepthwiseConv2d( |
|
in_chans=576, |
|
out_chans=256, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1 |
|
), |
|
DepthwiseConv2d( |
|
in_chans=256, |
|
out_chans=192, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1 |
|
), |
|
nn.GELU() |
|
) |
|
|
|
self.decoder1 = nn.Sequential( |
|
DepthwiseConv2d( |
|
in_chans=256, |
|
out_chans=64, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1 |
|
), |
|
DepthwiseConv2d( |
|
in_chans=64, |
|
out_chans=16, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1 |
|
), |
|
nn.GELU() |
|
) |
|
self.sbr4 = SBR(512) |
|
self.sbr3 = SBR(384) |
|
self.sbr2 = SBR(192) |
|
self.sbr1 = SBR(16) |
|
|
|
self.head = nn.Conv2d(in_channels=16, out_channels=num_classes, kernel_size=1) |
|
|
|
def forward(self, input): |
|
|
|
|
|
|
|
|
|
|
|
|
|
t_s1 = self.stage1_conv_embed(input) |
|
|
|
|
|
|
|
|
|
t_s1 = self.stage1_transformer(t_s1) |
|
|
|
|
|
|
|
|
|
c_s1 = self.c_stage1(input) |
|
|
|
|
|
|
|
|
|
stage1 = self.CTmerge1(torch.cat([t_s1, self.c_max(c_s1)], dim=1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
t_s2 = self.stage2_conv_embed(stage1) |
|
|
|
|
|
|
|
t_s2 = self.stage2_transformer(t_s2) |
|
|
|
|
|
|
|
c_s2 = self.c_stage2(c_s1) |
|
stage2 = self.CTmerge2( |
|
torch.cat([c_s2, F.interpolate(t_s2, size=c_s2.size()[2:], mode='bilinear', align_corners=True)], |
|
dim=1)) |
|
|
|
|
|
|
|
t_s3 = self.stage3_conv_embed(t_s2) |
|
|
|
|
|
t_s3 = self.stage3_transformer(t_s3) |
|
|
|
|
|
c_s3 = self.c_stage3(stage2) |
|
stage3 = self.CTmerge3(torch.cat([t_s3, self.c_max(c_s3)], dim=1)) |
|
|
|
|
|
|
|
t_s4 = self.stage4_conv_embed(stage3) |
|
|
|
|
|
t_s4 = self.stage4_transformer(t_s4) |
|
|
|
|
|
|
|
c_s4 = self.c_stage4(c_s3) |
|
stage4 = self.CTmerge4( |
|
torch.cat([c_s4, F.interpolate(t_s4, size=c_s4.size()[2:], mode='bilinear', align_corners=True)], |
|
dim=1)) |
|
|
|
|
|
c_s5 = self.c_stage5(stage4) |
|
|
|
|
|
decoder4 = torch.cat([c_s5, t_s4], dim=1) |
|
decoder4 = self.decoder4(decoder4) |
|
decoder4 = F.interpolate(decoder4, size=c_s3.size()[2:], mode='bilinear', |
|
align_corners=True) |
|
decoder4 = self.sbr4(decoder4) |
|
|
|
|
|
decoder3 = torch.cat([decoder4, c_s3], dim=1) |
|
decoder3 = self.decoder3(decoder3) |
|
decoder3 = F.interpolate(decoder3, size=t_s2.size()[2:], mode='bilinear', align_corners=True) |
|
decoder3 = self.sbr3(decoder3) |
|
|
|
|
|
decoder2 = torch.cat([decoder3, t_s2], dim=1) |
|
decoder2 = self.decoder2(decoder2) |
|
decoder2 = F.interpolate(decoder2, size=c_s1.size()[2:], mode='bilinear', align_corners=True) |
|
decoder2 = self.sbr2(decoder2) |
|
|
|
|
|
decoder1 = torch.cat([decoder2, c_s1], dim=1) |
|
decoder1 = self.decoder1(decoder1) |
|
|
|
final = F.interpolate(decoder1, size=input.size()[2:], mode='bilinear', align_corners=True) |
|
|
|
|
|
|
|
final = self.head(final) |
|
|
|
return final |
|
|
|
|
|
if __name__ == '__main__': |
|
x = torch.rand(1, 3, 224, 224).cuda() |
|
model = DBNet(img_size=224, in_channels=3, num_classes=7).cuda() |
|
y = model(x) |
|
print(y.shape) |
|
|