|
from collections import defaultdict |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
def make_positions(tensor, padding_idx): |
|
"""Replace non-padding symbols with their position numbers. |
|
|
|
Position numbers begin at padding_idx+1. Padding symbols are ignored. |
|
""" |
|
|
|
|
|
|
|
|
|
mask = tensor.ne(padding_idx).int() |
|
return ( |
|
torch.cumsum(mask, dim=1).type_as(mask) * mask |
|
).long() + padding_idx |
|
|
|
|
|
def softmax(x, dim): |
|
return F.softmax(x, dim=dim, dtype=torch.float32) |
|
|
|
|
|
def sequence_mask(lengths, maxlen, dtype=torch.bool): |
|
if maxlen is None: |
|
maxlen = lengths.max() |
|
mask = ~(torch.ones((len(lengths), maxlen)).to(lengths.device).cumsum(dim=1).t() > lengths).t() |
|
mask.type(dtype) |
|
return mask |
|
|
|
|
|
INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0) |
|
|
|
|
|
def _get_full_incremental_state_key(module_instance, key): |
|
module_name = module_instance.__class__.__name__ |
|
|
|
|
|
|
|
if not hasattr(module_instance, '_instance_id'): |
|
INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1 |
|
module_instance._instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name] |
|
|
|
return '{}.{}.{}'.format(module_name, module_instance._instance_id, key) |
|
|
|
|
|
def get_incremental_state(module, incremental_state, key): |
|
"""Helper for getting incremental state for an nn.Module.""" |
|
full_key = _get_full_incremental_state_key(module, key) |
|
if incremental_state is None or full_key not in incremental_state: |
|
return None |
|
return incremental_state[full_key] |
|
|
|
|
|
def set_incremental_state(module, incremental_state, key, value): |
|
"""Helper for setting incremental state for an nn.Module.""" |
|
if incremental_state is not None: |
|
full_key = _get_full_incremental_state_key(module, key) |
|
incremental_state[full_key] = value |
|
|
|
|
|
def fill_with_neg_inf(t): |
|
"""FP16-compatible function that fills a tensor with -inf.""" |
|
return t.float().fill_(float('-inf')).type_as(t) |
|
|
|
|
|
def fill_with_neg_inf2(t): |
|
"""FP16-compatible function that fills a tensor with -inf.""" |
|
return t.float().fill_(-1e8).type_as(t) |
|
|
|
|
|
def get_focus_rate(attn, src_padding_mask=None, tgt_padding_mask=None): |
|
''' |
|
attn: bs x L_t x L_s |
|
''' |
|
if src_padding_mask is not None: |
|
attn = attn * (1 - src_padding_mask.float())[:, None, :] |
|
|
|
if tgt_padding_mask is not None: |
|
attn = attn * (1 - tgt_padding_mask.float())[:, :, None] |
|
|
|
focus_rate = attn.max(-1).values.sum(-1) |
|
focus_rate = focus_rate / attn.sum(-1).sum(-1) |
|
return focus_rate |
|
|
|
|
|
def get_phone_coverage_rate(attn, src_padding_mask=None, src_seg_mask=None, tgt_padding_mask=None): |
|
''' |
|
attn: bs x L_t x L_s |
|
''' |
|
src_mask = attn.new(attn.size(0), attn.size(-1)).bool().fill_(False) |
|
if src_padding_mask is not None: |
|
src_mask |= src_padding_mask |
|
if src_seg_mask is not None: |
|
src_mask |= src_seg_mask |
|
|
|
attn = attn * (1 - src_mask.float())[:, None, :] |
|
if tgt_padding_mask is not None: |
|
attn = attn * (1 - tgt_padding_mask.float())[:, :, None] |
|
|
|
phone_coverage_rate = attn.max(1).values.sum(-1) |
|
|
|
phone_coverage_rate = phone_coverage_rate / (1 - src_mask.float()).sum(-1) |
|
return phone_coverage_rate |
|
|
|
|
|
def get_diagonal_focus_rate(attn, attn_ks, target_len, src_padding_mask=None, tgt_padding_mask=None, |
|
band_mask_factor=5, band_width=50): |
|
''' |
|
attn: bx x L_t x L_s |
|
attn_ks: shape: tensor with shape [batch_size], input_lens/output_lens |
|
|
|
diagonal: y=k*x (k=attn_ks, x:output, y:input) |
|
1 0 0 |
|
0 1 0 |
|
0 0 1 |
|
y>=k*(x-width) and y<=k*(x+width):1 |
|
else:0 |
|
''' |
|
|
|
width1 = target_len / band_mask_factor |
|
width2 = target_len.new(target_len.size()).fill_(band_width) |
|
width = torch.where(width1 < width2, width1, width2).float() |
|
base = torch.ones(attn.size()).to(attn.device) |
|
zero = torch.zeros(attn.size()).to(attn.device) |
|
x = torch.arange(0, attn.size(1)).to(attn.device)[None, :, None].float() * base |
|
y = torch.arange(0, attn.size(2)).to(attn.device)[None, None, :].float() * base |
|
cond = (y - attn_ks[:, None, None] * x) |
|
cond1 = cond + attn_ks[:, None, None] * width[:, None, None] |
|
cond2 = cond - attn_ks[:, None, None] * width[:, None, None] |
|
mask1 = torch.where(cond1 < 0, zero, base) |
|
mask2 = torch.where(cond2 > 0, zero, base) |
|
mask = mask1 * mask2 |
|
|
|
if src_padding_mask is not None: |
|
attn = attn * (1 - src_padding_mask.float())[:, None, :] |
|
if tgt_padding_mask is not None: |
|
attn = attn * (1 - tgt_padding_mask.float())[:, :, None] |
|
|
|
diagonal_attn = attn * mask |
|
diagonal_focus_rate = diagonal_attn.sum(-1).sum(-1) / attn.sum(-1).sum(-1) |
|
return diagonal_focus_rate, mask |
|
|
|
|
|
def select_attn(attn_logits, type='best'): |
|
""" |
|
|
|
:param attn_logits: [n_layers, B, n_head, T_sp, T_txt] |
|
:return: |
|
""" |
|
encdec_attn = torch.stack(attn_logits, 0).transpose(1, 2) |
|
|
|
encdec_attn = (encdec_attn.reshape([-1, *encdec_attn.shape[2:]])).softmax(-1) |
|
if type == 'best': |
|
indices = encdec_attn.max(-1).values.sum(-1).argmax(0) |
|
encdec_attn = encdec_attn.gather( |
|
0, indices[None, :, None, None].repeat(1, 1, encdec_attn.size(-2), encdec_attn.size(-1)))[0] |
|
return encdec_attn |
|
elif type == 'mean': |
|
return encdec_attn.mean(0) |
|
|
|
|
|
def make_pad_mask(lengths, xs=None, length_dim=-1): |
|
"""Make mask tensor containing indices of padded part. |
|
Args: |
|
lengths (LongTensor or List): Batch of lengths (B,). |
|
xs (Tensor, optional): The reference tensor. |
|
If set, masks will be the same shape as this tensor. |
|
length_dim (int, optional): Dimension indicator of the above tensor. |
|
See the example. |
|
Returns: |
|
Tensor: Mask tensor containing indices of padded part. |
|
dtype=torch.uint8 in PyTorch 1.2- |
|
dtype=torch.bool in PyTorch 1.2+ (including 1.2) |
|
Examples: |
|
With only lengths. |
|
>>> lengths = [5, 3, 2] |
|
>>> make_non_pad_mask(lengths) |
|
masks = [[0, 0, 0, 0 ,0], |
|
[0, 0, 0, 1, 1], |
|
[0, 0, 1, 1, 1]] |
|
With the reference tensor. |
|
>>> xs = torch.zeros((3, 2, 4)) |
|
>>> make_pad_mask(lengths, xs) |
|
tensor([[[0, 0, 0, 0], |
|
[0, 0, 0, 0]], |
|
[[0, 0, 0, 1], |
|
[0, 0, 0, 1]], |
|
[[0, 0, 1, 1], |
|
[0, 0, 1, 1]]], dtype=torch.uint8) |
|
>>> xs = torch.zeros((3, 2, 6)) |
|
>>> make_pad_mask(lengths, xs) |
|
tensor([[[0, 0, 0, 0, 0, 1], |
|
[0, 0, 0, 0, 0, 1]], |
|
[[0, 0, 0, 1, 1, 1], |
|
[0, 0, 0, 1, 1, 1]], |
|
[[0, 0, 1, 1, 1, 1], |
|
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) |
|
With the reference tensor and dimension indicator. |
|
>>> xs = torch.zeros((3, 6, 6)) |
|
>>> make_pad_mask(lengths, xs, 1) |
|
tensor([[[0, 0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0, 0], |
|
[1, 1, 1, 1, 1, 1]], |
|
[[0, 0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0, 0], |
|
[1, 1, 1, 1, 1, 1], |
|
[1, 1, 1, 1, 1, 1], |
|
[1, 1, 1, 1, 1, 1]], |
|
[[0, 0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0, 0], |
|
[1, 1, 1, 1, 1, 1], |
|
[1, 1, 1, 1, 1, 1], |
|
[1, 1, 1, 1, 1, 1], |
|
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8) |
|
>>> make_pad_mask(lengths, xs, 2) |
|
tensor([[[0, 0, 0, 0, 0, 1], |
|
[0, 0, 0, 0, 0, 1], |
|
[0, 0, 0, 0, 0, 1], |
|
[0, 0, 0, 0, 0, 1], |
|
[0, 0, 0, 0, 0, 1], |
|
[0, 0, 0, 0, 0, 1]], |
|
[[0, 0, 0, 1, 1, 1], |
|
[0, 0, 0, 1, 1, 1], |
|
[0, 0, 0, 1, 1, 1], |
|
[0, 0, 0, 1, 1, 1], |
|
[0, 0, 0, 1, 1, 1], |
|
[0, 0, 0, 1, 1, 1]], |
|
[[0, 0, 1, 1, 1, 1], |
|
[0, 0, 1, 1, 1, 1], |
|
[0, 0, 1, 1, 1, 1], |
|
[0, 0, 1, 1, 1, 1], |
|
[0, 0, 1, 1, 1, 1], |
|
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) |
|
""" |
|
if length_dim == 0: |
|
raise ValueError("length_dim cannot be 0: {}".format(length_dim)) |
|
|
|
if not isinstance(lengths, list): |
|
lengths = lengths.tolist() |
|
bs = int(len(lengths)) |
|
if xs is None: |
|
maxlen = int(max(lengths)) |
|
else: |
|
maxlen = xs.size(length_dim) |
|
|
|
seq_range = torch.arange(0, maxlen, dtype=torch.int64) |
|
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) |
|
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) |
|
mask = seq_range_expand >= seq_length_expand |
|
|
|
if xs is not None: |
|
assert xs.size(0) == bs, (xs.size(0), bs) |
|
|
|
if length_dim < 0: |
|
length_dim = xs.dim() + length_dim |
|
|
|
ind = tuple( |
|
slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) |
|
) |
|
mask = mask[ind].expand_as(xs).to(xs.device) |
|
return mask |
|
|
|
|
|
def make_non_pad_mask(lengths, xs=None, length_dim=-1): |
|
"""Make mask tensor containing indices of non-padded part. |
|
Args: |
|
lengths (LongTensor or List): Batch of lengths (B,). |
|
xs (Tensor, optional): The reference tensor. |
|
If set, masks will be the same shape as this tensor. |
|
length_dim (int, optional): Dimension indicator of the above tensor. |
|
See the example. |
|
Returns: |
|
ByteTensor: mask tensor containing indices of padded part. |
|
dtype=torch.uint8 in PyTorch 1.2- |
|
dtype=torch.bool in PyTorch 1.2+ (including 1.2) |
|
Examples: |
|
With only lengths. |
|
>>> lengths = [5, 3, 2] |
|
>>> make_non_pad_mask(lengths) |
|
masks = [[1, 1, 1, 1 ,1], |
|
[1, 1, 1, 0, 0], |
|
[1, 1, 0, 0, 0]] |
|
With the reference tensor. |
|
>>> xs = torch.zeros((3, 2, 4)) |
|
>>> make_non_pad_mask(lengths, xs) |
|
tensor([[[1, 1, 1, 1], |
|
[1, 1, 1, 1]], |
|
[[1, 1, 1, 0], |
|
[1, 1, 1, 0]], |
|
[[1, 1, 0, 0], |
|
[1, 1, 0, 0]]], dtype=torch.uint8) |
|
>>> xs = torch.zeros((3, 2, 6)) |
|
>>> make_non_pad_mask(lengths, xs) |
|
tensor([[[1, 1, 1, 1, 1, 0], |
|
[1, 1, 1, 1, 1, 0]], |
|
[[1, 1, 1, 0, 0, 0], |
|
[1, 1, 1, 0, 0, 0]], |
|
[[1, 1, 0, 0, 0, 0], |
|
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) |
|
With the reference tensor and dimension indicator. |
|
>>> xs = torch.zeros((3, 6, 6)) |
|
>>> make_non_pad_mask(lengths, xs, 1) |
|
tensor([[[1, 1, 1, 1, 1, 1], |
|
[1, 1, 1, 1, 1, 1], |
|
[1, 1, 1, 1, 1, 1], |
|
[1, 1, 1, 1, 1, 1], |
|
[1, 1, 1, 1, 1, 1], |
|
[0, 0, 0, 0, 0, 0]], |
|
[[1, 1, 1, 1, 1, 1], |
|
[1, 1, 1, 1, 1, 1], |
|
[1, 1, 1, 1, 1, 1], |
|
[0, 0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0, 0]], |
|
[[1, 1, 1, 1, 1, 1], |
|
[1, 1, 1, 1, 1, 1], |
|
[0, 0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8) |
|
>>> make_non_pad_mask(lengths, xs, 2) |
|
tensor([[[1, 1, 1, 1, 1, 0], |
|
[1, 1, 1, 1, 1, 0], |
|
[1, 1, 1, 1, 1, 0], |
|
[1, 1, 1, 1, 1, 0], |
|
[1, 1, 1, 1, 1, 0], |
|
[1, 1, 1, 1, 1, 0]], |
|
[[1, 1, 1, 0, 0, 0], |
|
[1, 1, 1, 0, 0, 0], |
|
[1, 1, 1, 0, 0, 0], |
|
[1, 1, 1, 0, 0, 0], |
|
[1, 1, 1, 0, 0, 0], |
|
[1, 1, 1, 0, 0, 0]], |
|
[[1, 1, 0, 0, 0, 0], |
|
[1, 1, 0, 0, 0, 0], |
|
[1, 1, 0, 0, 0, 0], |
|
[1, 1, 0, 0, 0, 0], |
|
[1, 1, 0, 0, 0, 0], |
|
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) |
|
""" |
|
return ~make_pad_mask(lengths, xs, length_dim) |
|
|
|
|
|
def get_mask_from_lengths(lengths): |
|
max_len = torch.max(lengths).item() |
|
ids = torch.arange(0, max_len).to(lengths.device) |
|
mask = (ids < lengths.unsqueeze(1)).bool() |
|
return mask |
|
|
|
|
|
def group_hidden_by_segs(h, seg_ids, max_len): |
|
""" |
|
|
|
:param h: [B, T, H] |
|
:param seg_ids: [B, T] |
|
:return: h_ph: [B, T_ph, H] |
|
""" |
|
B, T, H = h.shape |
|
h_gby_segs = h.new_zeros([B, max_len + 1, H]).scatter_add_(1, seg_ids[:, :, None].repeat([1, 1, H]), h) |
|
all_ones = h.new_ones(h.shape[:2]) |
|
cnt_gby_segs = h.new_zeros([B, max_len + 1]).scatter_add_(1, seg_ids, all_ones).contiguous() |
|
h_gby_segs = h_gby_segs[:, 1:] |
|
cnt_gby_segs = cnt_gby_segs[:, 1:] |
|
h_gby_segs = h_gby_segs / torch.clamp(cnt_gby_segs[:, :, None], min=1) |
|
return h_gby_segs, cnt_gby_segs |
|
|
|
def mel2token_to_dur(mel2token, T_txt=None, max_dur=None): |
|
is_torch = isinstance(mel2token, torch.Tensor) |
|
has_batch_dim = True |
|
if not is_torch: |
|
mel2token = torch.LongTensor(mel2token) |
|
if T_txt is None: |
|
T_txt = mel2token.max() |
|
if len(mel2token.shape) == 1: |
|
mel2token = mel2token[None, ...] |
|
has_batch_dim = False |
|
B, _ = mel2token.shape |
|
dur = mel2token.new_zeros(B, T_txt + 1).scatter_add(1, mel2token, torch.ones_like(mel2token)) |
|
dur = dur[:, 1:] |
|
if max_dur is not None: |
|
dur = dur.clamp(max=max_dur) |
|
if not is_torch: |
|
dur = dur.numpy() |
|
if not has_batch_dim: |
|
dur = dur[0] |
|
return dur |
|
|
|
def expand_word2ph(word_encoding, ph2word): |
|
word_encoding = F.pad(word_encoding,[0,0,1,0]) |
|
ph2word_ = ph2word[:, :, None].repeat([1, 1, word_encoding.shape[-1]]) |
|
out = torch.gather(word_encoding, 1, ph2word_) |
|
return out |