ymzhang319's picture
init
7f2690b
raw
history blame
8.24 kB
import numpy as np
import torch
import torchaudio.functional
import torchaudio
from . import utils
import pdb
def stft_frame_length(pr): return int(pr.frame_length_ms * pr.samp_sr * 0.001)
def stft_frame_step(pr): return int(pr.frame_step_ms * pr.samp_sr * 0.001)
def stft_num_fft(pr): return int(2**np.ceil(np.log2(stft_frame_length(pr))))
def log10(x): return torch.log(x)/torch.log(torch.tensor(10.))
def db_from_amp(x, cuda=False):
if cuda:
return 20. * log10(torch.max(torch.tensor(1e-5).to('cuda'), x.float()))
else:
return 20. * log10(torch.max(torch.tensor(1e-5), x.float()))
def amp_from_db(x):
return torch.pow(10., x / 20.)
def norm_range(x, min_val, max_val):
return 2.*(x - min_val)/float(max_val - min_val) - 1.
def unnorm_range(y, min_val, max_val):
return 0.5*float(max_val - min_val) * (y + 1) + min_val
def normalize_spec(spec, pr):
return norm_range(spec, pr.spec_min, pr.spec_max)
def unnormalize_spec(spec, pr):
return unnorm_range(spec, pr.spec_min, pr.spec_max)
def normalize_phase(phase, pr):
return norm_range(phase, -np.pi, np.pi)
def unnormalize_phase(phase, pr):
return unnorm_range(phase, -np.pi, np.pi)
def normalize_ims(im):
if type(im) == type(np.array([])):
im = im.astype('float32')
else:
im = im.float()
return -1. + 2. * im
def stft(samples, pr, cuda=False):
spec_complex = torch.stft(
samples,
stft_num_fft(pr),
hop_length=stft_frame_step(pr),
win_length=stft_frame_length(pr)).transpose(1,2)
real = spec_complex[..., 0]
imag = spec_complex[..., 1]
mag = torch.sqrt((real**2) + (imag**2))
phase = utils.angle(real, imag)
if pr.log_spec:
mag = db_from_amp(mag, cuda=cuda)
return mag, phase
def make_complex(mag, phase):
return torch.cat(((mag * torch.cos(phase)).unsqueeze(-1), (mag * torch.sin(phase)).unsqueeze(-1)), -1)
def istft(mag, phase, pr):
if pr.log_spec:
mag = amp_from_db(mag)
# print(make_complex(mag, phase).shape)
samples = torchaudio.functional.istft(
make_complex(mag, phase).transpose(1,2),
stft_num_fft(pr),
hop_length=stft_frame_step(pr),
win_length=stft_frame_length(pr))
return samples
def aud2spec(sample, pr, stereo=False, norm=False, cuda=True):
sample = sample[:, :pr.sample_len]
spec, phase = stft(sample.transpose(1,2).reshape((sample.shape[0]*2, -1)), pr, cuda=cuda)
spec = spec.reshape(sample.shape[0], 2, pr.spec_len, -1)
phase = phase.reshape(sample.shape[0], 2, pr.spec_len, -1)
return spec, phase
def mix_sounds(samples0, pr, samples1=None, cuda=False, dominant=False, noise_ratio=0):
# pdb.set_trace()
samples0 = utils.normalize_rms(samples0, pr.input_rms)
if samples1 is not None:
samples1 = utils.normalize_rms(samples1, pr.input_rms)
if dominant:
samples0 = samples0[:, :pr.sample_len]
samples1 = samples1[:, :pr.sample_len] * noise_ratio
else:
samples0 = samples0[:, :pr.sample_len]
samples1 = samples1[:, :pr.sample_len]
samples_mix = (samples0 + samples1)
if cuda:
samples0 = samples0.to('cuda')
samples1 = samples1.to('cuda')
samples_mix = samples_mix.to('cuda')
spec_mix, phase_mix = stft(samples_mix, pr, cuda=cuda)
spec0, phase0 = stft(samples0, pr, cuda=cuda)
spec1, phase1 = stft(samples1, pr, cuda=cuda)
spec_mix = spec_mix[:, :pr.spec_len]
phase_mix = phase_mix[:, :pr.spec_len]
spec0 = spec0[:, :pr.spec_len]
spec1 = spec1[:, :pr.spec_len]
phase0 = phase0[:, :pr.spec_len]
phase1 = phase1[:, :pr.spec_len]
return utils.Struct(
samples=samples_mix.float(),
phase=phase_mix.float(),
spec=spec_mix.float(),
sample_parts=[samples0, samples1],
spec_parts=[spec0.float(), spec1.float()],
phase_parts=[phase0.float(), phase1.float()])
def pit_loss(pred_spec_fg, pred_spec_bg, snd, pr, cuda=True, vis=False):
# if pr.norm_spec:
def ns(x): return normalize_spec(x, pr)
# else:
# def ns(x): return x
if pr.norm:
gts_ = [[ns(snd.spec_parts[0]), None],
[ns(snd.spec_parts[1]), None]]
preds = [[ns(pred_spec_fg), None],
[ns(pred_spec_bg), None]]
else:
gts_ = [[snd.spec_parts[0], None],
[snd.spec_parts[1], None]]
preds = [[pred_spec_fg, None],
[pred_spec_bg, None]]
def l1(x, y): return torch.mean(torch.abs(x - y), (1, 2))
losses = []
for i in range(2):
gt = [gts_[i % 2], gts_[(i+1) % 2]]
fg_spec = pr.l1_weight * l1(preds[0][0], gt[0][0])
bg_spec = pr.l1_weight * l1(preds[1][0], gt[1][0])
losses.append(fg_spec + bg_spec)
losses = torch.cat([x.unsqueeze(0) for x in losses], dim=0)
if vis:
print(losses)
loss_val = torch.min(losses, dim=0)
if vis:
print(loss_val[1])
loss = torch.mean(loss_val[0])
return loss
def diff_loss(spec_diff, phase_diff, snd, pr, device, norm=False, vis=False):
def ns(x): return normalize_spec(x, pr)
def np(x): return normalize_phase(x, pr)
criterion = torch.nn.L1Loss()
gt_spec_diff = snd.spec_diff
gt_phase_diff = snd.phase_diff
criterion = criterion.to(device)
if norm:
gt_spec_diff = ns(gt_spec_diff)
gt_phase_diff = np(gt_phase_diff)
pred_spec_diff = ns(spec_diff)
pred_phase_diff = np(phase_diff)
else:
pred_spec_diff = spec_diff
pred_phase_diff = phase_diff
spec_loss = criterion(pred_spec_diff, gt_spec_diff)
phase_loss = criterion(pred_phase_diff, gt_phase_diff)
loss = pr.l1_weight * spec_loss + pr.phase_weight * phase_loss
if vis:
print(loss)
return loss
# def pit_loss(out, snd, pr, cuda=False, vis=False):
# def ns(x): return normalize_spec(x, pr)
# def np(x): return normalize_phase(x, pr)
# if cuda:
# snd['spec_part0'] = snd['spec_part0'].to('cuda')
# snd['phase_part0'] = snd['phase_part0'].to('cuda')
# snd['spec_part1'] = snd['spec_part1'].to('cuda')
# snd['phase_part1'] = snd['phase_part1'].to('cuda')
# # gts_ = [[ns(snd['spec_part0'][:, 0, :, :]), np(snd['phase_part0'][:, 0, :, :])],
# # [ns(snd['spec_part1'][:, 0, :, :]), np(snd['phase_part1'][:, 0, :, :])]]
# gts_ = [[ns(snd.spec_parts[0]), np(snd.phase_parts[0])],
# [ns(snd.spec_parts[1]), np(snd.phase_parts[1])]]
# preds = [[ns(out.pred_spec_fg), np(out.pred_phase_fg)],
# [ns(out.pred_spec_bg), np(out.pred_phase_bg)]]
# def l1(x, y): return torch.mean(torch.abs(x - y), (1, 2))
# losses = []
# for i in range(2):
# gt = [gts_[i % 2], gts_[(i+1) % 2]]
# # print 'preds[0][0] shape =', shape(preds[0][0])
# # fg_spec = pr.l1_weight * l1(preds[0][0], gt[0][0])
# # fg_phase = pr.phase_weight * l1(preds[0][1], gt[0][1])
# # bg_spec = pr.l1_weight * l1(preds[1][0], gt[1][0])
# # bg_phase = pr.phase_weight * l1(preds[1][1], gt[1][1])
# # losses.append(fg_spec + fg_phase + bg_spec + bg_phase)
# fg_spec = pr.l1_weight * l1(preds[0][0], gt[0][0])
# bg_spec = pr.l1_weight * l1(preds[1][0], gt[1][0])
# losses.append(fg_spec + bg_spec)
# # pdb.set_trace()
# # pdb.set_trace()
# losses = torch.cat([x.unsqueeze(0) for x in losses], dim=0)
# if vis:
# print(losses)
# loss_val = torch.min(losses, dim=0)
# if vis:
# print(loss_val[1])
# loss = torch.mean(loss_val[0])
# return loss
# def stereo_mel()
def audio_stft(stft, audio, pr):
N, C, A = audio.size()
audio = audio.view(N * C, A)
spec = stft(audio)
spec = spec.transpose(-1, -2)
spec = db_from_amp(spec, cuda=True)
spec = normalize_spec(spec, pr)
_, T, F = spec.size()
spec = spec.view(N, C, T, F)
return spec
def normalize_audio(samples, desired_rms=0.1, eps=1e-4):
# print(np.mean(samples**2))
rms = np.maximum(eps, np.sqrt(np.mean(samples**2)))
samples = samples * (desired_rms / rms)
return samples