Spaces:
Sleeping
Sleeping
from abc import ABC, abstractmethod | |
from functools import reduce | |
import numpy as np | |
from sklearn.metrics import pairwise_distances | |
from sklearn.metrics.pairwise import linear_kernel, rbf_kernel | |
from ibydmt.bet import get_bet | |
class Payoff(ABC): | |
def __init__(self, config): | |
self.bet = get_bet(config.bet)(config) | |
def compute(self, *args, **kwargs): | |
pass | |
class Kernel: | |
def __init__(self, kernel: str, scale_method: str, scale: float): | |
if kernel == "linear": | |
self.base_kernel = linear_kernel | |
elif kernel == "rbf": | |
self.base_kernel = rbf_kernel | |
self.scale_method = scale_method | |
self.scale = scale | |
self.gamma = None | |
self.recompute_gamma = True | |
self.prev = None | |
else: | |
raise NotImplementedError(f"{kernel} is not implemented") | |
def __call__(self, x, y): | |
if self.base_kernel == linear_kernel: | |
return self.base_kernel(x, y) | |
if self.base_kernel == rbf_kernel: | |
if self.scale_method == "constant": | |
self.gamma = self.scale | |
elif self.scale_method == "quantile": | |
if self.prev is None: | |
self.prev = y | |
if self.recompute_gamma: | |
dist = pairwise_distances( | |
self.prev.reshape(-1, self.prev.shape[-1]) | |
) | |
scale = np.quantile(dist, self.scale) | |
gamma = 1 / (2 * scale**2) if scale > 0 else None | |
self.gamma = gamma | |
if len(self.prev) > 100: | |
self.recompute_gamma = False | |
self.prev = np.vstack([self.prev, x]) | |
else: | |
raise NotImplementedError( | |
f"{self.scale} is not implemented for rbf_kernel" | |
) | |
return self.base_kernel(x, y, gamma=self.gamma) | |
class KernelPayoff(Payoff): | |
def __init__(self, config): | |
super().__init__(config) | |
self.kernel = config.kernel | |
self.scale_method = config.get("kernel_scale_method", "quantile") | |
self.scale = config.get("kernel_scale", 0.5) | |
def witness_function(self, d, prev_d): | |
pass | |
def compute(self, d, null_d, prev_d): | |
g = reduce( | |
lambda acc, u: acc | |
+ self.witness_function(u[0], prev_d) | |
- self.witness_function(u[1], prev_d), | |
zip(d, null_d), | |
0, | |
) | |
g = g.squeeze().item() | |
return self.bet.compute(g) | |
class HSIC(KernelPayoff): | |
def __init__(self, config): | |
super().__init__(config) | |
kernel = self.kernel | |
scale_method = self.scale_method | |
scale = self.scale | |
self.kernel_y = Kernel(kernel, scale_method, scale) | |
self.kernel_z = Kernel(kernel, scale_method, scale) | |
def witness_function(self, d, prev_d): | |
y, z = d | |
prev_y, prev_z = prev_d[:, 0], prev_d[:, 1] | |
y_mat = self.kernel_y(y.reshape(-1, 1), prev_y.reshape(-1, 1)) | |
z_mat = self.kernel_z(z.reshape(-1, 1), prev_z.reshape(-1, 1)) | |
mu_joint = np.mean(y_mat * z_mat) | |
mu_prod = np.mean(y_mat, axis=1) @ np.mean(z_mat, axis=1) | |
return mu_joint - mu_prod | |
class cMMD(KernelPayoff): | |
def __init__(self, config): | |
super().__init__(config) | |
kernel = self.kernel | |
scale_method = self.scale_method | |
scale = self.scale | |
self.kernel_y = Kernel(kernel, scale_method, scale) | |
self.kernel_zj = Kernel(kernel, scale_method, scale) | |
self.kernel_cond_z = Kernel(kernel, scale_method, scale) | |
def witness_function(self, u, prev_d): | |
y, zj, cond_z = u[0], u[1], u[2:] | |
prev_y, prev_zj, prev_null_zj, prev_cond_z = ( | |
prev_d[:, 0], | |
prev_d[:, 1], | |
prev_d[:, 2], | |
prev_d[:, 3:], | |
) | |
y_mat = self.kernel_y(y.reshape(-1, 1), prev_y.reshape(-1, 1)) | |
zj_mat = self.kernel_zj(zj.reshape(-1, 1), prev_zj.reshape(-1, 1)) | |
cond_z_mat = self.kernel_cond_z( | |
cond_z.reshape(-1, prev_cond_z.shape[1]), | |
prev_cond_z.reshape(-1, prev_cond_z.shape[1]), | |
) | |
null_zj_mat = self.kernel_zj(zj.reshape(-1, 1), prev_null_zj.reshape(-1, 1)) | |
mu = np.mean(y_mat * zj_mat * cond_z_mat) | |
mu_null = np.mean(y_mat * null_zj_mat * cond_z_mat) | |
return mu - mu_null | |
class xMMD(KernelPayoff): | |
def __init__(self, config): | |
super().__init__(config) | |
self.kernel = Kernel(self.kernel, self.scale_method, self.scale) | |
def witness_function(self, u, prev_d): | |
prev_y, prev_y_null = prev_d[:, 0], prev_d[:, 1] | |
mu_y = np.mean(self.kernel(u.reshape(-1, 1), prev_y.reshape(-1, 1)), axis=1) | |
mu_y_null = np.mean( | |
self.kernel(u.reshape(-1, 1), prev_y_null.reshape(-1, 1)), axis=1 | |
) | |
return mu_y - mu_y_null | |