# -*- coding: utf-8 -*- r""" Layer-Wise Attention Mechanism ================================ Computes a parameterised scalar mixture of N tensors, `mixture = gamma * sum(s_k * tensor_k)` where `s = softmax(w)`, with `w` and `gamma` scalar parameters. If `do_layer_norm=True` then apply layer normalization to each tensor before weighting. If `dropout > 0`, then for each scalar weight, adjust its softmax weight mass to 0 with the dropout probability (i.e., setting the unnormalized weight to -inf). This effectively should redistribute dropped probability mass to all other weights. Original implementation: - https://github.com/Hyperparticle/udify """ from typing import List import torch from torch.nn import Parameter, ParameterList class ScalarMixWithDropout(torch.nn.Module): def __init__( self, mixture_size: int, do_layer_norm: bool = False, initial_scalar_parameters: list = None, trainable: bool = True, dropout: float = None, dropout_value: float = -1e20, ) -> None: super(ScalarMixWithDropout, self).__init__() self.mixture_size = mixture_size self.do_layer_norm = do_layer_norm self.dropout = dropout if initial_scalar_parameters is None: initial_scalar_parameters = [0.0] * mixture_size elif len(initial_scalar_parameters) != mixture_size: raise Exception( "Length of initial_scalar_parameters {} differs \ from mixture_size {}".format( initial_scalar_parameters, mixture_size ) ) self.scalar_parameters = ParameterList( [ Parameter( torch.FloatTensor([initial_scalar_parameters[i]]), requires_grad=trainable, ) for i in range(mixture_size) ] ) self.gamma = Parameter(torch.FloatTensor([1.0]), requires_grad=trainable) if self.dropout: dropout_mask = torch.zeros(len(self.scalar_parameters)) dropout_fill = torch.empty(len(self.scalar_parameters)).fill_(dropout_value) self.register_buffer("dropout_mask", dropout_mask) self.register_buffer("dropout_fill", dropout_fill) def forward( self, tensors: List[torch.Tensor], # pylint: disable=arguments-differ mask: torch.Tensor = None, ) -> torch.Tensor: """ Compute a weighted average of the `tensors`. The input tensors an be any shape with at least two dimensions, but must all be the same shape. When `do_layer_norm=True`, the `mask` is required input. If the `tensors` are dimensioned `(dim_0, ..., dim_{n-1}, dim_n)`, then the `mask` is dimensioned `(dim_0, ..., dim_{n-1})`, as in the typical case with `tensors` of shape `(batch_size, timesteps, dim)` and `mask` of shape `(batch_size, timesteps)`. When `do_layer_norm=False` the `mask` is ignored. """ if len(tensors) != self.mixture_size: raise Exception( "{} tensors were passed, but the module was initialized to \ mix {} tensors.".format( len(tensors), self.mixture_size ) ) def _do_layer_norm(tensor, broadcast_mask, num_elements_not_masked): tensor_masked = tensor * broadcast_mask mean = torch.sum(tensor_masked) / num_elements_not_masked variance = ( torch.sum(((tensor_masked - mean) * broadcast_mask) ** 2) / num_elements_not_masked ) return (tensor - mean) / torch.sqrt(variance + 1e-12) weights = torch.cat([parameter for parameter in self.scalar_parameters]) if self.training and self.dropout: weights = torch.where( self.dropout_mask.uniform_() > self.dropout, weights, self.dropout_fill ) normed_weights = torch.nn.functional.softmax(weights, dim=0) normed_weights = torch.split(normed_weights, split_size_or_sections=1) if not self.do_layer_norm: pieces = [] for weight, tensor in zip(normed_weights, tensors): pieces.append(weight * tensor) return self.gamma * sum(pieces) else: mask_float = mask.float() broadcast_mask = mask_float.unsqueeze(-1) input_dim = tensors[0].size(-1) num_elements_not_masked = torch.sum(mask_float) * input_dim pieces = [] for weight, tensor in zip(normed_weights, tensors): pieces.append( weight * _do_layer_norm(tensor, broadcast_mask, num_elements_not_masked) ) return self.gamma * sum(pieces)