Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import random | |
import torch | |
from audiocraft.losses import ( | |
MelSpectrogramL1Loss, | |
MultiScaleMelSpectrogramLoss, | |
MRSTFTLoss, | |
SISNR, | |
STFTLoss, | |
) | |
from audiocraft.losses.loudnessloss import TFLoudnessRatio | |
from audiocraft.losses.wmloss import WMMbLoss | |
from tests.common_utils.wav_utils import get_white_noise | |
def test_mel_l1_loss(): | |
N, C, T = 2, 2, random.randrange(1000, 100_000) | |
t1 = torch.randn(N, C, T) | |
t2 = torch.randn(N, C, T) | |
mel_l1 = MelSpectrogramL1Loss(sample_rate=22_050) | |
loss = mel_l1(t1, t2) | |
loss_same = mel_l1(t1, t1) | |
assert isinstance(loss, torch.Tensor) | |
assert isinstance(loss_same, torch.Tensor) | |
assert loss_same.item() == 0.0 | |
def test_msspec_loss(): | |
N, C, T = 2, 2, random.randrange(1000, 100_000) | |
t1 = torch.randn(N, C, T) | |
t2 = torch.randn(N, C, T) | |
msspec = MultiScaleMelSpectrogramLoss(sample_rate=22_050) | |
loss = msspec(t1, t2) | |
loss_same = msspec(t1, t1) | |
assert isinstance(loss, torch.Tensor) | |
assert isinstance(loss_same, torch.Tensor) | |
assert loss_same.item() == 0.0 | |
def test_mrstft_loss(): | |
N, C, T = 2, 2, random.randrange(1000, 100_000) | |
t1 = torch.randn(N, C, T) | |
t2 = torch.randn(N, C, T) | |
mrstft = MRSTFTLoss() | |
loss = mrstft(t1, t2) | |
assert isinstance(loss, torch.Tensor) | |
def test_sisnr_loss(): | |
N, C, T = 2, 2, random.randrange(1000, 100_000) | |
t1 = torch.randn(N, C, T) | |
t2 = torch.randn(N, C, T) | |
sisnr = SISNR() | |
loss = sisnr(t1, t2) | |
assert isinstance(loss, torch.Tensor) | |
def test_stft_loss(): | |
N, C, T = 2, 2, random.randrange(1000, 100_000) | |
t1 = torch.randn(N, C, T) | |
t2 = torch.randn(N, C, T) | |
mrstft = STFTLoss() | |
loss = mrstft(t1, t2) | |
assert isinstance(loss, torch.Tensor) | |
def test_wm_loss(): | |
N, nbits, T = 2, 16, random.randrange(1000, 100_000) | |
positive = torch.randn(N, 2 + nbits, T) | |
t2 = torch.randn(N, 1, T) | |
message = torch.randn(N, nbits) | |
wmloss = WMMbLoss(0.3, "mse") | |
loss = wmloss(positive, None, t2, message) | |
assert isinstance(loss, torch.Tensor) | |
def test_loudness_loss(): | |
sr = 16_000 | |
duration = 1.0 | |
wav = get_white_noise(1, int(sr * duration)).unsqueeze(0) | |
tflrloss = TFLoudnessRatio(sample_rate=sr, n_bands=1) | |
loss = tflrloss(wav, wav) | |
assert isinstance(loss, torch.Tensor) | |