import torch import torch.nn.functional as F def build_word_mask(x2word, y2word): return (x2word[:, :, None] == y2word[:, None, :]).long() def mel2ph_to_mel2word(mel2ph, ph2word): mel2word = (ph2word - 1).gather(1, (mel2ph - 1).clamp(min=0)) + 1 mel2word = mel2word * (mel2ph > 0).long() return mel2word def clip_mel2token_to_multiple(mel2token, frames_multiple): if mel2token.shape[1] % frames_multiple > 0: max_frames = mel2token.shape[1] // frames_multiple * frames_multiple mel2token = mel2token[:, :max_frames] return mel2token def expand_states(h, mel2token): h = F.pad(h, [0, 0, 1, 0]) mel2token_ = mel2token[..., None].repeat([1, 1, h.shape[-1]]) h = torch.gather(h, 1, mel2token_) # [B, T, H] return h