# -*- coding: utf-8 -*- # @Time : 2024/7/26 上午11:19 # @Author : xiaoshun # @Email : 3038523973@qq.com # @File : dbnet.py # @Software: PyCharm import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange # from models.Transformer.ViT import truncated_normal_ # Decoder细化卷积模块 class SBR(nn.Module): def __init__(self, in_ch): super(SBR, self).__init__() self.conv1x3 = nn.Sequential( nn.Conv2d(in_ch, in_ch, kernel_size=(1, 3), stride=1, padding=(0, 1)), nn.BatchNorm2d(in_ch), nn.ReLU(True) ) self.conv3x1 = nn.Sequential( nn.Conv2d(in_ch, in_ch, kernel_size=(3, 1), stride=1, padding=(1, 0)), nn.BatchNorm2d(in_ch), nn.ReLU(True) ) def forward(self, x): out = self.conv3x1(self.conv1x3(x)) # 先进行1x3的卷积,得到结果并将结果再进行3x1的卷积 return out + x # 下采样卷积模块 stage 1,2,3 class c_stage123(nn.Module): def __init__(self, in_chans, out_chans): super().__init__() self.stage123 = nn.Sequential( nn.Conv2d(in_channels=in_chans, out_channels=out_chans, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(out_chans), nn.ReLU(), nn.Conv2d(in_channels=out_chans, out_channels=out_chans, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_chans), nn.ReLU(), ) self.conv1x1_123 = nn.Conv2d(in_channels=in_chans, out_channels=out_chans, kernel_size=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) def forward(self, x): stage123 = self.stage123(x) # 3*3卷积,两倍下采样 3*224*224-->64*112*112 max = self.maxpool(x) # 最大值池化,两倍下采样 3*224*224-->3*112*112 max = self.conv1x1_123(max) # 1*1卷积 3*112*112-->64*112*112 stage123 = stage123 + max # 残差结构,广播机制 return stage123 # 下采样卷积模块 stage4,5 class c_stage45(nn.Module): def __init__(self, in_chans, out_chans): super().__init__() self.stage45 = nn.Sequential( nn.Conv2d(in_channels=in_chans, out_channels=out_chans, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(out_chans), nn.ReLU(), nn.Conv2d(in_channels=out_chans, out_channels=out_chans, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_chans), nn.ReLU(), nn.Conv2d(in_channels=out_chans, out_channels=out_chans, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_chans), nn.ReLU(), ) self.conv1x1_45 = nn.Conv2d(in_channels=in_chans, out_channels=out_chans, kernel_size=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) def forward(self, x): stage45 = self.stage45(x) # 3*3卷积模块 2倍下采样 max = self.maxpool(x) # 最大值池化,两倍下采样 max = self.conv1x1_45(max) # 1*1卷积模块 调整通道数 stage45 = stage45 + max # 残差结构 return stage45 class Identity(nn.Module): # 恒等映射 def __init__(self): super().__init__() def forward(self, x): return x # 轻量卷积模块 class DepthwiseConv2d(nn.Module): # 用于自注意力机制 def __init__(self, in_chans, out_chans, kernel_size=1, stride=1, padding=0, dilation=1): super().__init__() # depthwise conv self.depthwise = nn.Conv2d( in_channels=in_chans, out_channels=in_chans, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, # 深层卷积的膨胀率 groups=in_chans # 指定分组卷积的组数 ) # batch norm self.bn = nn.BatchNorm2d(num_features=in_chans) # pointwise conv 逐点卷积 self.pointwise = nn.Conv2d( in_channels=in_chans, out_channels=out_chans, kernel_size=1 ) def forward(self, x): x = self.depthwise(x) x = self.bn(x) x = self.pointwise(x) return x # residual skip connection 残差跳跃连接 class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, input, **kwargs): x = self.fn(input, **kwargs) return (x + input) # layer norm plus 层归一化 class PreNorm(nn.Module): # 代表神经网络层 def __init__(self, dim, fn): super().__init__() self.norm = nn.LayerNorm(dim) self.fn = fn def forward(self, input, **kwargs): return self.fn(self.norm(input), **kwargs) # FeedForward层使得representation的表达能力更强 class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout=0.): super().__init__() self.net = nn.Sequential( nn.Linear(in_features=dim, out_features=hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(in_features=hidden_dim, out_features=dim), nn.Dropout(dropout) ) def forward(self, input): return self.net(input) class ConvAttnetion(nn.Module): ''' using the Depth_Separable_Wise Conv2d to produce the q, k, v instead of using Linear Project in ViT ''' def __init__(self, dim, img_size, heads=8, dim_head=64, kernel_size=3, q_stride=1, k_stride=1, v_stride=1, dropout=0., last_stage=False): super().__init__() self.last_stage = last_stage self.img_size = img_size inner_dim = dim_head * heads # 512 project_out = not (heads == 1 and dim_head == dim) self.heads = heads self.scale = dim_head ** (-0.5) pad = (kernel_size - q_stride) // 2 self.to_q = DepthwiseConv2d(in_chans=dim, out_chans=inner_dim, kernel_size=kernel_size, stride=q_stride, padding=pad) # 自注意力机制 self.to_k = DepthwiseConv2d(in_chans=dim, out_chans=inner_dim, kernel_size=kernel_size, stride=k_stride, padding=pad) self.to_v = DepthwiseConv2d(in_chans=dim, out_chans=inner_dim, kernel_size=kernel_size, stride=v_stride, padding=pad) self.to_out = nn.Sequential( nn.Linear( in_features=inner_dim, out_features=dim ), nn.Dropout(dropout) ) if project_out else Identity() def forward(self, x): b, n, c, h = *x.shape, self.heads # * 星号的作用大概是去掉 tuple 属性吧 # print(x.shape) # print('+++++++++++++++++++++++++++++++++') # if语句内容没有使用 if self.last_stage: cls_token = x[:, 0] # print(cls_token.shape) # print('+++++++++++++++++++++++++++++++++') x = x[:, 1:] # 去掉每个数组的第一个元素 cls_token = rearrange(torch.unsqueeze(cls_token, dim=1), 'b n (h d) -> b h n d', h=h) # rearrange:用于对张量的维度进行重新变换排序,可用于替换pytorch中的reshape,view,transpose和permute等操作 x = rearrange(x, 'b (l w) n -> b n l w', l=self.img_size, w=self.img_size) # [1, 3136, 64]-->1*64*56*56 # batch_size,N(通道数),h,w q = self.to_q(x) # 1*64*56*56-->1*64*56*56 # print(q.shape) # print('++++++++++++++') q = rearrange(q, 'b (h d) l w -> b h (l w) d', h=h) # 1*64*56*56-->1*1*3136*64 # print(q.shape) # print('=====================') # batch_size,head,h*w,dim_head k = self.to_k(x) # 操作和q一样 k = rearrange(k, 'b (h d) l w -> b h (l w) d', h=h) # batch_size,head,h*w,dim_head v = self.to_v(x) ##操作和q一样 # print(v.shape) # print('[[[[[[[[[[[[[[[[[[[[[[[[[[[[') v = rearrange(v, 'b (h d) l w -> b h (l w) d', h=h) # print(v.shape) # print(']]]]]]]]]]]]]]]]]]]]]]]]]]]') # batch_size,head,h*w,dim_head if self.last_stage: # print(q.shape) # print('================') q = torch.cat([cls_token, q], dim=2) # print(q.shape) # print('++++++++++++++++++') v = torch.cat([cls_token, v], dim=2) k = torch.cat([cls_token, k], dim=2) # calculate attention by matmul + scale # permute:(batch_size,head,dim_head,h*w # print(k.shape) # print('++++++++++++++++++++') k = k.permute(0, 1, 3, 2) # 1*1*3136*64-->1*1*64*3136 # print(k.shape) # print('====================') attention = (q.matmul(k)) # 1*1*3136*3136 # print(attention.shape) # print('--------------------') attention = attention * self.scale # 可以得到一个logit的向量,避免出现梯度下降和梯度爆炸 # print(attention.shape) # print('####################') # pass a softmax attention = F.softmax(attention, dim=-1) # print(attention.shape) # print('********************') # matmul v # attention.matmul(v):(batch_size,head,h*w,dim_head) # permute:(batch_size,h*w,head,dim_head) out = (attention.matmul(v)).permute(0, 2, 1, 3).reshape(b, n, c) # 1*3136*64 这些操作的目的是将注意力权重和值向量相乘后得到的结果进行重塑,得到一个形状为 (batch size, 序列长度, 值向量或矩阵的维度) 的张量 # linear project out = self.to_out(out) return out # Reshape Layers class Rearrange(nn.Module): def __init__(self, string, h, w): super().__init__() self.string = string self.h = h self.w = w def forward(self, input): if self.string == 'b c h w -> b (h w) c': N, C, H, W = input.shape # print(input.shape) x = torch.reshape(input, shape=(N, -1, self.h * self.w)).permute(0, 2, 1) # print(x.shape) # print('+++++++++++++++++++') if self.string == 'b (h w) c -> b c h w': N, _, C = input.shape # print(input.shape) x = torch.reshape(input, shape=(N, self.h, self.w, -1)).permute(0, 3, 1, 2) # print(x.shape) # print('=====================') return x # Transformer layers class Transformer(nn.Module): def __init__(self, dim, img_size, depth, heads, dim_head, mlp_dim, dropout=0., last_stage=False): super().__init__() self.layers = nn.ModuleList([ # 管理子模块,参数注册 nn.ModuleList([ PreNorm(dim=dim, fn=ConvAttnetion(dim, img_size, heads=heads, dim_head=dim_head, dropout=dropout, last_stage=last_stage)), # 归一化,重参数化 PreNorm(dim=dim, fn=FeedForward(dim=dim, hidden_dim=mlp_dim, dropout=dropout)) ]) for _ in range(depth) ]) def forward(self, x): for attn, ff in self.layers: x = x + attn(x) x = x + ff(x) return x class DBNet(nn.Module): # 最主要的大函数 def __init__(self, img_size, in_channels, num_classes, dim=64, kernels=[7, 3, 3, 3], strides=[4, 2, 2, 2], heads=[1, 3, 6, 6], depth=[1, 2, 10, 10], pool='cls', dropout=0., emb_dropout=0., scale_dim=4, ): super().__init__() assert pool in ['cls', 'mean'], f'pool type must be either cls or mean pooling' self.pool = pool self.dim = dim # stage1 # k:7 s:4 in: 1, 64, 56, 56 out: 1, 3136, 64 self.stage1_conv_embed = nn.Sequential( nn.Conv2d( # 1*3*224*224-->[1, 64, 56, 56] in_channels=in_channels, out_channels=dim, kernel_size=kernels[0], stride=strides[0], padding=2 ), Rearrange('b c h w -> b (h w) c', h=img_size // 4, w=img_size // 4), # [1, 64, 56, 56]-->[1, 3136, 64] nn.LayerNorm(dim) # 对每个batch归一化 ) self.stage1_transformer = nn.Sequential( Transformer( # dim=dim, img_size=img_size // 4, depth=depth[0], # Transformer层中的编码器和解码器层数。 heads=heads[0], dim_head=self.dim, # 它是每个注意力头的维度大小,通常是嵌入维度除以头数。 mlp_dim=dim * scale_dim, # mlp_dim:它是Transformer中前馈神经网络的隐藏层维度大小,通常是嵌入维度乘以一个缩放因子。 dropout=dropout, # last_stage=last_stage #它是一个标志位,用于表示该Transformer层是否是最后一层。 ), Rearrange('b (h w) c -> b c h w', h=img_size // 4, w=img_size // 4) ) # stage2 # k:3 s:2 in: 1, 192, 28, 28 out: 1, 784, 192 in_channels = dim scale = heads[1] // heads[0] dim = scale * dim self.stage2_conv_embed = nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=dim, kernel_size=kernels[1], stride=strides[1], padding=1 ), Rearrange('b c h w -> b (h w) c', h=img_size // 8, w=img_size // 8), nn.LayerNorm(dim) ) self.stage2_transformer = nn.Sequential( Transformer( dim=dim, img_size=img_size // 8, depth=depth[1], heads=heads[1], dim_head=self.dim, mlp_dim=dim * scale_dim, dropout=dropout ), Rearrange('b (h w) c -> b c h w', h=img_size // 8, w=img_size // 8) ) # stage3 in_channels = dim scale = heads[2] // heads[1] dim = scale * dim self.stage3_conv_embed = nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=dim, kernel_size=kernels[2], stride=strides[2], padding=1 ), Rearrange('b c h w -> b (h w) c', h=img_size // 16, w=img_size // 16), nn.LayerNorm(dim) ) self.stage3_transformer = nn.Sequential( Transformer( dim=dim, img_size=img_size // 16, depth=depth[2], heads=heads[2], dim_head=self.dim, mlp_dim=dim * scale_dim, dropout=dropout ), Rearrange('b (h w) c -> b c h w', h=img_size // 16, w=img_size // 16) ) # stage4 in_channels = dim scale = heads[3] // heads[2] dim = scale * dim self.stage4_conv_embed = nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=dim, kernel_size=kernels[3], stride=strides[3], padding=1 ), Rearrange('b c h w -> b (h w) c', h=img_size // 32, w=img_size // 32), nn.LayerNorm(dim) ) self.stage4_transformer = nn.Sequential( Transformer( dim=dim, img_size=img_size // 32, depth=depth[3], heads=heads[3], dim_head=self.dim, mlp_dim=dim * scale_dim, dropout=dropout, ), Rearrange('b (h w) c -> b c h w', h=img_size // 32, w=img_size // 32) ) ### CNN Branch ### self.c_stage1 = c_stage123(in_chans=3, out_chans=64) self.c_stage2 = c_stage123(in_chans=64, out_chans=128) self.c_stage3 = c_stage123(in_chans=128, out_chans=384) self.c_stage4 = c_stage45(in_chans=384, out_chans=512) self.c_stage5 = c_stage45(in_chans=512, out_chans=1024) self.c_max = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.up_conv1 = nn.Conv2d(in_channels=192, out_channels=128, kernel_size=1) self.up_conv2 = nn.Conv2d(in_channels=384, out_channels=512, kernel_size=1) ### CTmerge ### self.CTmerge1 = nn.Sequential( nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(), ) self.CTmerge2 = nn.Sequential( nn.Conv2d(in_channels=320, out_channels=128, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(128), nn.ReLU(), ) self.CTmerge3 = nn.Sequential( nn.Conv2d(in_channels=768, out_channels=512, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(512), nn.ReLU(), nn.Conv2d(in_channels=512, out_channels=384, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(384), nn.ReLU(), nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(384), nn.ReLU(), ) self.CTmerge4 = nn.Sequential( nn.Conv2d(in_channels=896, out_channels=640, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(640), nn.ReLU(), nn.Conv2d(in_channels=640, out_channels=512, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(512), nn.ReLU(), nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(512), nn.ReLU(), ) # decoder self.decoder4 = nn.Sequential( DepthwiseConv2d( in_chans=1408, out_chans=1024, kernel_size=3, stride=1, padding=1 ), DepthwiseConv2d( in_chans=1024, out_chans=512, kernel_size=3, stride=1, padding=1 ), nn.GELU() ) self.decoder3 = nn.Sequential( DepthwiseConv2d( in_chans=896, out_chans=512, kernel_size=3, stride=1, padding=1 ), DepthwiseConv2d( in_chans=512, out_chans=384, kernel_size=3, stride=1, padding=1 ), nn.GELU() ) self.decoder2 = nn.Sequential( DepthwiseConv2d( in_chans=576, out_chans=256, kernel_size=3, stride=1, padding=1 ), DepthwiseConv2d( in_chans=256, out_chans=192, kernel_size=3, stride=1, padding=1 ), nn.GELU() ) self.decoder1 = nn.Sequential( DepthwiseConv2d( in_chans=256, out_chans=64, kernel_size=3, stride=1, padding=1 ), DepthwiseConv2d( in_chans=64, out_chans=16, kernel_size=3, stride=1, padding=1 ), nn.GELU() ) self.sbr4 = SBR(512) self.sbr3 = SBR(384) self.sbr2 = SBR(192) self.sbr1 = SBR(16) self.head = nn.Conv2d(in_channels=16, out_channels=num_classes, kernel_size=1) def forward(self, input): ### encoder ### # stage1 = ts1 cat cs1 # t_s1 = self.t_stage1(input) # print(input.shape) # print('++++++++++++++++++++++') t_s1 = self.stage1_conv_embed(input) # 1*3*224*224-->1*3136*64 # print(t_s1.shape) # print('======================') t_s1 = self.stage1_transformer(t_s1) # 1*3136*64-->1*64*56*56 # print(t_s1.shape) # print('----------------------') c_s1 = self.c_stage1(input) # 1*3*224*224-->1*64*112*112 # print(c_s1.shape) # print('!!!!!!!!!!!!!!!!!!!!!!!') stage1 = self.CTmerge1(torch.cat([t_s1, self.c_max(c_s1)], dim=1)) # 1*64*56*56 # 拼接两条分支 # print(stage1.shape) # print('[[[[[[[[[[[[[[[[[[[[[[[') # stage2 = ts2 up cs2 # t_s2 = self.t_stage2(stage1) t_s2 = self.stage2_conv_embed(stage1) # 1*64*56*56-->1*784*192 # stage2_conv_embed是转化为序列操作 # print(t_s2.shape) # print('[[[[[[[[[[[[[[[[[[[[[[[') t_s2 = self.stage2_transformer(t_s2) # 1*784*192-->1*192*28*28 # print(t_s2.shape) # print('+++++++++++++++++++++++++') c_s2 = self.c_stage2(c_s1) # 1*64*112*112-->1*128*56*56 stage2 = self.CTmerge2( torch.cat([c_s2, F.interpolate(t_s2, size=c_s2.size()[2:], mode='bilinear', align_corners=True)], dim=1)) # mode='bilinear'表示使用双线性插值 1*128*56*56 # stage3 = ts3 cat cs3 # t_s3 = self.t_stage3(t_s2) t_s3 = self.stage3_conv_embed(t_s2) # 1*192*28*28-->1*196*384 # print(t_s3.shape) # print('///////////////////////') t_s3 = self.stage3_transformer(t_s3) # 1*196*384-->1*384*14*14 # print(t_s3.shape) # print('....................') c_s3 = self.c_stage3(stage2) # 1*128*56*56-->1*384*28*28 stage3 = self.CTmerge3(torch.cat([t_s3, self.c_max(c_s3)], dim=1)) # 1*384*14*14 # stage4 = ts4 up cs4 # t_s4 = self.t_stage4(stage3) t_s4 = self.stage4_conv_embed(stage3) # 1*384*14*14-->1*49*384 # print(t_s4.shape) # print(';;;;;;;;;;;;;;;;;;;;;;;') t_s4 = self.stage4_transformer(t_s4) # 1*49*384-->1*384*7*7 # print(t_s4.shape) # print('::::::::::::::::::::') c_s4 = self.c_stage4(c_s3) # 1*384*28*28-->1*512*14*14 stage4 = self.CTmerge4( torch.cat([c_s4, F.interpolate(t_s4, size=c_s4.size()[2:], mode='bilinear', align_corners=True)], dim=1)) # 1*512*14*14 # cs5 c_s5 = self.c_stage5(stage4) # 1*512*14*14-->1*1024*7*7 ### decoder ### decoder4 = torch.cat([c_s5, t_s4], dim=1) # 1*1408*7*7 decoder4 = self.decoder4(decoder4) # 1*1408*7*7-->1*512*7*7 decoder4 = F.interpolate(decoder4, size=c_s3.size()[2:], mode='bilinear', align_corners=True) # 1*512*7*7-->1*512*28*28 decoder4 = self.sbr4(decoder4) # 1*512*28*28 # print(decoder4.shape) decoder3 = torch.cat([decoder4, c_s3], dim=1) # 1*896*28*28 decoder3 = self.decoder3(decoder3) # 1*384*28*28 decoder3 = F.interpolate(decoder3, size=t_s2.size()[2:], mode='bilinear', align_corners=True) # 1*384*28*28 decoder3 = self.sbr3(decoder3) # 1*384*28*28 # print(decoder3.shape) decoder2 = torch.cat([decoder3, t_s2], dim=1) # 1*576*28*28 decoder2 = self.decoder2(decoder2) # 1*192*28*28 decoder2 = F.interpolate(decoder2, size=c_s1.size()[2:], mode='bilinear', align_corners=True) # 1*192*112*112 decoder2 = self.sbr2(decoder2) # 1*192*112*112 # print(decoder2.shape) decoder1 = torch.cat([decoder2, c_s1], dim=1) # 1*256*112*112 decoder1 = self.decoder1(decoder1) # 1*16*112*112 # print(decoder1.shape) final = F.interpolate(decoder1, size=input.size()[2:], mode='bilinear', align_corners=True) # 1*16*224*224 # print(final.shape) # final = self.sbr1(decoder1) # print(final.shape) final = self.head(final) # 1*3*224*224 return final if __name__ == '__main__': x = torch.rand(1, 3, 224, 224).cuda() model = DBNet(img_size=224, in_channels=3, num_classes=7).cuda() y = model(x) print(y.shape) # torch.Size([1, 7, 224, 224])