taesiri's picture
Initial Commit
8390f90
raw
history blame
2.12 kB
r""" CHM 4D kernel (psi, iso, and full) generator """
import torch
from .geometry import Geometry
class KernelGenerator:
def __init__(self, ksz, ktype):
self.ksz = ksz
self.idx4d = Geometry.init_idx4d(ksz)
self.kernel = torch.zeros((ksz, ksz, ksz, ksz))
self.center = (ksz // 2, ksz // 2)
self.ktype = ktype
def quadrant(self, crd):
if crd[0] < self.center[0]:
horz_quad = -1
elif crd[0] < self.center[0]:
horz_quad = 1
else:
horz_quad = 0
if crd[1] < self.center[1]:
vert_quad = -1
elif crd[1] < self.center[1]:
vert_quad = 1
else:
vert_quad = 0
return horz_quad, vert_quad
def generate(self):
return None if self.ktype == 'full' else self.generate_chm_kernel()
def generate_chm_kernel(self):
param_dict = {}
for idx in self.idx4d:
src_i, src_j, trg_i, trg_j = idx
d_tail = Geometry.get_distance((src_i, src_j), self.center)
d_head = Geometry.get_distance((trg_i, trg_j), self.center)
d_off = Geometry.get_distance((src_i, src_j), (trg_i, trg_j))
horz_quad, vert_quad = self.quadrant((src_j, src_i))
src_crd = (src_i, src_j)
trg_crd = (trg_i, trg_j)
key = self.build_key(horz_quad, vert_quad, d_head, d_tail, src_crd, trg_crd, d_off)
coord1d = Geometry.get_coord1d((src_i, src_j, trg_i, trg_j), self.ksz)
if param_dict.get(key) is None: param_dict[key] = []
param_dict[key].append(coord1d)
return param_dict
def build_key(self, horz_quad, vert_quad, d_head, d_tail, src_crd, trg_crd, d_off):
if self.ktype == 'iso':
return '%d' % d_off
elif self.ktype == 'psi':
d_max = max(d_head, d_tail)
d_min = min(d_head, d_tail)
return '%d_%d_%d' % (d_max, d_min, d_off)
else:
raise Exception('not implemented.')