Beatpoettes1 / tests /models /test_encodec_model.py
MantraDas's picture
Duplicate from facebook/MusicGen
f54eb92
raw
history blame
2.33 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import random
import numpy as np
import torch
from audiocraft.models import EncodecModel
from audiocraft.modules import SEANetEncoder, SEANetDecoder
from audiocraft.quantization import DummyQuantizer
class TestEncodecModel:
def _create_encodec_model(self,
sample_rate: int,
channels: int,
dim: int = 5,
n_filters: int = 3,
n_residual_layers: int = 1,
ratios: list = [5, 4, 3, 2],
**kwargs):
frame_rate = np.prod(ratios)
encoder = SEANetEncoder(channels=channels, dimension=dim, n_filters=n_filters,
n_residual_layers=n_residual_layers, ratios=ratios)
decoder = SEANetDecoder(channels=channels, dimension=dim, n_filters=n_filters,
n_residual_layers=n_residual_layers, ratios=ratios)
quantizer = DummyQuantizer()
model = EncodecModel(encoder, decoder, quantizer, frame_rate=frame_rate,
sample_rate=sample_rate, channels=channels, **kwargs)
return model
def test_model(self):
random.seed(1234)
sample_rate = 24_000
channels = 1
model = self._create_encodec_model(sample_rate, channels)
for _ in range(10):
length = random.randrange(1, 10_000)
x = torch.randn(2, channels, length)
res = model(x)
assert res.x.shape == x.shape
def test_model_renorm(self):
random.seed(1234)
sample_rate = 24_000
channels = 1
model_nonorm = self._create_encodec_model(sample_rate, channels, renormalize=False)
model_renorm = self._create_encodec_model(sample_rate, channels, renormalize=True)
for _ in range(10):
length = random.randrange(1, 10_000)
x = torch.randn(2, channels, length)
codes, scales = model_nonorm.encode(x)
codes, scales = model_renorm.encode(x)
assert scales is not None