|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn import TransformerEncoder, TransformerEncoderLayer |
|
|
|
class _PositionalEncoding(nn.Module): |
|
def __init__(self, d_model, dropout=0.): |
|
super().__init__() |
|
self.dropout = nn.Dropout(p=dropout) |
|
self.d_model = d_model |
|
self.device_test_tensor = nn.Parameter(torch.tensor(1.)) |
|
|
|
def forward(self, x): |
|
assert self.d_model % x.shape[-1]*2 == 0 |
|
d_per_feature = self.d_model // x.shape[-1] |
|
pe = torch.zeros(*x.shape, d_per_feature, device=self.device_test_tensor.device) |
|
|
|
interval_size = 10 |
|
div_term = (1./interval_size) * 2*math.pi*torch.exp(torch.arange(0, d_per_feature, 2, device=self.device_test_tensor.device).float()*math.log(math.sqrt(2))) |
|
|
|
pe[..., 0::2] = torch.sin(x.unsqueeze(-1) * div_term) |
|
pe[..., 1::2] = torch.cos(x.unsqueeze(-1) * div_term) |
|
return self.dropout(pe).view(x.shape[0],x.shape[1],self.d_model) |
|
|
|
|
|
class EmbeddingEncoder(nn.Module): |
|
def __init__(self, num_features, em_size, num_embs=100): |
|
super().__init__() |
|
self.num_embs = num_embs |
|
self.embeddings = nn.Embedding(num_embs * num_features, em_size, max_norm=True) |
|
self.init_weights(.1) |
|
self.min_max = (-2,+2) |
|
|
|
@property |
|
def width(self): |
|
return self.min_max[1] - self.min_max[0] |
|
|
|
def init_weights(self, initrange): |
|
self.embeddings.weight.data.uniform_(-initrange, initrange) |
|
|
|
def discretize(self, x): |
|
split_size = self.width / self.num_embs |
|
return (x - self.min_max[0] // split_size).int().clamp(0, self.num_embs - 1) |
|
|
|
def forward(self, x): |
|
x_idxs = self.discretize(x) |
|
x_idxs += torch.arange(x.shape[-1], device=x.device).view(1, 1, -1) * self.num_embs |
|
|
|
return self.embeddings(x_idxs).mean(-2) |
|
|
|
Linear = nn.Linear |
|
MLP = lambda num_features, emsize: nn.Sequential(nn.Linear(num_features+1,emsize*2), |
|
nn.ReLU(), |
|
nn.Linear(emsize*2,emsize)) |
|
|
|
class Conv(nn.Module): |
|
def __init__(self, input_size, emsize): |
|
super().__init__() |
|
self.convs = torch.nn.ModuleList([nn.Conv2d(64 if i else 1, 64, 3) for i in range(5)]) |
|
self.linear = nn.Linear(64,emsize) |
|
|
|
|
|
def forward(self, x): |
|
size = math.isqrt(x.shape[-1]) |
|
assert size*size == x.shape[-1] |
|
x = x.reshape(*x.shape[:-1], 1, size, size) |
|
for conv in self.convs: |
|
if x.shape[-1] < 4: |
|
break |
|
x = conv(x) |
|
x.relu_() |
|
x = nn.AdaptiveAvgPool2d((1,1))(x).squeeze(-1).squeeze(-1) |
|
return self.linear(x) |
|
|
|
|
|
Positional = lambda _, emsize: _PositionalEncoding(d_model=emsize) |
|
|
|
|
|
class CanEmb(nn.Embedding): |
|
def __init__(self, num_features, num_embeddings: int, embedding_dim: int, *args, **kwargs): |
|
assert embedding_dim % num_features == 0 |
|
embedding_dim = embedding_dim // num_features |
|
super().__init__(num_embeddings, embedding_dim, *args, **kwargs) |
|
|
|
def forward(self, x): |
|
x = super().forward(x) |
|
return x.view(*x.shape[:-2], -1) |
|
|
|
def get_Canonical(num_classes): |
|
return lambda num_features, emsize: CanEmb(num_features, num_classes, emsize) |
|
|
|
def get_Embedding(num_embs_per_feature=100): |
|
return lambda num_features, emsize: EmbeddingEncoder(num_features, emsize, num_embs=num_embs_per_feature) |
|
|