|
r""" 4D and 6D convolutional Hough matching layers """
|
|
|
|
from torch.nn.modules.conv import _ConvNd
|
|
import torch.nn.functional as F
|
|
import torch.nn as nn
|
|
import torch
|
|
|
|
from common.logger import Logger
|
|
from . import chm_kernel
|
|
|
|
|
|
def fast4d(corr, kernel, bias=None):
|
|
r""" Optimized implementation of 4D convolution """
|
|
bsz, ch, srch, srcw, trgh, trgw = corr.size()
|
|
out_channels, _, kernel_size, kernel_size, kernel_size, kernel_size = kernel.size()
|
|
psz = kernel_size // 2
|
|
|
|
out_corr = torch.zeros((bsz, out_channels, srch, srcw, trgh, trgw))
|
|
corr = corr.transpose(1, 2).contiguous().view(bsz * srch, ch, srcw, trgh, trgw)
|
|
|
|
for pidx, k3d in enumerate(kernel.permute(2, 0, 1, 3, 4, 5)):
|
|
inter_corr = F.conv3d(corr, k3d, bias=None, stride=1, padding=psz)
|
|
inter_corr = inter_corr.view(bsz, srch, out_channels, srcw, trgh, trgw).transpose(1, 2).contiguous()
|
|
|
|
add_sid = max(psz - pidx, 0)
|
|
add_fid = min(srch, srch + psz - pidx)
|
|
slc_sid = max(pidx - psz, 0)
|
|
slc_fid = min(srch, srch - psz + pidx)
|
|
|
|
out_corr[:, :, add_sid:add_fid, :, :, :] += inter_corr[:, :, slc_sid:slc_fid, :, :, :]
|
|
|
|
if bias is not None:
|
|
out_corr += bias.view(1, out_channels, 1, 1, 1, 1)
|
|
|
|
return out_corr
|
|
|
|
|
|
def fast6d(corr, kernel, bias, diagonal_idx):
|
|
r""" Optimized implementation of 6D convolutional Hough matching
|
|
NOTE: this function only supports kernel size of (3, 3, 5, 5, 5, 5).
|
|
r"""
|
|
bsz, _, s6d, s6d, s4d, s4d, s4d, s4d = corr.size()
|
|
_, _, ks6d, ks6d, ks4d, ks4d, ks4d, ks4d = kernel.size()
|
|
corr = corr.permute(0, 2, 3, 1, 4, 5, 6, 7).contiguous().view(-1, 1, s4d, s4d, s4d, s4d)
|
|
kernel = kernel.view(-1, ks6d ** 2, ks4d, ks4d, ks4d, ks4d).transpose(0, 1)
|
|
corr = fast4d(corr, kernel).view(bsz, s6d * s6d, ks6d * ks6d, s4d, s4d, s4d, s4d)
|
|
corr = corr.view(bsz, s6d, s6d, ks6d, ks6d, s4d, s4d, s4d, s4d).transpose(2, 3).\
|
|
contiguous().view(-1, s6d * ks6d, s4d, s4d, s4d, s4d)
|
|
|
|
ndiag = s6d + (ks6d // 2) * 2
|
|
first_sum = []
|
|
for didx in diagonal_idx:
|
|
first_sum.append(corr[:, didx, :, :, :, :].sum(dim=1))
|
|
first_sum = torch.stack(first_sum).transpose(0, 1).view(bsz, s6d * ks6d, ndiag, s4d, s4d, s4d, s4d)
|
|
|
|
corr = []
|
|
for didx in diagonal_idx:
|
|
corr.append(first_sum[:, didx, :, :, :, :, :].sum(dim=1))
|
|
sidx = ks6d // 2
|
|
eidx = ndiag - sidx
|
|
corr = torch.stack(corr).transpose(0, 1)[:, sidx:eidx, sidx:eidx, :, :, :, :].unsqueeze(1).contiguous()
|
|
corr += bias.view(1, -1, 1, 1, 1, 1, 1, 1)
|
|
|
|
reverse_idx = torch.linspace(s6d * s6d - 1, 0, s6d * s6d).long()
|
|
corr = corr.view(bsz, 1, s6d * s6d, s4d, s4d, s4d, s4d)[:, :, reverse_idx, :, :, :, :].\
|
|
view(bsz, 1, s6d, s6d, s4d, s4d, s4d, s4d)
|
|
return corr
|
|
|
|
def init_param_idx4d(param_dict):
|
|
param_idx = []
|
|
for key in param_dict:
|
|
curr_offset = int(key.split('_')[-1])
|
|
param_idx.append(torch.tensor(param_dict[key]))
|
|
return param_idx
|
|
|
|
class CHM4d(_ConvNd):
|
|
r""" 4D convolutional Hough matching layer
|
|
NOTE: this function only supports in_channels=1 and out_channels=1.
|
|
r"""
|
|
def __init__(self, in_channels, out_channels, ksz4d, ktype, bias=True):
|
|
super(CHM4d, self).__init__(in_channels, out_channels, (ksz4d,) * 4,
|
|
(1,) * 4, (0,) * 4, (1,) * 4, False, (0,) * 4,
|
|
1, bias, padding_mode='zeros')
|
|
|
|
|
|
self.zero_kernel4d = torch.zeros((in_channels, out_channels, ksz4d, ksz4d, ksz4d, ksz4d))
|
|
self.nkernels = in_channels * out_channels
|
|
|
|
|
|
param_dict4d = chm_kernel.KernelGenerator(ksz4d, ktype).generate()
|
|
param_shared = param_dict4d is not None
|
|
|
|
if param_shared:
|
|
|
|
self.param_idx = init_param_idx4d(param_dict4d)
|
|
weights = torch.abs(torch.randn(len(self.param_idx) * self.nkernels)) * 1e-3
|
|
for weight, param_idx in zip(weights.sort()[0], self.param_idx):
|
|
weight *= len(param_idx)
|
|
self.weight = nn.Parameter(weights)
|
|
else:
|
|
self.param_idx = None
|
|
self.weight = nn.Parameter(torch.abs(self.weight))
|
|
if bias: self.bias = nn.Parameter(torch.tensor(0.0))
|
|
Logger.info('(%s) # params in CHM 4D: %d' % (ktype, len(self.weight.view(-1))))
|
|
|
|
def forward(self, x):
|
|
kernel = self.init_kernel()
|
|
x = fast4d(x, kernel, self.bias)
|
|
return x
|
|
|
|
def init_kernel(self):
|
|
|
|
ksz = self.kernel_size[-1]
|
|
if self.param_idx is None:
|
|
kernel = self.weight
|
|
else:
|
|
kernel = torch.zeros_like(self.zero_kernel4d)
|
|
for idx, pdx in enumerate(self.param_idx):
|
|
kernel = kernel.view(-1, ksz, ksz, ksz, ksz)
|
|
for jdx, kernel_single in enumerate(kernel):
|
|
weight = self.weight[idx + jdx * len(self.param_idx)].repeat(len(pdx)) / len(pdx)
|
|
kernel_single.view(-1)[pdx] += weight
|
|
kernel = kernel.view(self.in_channels, self.out_channels, ksz, ksz, ksz, ksz)
|
|
return kernel
|
|
|
|
|
|
class CHM6d(_ConvNd):
|
|
r""" 6D convolutional Hough matching layer with kernel (3, 3, 5, 5, 5, 5)
|
|
NOTE: this function only supports in_channels=1 and out_channels=1.
|
|
r"""
|
|
def __init__(self, in_channels, out_channels, ksz6d, ksz4d, ktype):
|
|
kernel_size = (ksz6d, ksz6d, ksz4d, ksz4d, ksz4d, ksz4d)
|
|
super(CHM6d, self).__init__(in_channels, out_channels, kernel_size, (1,) * 6,
|
|
(0,) * 6, (1,) * 6, False, (0,) * 6,
|
|
1, bias=True, padding_mode='zeros')
|
|
|
|
|
|
self.zero_kernel4d = torch.zeros((ksz4d, ksz4d, ksz4d, ksz4d))
|
|
self.zero_kernel6d = torch.zeros((ksz6d, ksz6d, ksz4d, ksz4d, ksz4d, ksz4d))
|
|
self.nkernels = in_channels * out_channels
|
|
|
|
|
|
|
|
self.diagonal_idx = [torch.tensor(x) for x in [[6], [3, 7], [0, 4, 8], [1, 5], [2]]]
|
|
param_dict4d = chm_kernel.KernelGenerator(ksz4d, ktype).generate()
|
|
param_shared = param_dict4d is not None
|
|
|
|
if param_shared:
|
|
if ktype == 'psi':
|
|
self.param_dict6d = [[4], [0, 8], [2, 6], [1, 3, 5, 7]]
|
|
elif ktype == 'iso':
|
|
self.param_dict6d = [[0, 4, 8], [2, 6], [1, 3, 5, 7]]
|
|
self.param_dict6d = [torch.tensor(i) for i in self.param_dict6d]
|
|
|
|
|
|
self.param_idx = init_param_idx4d(param_dict4d)
|
|
self.param = []
|
|
for param_dict6d in self.param_dict6d:
|
|
weights = torch.abs(torch.randn(len(self.param_idx))) * 1e-3
|
|
for weight, param_idx in zip(weights, self.param_idx):
|
|
weight *= (len(param_idx) * len(param_dict6d))
|
|
self.param.append(nn.Parameter(weights))
|
|
self.param = nn.ParameterList(self.param)
|
|
else:
|
|
self.param_idx = None
|
|
self.param = nn.Parameter(torch.abs(self.weight) * 1e-3)
|
|
Logger.info('(%s) # params in CHM 6D: %d' % (ktype, sum([len(x.view(-1)) for x in self.param])))
|
|
self.weight = None
|
|
|
|
def forward(self, corr):
|
|
kernel = self.init_kernel()
|
|
corr = fast6d(corr, kernel, self.bias, self.diagonal_idx)
|
|
return corr
|
|
|
|
def init_kernel(self):
|
|
|
|
if self.param_idx is None:
|
|
return self.param
|
|
|
|
kernel6d = torch.zeros_like(self.zero_kernel6d)
|
|
for idx, (param, param_dict6d) in enumerate(zip(self.param, self.param_dict6d)):
|
|
ksz4d = self.kernel_size[-1]
|
|
kernel4d = torch.zeros_like(self.zero_kernel4d)
|
|
for jdx, pdx in enumerate(self.param_idx):
|
|
kernel4d.view(-1)[pdx] += ((param[jdx] / len(pdx)) / len(param_dict6d))
|
|
kernel6d.view(-1, ksz4d, ksz4d, ksz4d, ksz4d)[param_dict6d] += kernel4d.view(ksz4d, ksz4d, ksz4d, ksz4d)
|
|
kernel6d = kernel6d.unsqueeze(0).unsqueeze(0)
|
|
|
|
return kernel6d
|
|
|
|
|