IBYDMT / ibydmt /test.py
jacopoteneggi's picture
Update
7e207f0 verified
raw
history blame
4.88 kB
import functools
from abc import ABC, abstractmethod
from collections import deque
from typing import Callable, Tuple, Union
import numpy as np
import torch
from jaxtyping import Float
from ibydmt.payoff import HSIC, cMMD, xMMD
from ibydmt.wealth import get_wealth
Array = Union[np.ndarray, torch.Tensor]
class Tester(ABC):
def __init__(self):
pass
@abstractmethod
def test(self, *args, **kwargs) -> Tuple[bool, int]:
pass
class SequentialTester(Tester):
def __init__(self, config):
super().__init__()
self.wealth = get_wealth(config.wealth)(config)
self.tau_max = config.tau_max
class SKIT(SequentialTester):
"""Global Independence Tester"""
def __init__(self, config):
super().__init__(config)
self.payoff = HSIC(config)
def test(self, Y: Float[Array, "N"], Z: Float[Array, "N"]) -> Tuple[bool, int]:
D = np.stack([Y, Z], axis=1)
for t in range(1, self.tau_max):
d = D[2 * t : 2 * (t + 1)]
prev_d = D[: 2 * t]
null_d = np.stack([d[:, 0], np.flip(d[:, 1])], axis=1)
payoff = self.payoff.compute(d, null_d, prev_d)
self.wealth.update(payoff)
if self.wealth.rejected:
return (True, t)
return (False, t)
class cSKIT(SequentialTester):
"""Global Conditional Independence Tester"""
def __init__(self, config):
super().__init__(config)
self.payoff = cMMD(config)
def _sample(
self,
z: Float[Array, "N D"],
j: int = None,
cond_p: Callable[[Float[Array, "N D"], list[int]], Float[Array, "N D"]] = None,
) -> Tuple[Float[Array, "N"], Float[Array, "N"], Float[Array, "N D-1"]]:
C = list(set(range(z.shape[1])) - {j})
zj, cond_z = z[:, [j]], z[:, C]
samples = cond_p(z, C)
null_zj = samples[:, [j]]
return zj, null_zj, cond_z
def test(
self,
Y: Float[Array, "N"],
Z: Float[Array, "N D"],
j: int,
cond_p: Callable[[Float[Array, "N D"], list[int]], Float[Array, "N D"]],
) -> Tuple[bool, int]:
sample = functools.partial(self._sample, j=j, cond_p=cond_p)
prev_y, prev_z = Y[:1][:, None], Z[:1]
prev_zj, prev_null_zj, prev_cond_z = sample(prev_z)
prev_d = np.concatenate([prev_y, prev_zj, prev_null_zj, prev_cond_z], axis=-1)
for t in range(1, self.tau_max):
y, z = Y[[t]][:, None], Z[[t]]
zj, null_zj, cond_z = sample(z)
u = np.concatenate([y, zj, cond_z], axis=-1)
null_u = np.concatenate([y, null_zj, cond_z], axis=-1)
payoff = self.payoff.compute(u, null_u, prev_d)
self.wealth.update(payoff)
d = np.concatenate([y, zj, null_zj, cond_z], axis=-1)
prev_d = np.vstack([prev_d, d])
if self.wealth.rejected:
return (True, t)
return (False, t)
class xSKIT(SequentialTester):
"""Local Conditional Independence Tester"""
def __init__(self, config):
super().__init__(config)
self.payoff = xMMD(config)
self._queue = deque()
def _sample(
self,
z: Float[Array, "D"],
j: int,
C: list[int],
cond_p: Callable[[Float[Array, "D"], list[int], int], Float[Array, "N D2"]],
model: Callable[[Float[Array, "N D2"]], Float[Array, "N"]],
) -> Tuple[Float[Array, "1"], Float[Array, "1"]]:
if len(self._queue) == 0:
Cuj = C + [j]
h = cond_p(z, Cuj, self.tau_max)
null_h = cond_p(z, C, self.tau_max)
y = model(h)[:, None]
null_y = model(null_h)[:, None]
self._queue.extend(zip(y, null_y))
return self._queue.pop()
def test(
self,
z: Float[Array, "D"],
j: int,
C: list[int],
cond_p: Callable[[Float[Array, "D"], list[int], int], Float[Array, "N D2"]],
model: Callable[[Float[Array, "N D2"]], Float[Array, "N"]],
interrupt_on: str = "rejection",
max_wealth: float = None,
) -> Tuple[bool, int]:
sample = functools.partial(self._sample, z, j, C, cond_p, model)
tau = self.tau_max - 1
prev_d = np.stack(sample(), axis=1)
for t in range(1, self.tau_max):
y, null_y = sample()
payoff = self.payoff.compute(y, null_y, prev_d)
self.wealth.update(payoff)
d = np.stack([y, null_y], axis=1)
prev_d = np.vstack([prev_d, d])
if self.wealth.rejected:
tau = min(tau, t)
if interrupt_on == "rejection":
break
if interrupt_on == "max_wealth" and self.wealth._w >= max_wealth:
break
return (self.wealth.rejected, tau)