import torch import torch.nn as nn import torch.nn.functional as F def init_weight(layer): nn.init.xavier_uniform_(layer.weight) if layer.bias is not None: nn.init.constant_(layer.bias, 0) class ConvBlock(nn.Module): def __init__(self, in_channel, out_channel, dilation=1): super(ConvBlock, self).__init__() self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=dilation, dilation=dilation) self.bn = nn.BatchNorm2d(out_channel) self.relu = nn.ReLU(inplace=True) init_weight(self.conv) def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) return x class RSU(nn.Module): def __init__(self, L, C_in, C_out, M): super(RSU, self).__init__() self.conv = ConvBlock(C_in, C_out) self.enc = nn.ModuleList([ConvBlock(C_out, M)]) for _ in range(L-2): self.enc.append(ConvBlock(M, M)) self.mid = ConvBlock(M, M, dilation=2) self.dec = nn.ModuleList([ConvBlock(2*M, M) for _ in range(L-2)]) self.dec.append(ConvBlock(2*M, C_out)) self.downsample = nn.MaxPool2d(2, stride=2) self.upsample = nn.Upsample(scale_factor=2, mode='bilinear') def forward(self, x): x = self.conv(x) out = [] for i, enc in enumerate(self.enc): if i == 0: out.append(enc(x)) else: out.append(enc(self.downsample(out[i-1]))) y = self.mid(out[-1]) for i, dec in enumerate(self.dec): if i > 0: y = self.upsample(y) y = dec(torch.cat((out[len(self.dec)-i-1], y), dim=1)) return x + y class RSU4F(nn.Module): def __init__(self, C_in, C_out, M): super(RSU4F, self).__init__() self.conv = ConvBlock(C_in, C_out) self.enc = nn.ModuleList([ ConvBlock(C_out, M), ConvBlock(M, M, dilation=2), ConvBlock(M, M, dilation=4) ]) self.mid = ConvBlock(M, M, dilation=8) self.dec = nn.ModuleList([ ConvBlock(2*M, M, dilation=4), ConvBlock(2*M, M, dilation=2), ConvBlock(2*M, C_out) ]) def forward(self, x): x = self.conv(x) out = [] for i, enc in enumerate(self.enc): if i == 0: out.append(enc(x)) else: out.append(enc(out[i-1])) y = self.mid(out[-1]) for i, dec in enumerate(self.dec): y = dec(torch.cat((out[len(self.dec)-i-1], y), dim=1)) return x + y class U2Net(nn.Module): def __init__(self): super(U2Net, self).__init__() self.enc = nn.ModuleList([ RSU(L=7, C_in=3, C_out=64, M=32), RSU(L=6, C_in=64, C_out=128, M=32), RSU(L=5, C_in=128, C_out=256, M=64), RSU(L=4, C_in=256, C_out=512, M=128), RSU4F(C_in=512, C_out=512, M=256), RSU4F(C_in=512, C_out=512, M=256) ]) self.dec = nn.ModuleList([ RSU4F(C_in=1024, C_out=512, M=256), RSU(L=4, C_in=1024, C_out=256, M=128), RSU(L=5, C_in=512, C_out=128, M=64), RSU(L=6, C_in=256, C_out=64, M=32), RSU(L=7, C_in=128, C_out=64, M=16) ]) self.convs = nn.ModuleList([ nn.Conv2d(64, 1, 3, padding=1), nn.Conv2d(64, 1, 3, padding=1), nn.Conv2d(128, 1, 3, padding=1), nn.Conv2d(256, 1, 3, padding=1), nn.Conv2d(512, 1, 3, padding=1), nn.Conv2d(512, 1, 3, padding=1) ]) self.lastconv = nn.Conv2d(6, 1, 1) self.downsample = nn.MaxPool2d(2, stride=2) init_weight(self.lastconv) for conv in self.convs: init_weight(conv) def upsample(self, x, target): return F.interpolate(x, size=target.shape[2:], mode='bilinear') def forward(self, x): enc_out = [] for i, enc in enumerate(self.enc): if i == 0: enc_out.append(enc(x)) else: enc_out.append(enc(self.downsample(enc_out[i-1]))) dec_out = [enc_out[-1]] for i, dec in enumerate(self.dec): dec_out.append(dec(torch.cat((self.upsample(dec_out[i], enc_out[4-i]), enc_out[4-i]), dim=1))) side_out = [] for i, conv in enumerate(self.convs): if i == 0: side_out.append(conv(dec_out[5])) else: side_out.append(self.upsample(conv(dec_out[5-i]), side_out[0])) side_out.append(self.lastconv(torch.cat(side_out, dim=1))) # logits return [s.squeeze(1) for s in side_out]