File size: 983 Bytes
c4c7cee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt


class BatchedKDE(nn.Module):
    def __init__(self, bandwith=0.0):
        super().__init__()
        self.bandwidth = bandwith
        self.X = None

    def fit(self, X: torch.Tensor):
        self.mu = X
        self.nmu2 = torch.sum(X * X, dim=-1, keepdim=True)
        b, n, d = X.shape
        if self.bandwidth == 0:
            q = torch.quantile(X.view(b, -1), 0.75) - torch.quantile(
                X.view(b, -1), 0.25
            )
            self.bandwidth = (
                0.9 * torch.min(torch.std(X, dim=(1, 2)), q / 1.34) / pow(n, 0.2)
            )

    def score(self, X):
        nx2 = torch.sum(X * X, dim=-1, keepdim=True)
        dot = torch.einsum("bnd, bmd -> bnm", X, self.mu)
        dist = nx2 + self.nmu2.transpose(1, 2) - 2 * dot
        return torch.sum(
            torch.exp(-dist / self.bandwidth.unsqueeze(-1).unsqueeze(-1)), dim=-1
        )