from turtle import forward from torch import Tensor import torch.nn.functional as F import torch.nn as nn import torch import copy import math from transformers import DistilBertForQuestionAnswering, DistilBertConfig from transformers import AutoModelForQuestionAnswering class MultiHeadAttention(nn.Module): def __init__(self, n_heads, dim, dropout_prob): super().__init__() # self.n_heads = config.n_heads # self.dim = config.dim # self.dropout = nn.Dropout(p=config.attention_dropout) self.n_heads = n_heads self.dim = dim self.dropout = nn.Dropout(p=dropout_prob) assert self.dim % self.n_heads == 0 self.q_lin = nn.Linear(in_features=self.dim, out_features=self.dim) self.k_lin = nn.Linear(in_features=self.dim, out_features=self.dim) self.v_lin = nn.Linear(in_features=self.dim, out_features=self.dim) self.out_lin = nn.Linear(in_features=self.dim, out_features=self.dim) def forward(self, query, key, value, mask, head_mask=None, output_attentions=False): """ Parameters: query: torch.tensor(bs, seq_length, dim) key: torch.tensor(bs, seq_length, dim) value: torch.tensor(bs, seq_length, dim) mask: torch.tensor(bs, seq_length) Returns: weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs, seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True` """ bs, q_length, dim = query.size() k_length = key.size(1) # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured' # assert key.size() == value.size() dim_per_head = self.dim // self.n_heads mask_reshp = (bs, 1, 1, k_length) def shape(x): """separate heads""" return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2) def unshape(x): """group heads""" return ( x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head) ) q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head) k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head) v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head) q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head) scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length) mask = ( (mask == 0).view(mask_reshp).expand_as(scores) ) # (bs, n_heads, q_length, k_length) scores = scores.masked_fill( mask, -float("inf") ) # (bs, n_heads, q_length, k_length) weights = nn.functional.softmax( scores, dim=-1 ) # (bs, n_heads, q_length, k_length) weights = self.dropout(weights) # (bs, n_heads, q_length, k_length) # Mask heads if we want to if head_mask is not None: weights = weights * head_mask context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head) context = unshape(context) # (bs, q_length, dim) context = self.out_lin(context) # (bs, q_length, dim) if output_attentions: return (context, weights) else: return context class FeedForward(nn.Module): def __init__(self, dim_input: int = 768, dim_feedforward: int = 4 * 768): super().__init__() self.linear1 = nn.Linear(dim_input, dim_feedforward) self.relu = nn.ReLU() self.linear2 = nn.Linear(dim_feedforward, dim_input) def forward(self, x): return self.linear2(self.relu(self.linear1(x))) class SwitchFeedForward(nn.Module): """ ## Routing among multiple FFNs """ def __init__( self, *, capacity_factor: float, drop_tokens: bool, is_scale_prob: bool, n_experts: int, expert: FeedForward, d_model: int ): """ * `capacity_factor` is the capacity of each expert as a factor relative to ideally balanced load * `drop_tokens` specifies whether to drop tokens if more tokens are routed to an expert than the capacity * `is_scale_prob` specifies whether to multiply the input to the FFN by the routing probability * `n_experts` is the number of experts * `expert` is the expert layer, a [FFN module](../feed_forward.html) * `d_model` is the number of features in a token embedding * `d_ff` is the number of features in the hidden layer of the FFN * `dropout` is dropout probability in the FFN """ super().__init__() self.capacity_factor = capacity_factor self.is_scale_prob = is_scale_prob self.n_experts = n_experts self.drop_tokens = drop_tokens # make copies of the FFNs self.experts = nn.ModuleList([copy.deepcopy(expert) for _ in range(n_experts)]) # Routing layer and softmax self.switch = nn.Linear(d_model, n_experts) self.softmax = nn.Softmax(dim=-1) def forward(self, x: torch.Tensor): """ * `x` is the input to the switching module with shape `[seq_len, batch_size, d_model]` """ # Capture the shape to change shapes later seq_len, batch_size, d_model = x.shape # Flatten the sequence and batch dimensions x = x.view(-1, d_model) # Get routing probabilities for each of the tokens. # $$p_i(x) = \frac{e^{h(x)_i}}{\sum^N_j e^{h(x)_j}}$$ # where $N$ is the number of experts `n_experts` and # $h(\cdot)$ is the linear transformation of token embeddings. route_prob = self.softmax(self.switch(x)) # Get the maximum routing probabilities and the routes. # We route to the expert with highest probability route_prob_max, routes = torch.max(route_prob, dim=-1) # Get indexes of tokens going to each expert indexes_list = [ torch.eq(routes, i).nonzero(as_tuple=True)[0] for i in range(self.n_experts) ] # Initialize an empty tensor to store outputs final_output = x.new_zeros(x.shape) # Capacity of each expert. # $$\mathrm{expert\;capacity} = # \frac{\mathrm{tokens\;per\;batch}}{\mathrm{number\;of\;experts}} # \times \mathrm{capacity\;factor}$$ capacity = int(self.capacity_factor * len(x) / self.n_experts) # Number of tokens routed to each expert. counts = x.new_tensor([len(indexes_list[i]) for i in range(self.n_experts)]) # Initialize an empty list of dropped tokens dropped = [] # Only drop tokens if `drop_tokens` is `True`. if self.drop_tokens: # Drop tokens in each of the experts for i in range(self.n_experts): # Ignore if the expert is not over capacity if len(indexes_list[i]) <= capacity: continue # Shuffle indexes before dropping indexes_list[i] = indexes_list[i][torch.randperm(len(indexes_list[i]))] # Collect the tokens over capacity as dropped tokens dropped.append(indexes_list[i][capacity:]) # Keep only the tokens upto the capacity of the expert indexes_list[i] = indexes_list[i][:capacity] # Get outputs of the expert FFNs expert_output = [ self.experts[i](x[indexes_list[i], :]) for i in range(self.n_experts) ] # Assign to final output for i in range(self.n_experts): final_output[indexes_list[i], :] = expert_output[i] # Pass through the dropped tokens if dropped: dropped = torch.cat(dropped) final_output[dropped, :] = x[dropped, :] if self.is_scale_prob: # Multiply by the expert outputs by the probabilities $y = p_i(x) E_i(x)$ final_output = final_output * route_prob_max.view(-1, 1) else: # Don't scale the values but multiply by $\frac{p}{\hat{p}} = 1$ so that the gradients flow # (this is something we experimented with). final_output = final_output * ( route_prob_max / route_prob_max.detach() ).view(-1, 1) # Change the shape of the final output back to `[seq_len, batch_size, d_model]` final_output = final_output.view(seq_len, batch_size, d_model) # Return # # * the final output # * number of tokens routed to each expert # * sum of probabilities for each expert # * number of tokens dropped. # * routing probabilities of the selected experts # # These are used for the load balancing loss and logging return final_output, counts, route_prob.sum(0), len(dropped), route_prob_max class SwitchTransformerLayer(nn.Module): """ # Switch Transformer Block This is the same as [normal transformer block](../models.html#TransformerLayer) with handling extra outputs of switch feedforward module. """ def __init__( self, *, d_model: int, attn: MultiHeadAttention, feed_forward: SwitchFeedForward, dropout_prob: float ): """ * `d_model` is the token embedding size * `attn` is the attention module * `feed_forward` is the feed forward module (which is the switching module in this case) * `dropout_prob` is the probability of dropping out after self attention and FFN """ super().__init__() self.size = d_model self.attn = attn self.feed_forward = feed_forward self.dropout = nn.Dropout(dropout_prob) self.norm_self_attn = nn.LayerNorm([d_model]) self.norm_ff = nn.LayerNorm([d_model]) def forward(self, *, x: torch.Tensor, mask: torch.Tensor): # Normalize the vectors before doing self attention z = self.norm_self_attn(x) # Run through self attention, i.e. keys and values are from self self_attn = self.attn(query=z, key=z, value=z, mask=mask) # Add the self attention results x = x + self.dropout(self_attn) # Normalize for feed-forward z = self.norm_ff(x) # Pass through the switching feed-forward network ff, counts, route_prob, n_dropped, route_prob_max = self.feed_forward(z) # Add the feed-forward results back x = x + self.dropout(ff) return x, counts, route_prob, n_dropped, route_prob_max class SwitchTransformer(nn.Module): """ ## Switch Transformer """ def __init__(self, layer, n_layers, n_experts, device, load_balancing_loss_ceof): super().__init__() # Make copies of the transformer layer self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layers)]) # Final normalization layer self.norm = nn.LayerNorm([layer.size]) self.qa_outputs = nn.Linear(768, 2) model = AutoModelForQuestionAnswering.from_pretrained("Kyrmasch/kaz-roberta-squad2-kaz").to(device) self.base_model = model self.device = device self.load_balancing_loss_ceof = load_balancing_loss_ceof self.n_experts = n_experts # used to calculate lb loss def freeze_base_model(self): for param in self.base_model.parameters(): param.requires_grad = False def freeze_experts(self): # TODO: find how to freeze the experts in the SwitchTransformer pass # def forward(self, x: torch.Tensor, mask: torch.Tensor): def forward(self, batch): input_ids = batch["input_ids"].to(self.device) attention_mask = batch["attention_mask"].to(self.device) start_positions = ( batch["start_positions"].to(self.device) if "start_positions" in batch.keys() else None ) end_positions = ( batch["end_positions"].to(self.device) if "end_positions" in batch.keys() else None ) outputs = self.base_model( input_ids, attention_mask=attention_mask, start_positions=None, end_positions=None, output_hidden_states=True, ) x = outputs.hidden_states[-1] # Run through each transformer layer counts, route_prob, n_dropped, route_prob_max = [], [], [], [] for layer in self.layers: x, f, p, n_d, p_max = layer(x=x, mask=attention_mask) counts.append(f) route_prob.append(p) n_dropped.append(n_d) route_prob_max.append(p_max) # Finally, normalize the vectors output = self.norm(x) logits = self.qa_outputs(output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1).contiguous() # (bs, max_query_len) end_logits = end_logits.squeeze(-1).contiguous() # (bs, max_query_len) loss = None if start_positions is not None and end_positions is not None: if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions = start_positions.clamp(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index) loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) loss = (start_loss + end_loss) / 2 counts = torch.stack(counts) route_prob = torch.stack(route_prob) route_prob_max = torch.stack(route_prob_max) total = counts.sum(dim=-1, keepdims=True) route_frac = counts / total route_prob = route_prob / total load_balancing_loss = self.n_experts * (route_frac * route_prob).sum() loss = ( load_balancing_loss if loss is None else loss + self.load_balancing_loss_ceof * load_balancing_loss ) return start_logits, end_logits, loss