taesiri's picture
Initial Commit
8390f90
raw
history blame
8.81 kB
r""" 4D and 6D convolutional Hough matching layers """
from torch.nn.modules.conv import _ConvNd
import torch.nn.functional as F
import torch.nn as nn
import torch
from common.logger import Logger
from . import chm_kernel
def fast4d(corr, kernel, bias=None):
r""" Optimized implementation of 4D convolution """
bsz, ch, srch, srcw, trgh, trgw = corr.size()
out_channels, _, kernel_size, kernel_size, kernel_size, kernel_size = kernel.size()
psz = kernel_size // 2
out_corr = torch.zeros((bsz, out_channels, srch, srcw, trgh, trgw))
corr = corr.transpose(1, 2).contiguous().view(bsz * srch, ch, srcw, trgh, trgw)
for pidx, k3d in enumerate(kernel.permute(2, 0, 1, 3, 4, 5)):
inter_corr = F.conv3d(corr, k3d, bias=None, stride=1, padding=psz)
inter_corr = inter_corr.view(bsz, srch, out_channels, srcw, trgh, trgw).transpose(1, 2).contiguous()
add_sid = max(psz - pidx, 0)
add_fid = min(srch, srch + psz - pidx)
slc_sid = max(pidx - psz, 0)
slc_fid = min(srch, srch - psz + pidx)
out_corr[:, :, add_sid:add_fid, :, :, :] += inter_corr[:, :, slc_sid:slc_fid, :, :, :]
if bias is not None:
out_corr += bias.view(1, out_channels, 1, 1, 1, 1)
return out_corr
def fast6d(corr, kernel, bias, diagonal_idx):
r""" Optimized implementation of 6D convolutional Hough matching
NOTE: this function only supports kernel size of (3, 3, 5, 5, 5, 5).
r"""
bsz, _, s6d, s6d, s4d, s4d, s4d, s4d = corr.size()
_, _, ks6d, ks6d, ks4d, ks4d, ks4d, ks4d = kernel.size()
corr = corr.permute(0, 2, 3, 1, 4, 5, 6, 7).contiguous().view(-1, 1, s4d, s4d, s4d, s4d)
kernel = kernel.view(-1, ks6d ** 2, ks4d, ks4d, ks4d, ks4d).transpose(0, 1)
corr = fast4d(corr, kernel).view(bsz, s6d * s6d, ks6d * ks6d, s4d, s4d, s4d, s4d)
corr = corr.view(bsz, s6d, s6d, ks6d, ks6d, s4d, s4d, s4d, s4d).transpose(2, 3).\
contiguous().view(-1, s6d * ks6d, s4d, s4d, s4d, s4d)
ndiag = s6d + (ks6d // 2) * 2
first_sum = []
for didx in diagonal_idx:
first_sum.append(corr[:, didx, :, :, :, :].sum(dim=1))
first_sum = torch.stack(first_sum).transpose(0, 1).view(bsz, s6d * ks6d, ndiag, s4d, s4d, s4d, s4d)
corr = []
for didx in diagonal_idx:
corr.append(first_sum[:, didx, :, :, :, :, :].sum(dim=1))
sidx = ks6d // 2
eidx = ndiag - sidx
corr = torch.stack(corr).transpose(0, 1)[:, sidx:eidx, sidx:eidx, :, :, :, :].unsqueeze(1).contiguous()
corr += bias.view(1, -1, 1, 1, 1, 1, 1, 1)
reverse_idx = torch.linspace(s6d * s6d - 1, 0, s6d * s6d).long()
corr = corr.view(bsz, 1, s6d * s6d, s4d, s4d, s4d, s4d)[:, :, reverse_idx, :, :, :, :].\
view(bsz, 1, s6d, s6d, s4d, s4d, s4d, s4d)
return corr
def init_param_idx4d(param_dict):
param_idx = []
for key in param_dict:
curr_offset = int(key.split('_')[-1])
param_idx.append(torch.tensor(param_dict[key]))
return param_idx
class CHM4d(_ConvNd):
r""" 4D convolutional Hough matching layer
NOTE: this function only supports in_channels=1 and out_channels=1.
r"""
def __init__(self, in_channels, out_channels, ksz4d, ktype, bias=True):
super(CHM4d, self).__init__(in_channels, out_channels, (ksz4d,) * 4,
(1,) * 4, (0,) * 4, (1,) * 4, False, (0,) * 4,
1, bias, padding_mode='zeros')
# Zero kernel initialization
self.zero_kernel4d = torch.zeros((in_channels, out_channels, ksz4d, ksz4d, ksz4d, ksz4d))
self.nkernels = in_channels * out_channels
# Initialize kernel indices
param_dict4d = chm_kernel.KernelGenerator(ksz4d, ktype).generate()
param_shared = param_dict4d is not None
if param_shared:
# Initialize the shared parameters (multiplied by the number of times being shared)
self.param_idx = init_param_idx4d(param_dict4d)
weights = torch.abs(torch.randn(len(self.param_idx) * self.nkernels)) * 1e-3
for weight, param_idx in zip(weights.sort()[0], self.param_idx):
weight *= len(param_idx)
self.weight = nn.Parameter(weights)
else: # full kernel initialziation
self.param_idx = None
self.weight = nn.Parameter(torch.abs(self.weight))
if bias: self.bias = nn.Parameter(torch.tensor(0.0))
Logger.info('(%s) # params in CHM 4D: %d' % (ktype, len(self.weight.view(-1))))
def forward(self, x):
kernel = self.init_kernel()
x = fast4d(x, kernel, self.bias)
return x
def init_kernel(self):
# Initialize CHM kernel (divided by the number of times being shared)
ksz = self.kernel_size[-1]
if self.param_idx is None:
kernel = self.weight
else:
kernel = torch.zeros_like(self.zero_kernel4d)
for idx, pdx in enumerate(self.param_idx):
kernel = kernel.view(-1, ksz, ksz, ksz, ksz)
for jdx, kernel_single in enumerate(kernel):
weight = self.weight[idx + jdx * len(self.param_idx)].repeat(len(pdx)) / len(pdx)
kernel_single.view(-1)[pdx] += weight
kernel = kernel.view(self.in_channels, self.out_channels, ksz, ksz, ksz, ksz)
return kernel
class CHM6d(_ConvNd):
r""" 6D convolutional Hough matching layer with kernel (3, 3, 5, 5, 5, 5)
NOTE: this function only supports in_channels=1 and out_channels=1.
r"""
def __init__(self, in_channels, out_channels, ksz6d, ksz4d, ktype):
kernel_size = (ksz6d, ksz6d, ksz4d, ksz4d, ksz4d, ksz4d)
super(CHM6d, self).__init__(in_channels, out_channels, kernel_size, (1,) * 6,
(0,) * 6, (1,) * 6, False, (0,) * 6,
1, bias=True, padding_mode='zeros')
# Zero kernel initialization
self.zero_kernel4d = torch.zeros((ksz4d, ksz4d, ksz4d, ksz4d))
self.zero_kernel6d = torch.zeros((ksz6d, ksz6d, ksz4d, ksz4d, ksz4d, ksz4d))
self.nkernels = in_channels * out_channels
# Initialize kernel indices
# Indices in scale-space where 4D convolutions are performed (3 by 3 scale-space)
self.diagonal_idx = [torch.tensor(x) for x in [[6], [3, 7], [0, 4, 8], [1, 5], [2]]]
param_dict4d = chm_kernel.KernelGenerator(ksz4d, ktype).generate()
param_shared = param_dict4d is not None
if param_shared: # psi & iso kernel initialization
if ktype == 'psi':
self.param_dict6d = [[4], [0, 8], [2, 6], [1, 3, 5, 7]]
elif ktype == 'iso':
self.param_dict6d = [[0, 4, 8], [2, 6], [1, 3, 5, 7]]
self.param_dict6d = [torch.tensor(i) for i in self.param_dict6d]
# Initialize the shared parameters (multiplied by the number of times being shared)
self.param_idx = init_param_idx4d(param_dict4d)
self.param = []
for param_dict6d in self.param_dict6d:
weights = torch.abs(torch.randn(len(self.param_idx))) * 1e-3
for weight, param_idx in zip(weights, self.param_idx):
weight *= (len(param_idx) * len(param_dict6d))
self.param.append(nn.Parameter(weights))
self.param = nn.ParameterList(self.param)
else: # full kernel initialziation
self.param_idx = None
self.param = nn.Parameter(torch.abs(self.weight) * 1e-3)
Logger.info('(%s) # params in CHM 6D: %d' % (ktype, sum([len(x.view(-1)) for x in self.param])))
self.weight = None
def forward(self, corr):
kernel = self.init_kernel()
corr = fast6d(corr, kernel, self.bias, self.diagonal_idx)
return corr
def init_kernel(self):
# Initialize CHM kernel (divided by the number of times being shared)
if self.param_idx is None:
return self.param
kernel6d = torch.zeros_like(self.zero_kernel6d)
for idx, (param, param_dict6d) in enumerate(zip(self.param, self.param_dict6d)):
ksz4d = self.kernel_size[-1]
kernel4d = torch.zeros_like(self.zero_kernel4d)
for jdx, pdx in enumerate(self.param_idx):
kernel4d.view(-1)[pdx] += ((param[jdx] / len(pdx)) / len(param_dict6d))
kernel6d.view(-1, ksz4d, ksz4d, ksz4d, ksz4d)[param_dict6d] += kernel4d.view(ksz4d, ksz4d, ksz4d, ksz4d)
kernel6d = kernel6d.unsqueeze(0).unsqueeze(0)
return kernel6d