dappyx's picture
Upload 4 files
75c80a0 verified
raw
history blame
15 kB
from turtle import forward
from torch import Tensor
import torch.nn.functional as F
import torch.nn as nn
import torch
import copy
import math
from transformers import DistilBertForQuestionAnswering, DistilBertConfig
from transformers import AutoModelForQuestionAnswering
class MultiHeadAttention(nn.Module):
def __init__(self, n_heads, dim, dropout_prob):
super().__init__()
# self.n_heads = config.n_heads
# self.dim = config.dim
# self.dropout = nn.Dropout(p=config.attention_dropout)
self.n_heads = n_heads
self.dim = dim
self.dropout = nn.Dropout(p=dropout_prob)
assert self.dim % self.n_heads == 0
self.q_lin = nn.Linear(in_features=self.dim, out_features=self.dim)
self.k_lin = nn.Linear(in_features=self.dim, out_features=self.dim)
self.v_lin = nn.Linear(in_features=self.dim, out_features=self.dim)
self.out_lin = nn.Linear(in_features=self.dim, out_features=self.dim)
def forward(self, query, key, value, mask, head_mask=None, output_attentions=False):
"""
Parameters:
query: torch.tensor(bs, seq_length, dim)
key: torch.tensor(bs, seq_length, dim)
value: torch.tensor(bs, seq_length, dim)
mask: torch.tensor(bs, seq_length)
Returns:
weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
"""
bs, q_length, dim = query.size()
k_length = key.size(1)
# assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'
# assert key.size() == value.size()
dim_per_head = self.dim // self.n_heads
mask_reshp = (bs, 1, 1, k_length)
def shape(x):
"""separate heads"""
return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
def unshape(x):
"""group heads"""
return (
x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
)
q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
mask = (
(mask == 0).view(mask_reshp).expand_as(scores)
) # (bs, n_heads, q_length, k_length)
scores = scores.masked_fill(
mask, -float("inf")
) # (bs, n_heads, q_length, k_length)
weights = nn.functional.softmax(
scores, dim=-1
) # (bs, n_heads, q_length, k_length)
weights = self.dropout(weights) # (bs, n_heads, q_length, k_length)
# Mask heads if we want to
if head_mask is not None:
weights = weights * head_mask
context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head)
context = unshape(context) # (bs, q_length, dim)
context = self.out_lin(context) # (bs, q_length, dim)
if output_attentions:
return (context, weights)
else:
return context
class FeedForward(nn.Module):
def __init__(self, dim_input: int = 768, dim_feedforward: int = 4 * 768):
super().__init__()
self.linear1 = nn.Linear(dim_input, dim_feedforward)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(dim_feedforward, dim_input)
def forward(self, x):
return self.linear2(self.relu(self.linear1(x)))
class SwitchFeedForward(nn.Module):
"""
## Routing among multiple FFNs
"""
def __init__(
self,
*,
capacity_factor: float,
drop_tokens: bool,
is_scale_prob: bool,
n_experts: int,
expert: FeedForward,
d_model: int
):
"""
* `capacity_factor` is the capacity of each expert as a factor relative to ideally balanced load
* `drop_tokens` specifies whether to drop tokens if more tokens are routed to an expert than the capacity
* `is_scale_prob` specifies whether to multiply the input to the FFN by the routing probability
* `n_experts` is the number of experts
* `expert` is the expert layer, a [FFN module](../feed_forward.html)
* `d_model` is the number of features in a token embedding
* `d_ff` is the number of features in the hidden layer of the FFN
* `dropout` is dropout probability in the FFN
"""
super().__init__()
self.capacity_factor = capacity_factor
self.is_scale_prob = is_scale_prob
self.n_experts = n_experts
self.drop_tokens = drop_tokens
# make copies of the FFNs
self.experts = nn.ModuleList([copy.deepcopy(expert) for _ in range(n_experts)])
# Routing layer and softmax
self.switch = nn.Linear(d_model, n_experts)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x: torch.Tensor):
"""
* `x` is the input to the switching module with shape `[seq_len, batch_size, d_model]`
"""
# Capture the shape to change shapes later
seq_len, batch_size, d_model = x.shape
# Flatten the sequence and batch dimensions
x = x.view(-1, d_model)
# Get routing probabilities for each of the tokens.
# $$p_i(x) = \frac{e^{h(x)_i}}{\sum^N_j e^{h(x)_j}}$$
# where $N$ is the number of experts `n_experts` and
# $h(\cdot)$ is the linear transformation of token embeddings.
route_prob = self.softmax(self.switch(x))
# Get the maximum routing probabilities and the routes.
# We route to the expert with highest probability
route_prob_max, routes = torch.max(route_prob, dim=-1)
# Get indexes of tokens going to each expert
indexes_list = [
torch.eq(routes, i).nonzero(as_tuple=True)[0] for i in range(self.n_experts)
]
# Initialize an empty tensor to store outputs
final_output = x.new_zeros(x.shape)
# Capacity of each expert.
# $$\mathrm{expert\;capacity} =
# \frac{\mathrm{tokens\;per\;batch}}{\mathrm{number\;of\;experts}}
# \times \mathrm{capacity\;factor}$$
capacity = int(self.capacity_factor * len(x) / self.n_experts)
# Number of tokens routed to each expert.
counts = x.new_tensor([len(indexes_list[i]) for i in range(self.n_experts)])
# Initialize an empty list of dropped tokens
dropped = []
# Only drop tokens if `drop_tokens` is `True`.
if self.drop_tokens:
# Drop tokens in each of the experts
for i in range(self.n_experts):
# Ignore if the expert is not over capacity
if len(indexes_list[i]) <= capacity:
continue
# Shuffle indexes before dropping
indexes_list[i] = indexes_list[i][torch.randperm(len(indexes_list[i]))]
# Collect the tokens over capacity as dropped tokens
dropped.append(indexes_list[i][capacity:])
# Keep only the tokens upto the capacity of the expert
indexes_list[i] = indexes_list[i][:capacity]
# Get outputs of the expert FFNs
expert_output = [
self.experts[i](x[indexes_list[i], :]) for i in range(self.n_experts)
]
# Assign to final output
for i in range(self.n_experts):
final_output[indexes_list[i], :] = expert_output[i]
# Pass through the dropped tokens
if dropped:
dropped = torch.cat(dropped)
final_output[dropped, :] = x[dropped, :]
if self.is_scale_prob:
# Multiply by the expert outputs by the probabilities $y = p_i(x) E_i(x)$
final_output = final_output * route_prob_max.view(-1, 1)
else:
# Don't scale the values but multiply by $\frac{p}{\hat{p}} = 1$ so that the gradients flow
# (this is something we experimented with).
final_output = final_output * (
route_prob_max / route_prob_max.detach()
).view(-1, 1)
# Change the shape of the final output back to `[seq_len, batch_size, d_model]`
final_output = final_output.view(seq_len, batch_size, d_model)
# Return
#
# * the final output
# * number of tokens routed to each expert
# * sum of probabilities for each expert
# * number of tokens dropped.
# * routing probabilities of the selected experts
#
# These are used for the load balancing loss and logging
return final_output, counts, route_prob.sum(0), len(dropped), route_prob_max
class SwitchTransformerLayer(nn.Module):
"""
# Switch Transformer Block
This is the same as [normal transformer block](../models.html#TransformerLayer)
with handling extra outputs of switch feedforward module.
"""
def __init__(
self,
*,
d_model: int,
attn: MultiHeadAttention,
feed_forward: SwitchFeedForward,
dropout_prob: float
):
"""
* `d_model` is the token embedding size
* `attn` is the attention module
* `feed_forward` is the feed forward module (which is the switching module in this case)
* `dropout_prob` is the probability of dropping out after self attention and FFN
"""
super().__init__()
self.size = d_model
self.attn = attn
self.feed_forward = feed_forward
self.dropout = nn.Dropout(dropout_prob)
self.norm_self_attn = nn.LayerNorm([d_model])
self.norm_ff = nn.LayerNorm([d_model])
def forward(self, *, x: torch.Tensor, mask: torch.Tensor):
# Normalize the vectors before doing self attention
z = self.norm_self_attn(x)
# Run through self attention, i.e. keys and values are from self
self_attn = self.attn(query=z, key=z, value=z, mask=mask)
# Add the self attention results
x = x + self.dropout(self_attn)
# Normalize for feed-forward
z = self.norm_ff(x)
# Pass through the switching feed-forward network
ff, counts, route_prob, n_dropped, route_prob_max = self.feed_forward(z)
# Add the feed-forward results back
x = x + self.dropout(ff)
return x, counts, route_prob, n_dropped, route_prob_max
class SwitchTransformer(nn.Module):
"""
## Switch Transformer
"""
def __init__(self, layer, n_layers, n_experts, device, load_balancing_loss_ceof):
super().__init__()
# Make copies of the transformer layer
self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layers)])
# Final normalization layer
self.norm = nn.LayerNorm([layer.size])
self.qa_outputs = nn.Linear(768, 2)
model = AutoModelForQuestionAnswering.from_pretrained("Kyrmasch/kaz-roberta-squad2-kaz").to(device)
self.base_model = model
self.device = device
self.load_balancing_loss_ceof = load_balancing_loss_ceof
self.n_experts = n_experts # used to calculate lb loss
def freeze_base_model(self):
for param in self.base_model.parameters():
param.requires_grad = False
def freeze_experts(self):
# TODO: find how to freeze the experts in the SwitchTransformer
pass
# def forward(self, x: torch.Tensor, mask: torch.Tensor):
def forward(self, batch):
input_ids = batch["input_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
start_positions = (
batch["start_positions"].to(self.device)
if "start_positions" in batch.keys()
else None
)
end_positions = (
batch["end_positions"].to(self.device)
if "end_positions" in batch.keys()
else None
)
outputs = self.base_model(
input_ids,
attention_mask=attention_mask,
start_positions=None,
end_positions=None,
output_hidden_states=True,
)
x = outputs.hidden_states[-1]
# Run through each transformer layer
counts, route_prob, n_dropped, route_prob_max = [], [], [], []
for layer in self.layers:
x, f, p, n_d, p_max = layer(x=x, mask=attention_mask)
counts.append(f)
route_prob.append(p)
n_dropped.append(n_d)
route_prob_max.append(p_max)
# Finally, normalize the vectors
output = self.norm(x)
logits = self.qa_outputs(output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous() # (bs, max_query_len)
end_logits = end_logits.squeeze(-1).contiguous() # (bs, max_query_len)
loss = None
if start_positions is not None and end_positions is not None:
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
loss = (start_loss + end_loss) / 2
counts = torch.stack(counts)
route_prob = torch.stack(route_prob)
route_prob_max = torch.stack(route_prob_max)
total = counts.sum(dim=-1, keepdims=True)
route_frac = counts / total
route_prob = route_prob / total
load_balancing_loss = self.n_experts * (route_frac * route_prob).sum()
loss = (
load_balancing_loss
if loss is None
else loss + self.load_balancing_loss_ceof * load_balancing_loss
)
return start_logits, end_logits, loss