from collections import defaultdict import torch import torch.nn.functional as F def make_positions(tensor, padding_idx): """Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols are ignored. """ # The series of casts and type-conversions here are carefully # balanced to both work with ONNX export and XLA. In particular XLA # prefers ints, cumsum defaults to output longs, and ONNX doesn't know # how to handle the dtype kwarg in cumsum. mask = tensor.ne(padding_idx).int() return ( torch.cumsum(mask, dim=1).type_as(mask) * mask ).long() + padding_idx def softmax(x, dim): return F.softmax(x, dim=dim, dtype=torch.float32) def sequence_mask(lengths, maxlen, dtype=torch.bool): if maxlen is None: maxlen = lengths.max() mask = ~(torch.ones((len(lengths), maxlen)).to(lengths.device).cumsum(dim=1).t() > lengths).t() mask.type(dtype) return mask def weights_nonzero_speech(target): # target : B x T x mel # Assign weight 1.0 to all labels except for padding (id=0). dim = target.size(-1) return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim) INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0) def _get_full_incremental_state_key(module_instance, key): module_name = module_instance.__class__.__name__ # assign a unique ID to each module instance, so that incremental state is # not shared across module instances if not hasattr(module_instance, '_instance_id'): INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1 module_instance._instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name] return '{}.{}.{}'.format(module_name, module_instance._instance_id, key) def get_incremental_state(module, incremental_state, key): """Helper for getting incremental state for an nn.Module.""" full_key = _get_full_incremental_state_key(module, key) if incremental_state is None or full_key not in incremental_state: return None return incremental_state[full_key] def set_incremental_state(module, incremental_state, key, value): """Helper for setting incremental state for an nn.Module.""" if incremental_state is not None: full_key = _get_full_incremental_state_key(module, key) incremental_state[full_key] = value def fill_with_neg_inf(t): """FP16-compatible function that fills a tensor with -inf.""" return t.float().fill_(float('-inf')).type_as(t) def fill_with_neg_inf2(t): """FP16-compatible function that fills a tensor with -inf.""" return t.float().fill_(-1e8).type_as(t) def select_attn(attn_logits, type='best'): """ :param attn_logits: [n_layers, B, n_head, T_sp, T_txt] :return: """ encdec_attn = torch.stack(attn_logits, 0).transpose(1, 2) # [n_layers * n_head, B, T_sp, T_txt] encdec_attn = (encdec_attn.reshape([-1, *encdec_attn.shape[2:]])).softmax(-1) if type == 'best': indices = encdec_attn.max(-1).values.sum(-1).argmax(0) encdec_attn = encdec_attn.gather( 0, indices[None, :, None, None].repeat(1, 1, encdec_attn.size(-2), encdec_attn.size(-1)))[0] return encdec_attn elif type == 'mean': return encdec_attn.mean(0) def make_pad_mask(lengths, xs=None, length_dim=-1): """Make mask tensor containing indices of padded part. Args: lengths (LongTensor or List): Batch of lengths (B,). xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor. length_dim (int, optional): Dimension indicator of the above tensor. See the example. Returns: Tensor: Mask tensor containing indices of padded part. dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (including 1.2) Examples: With only lengths. >>> lengths = [5, 3, 2] >>> make_non_pad_mask(lengths) masks = [[0, 0, 0, 0 ,0], [0, 0, 0, 1, 1], [0, 0, 1, 1, 1]] With the reference tensor. >>> xs = torch.zeros((3, 2, 4)) >>> make_pad_mask(lengths, xs) tensor([[[0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 1], [0, 0, 0, 1]], [[0, 0, 1, 1], [0, 0, 1, 1]]], dtype=torch.uint8) >>> xs = torch.zeros((3, 2, 6)) >>> make_pad_mask(lengths, xs) tensor([[[0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 1]], [[0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1]], [[0, 0, 1, 1, 1, 1], [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) With the reference tensor and dimension indicator. >>> xs = torch.zeros((3, 6, 6)) >>> make_pad_mask(lengths, xs, 1) tensor([[[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]], [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]], [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8) >>> make_pad_mask(lengths, xs, 2) tensor([[[0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 1]], [[0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1]], [[0, 0, 1, 1, 1, 1], [0, 0, 1, 1, 1, 1], [0, 0, 1, 1, 1, 1], [0, 0, 1, 1, 1, 1], [0, 0, 1, 1, 1, 1], [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) """ if length_dim == 0: raise ValueError("length_dim cannot be 0: {}".format(length_dim)) if not isinstance(lengths, list): lengths = lengths.tolist() bs = int(len(lengths)) if xs is None: maxlen = int(max(lengths)) else: maxlen = xs.size(length_dim) seq_range = torch.arange(0, maxlen, dtype=torch.int64) seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) mask = seq_range_expand >= seq_length_expand if xs is not None: assert xs.size(0) == bs, (xs.size(0), bs) if length_dim < 0: length_dim = xs.dim() + length_dim # ind = (:, None, ..., None, :, , None, ..., None) ind = tuple( slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) ) mask = mask[ind].expand_as(xs).to(xs.device) return mask def make_non_pad_mask(lengths, xs=None, length_dim=-1): """Make mask tensor containing indices of non-padded part. Args: lengths (LongTensor or List): Batch of lengths (B,). xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor. length_dim (int, optional): Dimension indicator of the above tensor. See the example. Returns: ByteTensor: mask tensor containing indices of padded part. dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (including 1.2) Examples: With only lengths. >>> lengths = [5, 3, 2] >>> make_non_pad_mask(lengths) masks = [[1, 1, 1, 1 ,1], [1, 1, 1, 0, 0], [1, 1, 0, 0, 0]] With the reference tensor. >>> xs = torch.zeros((3, 2, 4)) >>> make_non_pad_mask(lengths, xs) tensor([[[1, 1, 1, 1], [1, 1, 1, 1]], [[1, 1, 1, 0], [1, 1, 1, 0]], [[1, 1, 0, 0], [1, 1, 0, 0]]], dtype=torch.uint8) >>> xs = torch.zeros((3, 2, 6)) >>> make_non_pad_mask(lengths, xs) tensor([[[1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 0]], [[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0]], [[1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) With the reference tensor and dimension indicator. >>> xs = torch.zeros((3, 6, 6)) >>> make_non_pad_mask(lengths, xs, 1) tensor([[[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0]], [[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]], [[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8) >>> make_non_pad_mask(lengths, xs, 2) tensor([[[1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 0]], [[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0]], [[1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) """ return ~make_pad_mask(lengths, xs, length_dim) def get_mask_from_lengths(lengths): max_len = torch.max(lengths).item() ids = torch.arange(0, max_len).to(lengths.device) mask = (ids < lengths.unsqueeze(1)).bool() return mask def group_hidden_by_segs(h, seg_ids, max_len): """ :param h: [B, T, H] :param seg_ids: [B, T] :return: h_ph: [B, T_ph, H] """ B, T, H = h.shape h_gby_segs = h.new_zeros([B, max_len + 1, H]).scatter_add_(1, seg_ids[:, :, None].repeat([1, 1, H]), h) all_ones = h.new_ones(h.shape[:2]) cnt_gby_segs = h.new_zeros([B, max_len + 1]).scatter_add_(1, seg_ids, all_ones).contiguous() h_gby_segs = h_gby_segs[:, 1:] cnt_gby_segs = cnt_gby_segs[:, 1:] h_gby_segs = h_gby_segs / torch.clamp(cnt_gby_segs[:, :, None], min=1) return h_gby_segs, cnt_gby_segs