import torch import torch.nn as nn import torch.nn.functional as F import matplotlib.pyplot as plt class BatchedKDE(nn.Module): def __init__(self, bandwith=0.0): super().__init__() self.bandwidth = bandwith self.X = None def fit(self, X: torch.Tensor): self.mu = X self.nmu2 = torch.sum(X * X, dim=-1, keepdim=True) b, n, d = X.shape if self.bandwidth == 0: q = torch.quantile(X.view(b, -1), 0.75) - torch.quantile( X.view(b, -1), 0.25 ) self.bandwidth = ( 0.9 * torch.min(torch.std(X, dim=(1, 2)), q / 1.34) / pow(n, 0.2) ) def score(self, X): nx2 = torch.sum(X * X, dim=-1, keepdim=True) dot = torch.einsum("bnd, bmd -> bnm", X, self.mu) dist = nx2 + self.nmu2.transpose(1, 2) - 2 * dot return torch.sum( torch.exp(-dist / self.bandwidth.unsqueeze(-1).unsqueeze(-1)), dim=-1 )