IBYDMT / app_lib /ckde.py
jacopoteneggi's picture
Update
8e05eba verified
raw
history blame
2.3 kB
import numpy as np
from scipy.spatial.distance import cdist
from scipy.stats import gaussian_kde
class cKDE:
def __init__(
self, embedding, semantics, metric="euclidean", scale_method="neff", scale=2000
):
self.metric = metric
self.scale_method = scale_method
self.scale = scale
self.H = embedding
self.Z = semantics
def _quantile_scale(self, Z_cond_dist):
return np.quantile(Z_cond_dist, self.scale)
def _neff_scale(self, Z_cond_dist):
scales = np.linspace(1e-02, 0.4, 100)[:, None]
_Z_cond_dist = np.tile(Z_cond_dist, (len(scales), 1))
weights = np.exp(-(_Z_cond_dist**2) / (2 * scales**2))
neff = (np.sum(weights, axis=1) ** 2) / np.sum(weights**2, axis=1)
diff = np.abs(neff - self.scale)
scale_idx = np.argmin(diff)
return scales[scale_idx].item()
def _sample(self, z, cond_idx, m):
sample_idx = list(set(range(len(z))) - set(cond_idx))
sample_z = np.tile(z, (m, 1))
if len(sample_idx) > 0:
kde, _ = self.kde(z, cond_idx)
sample_z[:, sample_idx] = kde.resample(m).T
return sample_z
def kde(self, z, cond_idx):
sample_idx = list(set(range(len(z))) - set(cond_idx))
Z_sample = self.Z[:, sample_idx]
Z_cond = self.Z[:, cond_idx]
z_cond = z[cond_idx]
Z_cond_dist = cdist(z_cond.reshape(1, -1), Z_cond, self.metric).squeeze()
if self.scale_method == "constant":
scale = self.scale
if self.scale_method == "quantile":
scale = self._quantile_scale(Z_cond_dist)
elif self.scale_method == "neff":
scale = self._neff_scale(Z_cond_dist)
weights = np.exp(-(Z_cond_dist**2) / (2 * scale**2))
return gaussian_kde(Z_sample.T, weights=weights), scale
def nearest_neighbor(self, z):
dist = cdist(z, self.Z, metric=self.metric)
return np.argmin(dist, axis=-1)
def sample(self, z, cond_idx, m=1):
if z.ndim == 1:
z = z.reshape(1, -1)
sample_z = np.concatenate([self._sample(_z, cond_idx, m) for _z in z], axis=0)
nn_idx = self.nearest_neighbor(sample_z)
sample_h = self.H[nn_idx]
return sample_z, sample_h