Samuel Mueller
working locally
f50f696
raw
history blame contribute delete
No virus
3.88 kB
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.nn.modules.transformer import MultiheadAttention, _get_activation_fn
from utils import SeqBN
class TransformerModel(nn.Module):
def __init__(self, encoder, n_out, ninp, nhead, nhid, nlayers, dropout=0.0, y_encoder=None, pos_encoder=None, decoder=None, input_normalization=False):
super().__init__()
self.model_type = 'Transformer'
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout, activation='gelu')
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
self.ninp = ninp
self.encoder = encoder
self.y_encoder = y_encoder
self.pos_encoder = pos_encoder
self.decoder = decoder(ninp, nhid, n_out) if decoder is not None else nn.Sequential(nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, n_out))
self.input_ln = SeqBN(ninp) if input_normalization else None
self.init_weights()
@staticmethod
def generate_square_subsequent_mask(sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
@staticmethod
def generate_D_q_matrix(sz, query_size):
train_size = sz-query_size
mask = torch.zeros(sz,sz) == 0
mask[:,train_size:].zero_()
mask |= torch.eye(sz) == 1
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def init_weights(self):
initrange = 1.
# if isinstance(self.encoder,EmbeddingEncoder):
# self.encoder.weight.data.uniform_(-initrange, initrange)
# self.decoder.bias.data.zero_()
# self.decoder.weight.data.uniform_(-initrange, initrange)
for layer in self.transformer_encoder.layers:
nn.init.zeros_(layer.linear2.weight)
nn.init.zeros_(layer.linear2.bias)
nn.init.zeros_(layer.self_attn.out_proj.weight)
nn.init.zeros_(layer.self_attn.out_proj.bias)
def forward(self, src, src_mask=None, single_eval_pos=None):
assert single_eval_pos is not None, 'Single eval pos is required now.'
fuse_x_y = not isinstance(src, tuple)
assert not(fuse_x_y and single_eval_pos is not None), \
'Don\'t use both fuxe_x_y and single_eval_pos (permutation equivariant setup) at the same time.'
if src_mask is None:
x_src = src if fuse_x_y else src[0]
if single_eval_pos is None:
src_mask = self.generate_square_subsequent_mask(len(x_src) if fuse_x_y else 2*len(x_src)).to(x_src.device)
else:
src_mask = self.generate_D_q_matrix(len(x_src), len(x_src)-single_eval_pos).to(x_src.device)
if not fuse_x_y:
x_src, y_src = src
x_src = self.encoder(x_src)
y_src = self.y_encoder(y_src.unsqueeze(-1))
if single_eval_pos is None:
src = torch.stack([x_src, y_src], 1).view(-1, *x_src.shape[1:])
else:
train_x = x_src[:single_eval_pos] + y_src[:single_eval_pos]
src = torch.cat([train_x, x_src[single_eval_pos:]], 0)
else:
src = self.encoder(src)
if self.input_ln is not None:
src = self.input_ln(src)
if self.pos_encoder is not None:
src = self.pos_encoder(src)
output = self.transformer_encoder(src, src_mask)
output = self.decoder(output)
if fuse_x_y:
return output
elif single_eval_pos is None:
return output[0::2]
else:
return output[single_eval_pos:]