File size: 2,258 Bytes
4f55ca2
 
 
 
a40e67a
4f55ca2
a40e67a
 
 
 
 
 
4f55ca2
a40e67a
 
4f55ca2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a40e67a
4f55ca2
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import numpy as np
from scipy.spatial.distance import cdist
from scipy.stats import gaussian_kde


class cKDE:
    def __init__(
        self, embedding, semantics, metric="euclidean", scale_method="neff", scale=2000
    ):
        self.metric = metric
        self.scale_method = scale_method
        self.scale = scale

        self.H = embedding
        self.Z = semantics

    def _quantile_scale(self, Z_cond_dist):
        return np.quantile(Z_cond_dist, self.scale)

    def _neff_scale(self, Z_cond_dist):
        scales = np.linspace(1e-02, 0.4, 100)[:, None]

        _Z_cond_dist = np.tile(Z_cond_dist, (len(scales), 1))

        weights = np.exp(-(_Z_cond_dist**2) / (2 * scales**2))
        neff = (np.sum(weights, axis=1) ** 2) / np.sum(weights**2, axis=1)
        diff = np.abs(neff - self.scale)
        scale_idx = np.argmin(diff)
        return scales[scale_idx].item()

    def _sample(self, z, cond_idx, m):
        sample_idx = list(set(range(len(z))) - set(cond_idx))

        kde, _ = self.kde(z, cond_idx)

        sample_z = np.tile(z, (m, 1))
        sample_z[:, sample_idx] = kde.resample(m).T

        return sample_z

    def kde(self, z, cond_idx):
        sample_idx = list(set(range(len(z))) - set(cond_idx))

        Z_sample = self.Z[:, sample_idx]
        Z_cond = self.Z[:, cond_idx]

        z_cond = z[cond_idx]
        Z_cond_dist = cdist(z_cond.reshape(1, -1), Z_cond, self.metric).squeeze()

        if self.scale_method == "constant":
            scale = self.scale
        if self.scale_method == "quantile":
            scale = self._quantile_scale(Z_cond_dist)
        elif self.scale_method == "neff":
            scale = self._neff_scale(Z_cond_dist)

        weights = np.exp(-(Z_cond_dist**2) / (2 * scale**2))

        return gaussian_kde(Z_sample.T, weights=weights), scale

    def nearest_neighbor(self, z):
        dist = cdist(z, self.Z, metric=self.metric)
        return np.argmin(dist, axis=-1)

    def sample(self, z, cond_idx, m=1):
        if z.ndim == 1:
            z = z.reshape(1, -1)

        sample_z = np.concatenate([self._sample(_z, cond_idx, m) for _z in z], axis=0)

        nn_idx = self.nearest_neighbor(sample_z)
        sample_h = self.H[nn_idx]

        return sample_z, sample_h