Spaces:
Paused
Paused
import numpy as np | |
import torch | |
from TTS.vocoder.models.melgan_discriminator import MelganDiscriminator | |
from TTS.vocoder.models.melgan_multiscale_discriminator import MelganMultiscaleDiscriminator | |
def test_melgan_discriminator(): | |
model = MelganDiscriminator() | |
print(model) | |
dummy_input = torch.rand((4, 1, 256 * 10)) | |
output, _ = model(dummy_input) | |
assert np.all(output.shape == (4, 1, 10)) | |
def test_melgan_multi_scale_discriminator(): | |
model = MelganMultiscaleDiscriminator() | |
print(model) | |
dummy_input = torch.rand((4, 1, 256 * 16)) | |
scores, feats = model(dummy_input) | |
assert len(scores) == 3 | |
assert len(scores) == len(feats) | |
assert np.all(scores[0].shape == (4, 1, 64)) | |
assert np.all(feats[0][0].shape == (4, 16, 4096)) | |
assert np.all(feats[0][1].shape == (4, 64, 1024)) | |
assert np.all(feats[0][2].shape == (4, 256, 256)) | |