Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import torch | |
import torchaudio.functional | |
import torchaudio | |
from . import utils | |
import pdb | |
def stft_frame_length(pr): return int(pr.frame_length_ms * pr.samp_sr * 0.001) | |
def stft_frame_step(pr): return int(pr.frame_step_ms * pr.samp_sr * 0.001) | |
def stft_num_fft(pr): return int(2**np.ceil(np.log2(stft_frame_length(pr)))) | |
def log10(x): return torch.log(x)/torch.log(torch.tensor(10.)) | |
def db_from_amp(x, cuda=False): | |
if cuda: | |
return 20. * log10(torch.max(torch.tensor(1e-5).to('cuda'), x.float())) | |
else: | |
return 20. * log10(torch.max(torch.tensor(1e-5), x.float())) | |
def amp_from_db(x): | |
return torch.pow(10., x / 20.) | |
def norm_range(x, min_val, max_val): | |
return 2.*(x - min_val)/float(max_val - min_val) - 1. | |
def unnorm_range(y, min_val, max_val): | |
return 0.5*float(max_val - min_val) * (y + 1) + min_val | |
def normalize_spec(spec, pr): | |
return norm_range(spec, pr.spec_min, pr.spec_max) | |
def unnormalize_spec(spec, pr): | |
return unnorm_range(spec, pr.spec_min, pr.spec_max) | |
def normalize_phase(phase, pr): | |
return norm_range(phase, -np.pi, np.pi) | |
def unnormalize_phase(phase, pr): | |
return unnorm_range(phase, -np.pi, np.pi) | |
def normalize_ims(im): | |
if type(im) == type(np.array([])): | |
im = im.astype('float32') | |
else: | |
im = im.float() | |
return -1. + 2. * im | |
def stft(samples, pr, cuda=False): | |
spec_complex = torch.stft( | |
samples, | |
stft_num_fft(pr), | |
hop_length=stft_frame_step(pr), | |
win_length=stft_frame_length(pr)).transpose(1,2) | |
real = spec_complex[..., 0] | |
imag = spec_complex[..., 1] | |
mag = torch.sqrt((real**2) + (imag**2)) | |
phase = utils.angle(real, imag) | |
if pr.log_spec: | |
mag = db_from_amp(mag, cuda=cuda) | |
return mag, phase | |
def make_complex(mag, phase): | |
return torch.cat(((mag * torch.cos(phase)).unsqueeze(-1), (mag * torch.sin(phase)).unsqueeze(-1)), -1) | |
def istft(mag, phase, pr): | |
if pr.log_spec: | |
mag = amp_from_db(mag) | |
# print(make_complex(mag, phase).shape) | |
samples = torchaudio.functional.istft( | |
make_complex(mag, phase).transpose(1,2), | |
stft_num_fft(pr), | |
hop_length=stft_frame_step(pr), | |
win_length=stft_frame_length(pr)) | |
return samples | |
def aud2spec(sample, pr, stereo=False, norm=False, cuda=True): | |
sample = sample[:, :pr.sample_len] | |
spec, phase = stft(sample.transpose(1,2).reshape((sample.shape[0]*2, -1)), pr, cuda=cuda) | |
spec = spec.reshape(sample.shape[0], 2, pr.spec_len, -1) | |
phase = phase.reshape(sample.shape[0], 2, pr.spec_len, -1) | |
return spec, phase | |
def mix_sounds(samples0, pr, samples1=None, cuda=False, dominant=False, noise_ratio=0): | |
# pdb.set_trace() | |
samples0 = utils.normalize_rms(samples0, pr.input_rms) | |
if samples1 is not None: | |
samples1 = utils.normalize_rms(samples1, pr.input_rms) | |
if dominant: | |
samples0 = samples0[:, :pr.sample_len] | |
samples1 = samples1[:, :pr.sample_len] * noise_ratio | |
else: | |
samples0 = samples0[:, :pr.sample_len] | |
samples1 = samples1[:, :pr.sample_len] | |
samples_mix = (samples0 + samples1) | |
if cuda: | |
samples0 = samples0.to('cuda') | |
samples1 = samples1.to('cuda') | |
samples_mix = samples_mix.to('cuda') | |
spec_mix, phase_mix = stft(samples_mix, pr, cuda=cuda) | |
spec0, phase0 = stft(samples0, pr, cuda=cuda) | |
spec1, phase1 = stft(samples1, pr, cuda=cuda) | |
spec_mix = spec_mix[:, :pr.spec_len] | |
phase_mix = phase_mix[:, :pr.spec_len] | |
spec0 = spec0[:, :pr.spec_len] | |
spec1 = spec1[:, :pr.spec_len] | |
phase0 = phase0[:, :pr.spec_len] | |
phase1 = phase1[:, :pr.spec_len] | |
return utils.Struct( | |
samples=samples_mix.float(), | |
phase=phase_mix.float(), | |
spec=spec_mix.float(), | |
sample_parts=[samples0, samples1], | |
spec_parts=[spec0.float(), spec1.float()], | |
phase_parts=[phase0.float(), phase1.float()]) | |
def pit_loss(pred_spec_fg, pred_spec_bg, snd, pr, cuda=True, vis=False): | |
# if pr.norm_spec: | |
def ns(x): return normalize_spec(x, pr) | |
# else: | |
# def ns(x): return x | |
if pr.norm: | |
gts_ = [[ns(snd.spec_parts[0]), None], | |
[ns(snd.spec_parts[1]), None]] | |
preds = [[ns(pred_spec_fg), None], | |
[ns(pred_spec_bg), None]] | |
else: | |
gts_ = [[snd.spec_parts[0], None], | |
[snd.spec_parts[1], None]] | |
preds = [[pred_spec_fg, None], | |
[pred_spec_bg, None]] | |
def l1(x, y): return torch.mean(torch.abs(x - y), (1, 2)) | |
losses = [] | |
for i in range(2): | |
gt = [gts_[i % 2], gts_[(i+1) % 2]] | |
fg_spec = pr.l1_weight * l1(preds[0][0], gt[0][0]) | |
bg_spec = pr.l1_weight * l1(preds[1][0], gt[1][0]) | |
losses.append(fg_spec + bg_spec) | |
losses = torch.cat([x.unsqueeze(0) for x in losses], dim=0) | |
if vis: | |
print(losses) | |
loss_val = torch.min(losses, dim=0) | |
if vis: | |
print(loss_val[1]) | |
loss = torch.mean(loss_val[0]) | |
return loss | |
def diff_loss(spec_diff, phase_diff, snd, pr, device, norm=False, vis=False): | |
def ns(x): return normalize_spec(x, pr) | |
def np(x): return normalize_phase(x, pr) | |
criterion = torch.nn.L1Loss() | |
gt_spec_diff = snd.spec_diff | |
gt_phase_diff = snd.phase_diff | |
criterion = criterion.to(device) | |
if norm: | |
gt_spec_diff = ns(gt_spec_diff) | |
gt_phase_diff = np(gt_phase_diff) | |
pred_spec_diff = ns(spec_diff) | |
pred_phase_diff = np(phase_diff) | |
else: | |
pred_spec_diff = spec_diff | |
pred_phase_diff = phase_diff | |
spec_loss = criterion(pred_spec_diff, gt_spec_diff) | |
phase_loss = criterion(pred_phase_diff, gt_phase_diff) | |
loss = pr.l1_weight * spec_loss + pr.phase_weight * phase_loss | |
if vis: | |
print(loss) | |
return loss | |
# def pit_loss(out, snd, pr, cuda=False, vis=False): | |
# def ns(x): return normalize_spec(x, pr) | |
# def np(x): return normalize_phase(x, pr) | |
# if cuda: | |
# snd['spec_part0'] = snd['spec_part0'].to('cuda') | |
# snd['phase_part0'] = snd['phase_part0'].to('cuda') | |
# snd['spec_part1'] = snd['spec_part1'].to('cuda') | |
# snd['phase_part1'] = snd['phase_part1'].to('cuda') | |
# # gts_ = [[ns(snd['spec_part0'][:, 0, :, :]), np(snd['phase_part0'][:, 0, :, :])], | |
# # [ns(snd['spec_part1'][:, 0, :, :]), np(snd['phase_part1'][:, 0, :, :])]] | |
# gts_ = [[ns(snd.spec_parts[0]), np(snd.phase_parts[0])], | |
# [ns(snd.spec_parts[1]), np(snd.phase_parts[1])]] | |
# preds = [[ns(out.pred_spec_fg), np(out.pred_phase_fg)], | |
# [ns(out.pred_spec_bg), np(out.pred_phase_bg)]] | |
# def l1(x, y): return torch.mean(torch.abs(x - y), (1, 2)) | |
# losses = [] | |
# for i in range(2): | |
# gt = [gts_[i % 2], gts_[(i+1) % 2]] | |
# # print 'preds[0][0] shape =', shape(preds[0][0]) | |
# # fg_spec = pr.l1_weight * l1(preds[0][0], gt[0][0]) | |
# # fg_phase = pr.phase_weight * l1(preds[0][1], gt[0][1]) | |
# # bg_spec = pr.l1_weight * l1(preds[1][0], gt[1][0]) | |
# # bg_phase = pr.phase_weight * l1(preds[1][1], gt[1][1]) | |
# # losses.append(fg_spec + fg_phase + bg_spec + bg_phase) | |
# fg_spec = pr.l1_weight * l1(preds[0][0], gt[0][0]) | |
# bg_spec = pr.l1_weight * l1(preds[1][0], gt[1][0]) | |
# losses.append(fg_spec + bg_spec) | |
# # pdb.set_trace() | |
# # pdb.set_trace() | |
# losses = torch.cat([x.unsqueeze(0) for x in losses], dim=0) | |
# if vis: | |
# print(losses) | |
# loss_val = torch.min(losses, dim=0) | |
# if vis: | |
# print(loss_val[1]) | |
# loss = torch.mean(loss_val[0]) | |
# return loss | |
# def stereo_mel() | |
def audio_stft(stft, audio, pr): | |
N, C, A = audio.size() | |
audio = audio.view(N * C, A) | |
spec = stft(audio) | |
spec = spec.transpose(-1, -2) | |
spec = db_from_amp(spec, cuda=True) | |
spec = normalize_spec(spec, pr) | |
_, T, F = spec.size() | |
spec = spec.view(N, C, T, F) | |
return spec | |
def normalize_audio(samples, desired_rms=0.1, eps=1e-4): | |
# print(np.mean(samples**2)) | |
rms = np.maximum(eps, np.sqrt(np.mean(samples**2))) | |
samples = samples * (desired_rms / rms) | |
return samples |