|
|
|
|
|
|
|
import torch |
|
from torch import nn as nn |
|
from torch.nn import functional as F |
|
import os, sys |
|
import numpy as np |
|
from time import time as ttime, sleep |
|
|
|
|
|
class UNet_Full(nn.Module): |
|
|
|
def __init__(self): |
|
super(UNet_Full, self).__init__() |
|
self.unet1 = UNet1(3, 3, deconv=True) |
|
self.unet2 = UNet2(3, 3, deconv=False) |
|
|
|
def forward(self, x): |
|
n, c, h0, w0 = x.shape |
|
|
|
ph = ((h0 - 1) // 2 + 1) * 2 |
|
pw = ((w0 - 1) // 2 + 1) * 2 |
|
x = F.pad(x, (18, 18 + pw - w0, 18, 18 + ph - h0), 'reflect') |
|
|
|
x1 = self.unet1(x) |
|
x2 = self.unet2(x1) |
|
|
|
x1 = F.pad(x1, (-20, -20, -20, -20)) |
|
output = torch.add(x2, x1) |
|
|
|
if (w0 != pw or h0 != ph): |
|
output = output[:, :, :h0 * 2, :w0 * 2] |
|
|
|
return output |
|
|
|
|
|
class SEBlock(nn.Module): |
|
def __init__(self, in_channels, reduction=8, bias=False): |
|
super(SEBlock, self).__init__() |
|
self.conv1 = nn.Conv2d(in_channels, in_channels // reduction, 1, 1, 0, bias=bias) |
|
self.conv2 = nn.Conv2d(in_channels // reduction, in_channels, 1, 1, 0, bias=bias) |
|
|
|
def forward(self, x): |
|
if ("Half" in x.type()): |
|
x0 = torch.mean(x.float(), dim=(2, 3), keepdim=True).half() |
|
else: |
|
x0 = torch.mean(x, dim=(2, 3), keepdim=True) |
|
x0 = self.conv1(x0) |
|
x0 = F.relu(x0, inplace=True) |
|
x0 = self.conv2(x0) |
|
x0 = torch.sigmoid(x0) |
|
x = torch.mul(x, x0) |
|
return x |
|
|
|
class UNetConv(nn.Module): |
|
def __init__(self, in_channels, mid_channels, out_channels, se): |
|
super(UNetConv, self).__init__() |
|
self.conv = nn.Sequential( |
|
nn.Conv2d(in_channels, mid_channels, 3, 1, 0), |
|
nn.LeakyReLU(0.1, inplace=True), |
|
nn.Conv2d(mid_channels, out_channels, 3, 1, 0), |
|
nn.LeakyReLU(0.1, inplace=True), |
|
) |
|
if se: |
|
self.seblock = SEBlock(out_channels, reduction=8, bias=True) |
|
else: |
|
self.seblock = None |
|
|
|
def forward(self, x): |
|
z = self.conv(x) |
|
if self.seblock is not None: |
|
z = self.seblock(z) |
|
return z |
|
|
|
class UNet1(nn.Module): |
|
def __init__(self, in_channels, out_channels, deconv): |
|
super(UNet1, self).__init__() |
|
self.conv1 = UNetConv(in_channels, 32, 64, se=False) |
|
self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0) |
|
self.conv2 = UNetConv(64, 128, 64, se=True) |
|
self.conv2_up = nn.ConvTranspose2d(64, 64, 2, 2, 0) |
|
self.conv3 = nn.Conv2d(64, 64, 3, 1, 0) |
|
|
|
if deconv: |
|
self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 4, 2, 3) |
|
else: |
|
self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0) |
|
|
|
for m in self.modules(): |
|
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): |
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
|
elif isinstance(m, nn.Linear): |
|
nn.init.normal_(m.weight, 0, 0.01) |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def forward(self, x): |
|
x1 = self.conv1(x) |
|
x2 = self.conv1_down(x1) |
|
x2 = F.leaky_relu(x2, 0.1, inplace=True) |
|
x2 = self.conv2(x2) |
|
x2 = self.conv2_up(x2) |
|
x2 = F.leaky_relu(x2, 0.1, inplace=True) |
|
|
|
x1 = F.pad(x1, (-4, -4, -4, -4)) |
|
x3 = self.conv3(x1 + x2) |
|
x3 = F.leaky_relu(x3, 0.1, inplace=True) |
|
z = self.conv_bottom(x3) |
|
return z |
|
|
|
|
|
class UNet2(nn.Module): |
|
def __init__(self, in_channels, out_channels, deconv): |
|
super(UNet2, self).__init__() |
|
|
|
self.conv1 = UNetConv(in_channels, 32, 64, se=False) |
|
self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0) |
|
self.conv2 = UNetConv(64, 64, 128, se=True) |
|
self.conv2_down = nn.Conv2d(128, 128, 2, 2, 0) |
|
self.conv3 = UNetConv(128, 256, 128, se=True) |
|
self.conv3_up = nn.ConvTranspose2d(128, 128, 2, 2, 0) |
|
self.conv4 = UNetConv(128, 64, 64, se=True) |
|
self.conv4_up = nn.ConvTranspose2d(64, 64, 2, 2, 0) |
|
self.conv5 = nn.Conv2d(64, 64, 3, 1, 0) |
|
|
|
if deconv: |
|
self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 4, 2, 3) |
|
else: |
|
self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0) |
|
|
|
for m in self.modules(): |
|
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): |
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
|
elif isinstance(m, nn.Linear): |
|
nn.init.normal_(m.weight, 0, 0.01) |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def forward(self, x): |
|
x1 = self.conv1(x) |
|
x2 = self.conv1_down(x1) |
|
x2 = F.leaky_relu(x2, 0.1, inplace=True) |
|
x2 = self.conv2(x2) |
|
|
|
x3 = self.conv2_down(x2) |
|
x3 = F.leaky_relu(x3, 0.1, inplace=True) |
|
x3 = self.conv3(x3) |
|
x3 = self.conv3_up(x3) |
|
x3 = F.leaky_relu(x3, 0.1, inplace=True) |
|
|
|
x2 = F.pad(x2, (-4, -4, -4, -4)) |
|
x4 = self.conv4(x2 + x3) |
|
x4 = self.conv4_up(x4) |
|
x4 = F.leaky_relu(x4, 0.1, inplace=True) |
|
|
|
x1 = F.pad(x1, (-16, -16, -16, -16)) |
|
x5 = self.conv5(x1 + x4) |
|
x5 = F.leaky_relu(x5, 0.1, inplace=True) |
|
|
|
z = self.conv_bottom(x5) |
|
return z |
|
|
|
|
|
|
|
def main(): |
|
root_path = os.path.abspath('.') |
|
sys.path.append(root_path) |
|
|
|
from opt import opt |
|
import time |
|
|
|
model = UNet_Full().cuda() |
|
pytorch_total_params = sum(p.numel() for p in model.parameters()) |
|
print(f"CuNet has param {pytorch_total_params//1000} K params") |
|
|
|
|
|
|
|
x = torch.randn((1, 3, 180, 180)).cuda() |
|
start = time.time() |
|
x = model(x) |
|
print("output size is ", x.shape) |
|
total = time.time() - start |
|
print(total) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |