enhg-parsing / benepar /partitioned_transformer.py
nielklug's picture
add parsing
8778cfe
"""
Transformer with partitioned content and position features.
See section 3 of https://arxiv.org/pdf/1805.01052.pdf
"""
import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class FeatureDropoutFunction(torch.autograd.function.InplaceFunction):
@staticmethod
def forward(ctx, input, p=0.5, train=False, inplace=False):
if p < 0 or p > 1:
raise ValueError(
"dropout probability has to be between 0 and 1, but got {}".format(p)
)
ctx.p = p
ctx.train = train
ctx.inplace = inplace
if ctx.inplace:
ctx.mark_dirty(input)
output = input
else:
output = input.clone()
if ctx.p > 0 and ctx.train:
ctx.noise = torch.empty(
(input.size(0), input.size(-1)),
dtype=input.dtype,
layout=input.layout,
device=input.device,
)
if ctx.p == 1:
ctx.noise.fill_(0)
else:
ctx.noise.bernoulli_(1 - ctx.p).div_(1 - ctx.p)
ctx.noise = ctx.noise[:, None, :]
output.mul_(ctx.noise)
return output
@staticmethod
def backward(ctx, grad_output):
if ctx.p > 0 and ctx.train:
return grad_output.mul(ctx.noise), None, None, None
else:
return grad_output, None, None, None
class FeatureDropout(nn.Dropout):
"""
Feature-level dropout: takes an input of size len x num_features and drops
each feature with probabibility p. A feature is dropped across the full
portion of the input that corresponds to a single batch element.
"""
def forward(self, x):
if isinstance(x, tuple):
x_c, x_p = x
x_c = FeatureDropoutFunction.apply(x_c, self.p, self.training, self.inplace)
x_p = FeatureDropoutFunction.apply(x_p, self.p, self.training, self.inplace)
return x_c, x_p
else:
return FeatureDropoutFunction.apply(x, self.p, self.training, self.inplace)
class PartitionedReLU(nn.ReLU):
def forward(self, x):
if isinstance(x, tuple):
x_c, x_p = x
else:
x_c, x_p = torch.chunk(x, 2, dim=-1)
return super().forward(x_c), super().forward(x_p)
class PartitionedLinear(nn.Module):
def __init__(self, in_features, out_features, bias=True):
super().__init__()
self.linear_c = nn.Linear(in_features // 2, out_features // 2, bias)
self.linear_p = nn.Linear(in_features // 2, out_features // 2, bias)
def forward(self, x):
if isinstance(x, tuple):
x_c, x_p = x
else:
x_c, x_p = torch.chunk(x, 2, dim=-1)
out_c = self.linear_c(x_c)
out_p = self.linear_p(x_p)
return out_c, out_p
class PartitionedMultiHeadAttention(nn.Module):
def __init__(
self, d_model, n_head, d_qkv, attention_dropout=0.1, initializer_range=0.02
):
super().__init__()
self.w_qkv_c = nn.Parameter(torch.Tensor(n_head, d_model // 2, 3, d_qkv // 2))
self.w_qkv_p = nn.Parameter(torch.Tensor(n_head, d_model // 2, 3, d_qkv // 2))
self.w_o_c = nn.Parameter(torch.Tensor(n_head, d_qkv // 2, d_model // 2))
self.w_o_p = nn.Parameter(torch.Tensor(n_head, d_qkv // 2, d_model // 2))
bound = math.sqrt(3.0) * initializer_range
for param in [self.w_qkv_c, self.w_qkv_p, self.w_o_c, self.w_o_p]:
nn.init.uniform_(param, -bound, bound)
self.scaling_factor = 1 / d_qkv ** 0.5
self.dropout = nn.Dropout(attention_dropout)
def forward(self, x, mask=None):
if isinstance(x, tuple):
x_c, x_p = x
else:
x_c, x_p = torch.chunk(x, 2, dim=-1)
qkv_c = torch.einsum("btf,hfca->bhtca", x_c, self.w_qkv_c)
qkv_p = torch.einsum("btf,hfca->bhtca", x_p, self.w_qkv_p)
q_c, k_c, v_c = [c.squeeze(dim=3) for c in torch.chunk(qkv_c, 3, dim=3)]
q_p, k_p, v_p = [c.squeeze(dim=3) for c in torch.chunk(qkv_p, 3, dim=3)]
q = torch.cat([q_c, q_p], dim=-1) * self.scaling_factor
k = torch.cat([k_c, k_p], dim=-1)
v = torch.cat([v_c, v_p], dim=-1)
dots = torch.einsum("bhqa,bhka->bhqk", q, k)
if mask is not None:
dots.data.masked_fill_(~mask[:, None, None, :], -float("inf"))
probs = F.softmax(dots, dim=-1)
probs = self.dropout(probs)
o = torch.einsum("bhqk,bhka->bhqa", probs, v)
o_c, o_p = torch.chunk(o, 2, dim=-1)
out_c = torch.einsum("bhta,haf->btf", o_c, self.w_o_c)
out_p = torch.einsum("bhta,haf->btf", o_p, self.w_o_p)
return out_c, out_p
class PartitionedTransformerEncoderLayer(nn.Module):
def __init__(
self,
d_model,
n_head,
d_qkv,
d_ff,
ff_dropout=0.1,
residual_dropout=0.1,
attention_dropout=0.1,
activation=PartitionedReLU(),
):
super().__init__()
self.self_attn = PartitionedMultiHeadAttention(
d_model, n_head, d_qkv, attention_dropout=attention_dropout
)
self.linear1 = PartitionedLinear(d_model, d_ff)
self.ff_dropout = FeatureDropout(ff_dropout)
self.linear2 = PartitionedLinear(d_ff, d_model)
self.norm_attn = nn.LayerNorm(d_model)
self.norm_ff = nn.LayerNorm(d_model)
self.residual_dropout_attn = FeatureDropout(residual_dropout)
self.residual_dropout_ff = FeatureDropout(residual_dropout)
self.activation = activation
def forward(self, x, mask=None):
residual = self.self_attn(x, mask=mask)
residual = torch.cat(residual, dim=-1)
residual = self.residual_dropout_attn(residual)
x = self.norm_attn(x + residual)
residual = self.linear2(self.ff_dropout(self.activation(self.linear1(x))))
residual = torch.cat(residual, dim=-1)
residual = self.residual_dropout_ff(residual)
x = self.norm_ff(x + residual)
return x
class PartitionedTransformerEncoder(nn.Module):
def __init__(self, encoder_layer, n_layers):
super().__init__()
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(n_layers)]
)
def forward(self, x, mask=None):
for layer in self.layers:
x = layer(x, mask=mask)
return x
class ConcatPositionalEncoding(nn.Module):
def __init__(self, d_model=256, max_len=512):
super().__init__()
self.timing_table = nn.Parameter(torch.FloatTensor(max_len, d_model // 2))
nn.init.normal_(self.timing_table)
self.norm = nn.LayerNorm(d_model)
def forward(self, x):
timing = self.timing_table[None, : x.shape[1], :]
x, timing = torch.broadcast_tensors(x, timing)
out = torch.cat([x, timing], dim=-1)
out = self.norm(out)
return out