|
import math |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import torch.utils.model_zoo as model_zoo |
|
from torch import nn |
|
from torch.nn import Parameter |
|
import pdb |
|
import numpy as np |
|
|
|
|
|
class Conv2d_cd(nn.Module): |
|
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, |
|
padding=1, dilation=1, groups=1, bias=False, theta=0.7): |
|
|
|
super(Conv2d_cd, self).__init__() |
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, |
|
dilation=dilation, groups=groups, bias=bias) |
|
self.theta = theta |
|
|
|
def forward(self, x): |
|
out_normal = self.conv(x) |
|
|
|
if math.fabs(self.theta - 0.0) < 1e-8: |
|
return out_normal |
|
else: |
|
|
|
[C_out, C_in, kernel_size, kernel_size] = self.conv.weight.shape |
|
kernel_diff = self.conv.weight.sum(2).sum(2) |
|
kernel_diff = kernel_diff[:, :, None, None] |
|
out_diff = F.conv2d(input=x, weight=kernel_diff, bias=self.conv.bias, stride=self.conv.stride, padding=0, |
|
groups=self.conv.groups) |
|
|
|
return out_normal - self.theta * out_diff |
|
|
|
|
|
class SpatialAttention(nn.Module): |
|
def __init__(self, kernel=3): |
|
super(SpatialAttention, self).__init__() |
|
|
|
self.conv1 = nn.Conv2d(2, 1, kernel_size=kernel, padding=kernel // 2, bias=False) |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, x): |
|
avg_out = torch.mean(x, dim=1, keepdim=True) |
|
max_out, _ = torch.max(x, dim=1, keepdim=True) |
|
x = torch.cat([avg_out, max_out], dim=1) |
|
x = self.conv1(x) |
|
|
|
return self.sigmoid(x) |
|
|
|
|
|
class CDCN_u(nn.Module): |
|
|
|
def __init__(self, basic_conv=Conv2d_cd, theta=0.7): |
|
super(CDCN_u, self).__init__() |
|
|
|
self.conv1 = nn.Sequential( |
|
basic_conv(3, 64, kernel_size=3, stride=1, padding=1, bias=False, theta=theta), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
) |
|
|
|
self.Block1 = nn.Sequential( |
|
basic_conv(64, 128, kernel_size=3, stride=1, padding=1, bias=False, theta=theta), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
basic_conv(128, 196, kernel_size=3, stride=1, padding=1, bias=False, theta=theta), |
|
nn.BatchNorm2d(196), |
|
nn.ReLU(), |
|
basic_conv(196, 128, kernel_size=3, stride=1, padding=1, bias=False, theta=theta), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.MaxPool2d(kernel_size=3, stride=2, padding=1), |
|
|
|
) |
|
|
|
self.Block2 = nn.Sequential( |
|
basic_conv(128, 128, kernel_size=3, stride=1, padding=1, bias=False, theta=theta), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
basic_conv(128, 196, kernel_size=3, stride=1, padding=1, bias=False, theta=theta), |
|
nn.BatchNorm2d(196), |
|
nn.ReLU(), |
|
basic_conv(196, 128, kernel_size=3, stride=1, padding=1, bias=False, theta=theta), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.MaxPool2d(kernel_size=3, stride=2, padding=1), |
|
) |
|
|
|
self.Block3 = nn.Sequential( |
|
basic_conv(128, 128, kernel_size=3, stride=1, padding=1, bias=False, theta=theta), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
basic_conv(128, 196, kernel_size=3, stride=1, padding=1, bias=False, theta=theta), |
|
nn.BatchNorm2d(196), |
|
nn.ReLU(), |
|
basic_conv(196, 128, kernel_size=3, stride=1, padding=1, bias=False, theta=theta), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.MaxPool2d(kernel_size=3, stride=2, padding=1), |
|
) |
|
|
|
self.lastconv1 = nn.Sequential( |
|
basic_conv(128 * 3, 128, kernel_size=3, stride=1, padding=1, bias=False, theta=theta), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
) |
|
|
|
self.lastconv2 = nn.Sequential( |
|
basic_conv(128, 64, kernel_size=3, stride=1, padding=1, bias=False, theta=theta), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
) |
|
|
|
self.mu_head = nn.Sequential( |
|
nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1, bias=False), |
|
) |
|
|
|
self.logvar_head = nn.Sequential( |
|
nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1, bias=False), |
|
) |
|
|
|
self.downsample32x32 = nn.Upsample(size=(32, 32), mode='bilinear') |
|
|
|
def _reparameterize(self, mu, logvar): |
|
std = torch.exp(logvar).sqrt() |
|
epsilon = torch.randn_like(std) |
|
return mu + epsilon * std |
|
|
|
def forward(self, x): |
|
|
|
x_input = x |
|
x = self.conv1(x) |
|
|
|
x_Block1 = self.Block1(x) |
|
x_Block1_32x32 = self.downsample32x32(x_Block1) |
|
|
|
x_Block2 = self.Block2(x_Block1) |
|
x_Block2_32x32 = self.downsample32x32(x_Block2) |
|
|
|
x_Block3 = self.Block3(x_Block2) |
|
x_Block3_32x32 = self.downsample32x32(x_Block3) |
|
|
|
x_concat = torch.cat((x_Block1_32x32, x_Block2_32x32, x_Block3_32x32), dim=1) |
|
|
|
|
|
|
|
x = self.lastconv1(x_concat) |
|
x = self.lastconv2(x) |
|
|
|
mu = self.mu_head(x) |
|
mu = mu.squeeze(1) |
|
logvar = self.logvar_head(x) |
|
logvar = logvar.squeeze(1) |
|
embedding = self._reparameterize(mu, logvar) |
|
|
|
return mu, logvar, embedding, x_concat, x_Block1, x_Block2, x_Block3, x_input |
|
|
|
|
|
class depthnet_u(nn.Module): |
|
|
|
def __init__(self): |
|
super(depthnet_u, self).__init__() |
|
|
|
self.conv1 = nn.Sequential( |
|
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
) |
|
|
|
self.Block1 = nn.Sequential( |
|
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.Conv2d(128, 196, kernel_size=3, stride=1, padding=1, bias=False), |
|
nn.BatchNorm2d(196), |
|
nn.ReLU(), |
|
nn.Conv2d(196, 128, kernel_size=3, stride=1, padding=1, bias=False), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.MaxPool2d(kernel_size=3, stride=2, padding=1), |
|
|
|
) |
|
|
|
self.Block2 = nn.Sequential( |
|
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.Conv2d(128, 196, kernel_size=3, stride=1, padding=1, bias=False), |
|
nn.BatchNorm2d(196), |
|
nn.ReLU(), |
|
nn.Conv2d(196, 128, kernel_size=3, stride=1, padding=1, bias=False), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.MaxPool2d(kernel_size=3, stride=2, padding=1), |
|
) |
|
|
|
self.Block3 = nn.Sequential( |
|
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.Conv2d(128, 196, kernel_size=3, stride=1, padding=1, bias=False), |
|
nn.BatchNorm2d(196), |
|
nn.ReLU(), |
|
nn.Conv2d(196, 128, kernel_size=3, stride=1, padding=1, bias=False), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.MaxPool2d(kernel_size=3, stride=2, padding=1), |
|
) |
|
|
|
self.lastconv1 = nn.Sequential( |
|
nn.Conv2d(128 * 3, 128, kernel_size=3, stride=1, padding=1, bias=False), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
) |
|
|
|
self.lastconv2 = nn.Sequential( |
|
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1, bias=False), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
) |
|
|
|
self.mu_head = nn.Sequential( |
|
nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1, bias=False), |
|
) |
|
|
|
self.logvar_head = nn.Sequential( |
|
nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1, bias=False), |
|
) |
|
|
|
self.downsample32x32 = nn.Upsample(size=(32, 32), mode='bilinear') |
|
|
|
def _reparameterize(self, mu, logvar): |
|
std = torch.exp(logvar).sqrt() |
|
epsilon = torch.randn_like(std) |
|
return mu + epsilon * std |
|
|
|
def forward(self, x): |
|
|
|
x_input = x |
|
x = self.conv1(x) |
|
|
|
x_Block1 = self.Block1(x) |
|
x_Block1_32x32 = self.downsample32x32(x_Block1) |
|
|
|
x_Block2 = self.Block2(x_Block1) |
|
x_Block2_32x32 = self.downsample32x32(x_Block2) |
|
|
|
x_Block3 = self.Block3(x_Block2) |
|
x_Block3_32x32 = self.downsample32x32(x_Block3) |
|
|
|
x_concat = torch.cat((x_Block1_32x32, x_Block2_32x32, x_Block3_32x32), dim=1) |
|
|
|
|
|
|
|
x = self.lastconv1(x_concat) |
|
x = self.lastconv2(x) |
|
|
|
mu = self.mu_head(x) |
|
mu = mu.squeeze(1) |
|
logvar = self.logvar_head(x) |
|
logvar = logvar.squeeze(1) |
|
embedding = self._reparameterize(mu, logvar) |
|
|
|
return mu, logvar, embedding, x_concat, x_Block1, x_Block2, x_Block3, x_input |