|
r""" CHM 4D kernel (psi, iso, and full) generator """
|
|
|
|
import torch
|
|
|
|
from .geometry import Geometry
|
|
|
|
|
|
class KernelGenerator:
|
|
def __init__(self, ksz, ktype):
|
|
self.ksz = ksz
|
|
self.idx4d = Geometry.init_idx4d(ksz)
|
|
self.kernel = torch.zeros((ksz, ksz, ksz, ksz))
|
|
self.center = (ksz // 2, ksz // 2)
|
|
self.ktype = ktype
|
|
|
|
def quadrant(self, crd):
|
|
if crd[0] < self.center[0]:
|
|
horz_quad = -1
|
|
elif crd[0] < self.center[0]:
|
|
horz_quad = 1
|
|
else:
|
|
horz_quad = 0
|
|
|
|
if crd[1] < self.center[1]:
|
|
vert_quad = -1
|
|
elif crd[1] < self.center[1]:
|
|
vert_quad = 1
|
|
else:
|
|
vert_quad = 0
|
|
|
|
return horz_quad, vert_quad
|
|
|
|
def generate(self):
|
|
return None if self.ktype == 'full' else self.generate_chm_kernel()
|
|
|
|
def generate_chm_kernel(self):
|
|
param_dict = {}
|
|
for idx in self.idx4d:
|
|
src_i, src_j, trg_i, trg_j = idx
|
|
d_tail = Geometry.get_distance((src_i, src_j), self.center)
|
|
d_head = Geometry.get_distance((trg_i, trg_j), self.center)
|
|
d_off = Geometry.get_distance((src_i, src_j), (trg_i, trg_j))
|
|
horz_quad, vert_quad = self.quadrant((src_j, src_i))
|
|
|
|
src_crd = (src_i, src_j)
|
|
trg_crd = (trg_i, trg_j)
|
|
|
|
key = self.build_key(horz_quad, vert_quad, d_head, d_tail, src_crd, trg_crd, d_off)
|
|
coord1d = Geometry.get_coord1d((src_i, src_j, trg_i, trg_j), self.ksz)
|
|
|
|
if param_dict.get(key) is None: param_dict[key] = []
|
|
param_dict[key].append(coord1d)
|
|
|
|
return param_dict
|
|
|
|
def build_key(self, horz_quad, vert_quad, d_head, d_tail, src_crd, trg_crd, d_off):
|
|
|
|
if self.ktype == 'iso':
|
|
return '%d' % d_off
|
|
elif self.ktype == 'psi':
|
|
d_max = max(d_head, d_tail)
|
|
d_min = min(d_head, d_tail)
|
|
return '%d_%d_%d' % (d_max, d_min, d_off)
|
|
else:
|
|
raise Exception('not implemented.')
|
|
|
|
|