Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,543 Bytes
9d0d223 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import random
import torch
from audiocraft.losses import (
MelSpectrogramL1Loss,
MultiScaleMelSpectrogramLoss,
MRSTFTLoss,
SISNR,
STFTLoss,
)
from audiocraft.losses.loudnessloss import TFLoudnessRatio
from audiocraft.losses.wmloss import WMMbLoss
from tests.common_utils.wav_utils import get_white_noise
def test_mel_l1_loss():
N, C, T = 2, 2, random.randrange(1000, 100_000)
t1 = torch.randn(N, C, T)
t2 = torch.randn(N, C, T)
mel_l1 = MelSpectrogramL1Loss(sample_rate=22_050)
loss = mel_l1(t1, t2)
loss_same = mel_l1(t1, t1)
assert isinstance(loss, torch.Tensor)
assert isinstance(loss_same, torch.Tensor)
assert loss_same.item() == 0.0
def test_msspec_loss():
N, C, T = 2, 2, random.randrange(1000, 100_000)
t1 = torch.randn(N, C, T)
t2 = torch.randn(N, C, T)
msspec = MultiScaleMelSpectrogramLoss(sample_rate=22_050)
loss = msspec(t1, t2)
loss_same = msspec(t1, t1)
assert isinstance(loss, torch.Tensor)
assert isinstance(loss_same, torch.Tensor)
assert loss_same.item() == 0.0
def test_mrstft_loss():
N, C, T = 2, 2, random.randrange(1000, 100_000)
t1 = torch.randn(N, C, T)
t2 = torch.randn(N, C, T)
mrstft = MRSTFTLoss()
loss = mrstft(t1, t2)
assert isinstance(loss, torch.Tensor)
def test_sisnr_loss():
N, C, T = 2, 2, random.randrange(1000, 100_000)
t1 = torch.randn(N, C, T)
t2 = torch.randn(N, C, T)
sisnr = SISNR()
loss = sisnr(t1, t2)
assert isinstance(loss, torch.Tensor)
def test_stft_loss():
N, C, T = 2, 2, random.randrange(1000, 100_000)
t1 = torch.randn(N, C, T)
t2 = torch.randn(N, C, T)
mrstft = STFTLoss()
loss = mrstft(t1, t2)
assert isinstance(loss, torch.Tensor)
def test_wm_loss():
N, nbits, T = 2, 16, random.randrange(1000, 100_000)
positive = torch.randn(N, 2 + nbits, T)
t2 = torch.randn(N, 1, T)
message = torch.randn(N, nbits)
wmloss = WMMbLoss(0.3, "mse")
loss = wmloss(positive, None, t2, message)
assert isinstance(loss, torch.Tensor)
def test_loudness_loss():
sr = 16_000
duration = 1.0
wav = get_white_noise(1, int(sr * duration)).unsqueeze(0)
tflrloss = TFLoudnessRatio(sample_rate=sr, n_bands=1)
loss = tflrloss(wav, wav)
assert isinstance(loss, torch.Tensor)
|