Spaces:
Running
Running
File size: 983 Bytes
c4c7cee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
class BatchedKDE(nn.Module):
def __init__(self, bandwith=0.0):
super().__init__()
self.bandwidth = bandwith
self.X = None
def fit(self, X: torch.Tensor):
self.mu = X
self.nmu2 = torch.sum(X * X, dim=-1, keepdim=True)
b, n, d = X.shape
if self.bandwidth == 0:
q = torch.quantile(X.view(b, -1), 0.75) - torch.quantile(
X.view(b, -1), 0.25
)
self.bandwidth = (
0.9 * torch.min(torch.std(X, dim=(1, 2)), q / 1.34) / pow(n, 0.2)
)
def score(self, X):
nx2 = torch.sum(X * X, dim=-1, keepdim=True)
dot = torch.einsum("bnd, bmd -> bnm", X, self.mu)
dist = nx2 + self.nmu2.transpose(1, 2) - 2 * dot
return torch.sum(
torch.exp(-dist / self.bandwidth.unsqueeze(-1).unsqueeze(-1)), dim=-1
)
|