# The following piece of code was adapted from https://github.com/kaituoxu/Conv-TasNet # released under the MIT License. # Author: Kaituo XU # Created on 2018/12 from itertools import permutations import torch import torch.nn.functional as F EPS = 1e-8 def cal_loss(source, estimate_source, source_lengths): """ Args: source: [B, C, T], B is batch size estimate_source: [B, C, T] source_lengths: [B] """ max_snr, perms, max_snr_idx, snr_set = cal_si_snr_with_pit(source, estimate_source, source_lengths) B, C, T = estimate_source.shape loss = 0 - torch.mean(max_snr) reorder_estimate_source = reorder_source( estimate_source, perms, max_snr_idx) return loss, max_snr, estimate_source, reorder_estimate_source def cal_si_snr_with_pit(source, estimate_source, source_lengths): """Calculate SI-SNR with PIT training. Args: source: [B, C, T], B is batch size estimate_source: [B, C, T] source_lengths: [B], each item is between [0, T] """ assert source.size() == estimate_source.size() B, C, T = source.size() # mask padding position along T mask = get_mask(source, source_lengths) estimate_source *= mask # Step 1. Zero-mean norm num_samples = source_lengths.view(-1, 1, 1).float() # [B, 1, 1] mean_target = torch.sum(source, dim=2, keepdim=True) / num_samples mean_estimate = torch.sum(estimate_source, dim=2, keepdim=True) / num_samples zero_mean_target = source - mean_target zero_mean_estimate = estimate_source - mean_estimate # mask padding position along T zero_mean_target *= mask zero_mean_estimate *= mask # Step 2. SI-SNR with PIT # reshape to use broadcast s_target = torch.unsqueeze(zero_mean_target, dim=1) # [B, 1, C, T] s_estimate = torch.unsqueeze(zero_mean_estimate, dim=2) # [B, C, 1, T] # s_target = s / ||s||^2 pair_wise_dot = torch.sum(s_estimate * s_target, dim=3, keepdim=True) # [B, C, C, 1] s_target_energy = torch.sum( s_target ** 2, dim=3, keepdim=True) + EPS # [B, 1, C, 1] pair_wise_proj = pair_wise_dot * s_target / s_target_energy # [B, C, C, T] # e_noise = s' - s_target e_noise = s_estimate - pair_wise_proj # [B, C, C, T] # SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2) pair_wise_si_snr = torch.sum( pair_wise_proj ** 2, dim=3) / (torch.sum(e_noise ** 2, dim=3) + EPS) pair_wise_si_snr = 10 * torch.log10(pair_wise_si_snr + EPS) # [B, C, C] pair_wise_si_snr = torch.transpose(pair_wise_si_snr, 1, 2) # Get max_snr of each utterance # permutations, [C!, C] perms = source.new_tensor(list(permutations(range(C))), dtype=torch.long) # one-hot, [C!, C, C] index = torch.unsqueeze(perms, 2) perms_one_hot = source.new_zeros((*perms.size(), C)).scatter_(2, index, 1) # [B, C!] <- [B, C, C] einsum [C!, C, C], SI-SNR sum of each permutation snr_set = torch.einsum('bij,pij->bp', [pair_wise_si_snr, perms_one_hot]) max_snr_idx = torch.argmax(snr_set, dim=1) # [B] # max_snr = torch.gather(snr_set, 1, max_snr_idx.view(-1, 1)) # [B, 1] max_snr, _ = torch.max(snr_set, dim=1, keepdim=True) max_snr /= C return max_snr, perms, max_snr_idx, snr_set / C def reorder_source(source, perms, max_snr_idx): """ Args: source: [B, C, T] perms: [C!, C], permutations max_snr_idx: [B], each item is between [0, C!) Returns: reorder_source: [B, C, T] """ B, C, *_ = source.size() # [B, C], permutation whose SI-SNR is max of each utterance # for each utterance, reorder estimate source according this permutation max_snr_perm = torch.index_select(perms, dim=0, index=max_snr_idx) # print('max_snr_perm', max_snr_perm) # maybe use torch.gather()/index_select()/scatter() to impl this? reorder_source = torch.zeros_like(source) for b in range(B): for c in range(C): reorder_source[b, c] = source[b, max_snr_perm[b][c]] return reorder_source def get_mask(source, source_lengths): """ Args: source: [B, C, T] source_lengths: [B] Returns: mask: [B, 1, T] """ B, _, T = source.size() mask = source.new_ones((B, 1, T)) for i in range(B): mask[i, :, source_lengths[i]:] = 0 return mask