Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import random | |
import torch | |
from audiocraft.losses import ( | |
MelSpectrogramL1Loss, | |
MultiScaleMelSpectrogramLoss, | |
MRSTFTLoss, | |
SISNR, | |
STFTLoss, | |
) | |
def test_mel_l1_loss(): | |
N, C, T = 2, 2, random.randrange(1000, 100_000) | |
t1 = torch.randn(N, C, T) | |
t2 = torch.randn(N, C, T) | |
mel_l1 = MelSpectrogramL1Loss(sample_rate=22_050) | |
loss = mel_l1(t1, t2) | |
loss_same = mel_l1(t1, t1) | |
assert isinstance(loss, torch.Tensor) | |
assert isinstance(loss_same, torch.Tensor) | |
assert loss_same.item() == 0.0 | |
def test_msspec_loss(): | |
N, C, T = 2, 2, random.randrange(1000, 100_000) | |
t1 = torch.randn(N, C, T) | |
t2 = torch.randn(N, C, T) | |
msspec = MultiScaleMelSpectrogramLoss(sample_rate=22_050) | |
loss = msspec(t1, t2) | |
loss_same = msspec(t1, t1) | |
assert isinstance(loss, torch.Tensor) | |
assert isinstance(loss_same, torch.Tensor) | |
assert loss_same.item() == 0.0 | |
def test_mrstft_loss(): | |
N, C, T = 2, 2, random.randrange(1000, 100_000) | |
t1 = torch.randn(N, C, T) | |
t2 = torch.randn(N, C, T) | |
mrstft = MRSTFTLoss() | |
loss = mrstft(t1, t2) | |
assert isinstance(loss, torch.Tensor) | |
def test_sisnr_loss(): | |
N, C, T = 2, 2, random.randrange(1000, 100_000) | |
t1 = torch.randn(N, C, T) | |
t2 = torch.randn(N, C, T) | |
sisnr = SISNR() | |
loss = sisnr(t1, t2) | |
assert isinstance(loss, torch.Tensor) | |
def test_stft_loss(): | |
N, C, T = 2, 2, random.randrange(1000, 100_000) | |
t1 = torch.randn(N, C, T) | |
t2 = torch.randn(N, C, T) | |
mrstft = STFTLoss() | |
loss = mrstft(t1, t2) | |
assert isinstance(loss, torch.Tensor) | |