Spaces:
Build error
Build error
from collections import defaultdict | |
import torch | |
import torch.nn.functional as F | |
def make_positions(tensor, padding_idx): | |
"""Replace non-padding symbols with their position numbers. | |
Position numbers begin at padding_idx+1. Padding symbols are ignored. | |
""" | |
# The series of casts and type-conversions here are carefully | |
# balanced to both work with ONNX export and XLA. In particular XLA | |
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know | |
# how to handle the dtype kwarg in cumsum. | |
mask = tensor.ne(padding_idx).int() | |
return ( | |
torch.cumsum(mask, dim=1).type_as(mask) * mask | |
).long() + padding_idx | |
def softmax(x, dim): | |
return F.softmax(x, dim=dim, dtype=torch.float32) | |
def sequence_mask(lengths, maxlen, dtype=torch.bool): | |
if maxlen is None: | |
maxlen = lengths.max() | |
mask = ~(torch.ones((len(lengths), maxlen)).to(lengths.device).cumsum(dim=1).t() > lengths).t() | |
mask.type(dtype) | |
return mask | |
def weights_nonzero_speech(target): | |
# target : B x T x mel | |
# Assign weight 1.0 to all labels except for padding (id=0). | |
dim = target.size(-1) | |
return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim) | |
INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0) | |
def _get_full_incremental_state_key(module_instance, key): | |
module_name = module_instance.__class__.__name__ | |
# assign a unique ID to each module instance, so that incremental state is | |
# not shared across module instances | |
if not hasattr(module_instance, '_instance_id'): | |
INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1 | |
module_instance._instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name] | |
return '{}.{}.{}'.format(module_name, module_instance._instance_id, key) | |
def get_incremental_state(module, incremental_state, key): | |
"""Helper for getting incremental state for an nn.Module.""" | |
full_key = _get_full_incremental_state_key(module, key) | |
if incremental_state is None or full_key not in incremental_state: | |
return None | |
return incremental_state[full_key] | |
def set_incremental_state(module, incremental_state, key, value): | |
"""Helper for setting incremental state for an nn.Module.""" | |
if incremental_state is not None: | |
full_key = _get_full_incremental_state_key(module, key) | |
incremental_state[full_key] = value | |
def fill_with_neg_inf(t): | |
"""FP16-compatible function that fills a tensor with -inf.""" | |
return t.float().fill_(float('-inf')).type_as(t) | |
def fill_with_neg_inf2(t): | |
"""FP16-compatible function that fills a tensor with -inf.""" | |
return t.float().fill_(-1e8).type_as(t) | |
def select_attn(attn_logits, type='best'): | |
""" | |
:param attn_logits: [n_layers, B, n_head, T_sp, T_txt] | |
:return: | |
""" | |
encdec_attn = torch.stack(attn_logits, 0).transpose(1, 2) | |
# [n_layers * n_head, B, T_sp, T_txt] | |
encdec_attn = (encdec_attn.reshape([-1, *encdec_attn.shape[2:]])).softmax(-1) | |
if type == 'best': | |
indices = encdec_attn.max(-1).values.sum(-1).argmax(0) | |
encdec_attn = encdec_attn.gather( | |
0, indices[None, :, None, None].repeat(1, 1, encdec_attn.size(-2), encdec_attn.size(-1)))[0] | |
return encdec_attn | |
elif type == 'mean': | |
return encdec_attn.mean(0) | |
def make_pad_mask(lengths, xs=None, length_dim=-1): | |
"""Make mask tensor containing indices of padded part. | |
Args: | |
lengths (LongTensor or List): Batch of lengths (B,). | |
xs (Tensor, optional): The reference tensor. | |
If set, masks will be the same shape as this tensor. | |
length_dim (int, optional): Dimension indicator of the above tensor. | |
See the example. | |
Returns: | |
Tensor: Mask tensor containing indices of padded part. | |
dtype=torch.uint8 in PyTorch 1.2- | |
dtype=torch.bool in PyTorch 1.2+ (including 1.2) | |
Examples: | |
With only lengths. | |
>>> lengths = [5, 3, 2] | |
>>> make_non_pad_mask(lengths) | |
masks = [[0, 0, 0, 0 ,0], | |
[0, 0, 0, 1, 1], | |
[0, 0, 1, 1, 1]] | |
With the reference tensor. | |
>>> xs = torch.zeros((3, 2, 4)) | |
>>> make_pad_mask(lengths, xs) | |
tensor([[[0, 0, 0, 0], | |
[0, 0, 0, 0]], | |
[[0, 0, 0, 1], | |
[0, 0, 0, 1]], | |
[[0, 0, 1, 1], | |
[0, 0, 1, 1]]], dtype=torch.uint8) | |
>>> xs = torch.zeros((3, 2, 6)) | |
>>> make_pad_mask(lengths, xs) | |
tensor([[[0, 0, 0, 0, 0, 1], | |
[0, 0, 0, 0, 0, 1]], | |
[[0, 0, 0, 1, 1, 1], | |
[0, 0, 0, 1, 1, 1]], | |
[[0, 0, 1, 1, 1, 1], | |
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) | |
With the reference tensor and dimension indicator. | |
>>> xs = torch.zeros((3, 6, 6)) | |
>>> make_pad_mask(lengths, xs, 1) | |
tensor([[[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
[1, 1, 1, 1, 1, 1]], | |
[[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
[1, 1, 1, 1, 1, 1], | |
[1, 1, 1, 1, 1, 1], | |
[1, 1, 1, 1, 1, 1]], | |
[[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
[1, 1, 1, 1, 1, 1], | |
[1, 1, 1, 1, 1, 1], | |
[1, 1, 1, 1, 1, 1], | |
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8) | |
>>> make_pad_mask(lengths, xs, 2) | |
tensor([[[0, 0, 0, 0, 0, 1], | |
[0, 0, 0, 0, 0, 1], | |
[0, 0, 0, 0, 0, 1], | |
[0, 0, 0, 0, 0, 1], | |
[0, 0, 0, 0, 0, 1], | |
[0, 0, 0, 0, 0, 1]], | |
[[0, 0, 0, 1, 1, 1], | |
[0, 0, 0, 1, 1, 1], | |
[0, 0, 0, 1, 1, 1], | |
[0, 0, 0, 1, 1, 1], | |
[0, 0, 0, 1, 1, 1], | |
[0, 0, 0, 1, 1, 1]], | |
[[0, 0, 1, 1, 1, 1], | |
[0, 0, 1, 1, 1, 1], | |
[0, 0, 1, 1, 1, 1], | |
[0, 0, 1, 1, 1, 1], | |
[0, 0, 1, 1, 1, 1], | |
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) | |
""" | |
if length_dim == 0: | |
raise ValueError("length_dim cannot be 0: {}".format(length_dim)) | |
if not isinstance(lengths, list): | |
lengths = lengths.tolist() | |
bs = int(len(lengths)) | |
if xs is None: | |
maxlen = int(max(lengths)) | |
else: | |
maxlen = xs.size(length_dim) | |
seq_range = torch.arange(0, maxlen, dtype=torch.int64) | |
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) | |
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) | |
mask = seq_range_expand >= seq_length_expand | |
if xs is not None: | |
assert xs.size(0) == bs, (xs.size(0), bs) | |
if length_dim < 0: | |
length_dim = xs.dim() + length_dim | |
# ind = (:, None, ..., None, :, , None, ..., None) | |
ind = tuple( | |
slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) | |
) | |
mask = mask[ind].expand_as(xs).to(xs.device) | |
return mask | |
def make_non_pad_mask(lengths, xs=None, length_dim=-1): | |
"""Make mask tensor containing indices of non-padded part. | |
Args: | |
lengths (LongTensor or List): Batch of lengths (B,). | |
xs (Tensor, optional): The reference tensor. | |
If set, masks will be the same shape as this tensor. | |
length_dim (int, optional): Dimension indicator of the above tensor. | |
See the example. | |
Returns: | |
ByteTensor: mask tensor containing indices of padded part. | |
dtype=torch.uint8 in PyTorch 1.2- | |
dtype=torch.bool in PyTorch 1.2+ (including 1.2) | |
Examples: | |
With only lengths. | |
>>> lengths = [5, 3, 2] | |
>>> make_non_pad_mask(lengths) | |
masks = [[1, 1, 1, 1 ,1], | |
[1, 1, 1, 0, 0], | |
[1, 1, 0, 0, 0]] | |
With the reference tensor. | |
>>> xs = torch.zeros((3, 2, 4)) | |
>>> make_non_pad_mask(lengths, xs) | |
tensor([[[1, 1, 1, 1], | |
[1, 1, 1, 1]], | |
[[1, 1, 1, 0], | |
[1, 1, 1, 0]], | |
[[1, 1, 0, 0], | |
[1, 1, 0, 0]]], dtype=torch.uint8) | |
>>> xs = torch.zeros((3, 2, 6)) | |
>>> make_non_pad_mask(lengths, xs) | |
tensor([[[1, 1, 1, 1, 1, 0], | |
[1, 1, 1, 1, 1, 0]], | |
[[1, 1, 1, 0, 0, 0], | |
[1, 1, 1, 0, 0, 0]], | |
[[1, 1, 0, 0, 0, 0], | |
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) | |
With the reference tensor and dimension indicator. | |
>>> xs = torch.zeros((3, 6, 6)) | |
>>> make_non_pad_mask(lengths, xs, 1) | |
tensor([[[1, 1, 1, 1, 1, 1], | |
[1, 1, 1, 1, 1, 1], | |
[1, 1, 1, 1, 1, 1], | |
[1, 1, 1, 1, 1, 1], | |
[1, 1, 1, 1, 1, 1], | |
[0, 0, 0, 0, 0, 0]], | |
[[1, 1, 1, 1, 1, 1], | |
[1, 1, 1, 1, 1, 1], | |
[1, 1, 1, 1, 1, 1], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0]], | |
[[1, 1, 1, 1, 1, 1], | |
[1, 1, 1, 1, 1, 1], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8) | |
>>> make_non_pad_mask(lengths, xs, 2) | |
tensor([[[1, 1, 1, 1, 1, 0], | |
[1, 1, 1, 1, 1, 0], | |
[1, 1, 1, 1, 1, 0], | |
[1, 1, 1, 1, 1, 0], | |
[1, 1, 1, 1, 1, 0], | |
[1, 1, 1, 1, 1, 0]], | |
[[1, 1, 1, 0, 0, 0], | |
[1, 1, 1, 0, 0, 0], | |
[1, 1, 1, 0, 0, 0], | |
[1, 1, 1, 0, 0, 0], | |
[1, 1, 1, 0, 0, 0], | |
[1, 1, 1, 0, 0, 0]], | |
[[1, 1, 0, 0, 0, 0], | |
[1, 1, 0, 0, 0, 0], | |
[1, 1, 0, 0, 0, 0], | |
[1, 1, 0, 0, 0, 0], | |
[1, 1, 0, 0, 0, 0], | |
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) | |
""" | |
return ~make_pad_mask(lengths, xs, length_dim) | |
def get_mask_from_lengths(lengths): | |
max_len = torch.max(lengths).item() | |
ids = torch.arange(0, max_len).to(lengths.device) | |
mask = (ids < lengths.unsqueeze(1)).bool() | |
return mask | |
def group_hidden_by_segs(h, seg_ids, max_len): | |
""" | |
:param h: [B, T, H] | |
:param seg_ids: [B, T] | |
:return: h_ph: [B, T_ph, H] | |
""" | |
B, T, H = h.shape | |
h_gby_segs = h.new_zeros([B, max_len + 1, H]).scatter_add_(1, seg_ids[:, :, None].repeat([1, 1, H]), h) | |
all_ones = h.new_ones(h.shape[:2]) | |
cnt_gby_segs = h.new_zeros([B, max_len + 1]).scatter_add_(1, seg_ids, all_ones).contiguous() | |
h_gby_segs = h_gby_segs[:, 1:] | |
cnt_gby_segs = cnt_gby_segs[:, 1:] | |
h_gby_segs = h_gby_segs / torch.clamp(cnt_gby_segs[:, :, None], min=1) | |
return h_gby_segs, cnt_gby_segs | |