from typing import Optional import torch from torch import Tensor from examples.simultaneous_translation.utils.functions import ( exclusive_cumprod, prob_check, moving_sum, ) def expected_alignment_from_p_choose( p_choose: Tensor, padding_mask: Optional[Tensor] = None, eps: float = 1e-6 ): """ Calculating expected alignment for from stepwise probability Reference: Online and Linear-Time Attention by Enforcing Monotonic Alignments https://arxiv.org/pdf/1704.00784.pdf q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j} a_ij = p_ij q_ij Parallel solution: ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi)) ============================================================ Expected input size p_choose: bsz, tgt_len, src_len """ prob_check(p_choose) # p_choose: bsz, tgt_len, src_len bsz, tgt_len, src_len = p_choose.size() dtype = p_choose.dtype p_choose = p_choose.float() if padding_mask is not None: p_choose = p_choose.masked_fill(padding_mask.unsqueeze(1), 0.0) # cumprod_1mp : bsz, tgt_len, src_len cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=eps) cumprod_1mp_clamp = torch.clamp(cumprod_1mp, eps, 1.0) alpha_0 = p_choose.new_zeros([bsz, 1, src_len]) alpha_0[:, :, 0] = 1.0 previous_alpha = [alpha_0] for i in range(tgt_len): # p_choose: bsz , tgt_len, src_len # cumprod_1mp_clamp : bsz, tgt_len, src_len # previous_alpha[i]: bsz, 1, src_len # alpha_i: bsz, src_len alpha_i = ( p_choose[:, i] * cumprod_1mp[:, i] * torch.cumsum( previous_alpha[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1 ) ).clamp(0, 1.0) previous_alpha.append(alpha_i.unsqueeze(1)) # alpha: bsz * num_heads, tgt_len, src_len alpha = torch.cat(previous_alpha[1:], dim=1) # Mix precision to prevent overflow for fp16 alpha = alpha.type(dtype) prob_check(alpha) return alpha def expected_soft_attention( alpha: Tensor, soft_energy: Tensor, padding_mask: Optional[Tensor] = None, chunk_size: Optional[int] = None, eps: float = 1e-10 ): """ Function to compute expected soft attention for monotonic infinite lookback attention from expected alignment and soft energy. Reference: Monotonic Chunkwise Attention https://arxiv.org/abs/1712.05382 Monotonic Infinite Lookback Attention for Simultaneous Machine Translation https://arxiv.org/abs/1906.05218 alpha: bsz, tgt_len, src_len soft_energy: bsz, tgt_len, src_len padding_mask: bsz, src_len left_padding: bool """ if padding_mask is not None: alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0.0) soft_energy = soft_energy.masked_fill( padding_mask.unsqueeze(1), -float("inf") ) prob_check(alpha) dtype = alpha.dtype alpha = alpha.float() soft_energy = soft_energy.float() soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0] exp_soft_energy = torch.exp(soft_energy) + eps if chunk_size is not None: # Chunkwise beta = ( exp_soft_energy * moving_sum( alpha / (eps + moving_sum(exp_soft_energy, chunk_size, 1)), 1, chunk_size ) ) else: # Infinite lookback # Notice that infinite lookback is a special case of chunkwise # where chunksize = inf inner_items = alpha / (eps + torch.cumsum(exp_soft_energy, dim=2)) beta = ( exp_soft_energy * torch.cumsum(inner_items.flip(dims=[2]), dim=2) .flip(dims=[2]) ) if padding_mask is not None: beta = beta.masked_fill( padding_mask.unsqueeze(1).to(torch.bool), 0.0) # Mix precision to prevent overflow for fp16 beta = beta.type(dtype) beta = beta.clamp(0, 1) prob_check(beta) return beta def mass_preservation( alpha: Tensor, padding_mask: Optional[Tensor] = None, left_padding: bool = False ): """ Function to compute the mass perservation for alpha. This means that the residual weights of alpha will be assigned to the last token. Reference: Monotonic Infinite Lookback Attention for Simultaneous Machine Translation https://arxiv.org/abs/1906.05218 alpha: bsz, tgt_len, src_len padding_mask: bsz, src_len left_padding: bool """ prob_check(alpha) if padding_mask is not None: if not left_padding: assert not padding_mask[:, 0].any(), ( "Find padding on the beginning of the sequence." ) alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0.0) if left_padding or padding_mask is None: residuals = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0, 1) alpha[:, :, -1] = residuals else: # right padding _, tgt_len, src_len = alpha.size() residuals = 1 - alpha.sum(dim=-1, keepdim=True).clamp(0, 1) src_lens = src_len - padding_mask.sum(dim=1, keepdim=True) src_lens = src_lens.expand(-1, tgt_len).contiguous() # add back the last value residuals += alpha.gather(2, src_lens.unsqueeze(2) - 1) alpha = alpha.scatter(2, src_lens.unsqueeze(2) - 1, residuals) prob_check(alpha) return alpha