import torch from torch import nn import random class ScaledDecoder(nn.Module): def __init__(self, ninp, nhid, nout): super().__init__() self.linear = nn.Linear(ninp, nhid) self.linear1 = nn.Linear(nhid, nout) self.linear2 = nn.Linear(nhid, 10) def forward(self, x): #return torch.cat([self.linear1(x), self.linear2(x)], -1) x = self.linear(x) x = nn.GELU()(x) temps = self.linear2(x).softmax(-1) @ torch.tensor([1.,1.4,1.7,2.,5.,10.,20.,40.,80.,160.], device=x.device) if random.random() > .99: print(temps.shape,temps[:,:2]) return self.linear1(x) / temps.unsqueeze(-1) class FixedScaledDecoder(nn.Module): def __init__(self, ninp, nhid, nout): super().__init__() self.mapper = nn.Sequential(nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, nout)) self.T = nn.Parameter(torch.ones(10000)/10000) def forward(self, x): return self.mapper(x)/self.T.sum()