Delete tests
Browse files- tests/__init__.py +0 -5
- tests/common_utils/__init__.py +0 -9
- tests/common_utils/temp_utils.py +0 -56
- tests/common_utils/wav_utils.py +0 -32
- tests/data/__init__.py +0 -5
- tests/data/test_audio.py +0 -239
- tests/data/test_audio_dataset.py +0 -352
- tests/data/test_audio_utils.py +0 -110
- tests/models/test_encodec_model.py +0 -60
- tests/models/test_musicgen.py +0 -58
- tests/modules/__init__.py +0 -5
- tests/modules/test_codebooks_patterns.py +0 -246
- tests/modules/test_conv.py +0 -203
- tests/modules/test_lstm.py +0 -32
- tests/modules/test_rope.py +0 -168
- tests/modules/test_seanet.py +0 -115
- tests/modules/test_transformer.py +0 -253
- tests/quantization/test_vq.py +0 -18
- tests/utils/__init__.py +0 -5
tests/__init__.py
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
|
|
tests/common_utils/__init__.py
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
# flake8: noqa
|
8 |
-
from .temp_utils import TempDirMixin
|
9 |
-
from .wav_utils import get_batch_white_noise, get_white_noise, save_wav
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/common_utils/temp_utils.py
DELETED
@@ -1,56 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import os
|
8 |
-
import tempfile
|
9 |
-
|
10 |
-
|
11 |
-
class TempDirMixin:
|
12 |
-
"""Mixin to provide easy access to temp dir.
|
13 |
-
"""
|
14 |
-
|
15 |
-
temp_dir_ = None
|
16 |
-
|
17 |
-
@classmethod
|
18 |
-
def get_base_temp_dir(cls):
|
19 |
-
# If AUDIOCRAFT_TEST_DIR is set, use it instead of temporary directory.
|
20 |
-
# this is handy for debugging.
|
21 |
-
key = "AUDIOCRAFT_TEST_DIR"
|
22 |
-
if key in os.environ:
|
23 |
-
return os.environ[key]
|
24 |
-
if cls.temp_dir_ is None:
|
25 |
-
cls.temp_dir_ = tempfile.TemporaryDirectory()
|
26 |
-
return cls.temp_dir_.name
|
27 |
-
|
28 |
-
@classmethod
|
29 |
-
def tearDownClass(cls):
|
30 |
-
if cls.temp_dir_ is not None:
|
31 |
-
try:
|
32 |
-
cls.temp_dir_.cleanup()
|
33 |
-
cls.temp_dir_ = None
|
34 |
-
except PermissionError:
|
35 |
-
# On Windows there is a know issue with `shutil.rmtree`,
|
36 |
-
# which fails intermittenly.
|
37 |
-
# https://github.com/python/cpython/issues/74168
|
38 |
-
# Following the above thread, we ignore it.
|
39 |
-
pass
|
40 |
-
super().tearDownClass()
|
41 |
-
|
42 |
-
@property
|
43 |
-
def id(self):
|
44 |
-
return self.__class__.__name__
|
45 |
-
|
46 |
-
def get_temp_path(self, *paths):
|
47 |
-
temp_dir = os.path.join(self.get_base_temp_dir(), self.id)
|
48 |
-
path = os.path.join(temp_dir, *paths)
|
49 |
-
os.makedirs(os.path.dirname(path), exist_ok=True)
|
50 |
-
return path
|
51 |
-
|
52 |
-
def get_temp_dir(self, *paths):
|
53 |
-
temp_dir = os.path.join(self.get_base_temp_dir(), self.id)
|
54 |
-
path = os.path.join(temp_dir, *paths)
|
55 |
-
os.makedirs(path, exist_ok=True)
|
56 |
-
return path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/common_utils/wav_utils.py
DELETED
@@ -1,32 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
from pathlib import Path
|
8 |
-
import typing as tp
|
9 |
-
|
10 |
-
import torch
|
11 |
-
import torchaudio
|
12 |
-
|
13 |
-
|
14 |
-
def get_white_noise(chs: int = 1, num_frames: int = 1):
|
15 |
-
wav = torch.randn(chs, num_frames)
|
16 |
-
return wav
|
17 |
-
|
18 |
-
|
19 |
-
def get_batch_white_noise(bs: int = 1, chs: int = 1, num_frames: int = 1):
|
20 |
-
wav = torch.randn(bs, chs, num_frames)
|
21 |
-
return wav
|
22 |
-
|
23 |
-
|
24 |
-
def save_wav(path: str, wav: torch.Tensor, sample_rate: int):
|
25 |
-
fp = Path(path)
|
26 |
-
kwargs: tp.Dict[str, tp.Any] = {}
|
27 |
-
if fp.suffix == '.wav':
|
28 |
-
kwargs['encoding'] = 'PCM_S'
|
29 |
-
kwargs['bits_per_sample'] = 16
|
30 |
-
elif fp.suffix == '.mp3':
|
31 |
-
kwargs['compression'] = 320
|
32 |
-
torchaudio.save(str(fp), wav, sample_rate, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/data/__init__.py
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
|
|
tests/data/test_audio.py
DELETED
@@ -1,239 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
from itertools import product
|
8 |
-
import random
|
9 |
-
|
10 |
-
import numpy as np
|
11 |
-
import torch
|
12 |
-
import torchaudio
|
13 |
-
|
14 |
-
from audiocraft.data.audio import audio_info, audio_read, audio_write, _av_read
|
15 |
-
|
16 |
-
from ..common_utils import TempDirMixin, get_white_noise, save_wav
|
17 |
-
|
18 |
-
|
19 |
-
class TestInfo(TempDirMixin):
|
20 |
-
|
21 |
-
def test_info_mp3(self):
|
22 |
-
sample_rates = [8000, 16_000]
|
23 |
-
channels = [1, 2]
|
24 |
-
duration = 1.
|
25 |
-
for sample_rate, ch in product(sample_rates, channels):
|
26 |
-
wav = get_white_noise(ch, int(sample_rate * duration))
|
27 |
-
path = self.get_temp_path('sample_wav.mp3')
|
28 |
-
save_wav(path, wav, sample_rate)
|
29 |
-
info = audio_info(path)
|
30 |
-
assert info.sample_rate == sample_rate
|
31 |
-
assert info.channels == ch
|
32 |
-
# we cannot trust torchaudio for num_frames, so we don't check
|
33 |
-
|
34 |
-
def _test_info_format(self, ext: str):
|
35 |
-
sample_rates = [8000, 16_000]
|
36 |
-
channels = [1, 2]
|
37 |
-
duration = 1.
|
38 |
-
for sample_rate, ch in product(sample_rates, channels):
|
39 |
-
n_frames = int(sample_rate * duration)
|
40 |
-
wav = get_white_noise(ch, n_frames)
|
41 |
-
path = self.get_temp_path(f'sample_wav{ext}')
|
42 |
-
save_wav(path, wav, sample_rate)
|
43 |
-
info = audio_info(path)
|
44 |
-
assert info.sample_rate == sample_rate
|
45 |
-
assert info.channels == ch
|
46 |
-
assert np.isclose(info.duration, duration, atol=1e-5)
|
47 |
-
|
48 |
-
def test_info_wav(self):
|
49 |
-
self._test_info_format('.wav')
|
50 |
-
|
51 |
-
def test_info_flac(self):
|
52 |
-
self._test_info_format('.flac')
|
53 |
-
|
54 |
-
def test_info_ogg(self):
|
55 |
-
self._test_info_format('.ogg')
|
56 |
-
|
57 |
-
def test_info_m4a(self):
|
58 |
-
# TODO: generate m4a file programmatically
|
59 |
-
# self._test_info_format('.m4a')
|
60 |
-
pass
|
61 |
-
|
62 |
-
|
63 |
-
class TestRead(TempDirMixin):
|
64 |
-
|
65 |
-
def test_read_full_wav(self):
|
66 |
-
sample_rates = [8000, 16_000]
|
67 |
-
channels = [1, 2]
|
68 |
-
duration = 1.
|
69 |
-
for sample_rate, ch in product(sample_rates, channels):
|
70 |
-
n_frames = int(sample_rate * duration)
|
71 |
-
wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99)
|
72 |
-
path = self.get_temp_path('sample_wav.wav')
|
73 |
-
save_wav(path, wav, sample_rate)
|
74 |
-
read_wav, read_sr = audio_read(path)
|
75 |
-
assert read_sr == sample_rate
|
76 |
-
assert read_wav.shape[0] == wav.shape[0]
|
77 |
-
assert read_wav.shape[1] == wav.shape[1]
|
78 |
-
assert torch.allclose(read_wav, wav, rtol=1e-03, atol=1e-04)
|
79 |
-
|
80 |
-
def test_read_partial_wav(self):
|
81 |
-
sample_rates = [8000, 16_000]
|
82 |
-
channels = [1, 2]
|
83 |
-
duration = 1.
|
84 |
-
read_duration = torch.rand(1).item()
|
85 |
-
for sample_rate, ch in product(sample_rates, channels):
|
86 |
-
n_frames = int(sample_rate * duration)
|
87 |
-
read_frames = int(sample_rate * read_duration)
|
88 |
-
wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99)
|
89 |
-
path = self.get_temp_path('sample_wav.wav')
|
90 |
-
save_wav(path, wav, sample_rate)
|
91 |
-
read_wav, read_sr = audio_read(path, 0, read_duration)
|
92 |
-
assert read_sr == sample_rate
|
93 |
-
assert read_wav.shape[0] == wav.shape[0]
|
94 |
-
assert read_wav.shape[1] == read_frames
|
95 |
-
assert torch.allclose(read_wav[..., 0:read_frames], wav[..., 0:read_frames], rtol=1e-03, atol=1e-04)
|
96 |
-
|
97 |
-
def test_read_seek_time_wav(self):
|
98 |
-
sample_rates = [8000, 16_000]
|
99 |
-
channels = [1, 2]
|
100 |
-
duration = 1.
|
101 |
-
read_duration = 1.
|
102 |
-
for sample_rate, ch in product(sample_rates, channels):
|
103 |
-
n_frames = int(sample_rate * duration)
|
104 |
-
wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99)
|
105 |
-
path = self.get_temp_path('sample_wav.wav')
|
106 |
-
save_wav(path, wav, sample_rate)
|
107 |
-
seek_time = torch.rand(1).item()
|
108 |
-
read_wav, read_sr = audio_read(path, seek_time, read_duration)
|
109 |
-
seek_frames = int(sample_rate * seek_time)
|
110 |
-
expected_frames = n_frames - seek_frames
|
111 |
-
assert read_sr == sample_rate
|
112 |
-
assert read_wav.shape[0] == wav.shape[0]
|
113 |
-
assert read_wav.shape[1] == expected_frames
|
114 |
-
assert torch.allclose(read_wav, wav[..., seek_frames:], rtol=1e-03, atol=1e-04)
|
115 |
-
|
116 |
-
def test_read_seek_time_wav_padded(self):
|
117 |
-
sample_rates = [8000, 16_000]
|
118 |
-
channels = [1, 2]
|
119 |
-
duration = 1.
|
120 |
-
read_duration = 1.
|
121 |
-
for sample_rate, ch in product(sample_rates, channels):
|
122 |
-
n_frames = int(sample_rate * duration)
|
123 |
-
read_frames = int(sample_rate * read_duration)
|
124 |
-
wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99)
|
125 |
-
path = self.get_temp_path('sample_wav.wav')
|
126 |
-
save_wav(path, wav, sample_rate)
|
127 |
-
seek_time = torch.rand(1).item()
|
128 |
-
seek_frames = int(sample_rate * seek_time)
|
129 |
-
expected_frames = n_frames - seek_frames
|
130 |
-
read_wav, read_sr = audio_read(path, seek_time, read_duration, pad=True)
|
131 |
-
expected_pad_wav = torch.zeros(wav.shape[0], read_frames - expected_frames)
|
132 |
-
assert read_sr == sample_rate
|
133 |
-
assert read_wav.shape[0] == wav.shape[0]
|
134 |
-
assert read_wav.shape[1] == read_frames
|
135 |
-
assert torch.allclose(read_wav[..., :expected_frames], wav[..., seek_frames:], rtol=1e-03, atol=1e-04)
|
136 |
-
assert torch.allclose(read_wav[..., expected_frames:], expected_pad_wav)
|
137 |
-
|
138 |
-
|
139 |
-
class TestAvRead(TempDirMixin):
|
140 |
-
|
141 |
-
def test_avread_seek_base(self):
|
142 |
-
sample_rates = [8000, 16_000]
|
143 |
-
channels = [1, 2]
|
144 |
-
duration = 2.
|
145 |
-
for sample_rate, ch in product(sample_rates, channels):
|
146 |
-
n_frames = int(sample_rate * duration)
|
147 |
-
wav = get_white_noise(ch, n_frames)
|
148 |
-
path = self.get_temp_path(f'reference_a_{sample_rate}_{ch}.wav')
|
149 |
-
save_wav(path, wav, sample_rate)
|
150 |
-
for _ in range(100):
|
151 |
-
# seek will always load a full duration segment in the file
|
152 |
-
seek_time = random.uniform(0.0, 1.0)
|
153 |
-
seek_duration = random.uniform(0.001, 1.0)
|
154 |
-
read_wav, read_sr = _av_read(path, seek_time, seek_duration)
|
155 |
-
assert read_sr == sample_rate
|
156 |
-
assert read_wav.shape[0] == wav.shape[0]
|
157 |
-
assert read_wav.shape[-1] == int(seek_duration * sample_rate)
|
158 |
-
|
159 |
-
def test_avread_seek_partial(self):
|
160 |
-
sample_rates = [8000, 16_000]
|
161 |
-
channels = [1, 2]
|
162 |
-
duration = 1.
|
163 |
-
for sample_rate, ch in product(sample_rates, channels):
|
164 |
-
n_frames = int(sample_rate * duration)
|
165 |
-
wav = get_white_noise(ch, n_frames)
|
166 |
-
path = self.get_temp_path(f'reference_b_{sample_rate}_{ch}.wav')
|
167 |
-
save_wav(path, wav, sample_rate)
|
168 |
-
for _ in range(100):
|
169 |
-
# seek will always load a partial segment
|
170 |
-
seek_time = random.uniform(0.5, 1.)
|
171 |
-
seek_duration = 1.
|
172 |
-
expected_num_frames = n_frames - int(seek_time * sample_rate)
|
173 |
-
read_wav, read_sr = _av_read(path, seek_time, seek_duration)
|
174 |
-
assert read_sr == sample_rate
|
175 |
-
assert read_wav.shape[0] == wav.shape[0]
|
176 |
-
assert read_wav.shape[-1] == expected_num_frames
|
177 |
-
|
178 |
-
def test_avread_seek_outofbound(self):
|
179 |
-
sample_rates = [8000, 16_000]
|
180 |
-
channels = [1, 2]
|
181 |
-
duration = 1.
|
182 |
-
for sample_rate, ch in product(sample_rates, channels):
|
183 |
-
n_frames = int(sample_rate * duration)
|
184 |
-
wav = get_white_noise(ch, n_frames)
|
185 |
-
path = self.get_temp_path(f'reference_c_{sample_rate}_{ch}.wav')
|
186 |
-
save_wav(path, wav, sample_rate)
|
187 |
-
seek_time = 1.5
|
188 |
-
read_wav, read_sr = _av_read(path, seek_time, 1.)
|
189 |
-
assert read_sr == sample_rate
|
190 |
-
assert read_wav.shape[0] == wav.shape[0]
|
191 |
-
assert read_wav.shape[-1] == 0
|
192 |
-
|
193 |
-
def test_avread_seek_edge(self):
|
194 |
-
sample_rates = [8000, 16_000]
|
195 |
-
# some of these values will have
|
196 |
-
# int(((frames - 1) / sample_rate) * sample_rate) != (frames - 1)
|
197 |
-
n_frames = [1000, 1001, 1002]
|
198 |
-
channels = [1, 2]
|
199 |
-
for sample_rate, ch, frames in product(sample_rates, channels, n_frames):
|
200 |
-
duration = frames / sample_rate
|
201 |
-
wav = get_white_noise(ch, frames)
|
202 |
-
path = self.get_temp_path(f'reference_d_{sample_rate}_{ch}.wav')
|
203 |
-
save_wav(path, wav, sample_rate)
|
204 |
-
seek_time = (frames - 1) / sample_rate
|
205 |
-
seek_frames = int(seek_time * sample_rate)
|
206 |
-
read_wav, read_sr = _av_read(path, seek_time, duration)
|
207 |
-
assert read_sr == sample_rate
|
208 |
-
assert read_wav.shape[0] == wav.shape[0]
|
209 |
-
assert read_wav.shape[-1] == (frames - seek_frames)
|
210 |
-
|
211 |
-
|
212 |
-
class TestAudioWrite(TempDirMixin):
|
213 |
-
|
214 |
-
def test_audio_write_wav(self):
|
215 |
-
torch.manual_seed(1234)
|
216 |
-
sample_rates = [8000, 16_000]
|
217 |
-
n_frames = [1000, 1001, 1002]
|
218 |
-
channels = [1, 2]
|
219 |
-
strategies = ["peak", "clip", "rms"]
|
220 |
-
formats = ["wav", "mp3"]
|
221 |
-
for sample_rate, ch, frames in product(sample_rates, channels, n_frames):
|
222 |
-
for format_, strategy in product(formats, strategies):
|
223 |
-
wav = get_white_noise(ch, frames)
|
224 |
-
path = self.get_temp_path(f'pred_{sample_rate}_{ch}')
|
225 |
-
audio_write(path, wav, sample_rate, format_, strategy=strategy)
|
226 |
-
read_wav, read_sr = torchaudio.load(f'{path}.{format_}')
|
227 |
-
if format_ == "wav":
|
228 |
-
assert read_wav.shape == wav.shape
|
229 |
-
|
230 |
-
if format_ == "wav" and strategy in ["peak", "rms"]:
|
231 |
-
rescaled_read_wav = read_wav / read_wav.abs().max() * wav.abs().max()
|
232 |
-
# for a Gaussian, the typical max scale will be less than ~5x the std.
|
233 |
-
# The error when writing to disk will ~ 1/2**15, and when rescaling, 5x that.
|
234 |
-
# For RMS target, rescaling leaves more headroom by default, leading
|
235 |
-
# to a 20x rescaling typically
|
236 |
-
atol = (5 if strategy == "peak" else 20) / 2**15
|
237 |
-
delta = (rescaled_read_wav - wav).abs().max()
|
238 |
-
assert torch.allclose(wav, rescaled_read_wav, rtol=0, atol=atol), (delta, atol)
|
239 |
-
formats = ["wav"] # faster unit tests
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/data/test_audio_dataset.py
DELETED
@@ -1,352 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
from functools import partial
|
8 |
-
from itertools import product
|
9 |
-
import json
|
10 |
-
import math
|
11 |
-
import os
|
12 |
-
import random
|
13 |
-
import typing as tp
|
14 |
-
|
15 |
-
import pytest
|
16 |
-
import torch
|
17 |
-
from torch.utils.data import DataLoader
|
18 |
-
|
19 |
-
from audiocraft.data.audio_dataset import (
|
20 |
-
AudioDataset,
|
21 |
-
AudioMeta,
|
22 |
-
_get_audio_meta,
|
23 |
-
load_audio_meta,
|
24 |
-
save_audio_meta
|
25 |
-
)
|
26 |
-
from audiocraft.data.zip import PathInZip
|
27 |
-
|
28 |
-
from ..common_utils import TempDirMixin, get_white_noise, save_wav
|
29 |
-
|
30 |
-
|
31 |
-
class TestAudioMeta(TempDirMixin):
|
32 |
-
|
33 |
-
def test_get_audio_meta(self):
|
34 |
-
sample_rates = [8000, 16_000]
|
35 |
-
channels = [1, 2]
|
36 |
-
duration = 1.
|
37 |
-
for sample_rate, ch in product(sample_rates, channels):
|
38 |
-
n_frames = int(duration * sample_rate)
|
39 |
-
wav = get_white_noise(ch, n_frames)
|
40 |
-
path = self.get_temp_path('sample.wav')
|
41 |
-
save_wav(path, wav, sample_rate)
|
42 |
-
m = _get_audio_meta(path, minimal=True)
|
43 |
-
assert m.path == path, 'path does not match'
|
44 |
-
assert m.sample_rate == sample_rate, 'sample rate does not match'
|
45 |
-
assert m.duration == duration, 'duration does not match'
|
46 |
-
assert m.amplitude is None
|
47 |
-
assert m.info_path is None
|
48 |
-
|
49 |
-
def test_save_audio_meta(self):
|
50 |
-
audio_meta = [
|
51 |
-
AudioMeta("mypath1", 1., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file1.json')),
|
52 |
-
AudioMeta("mypath2", 2., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file2.json'))
|
53 |
-
]
|
54 |
-
empty_audio_meta = []
|
55 |
-
for idx, meta in enumerate([audio_meta, empty_audio_meta]):
|
56 |
-
path = self.get_temp_path(f'data_{idx}_save.jsonl')
|
57 |
-
save_audio_meta(path, meta)
|
58 |
-
with open(path, 'r') as f:
|
59 |
-
lines = f.readlines()
|
60 |
-
read_meta = [AudioMeta.from_dict(json.loads(line)) for line in lines]
|
61 |
-
assert len(read_meta) == len(meta)
|
62 |
-
for m, read_m in zip(meta, read_meta):
|
63 |
-
assert m == read_m
|
64 |
-
|
65 |
-
def test_load_audio_meta(self):
|
66 |
-
try:
|
67 |
-
import dora
|
68 |
-
except ImportError:
|
69 |
-
dora = None # type: ignore
|
70 |
-
|
71 |
-
audio_meta = [
|
72 |
-
AudioMeta("mypath1", 1., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file1.json')),
|
73 |
-
AudioMeta("mypath2", 2., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file2.json'))
|
74 |
-
]
|
75 |
-
empty_meta = []
|
76 |
-
for idx, meta in enumerate([audio_meta, empty_meta]):
|
77 |
-
path = self.get_temp_path(f'data_{idx}_load.jsonl')
|
78 |
-
with open(path, 'w') as f:
|
79 |
-
for m in meta:
|
80 |
-
json_str = json.dumps(m.to_dict()) + '\n'
|
81 |
-
f.write(json_str)
|
82 |
-
read_meta = load_audio_meta(path)
|
83 |
-
assert len(read_meta) == len(meta)
|
84 |
-
for m, read_m in zip(meta, read_meta):
|
85 |
-
if dora:
|
86 |
-
m.path = dora.git_save.to_absolute_path(m.path)
|
87 |
-
assert m == read_m, f'original={m}, read={read_m}'
|
88 |
-
|
89 |
-
|
90 |
-
class TestAudioDataset(TempDirMixin):
|
91 |
-
|
92 |
-
def _create_audio_files(self,
|
93 |
-
root_name: str,
|
94 |
-
num_examples: int,
|
95 |
-
durations: tp.Union[float, tp.Tuple[float, float]] = (0.1, 1.),
|
96 |
-
sample_rate: int = 16_000,
|
97 |
-
channels: int = 1):
|
98 |
-
root_dir = self.get_temp_dir(root_name)
|
99 |
-
for i in range(num_examples):
|
100 |
-
if isinstance(durations, float):
|
101 |
-
duration = durations
|
102 |
-
elif isinstance(durations, tuple) and len(durations) == 1:
|
103 |
-
duration = durations[0]
|
104 |
-
elif isinstance(durations, tuple) and len(durations) == 2:
|
105 |
-
duration = random.uniform(durations[0], durations[1])
|
106 |
-
else:
|
107 |
-
assert False
|
108 |
-
n_frames = int(duration * sample_rate)
|
109 |
-
wav = get_white_noise(channels, n_frames)
|
110 |
-
path = os.path.join(root_dir, f'example_{i}.wav')
|
111 |
-
save_wav(path, wav, sample_rate)
|
112 |
-
return root_dir
|
113 |
-
|
114 |
-
def _create_audio_dataset(self,
|
115 |
-
root_name: str,
|
116 |
-
total_num_examples: int,
|
117 |
-
durations: tp.Union[float, tp.Tuple[float, float]] = (0.1, 1.),
|
118 |
-
sample_rate: int = 16_000,
|
119 |
-
channels: int = 1,
|
120 |
-
segment_duration: tp.Optional[float] = None,
|
121 |
-
num_examples: int = 10,
|
122 |
-
shuffle: bool = True,
|
123 |
-
return_info: bool = False):
|
124 |
-
root_dir = self._create_audio_files(root_name, total_num_examples, durations, sample_rate, channels)
|
125 |
-
dataset = AudioDataset.from_path(root_dir,
|
126 |
-
minimal_meta=True,
|
127 |
-
segment_duration=segment_duration,
|
128 |
-
num_samples=num_examples,
|
129 |
-
sample_rate=sample_rate,
|
130 |
-
channels=channels,
|
131 |
-
shuffle=shuffle,
|
132 |
-
return_info=return_info)
|
133 |
-
return dataset
|
134 |
-
|
135 |
-
def test_dataset_full(self):
|
136 |
-
total_examples = 10
|
137 |
-
min_duration, max_duration = 1., 4.
|
138 |
-
sample_rate = 16_000
|
139 |
-
channels = 1
|
140 |
-
dataset = self._create_audio_dataset(
|
141 |
-
'dset', total_examples, durations=(min_duration, max_duration),
|
142 |
-
sample_rate=sample_rate, channels=channels, segment_duration=None)
|
143 |
-
assert len(dataset) == total_examples
|
144 |
-
assert dataset.sample_rate == sample_rate
|
145 |
-
assert dataset.channels == channels
|
146 |
-
for idx in range(len(dataset)):
|
147 |
-
sample = dataset[idx]
|
148 |
-
assert sample.shape[0] == channels
|
149 |
-
assert sample.shape[1] <= int(max_duration * sample_rate)
|
150 |
-
assert sample.shape[1] >= int(min_duration * sample_rate)
|
151 |
-
|
152 |
-
def test_dataset_segment(self):
|
153 |
-
total_examples = 10
|
154 |
-
num_samples = 20
|
155 |
-
min_duration, max_duration = 1., 4.
|
156 |
-
segment_duration = 1.
|
157 |
-
sample_rate = 16_000
|
158 |
-
channels = 1
|
159 |
-
dataset = self._create_audio_dataset(
|
160 |
-
'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
|
161 |
-
channels=channels, segment_duration=segment_duration, num_examples=num_samples)
|
162 |
-
assert len(dataset) == num_samples
|
163 |
-
assert dataset.sample_rate == sample_rate
|
164 |
-
assert dataset.channels == channels
|
165 |
-
for idx in range(len(dataset)):
|
166 |
-
sample = dataset[idx]
|
167 |
-
assert sample.shape[0] == channels
|
168 |
-
assert sample.shape[1] == int(segment_duration * sample_rate)
|
169 |
-
|
170 |
-
def test_dataset_equal_audio_and_segment_durations(self):
|
171 |
-
total_examples = 1
|
172 |
-
num_samples = 2
|
173 |
-
audio_duration = 1.
|
174 |
-
segment_duration = 1.
|
175 |
-
sample_rate = 16_000
|
176 |
-
channels = 1
|
177 |
-
dataset = self._create_audio_dataset(
|
178 |
-
'dset', total_examples, durations=audio_duration, sample_rate=sample_rate,
|
179 |
-
channels=channels, segment_duration=segment_duration, num_examples=num_samples)
|
180 |
-
assert len(dataset) == num_samples
|
181 |
-
assert dataset.sample_rate == sample_rate
|
182 |
-
assert dataset.channels == channels
|
183 |
-
for idx in range(len(dataset)):
|
184 |
-
sample = dataset[idx]
|
185 |
-
assert sample.shape[0] == channels
|
186 |
-
assert sample.shape[1] == int(segment_duration * sample_rate)
|
187 |
-
# the random seek_time adds variability on audio read
|
188 |
-
sample_1 = dataset[0]
|
189 |
-
sample_2 = dataset[1]
|
190 |
-
assert not torch.allclose(sample_1, sample_2)
|
191 |
-
|
192 |
-
def test_dataset_samples(self):
|
193 |
-
total_examples = 1
|
194 |
-
num_samples = 2
|
195 |
-
audio_duration = 1.
|
196 |
-
segment_duration = 1.
|
197 |
-
sample_rate = 16_000
|
198 |
-
channels = 1
|
199 |
-
|
200 |
-
create_dataset = partial(
|
201 |
-
self._create_audio_dataset,
|
202 |
-
'dset', total_examples, durations=audio_duration, sample_rate=sample_rate,
|
203 |
-
channels=channels, segment_duration=segment_duration, num_examples=num_samples,
|
204 |
-
)
|
205 |
-
|
206 |
-
dataset = create_dataset(shuffle=True)
|
207 |
-
# when shuffle = True, we have different inputs for the same index across epoch
|
208 |
-
sample_1 = dataset[0]
|
209 |
-
sample_2 = dataset[0]
|
210 |
-
assert not torch.allclose(sample_1, sample_2)
|
211 |
-
|
212 |
-
dataset_noshuffle = create_dataset(shuffle=False)
|
213 |
-
# when shuffle = False, we have same inputs for the same index across epoch
|
214 |
-
sample_1 = dataset_noshuffle[0]
|
215 |
-
sample_2 = dataset_noshuffle[0]
|
216 |
-
assert torch.allclose(sample_1, sample_2)
|
217 |
-
|
218 |
-
def test_dataset_return_info(self):
|
219 |
-
total_examples = 10
|
220 |
-
num_samples = 20
|
221 |
-
min_duration, max_duration = 1., 4.
|
222 |
-
segment_duration = 1.
|
223 |
-
sample_rate = 16_000
|
224 |
-
channels = 1
|
225 |
-
dataset = self._create_audio_dataset(
|
226 |
-
'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
|
227 |
-
channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=True)
|
228 |
-
assert len(dataset) == num_samples
|
229 |
-
assert dataset.sample_rate == sample_rate
|
230 |
-
assert dataset.channels == channels
|
231 |
-
for idx in range(len(dataset)):
|
232 |
-
sample, segment_info = dataset[idx]
|
233 |
-
assert sample.shape[0] == channels
|
234 |
-
assert sample.shape[1] == int(segment_duration * sample_rate)
|
235 |
-
assert segment_info.sample_rate == sample_rate
|
236 |
-
assert segment_info.total_frames == int(segment_duration * sample_rate)
|
237 |
-
assert segment_info.n_frames <= int(segment_duration * sample_rate)
|
238 |
-
assert segment_info.seek_time >= 0
|
239 |
-
|
240 |
-
def test_dataset_return_info_no_segment_duration(self):
|
241 |
-
total_examples = 10
|
242 |
-
num_samples = 20
|
243 |
-
min_duration, max_duration = 1., 4.
|
244 |
-
segment_duration = None
|
245 |
-
sample_rate = 16_000
|
246 |
-
channels = 1
|
247 |
-
dataset = self._create_audio_dataset(
|
248 |
-
'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
|
249 |
-
channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=True)
|
250 |
-
assert len(dataset) == total_examples
|
251 |
-
assert dataset.sample_rate == sample_rate
|
252 |
-
assert dataset.channels == channels
|
253 |
-
for idx in range(len(dataset)):
|
254 |
-
sample, segment_info = dataset[idx]
|
255 |
-
assert sample.shape[0] == channels
|
256 |
-
assert sample.shape[1] == segment_info.total_frames
|
257 |
-
assert segment_info.sample_rate == sample_rate
|
258 |
-
assert segment_info.n_frames <= segment_info.total_frames
|
259 |
-
|
260 |
-
def test_dataset_collate_fn(self):
|
261 |
-
total_examples = 10
|
262 |
-
num_samples = 20
|
263 |
-
min_duration, max_duration = 1., 4.
|
264 |
-
segment_duration = 1.
|
265 |
-
sample_rate = 16_000
|
266 |
-
channels = 1
|
267 |
-
dataset = self._create_audio_dataset(
|
268 |
-
'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
|
269 |
-
channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=False)
|
270 |
-
batch_size = 4
|
271 |
-
dataloader = DataLoader(
|
272 |
-
dataset,
|
273 |
-
batch_size=batch_size,
|
274 |
-
num_workers=0
|
275 |
-
)
|
276 |
-
for idx, batch in enumerate(dataloader):
|
277 |
-
assert batch.shape[0] == batch_size
|
278 |
-
|
279 |
-
@pytest.mark.parametrize("segment_duration", [1.0, None])
|
280 |
-
def test_dataset_with_meta_collate_fn(self, segment_duration):
|
281 |
-
total_examples = 10
|
282 |
-
num_samples = 20
|
283 |
-
min_duration, max_duration = 1., 4.
|
284 |
-
segment_duration = 1.
|
285 |
-
sample_rate = 16_000
|
286 |
-
channels = 1
|
287 |
-
dataset = self._create_audio_dataset(
|
288 |
-
'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
|
289 |
-
channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=True)
|
290 |
-
batch_size = 4
|
291 |
-
dataloader = DataLoader(
|
292 |
-
dataset,
|
293 |
-
batch_size=batch_size,
|
294 |
-
collate_fn=dataset.collater,
|
295 |
-
num_workers=0
|
296 |
-
)
|
297 |
-
for idx, batch in enumerate(dataloader):
|
298 |
-
wav, infos = batch
|
299 |
-
assert wav.shape[0] == batch_size
|
300 |
-
assert len(infos) == batch_size
|
301 |
-
|
302 |
-
@pytest.mark.parametrize("segment_duration,sample_on_weight,sample_on_duration,a_hist,b_hist,c_hist", [
|
303 |
-
[1, True, True, 0.5, 0.5, 0.0],
|
304 |
-
[1, False, True, 0.25, 0.5, 0.25],
|
305 |
-
[1, True, False, 0.666, 0.333, 0.0],
|
306 |
-
[1, False, False, 0.333, 0.333, 0.333],
|
307 |
-
[None, False, False, 0.333, 0.333, 0.333]])
|
308 |
-
def test_sample_with_weight(self, segment_duration, sample_on_weight, sample_on_duration, a_hist, b_hist, c_hist):
|
309 |
-
random.seed(1234)
|
310 |
-
rng = torch.Generator()
|
311 |
-
rng.manual_seed(1234)
|
312 |
-
|
313 |
-
def _get_histogram(dataset, repetitions=20_000):
|
314 |
-
counts = {file_meta.path: 0. for file_meta in meta}
|
315 |
-
for _ in range(repetitions):
|
316 |
-
file_meta = dataset.sample_file(rng)
|
317 |
-
counts[file_meta.path] += 1
|
318 |
-
return {name: count / repetitions for name, count in counts.items()}
|
319 |
-
|
320 |
-
meta = [
|
321 |
-
AudioMeta(path='a', duration=5, sample_rate=1, weight=2),
|
322 |
-
AudioMeta(path='b', duration=10, sample_rate=1, weight=None),
|
323 |
-
AudioMeta(path='c', duration=5, sample_rate=1, weight=0),
|
324 |
-
]
|
325 |
-
dataset = AudioDataset(
|
326 |
-
meta, segment_duration=segment_duration, sample_on_weight=sample_on_weight,
|
327 |
-
sample_on_duration=sample_on_duration)
|
328 |
-
hist = _get_histogram(dataset)
|
329 |
-
assert math.isclose(hist['a'], a_hist, abs_tol=0.01)
|
330 |
-
assert math.isclose(hist['b'], b_hist, abs_tol=0.01)
|
331 |
-
assert math.isclose(hist['c'], c_hist, abs_tol=0.01)
|
332 |
-
|
333 |
-
def test_meta_duration_filter_all(self):
|
334 |
-
meta = [
|
335 |
-
AudioMeta(path='a', duration=5, sample_rate=1, weight=2),
|
336 |
-
AudioMeta(path='b', duration=10, sample_rate=1, weight=None),
|
337 |
-
AudioMeta(path='c', duration=5, sample_rate=1, weight=0),
|
338 |
-
]
|
339 |
-
try:
|
340 |
-
AudioDataset(meta, segment_duration=11, min_segment_ratio=1)
|
341 |
-
assert False
|
342 |
-
except AssertionError:
|
343 |
-
assert True
|
344 |
-
|
345 |
-
def test_meta_duration_filter_long(self):
|
346 |
-
meta = [
|
347 |
-
AudioMeta(path='a', duration=5, sample_rate=1, weight=2),
|
348 |
-
AudioMeta(path='b', duration=10, sample_rate=1, weight=None),
|
349 |
-
AudioMeta(path='c', duration=5, sample_rate=1, weight=0),
|
350 |
-
]
|
351 |
-
dataset = AudioDataset(meta, segment_duration=None, min_segment_ratio=1, max_audio_duration=7)
|
352 |
-
assert len(dataset) == 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/data/test_audio_utils.py
DELETED
@@ -1,110 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import julius
|
8 |
-
import torch
|
9 |
-
import pytest
|
10 |
-
|
11 |
-
from audiocraft.data.audio_utils import (
|
12 |
-
_clip_wav,
|
13 |
-
convert_audio_channels,
|
14 |
-
convert_audio,
|
15 |
-
normalize_audio
|
16 |
-
)
|
17 |
-
from ..common_utils import get_batch_white_noise
|
18 |
-
|
19 |
-
|
20 |
-
class TestConvertAudioChannels:
|
21 |
-
|
22 |
-
def test_convert_audio_channels_downmix(self):
|
23 |
-
b, c, t = 2, 3, 100
|
24 |
-
audio = get_batch_white_noise(b, c, t)
|
25 |
-
mixed = convert_audio_channels(audio, channels=2)
|
26 |
-
assert list(mixed.shape) == [b, 2, t]
|
27 |
-
|
28 |
-
def test_convert_audio_channels_nochange(self):
|
29 |
-
b, c, t = 2, 3, 100
|
30 |
-
audio = get_batch_white_noise(b, c, t)
|
31 |
-
mixed = convert_audio_channels(audio, channels=c)
|
32 |
-
assert list(mixed.shape) == list(audio.shape)
|
33 |
-
|
34 |
-
def test_convert_audio_channels_upmix(self):
|
35 |
-
b, c, t = 2, 1, 100
|
36 |
-
audio = get_batch_white_noise(b, c, t)
|
37 |
-
mixed = convert_audio_channels(audio, channels=3)
|
38 |
-
assert list(mixed.shape) == [b, 3, t]
|
39 |
-
|
40 |
-
def test_convert_audio_channels_upmix_error(self):
|
41 |
-
b, c, t = 2, 2, 100
|
42 |
-
audio = get_batch_white_noise(b, c, t)
|
43 |
-
with pytest.raises(ValueError):
|
44 |
-
convert_audio_channels(audio, channels=3)
|
45 |
-
|
46 |
-
|
47 |
-
class TestConvertAudio:
|
48 |
-
|
49 |
-
def test_convert_audio_channels_downmix(self):
|
50 |
-
b, c, dur = 2, 3, 4.
|
51 |
-
sr = 128
|
52 |
-
audio = get_batch_white_noise(b, c, int(sr * dur))
|
53 |
-
out = convert_audio(audio, from_rate=sr, to_rate=sr, to_channels=2)
|
54 |
-
assert list(out.shape) == [audio.shape[0], 2, audio.shape[-1]]
|
55 |
-
|
56 |
-
def test_convert_audio_channels_upmix(self):
|
57 |
-
b, c, dur = 2, 1, 4.
|
58 |
-
sr = 128
|
59 |
-
audio = get_batch_white_noise(b, c, int(sr * dur))
|
60 |
-
out = convert_audio(audio, from_rate=sr, to_rate=sr, to_channels=3)
|
61 |
-
assert list(out.shape) == [audio.shape[0], 3, audio.shape[-1]]
|
62 |
-
|
63 |
-
def test_convert_audio_upsample(self):
|
64 |
-
b, c, dur = 2, 1, 4.
|
65 |
-
sr = 2
|
66 |
-
new_sr = 3
|
67 |
-
audio = get_batch_white_noise(b, c, int(sr * dur))
|
68 |
-
out = convert_audio(audio, from_rate=sr, to_rate=new_sr, to_channels=c)
|
69 |
-
out_j = julius.resample.resample_frac(audio, old_sr=sr, new_sr=new_sr)
|
70 |
-
assert torch.allclose(out, out_j)
|
71 |
-
|
72 |
-
def test_convert_audio_resample(self):
|
73 |
-
b, c, dur = 2, 1, 4.
|
74 |
-
sr = 3
|
75 |
-
new_sr = 2
|
76 |
-
audio = get_batch_white_noise(b, c, int(sr * dur))
|
77 |
-
out = convert_audio(audio, from_rate=sr, to_rate=new_sr, to_channels=c)
|
78 |
-
out_j = julius.resample.resample_frac(audio, old_sr=sr, new_sr=new_sr)
|
79 |
-
assert torch.allclose(out, out_j)
|
80 |
-
|
81 |
-
|
82 |
-
class TestNormalizeAudio:
|
83 |
-
|
84 |
-
def test_clip_wav(self):
|
85 |
-
b, c, dur = 2, 1, 4.
|
86 |
-
sr = 3
|
87 |
-
audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
|
88 |
-
_clip_wav(audio)
|
89 |
-
assert audio.abs().max() <= 1
|
90 |
-
|
91 |
-
def test_normalize_audio_clip(self):
|
92 |
-
b, c, dur = 2, 1, 4.
|
93 |
-
sr = 3
|
94 |
-
audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
|
95 |
-
norm_audio = normalize_audio(audio, strategy='clip')
|
96 |
-
assert norm_audio.abs().max() <= 1
|
97 |
-
|
98 |
-
def test_normalize_audio_rms(self):
|
99 |
-
b, c, dur = 2, 1, 4.
|
100 |
-
sr = 3
|
101 |
-
audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
|
102 |
-
norm_audio = normalize_audio(audio, strategy='rms')
|
103 |
-
assert norm_audio.abs().max() <= 1
|
104 |
-
|
105 |
-
def test_normalize_audio_peak(self):
|
106 |
-
b, c, dur = 2, 1, 4.
|
107 |
-
sr = 3
|
108 |
-
audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
|
109 |
-
norm_audio = normalize_audio(audio, strategy='peak')
|
110 |
-
assert norm_audio.abs().max() <= 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/models/test_encodec_model.py
DELETED
@@ -1,60 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import random
|
8 |
-
|
9 |
-
import numpy as np
|
10 |
-
import torch
|
11 |
-
|
12 |
-
from audiocraft.models import EncodecModel
|
13 |
-
from audiocraft.modules import SEANetEncoder, SEANetDecoder
|
14 |
-
from audiocraft.quantization import DummyQuantizer
|
15 |
-
|
16 |
-
|
17 |
-
class TestEncodecModel:
|
18 |
-
|
19 |
-
def _create_encodec_model(self,
|
20 |
-
sample_rate: int,
|
21 |
-
channels: int,
|
22 |
-
dim: int = 5,
|
23 |
-
n_filters: int = 3,
|
24 |
-
n_residual_layers: int = 1,
|
25 |
-
ratios: list = [5, 4, 3, 2],
|
26 |
-
**kwargs):
|
27 |
-
frame_rate = np.prod(ratios)
|
28 |
-
encoder = SEANetEncoder(channels=channels, dimension=dim, n_filters=n_filters,
|
29 |
-
n_residual_layers=n_residual_layers, ratios=ratios)
|
30 |
-
decoder = SEANetDecoder(channels=channels, dimension=dim, n_filters=n_filters,
|
31 |
-
n_residual_layers=n_residual_layers, ratios=ratios)
|
32 |
-
quantizer = DummyQuantizer()
|
33 |
-
model = EncodecModel(encoder, decoder, quantizer, frame_rate=frame_rate,
|
34 |
-
sample_rate=sample_rate, channels=channels, **kwargs)
|
35 |
-
return model
|
36 |
-
|
37 |
-
def test_model(self):
|
38 |
-
random.seed(1234)
|
39 |
-
sample_rate = 24_000
|
40 |
-
channels = 1
|
41 |
-
model = self._create_encodec_model(sample_rate, channels)
|
42 |
-
for _ in range(10):
|
43 |
-
length = random.randrange(1, 10_000)
|
44 |
-
x = torch.randn(2, channels, length)
|
45 |
-
res = model(x)
|
46 |
-
assert res.x.shape == x.shape
|
47 |
-
|
48 |
-
def test_model_renorm(self):
|
49 |
-
random.seed(1234)
|
50 |
-
sample_rate = 24_000
|
51 |
-
channels = 1
|
52 |
-
model_nonorm = self._create_encodec_model(sample_rate, channels, renormalize=False)
|
53 |
-
model_renorm = self._create_encodec_model(sample_rate, channels, renormalize=True)
|
54 |
-
|
55 |
-
for _ in range(10):
|
56 |
-
length = random.randrange(1, 10_000)
|
57 |
-
x = torch.randn(2, channels, length)
|
58 |
-
codes, scales = model_nonorm.encode(x)
|
59 |
-
codes, scales = model_renorm.encode(x)
|
60 |
-
assert scales is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/models/test_musicgen.py
DELETED
@@ -1,58 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import pytest
|
8 |
-
import torch
|
9 |
-
|
10 |
-
from audiocraft.models import MusicGen
|
11 |
-
|
12 |
-
|
13 |
-
class TestSEANetModel:
|
14 |
-
def get_musicgen(self):
|
15 |
-
mg = MusicGen.get_pretrained(name='debug', device='cpu')
|
16 |
-
mg.set_generation_params(duration=2.0, extend_stride=2.)
|
17 |
-
return mg
|
18 |
-
|
19 |
-
def test_base(self):
|
20 |
-
mg = self.get_musicgen()
|
21 |
-
assert mg.frame_rate == 25
|
22 |
-
assert mg.sample_rate == 32000
|
23 |
-
assert mg.audio_channels == 1
|
24 |
-
|
25 |
-
def test_generate_unconditional(self):
|
26 |
-
mg = self.get_musicgen()
|
27 |
-
wav = mg.generate_unconditional(3)
|
28 |
-
assert list(wav.shape) == [3, 1, 64000]
|
29 |
-
|
30 |
-
def test_generate_continuation(self):
|
31 |
-
mg = self.get_musicgen()
|
32 |
-
prompt = torch.randn(3, 1, 32000)
|
33 |
-
wav = mg.generate_continuation(prompt, 32000)
|
34 |
-
assert list(wav.shape) == [3, 1, 64000]
|
35 |
-
|
36 |
-
prompt = torch.randn(2, 1, 32000)
|
37 |
-
wav = mg.generate_continuation(
|
38 |
-
prompt, 32000, ['youpi', 'lapin dort'])
|
39 |
-
assert list(wav.shape) == [2, 1, 64000]
|
40 |
-
|
41 |
-
prompt = torch.randn(2, 1, 32000)
|
42 |
-
with pytest.raises(AssertionError):
|
43 |
-
wav = mg.generate_continuation(
|
44 |
-
prompt, 32000, ['youpi', 'lapin dort', 'one too many'])
|
45 |
-
|
46 |
-
def test_generate(self):
|
47 |
-
mg = self.get_musicgen()
|
48 |
-
wav = mg.generate(
|
49 |
-
['youpi', 'lapin dort'])
|
50 |
-
assert list(wav.shape) == [2, 1, 64000]
|
51 |
-
|
52 |
-
def test_generate_long(self):
|
53 |
-
mg = self.get_musicgen()
|
54 |
-
mg.max_duration = 3.
|
55 |
-
mg.set_generation_params(duration=4., extend_stride=2.)
|
56 |
-
wav = mg.generate(
|
57 |
-
['youpi', 'lapin dort'])
|
58 |
-
assert list(wav.shape) == [2, 1, 32000 * 4]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/modules/__init__.py
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
|
|
tests/modules/test_codebooks_patterns.py
DELETED
@@ -1,246 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import pytest
|
8 |
-
import torch
|
9 |
-
|
10 |
-
from audiocraft.modules.codebooks_patterns import (
|
11 |
-
DelayedPatternProvider,
|
12 |
-
ParallelPatternProvider,
|
13 |
-
Pattern,
|
14 |
-
UnrolledPatternProvider,
|
15 |
-
)
|
16 |
-
|
17 |
-
|
18 |
-
class TestParallelPatternProvider:
|
19 |
-
|
20 |
-
@pytest.mark.parametrize("n_q", [1, 4, 32])
|
21 |
-
@pytest.mark.parametrize("timesteps", [0, 1, 16, 100])
|
22 |
-
def test_get_pattern(self, n_q: int, timesteps: int):
|
23 |
-
provider = ParallelPatternProvider(n_q)
|
24 |
-
pattern = provider.get_pattern(timesteps)
|
25 |
-
# + 1 to account for 1st step
|
26 |
-
assert len(pattern.layout) == timesteps + 1
|
27 |
-
|
28 |
-
@pytest.mark.parametrize("n_q", [1, 4, 32])
|
29 |
-
@pytest.mark.parametrize("timesteps", [8, 16, 100])
|
30 |
-
def test_pattern_content(self, n_q: int, timesteps: int):
|
31 |
-
provider = ParallelPatternProvider(n_q)
|
32 |
-
pattern = provider.get_pattern(timesteps)
|
33 |
-
for s, v in enumerate(pattern.layout):
|
34 |
-
for i, code in enumerate(v):
|
35 |
-
assert i == code.q
|
36 |
-
assert code.t == s - 1 # account for the 1st empty step
|
37 |
-
|
38 |
-
@pytest.mark.parametrize("n_q", [1, 4, 32])
|
39 |
-
@pytest.mark.parametrize("timesteps", [8, 16, 100])
|
40 |
-
def test_pattern_max_delay(self, n_q: int, timesteps: int):
|
41 |
-
provider = ParallelPatternProvider(n_q)
|
42 |
-
pattern = provider.get_pattern(timesteps)
|
43 |
-
assert pattern.max_delay == 0
|
44 |
-
assert len(pattern.valid_layout) == len(pattern.layout) - pattern.max_delay
|
45 |
-
|
46 |
-
|
47 |
-
class TestDelayedPatternProvider:
|
48 |
-
|
49 |
-
@pytest.mark.parametrize("n_q", [1, 4, 32])
|
50 |
-
@pytest.mark.parametrize("timesteps", [0, 1, 16, 100])
|
51 |
-
def test_get_pattern(self, n_q: int, timesteps: int):
|
52 |
-
delays = [
|
53 |
-
list(range(n_q)),
|
54 |
-
[0] + [1] * (n_q - 1),
|
55 |
-
[0] + [4] * (n_q - 1),
|
56 |
-
]
|
57 |
-
for delay in delays:
|
58 |
-
provider = DelayedPatternProvider(n_q, delay)
|
59 |
-
pattern = provider.get_pattern(timesteps)
|
60 |
-
# + 1 to account for 1st step
|
61 |
-
assert len(pattern.layout) == timesteps + max(delay) + 1
|
62 |
-
|
63 |
-
@pytest.mark.parametrize("n_q", [1, 4, 32])
|
64 |
-
@pytest.mark.parametrize("timesteps", [8, 16, 100])
|
65 |
-
def test_pattern_content(self, n_q: int, timesteps: int):
|
66 |
-
provider = DelayedPatternProvider(n_q)
|
67 |
-
pattern = provider.get_pattern(timesteps)
|
68 |
-
for s, v in enumerate(pattern.layout):
|
69 |
-
for i, code in enumerate(v):
|
70 |
-
assert i == code.q
|
71 |
-
assert code.t == max(0, s - code.q - 1)
|
72 |
-
|
73 |
-
@pytest.mark.parametrize("timesteps", [8, 16, 100])
|
74 |
-
@pytest.mark.parametrize("delay", [[0, 1, 2, 3], [0, 1, 1, 1], [0, 3, 3, 3], [0, 3]])
|
75 |
-
def test_pattern_max_delay(self, timesteps: int, delay: list):
|
76 |
-
provider = DelayedPatternProvider(len(delay), delay)
|
77 |
-
pattern = provider.get_pattern(timesteps)
|
78 |
-
assert pattern.max_delay == max(delay)
|
79 |
-
assert len(pattern.valid_layout) == len(pattern.layout) - pattern.max_delay
|
80 |
-
|
81 |
-
|
82 |
-
class TestUnrolledPatternProvider:
|
83 |
-
|
84 |
-
@pytest.mark.parametrize("timesteps", [0, 1, 16])
|
85 |
-
@pytest.mark.parametrize("flattening", [[0, 1, 2], [0, 1, 1]])
|
86 |
-
@pytest.mark.parametrize("delays", [[0, 0, 0], [0, 5, 5]])
|
87 |
-
def test_get_pattern(self, timesteps: int, flattening: list, delays: list):
|
88 |
-
n_q = len(flattening)
|
89 |
-
max_delay = max(delays)
|
90 |
-
provider = UnrolledPatternProvider(n_q, flattening, delays)
|
91 |
-
pattern = provider.get_pattern(timesteps)
|
92 |
-
assert len(pattern.layout) == provider.num_virtual_steps(timesteps) + max_delay
|
93 |
-
|
94 |
-
@pytest.mark.parametrize("timesteps", [0, 1, 16])
|
95 |
-
@pytest.mark.parametrize("flattening", [[0, 1, 2], [0, 1, 1]])
|
96 |
-
@pytest.mark.parametrize("delays", [[0, 0, 0], [0, 5, 5]])
|
97 |
-
def test_pattern_max_delay(self, timesteps: int, flattening: list, delays: list):
|
98 |
-
n_q = len(flattening)
|
99 |
-
max_delay = max(delays)
|
100 |
-
provider = UnrolledPatternProvider(n_q, flattening, delays)
|
101 |
-
pattern = provider.get_pattern(timesteps)
|
102 |
-
assert pattern.max_delay == max_delay
|
103 |
-
|
104 |
-
|
105 |
-
class TestPattern:
|
106 |
-
|
107 |
-
def ref_build_pattern_sequence(self, z: torch.Tensor, pattern: Pattern, special_token: int):
|
108 |
-
"""Reference method to build the sequence from the pattern without using fancy scatter."""
|
109 |
-
bs, n_q, T = z.shape
|
110 |
-
z = z.cpu().numpy()
|
111 |
-
assert n_q == pattern.n_q
|
112 |
-
assert T <= pattern.timesteps
|
113 |
-
inp = torch.full((bs, n_q, len(pattern.layout)), special_token, dtype=torch.long).numpy()
|
114 |
-
inp[:] = special_token
|
115 |
-
for s, v in enumerate(pattern.layout):
|
116 |
-
for (t, q) in v:
|
117 |
-
if t < T:
|
118 |
-
inp[:, q, s] = z[:, q, t]
|
119 |
-
return torch.from_numpy(inp)
|
120 |
-
|
121 |
-
def ref_revert_pattern_sequence(self, z: torch.Tensor, pattern: Pattern, special_token: int):
|
122 |
-
"""Reference method to revert the sequence from the pattern without using fancy scatter."""
|
123 |
-
z = z.cpu().numpy()
|
124 |
-
bs, n_q, S = z.shape
|
125 |
-
assert pattern.n_q == n_q
|
126 |
-
inp = torch.full((bs, pattern.n_q, pattern.timesteps), special_token, dtype=torch.long).numpy()
|
127 |
-
inp[:] = special_token
|
128 |
-
for s, v in enumerate(pattern.layout):
|
129 |
-
for (t, q) in v:
|
130 |
-
if t < pattern.timesteps:
|
131 |
-
inp[:, q, t] = z[:, q, s]
|
132 |
-
return torch.from_numpy(inp)
|
133 |
-
|
134 |
-
def ref_revert_pattern_logits(self, z: torch.Tensor, pattern: Pattern, special_token: float):
|
135 |
-
"""Reference method to revert the logits from the pattern without using fancy scatter."""
|
136 |
-
z = z.cpu().numpy()
|
137 |
-
bs, card, n_q, S = z.shape
|
138 |
-
assert pattern.n_q == n_q
|
139 |
-
ref_layout = pattern.layout
|
140 |
-
inp = torch.full((bs, card, pattern.n_q, pattern.timesteps), special_token, dtype=torch.float).numpy()
|
141 |
-
inp[:] = special_token
|
142 |
-
for s, v in enumerate(ref_layout[1:]):
|
143 |
-
if s < S:
|
144 |
-
for (t, q) in v:
|
145 |
-
if t < pattern.timesteps:
|
146 |
-
inp[:, :, q, t] = z[:, :, q, s]
|
147 |
-
return torch.from_numpy(inp)
|
148 |
-
|
149 |
-
def _get_pattern_providers(self, n_q: int):
|
150 |
-
pattern_provider_1 = ParallelPatternProvider(n_q)
|
151 |
-
pattern_provider_2 = DelayedPatternProvider(n_q, list(range(n_q)))
|
152 |
-
pattern_provider_3 = DelayedPatternProvider(n_q, [0] + [1] * (n_q - 1))
|
153 |
-
pattern_provider_4 = UnrolledPatternProvider(
|
154 |
-
n_q, flattening=list(range(n_q)), delays=[0] * n_q
|
155 |
-
)
|
156 |
-
pattern_provider_5 = UnrolledPatternProvider(
|
157 |
-
n_q, flattening=[0] + [1] * (n_q - 1), delays=[0] * n_q
|
158 |
-
)
|
159 |
-
pattern_provider_6 = UnrolledPatternProvider(
|
160 |
-
n_q, flattening=[0] + [1] * (n_q - 1), delays=[0] + [5] * (n_q - 1)
|
161 |
-
)
|
162 |
-
return [
|
163 |
-
pattern_provider_1,
|
164 |
-
pattern_provider_2,
|
165 |
-
pattern_provider_3,
|
166 |
-
pattern_provider_4,
|
167 |
-
pattern_provider_5,
|
168 |
-
pattern_provider_6,
|
169 |
-
]
|
170 |
-
|
171 |
-
@pytest.mark.parametrize("n_q", [1, 4, 32])
|
172 |
-
@pytest.mark.parametrize("timesteps", [16, 72])
|
173 |
-
def test_build_pattern_sequence(self, n_q: int, timesteps: int):
|
174 |
-
bs = 2
|
175 |
-
card = 256
|
176 |
-
special_token = card
|
177 |
-
|
178 |
-
pattern_providers = self._get_pattern_providers(n_q)
|
179 |
-
for pattern_provider in pattern_providers:
|
180 |
-
pattern = pattern_provider.get_pattern(timesteps)
|
181 |
-
# we can correctly build the sequence from the pattern
|
182 |
-
z = torch.randint(0, card, (bs, n_q, timesteps))
|
183 |
-
ref_res = self.ref_build_pattern_sequence(z, pattern, special_token)
|
184 |
-
res, indexes, mask = pattern.build_pattern_sequence(z, special_token)
|
185 |
-
assert (res == ref_res).float().mean() == 1.0
|
186 |
-
|
187 |
-
# expected assertion fails on the number of timesteps
|
188 |
-
invalid_timesteps = [timesteps + 1]
|
189 |
-
if pattern.num_sequence_steps != pattern.timesteps:
|
190 |
-
invalid_timesteps.append(pattern.num_sequence_steps)
|
191 |
-
for i_timesteps in invalid_timesteps:
|
192 |
-
z2 = torch.randint(0, card, (bs, n_q, i_timesteps))
|
193 |
-
with pytest.raises(AssertionError):
|
194 |
-
pattern.build_pattern_sequence(z2, special_token)
|
195 |
-
|
196 |
-
# expected assertion fails on the number of codebooks
|
197 |
-
invalid_qs = [0, n_q - 1, n_q + 1]
|
198 |
-
for i_q in invalid_qs:
|
199 |
-
z3 = torch.randint(0, card, (bs, i_q, timesteps))
|
200 |
-
with pytest.raises(AssertionError):
|
201 |
-
pattern.build_pattern_sequence(z3, special_token)
|
202 |
-
|
203 |
-
@pytest.mark.parametrize("n_q", [1, 4, 32])
|
204 |
-
@pytest.mark.parametrize("timesteps", [16, 72])
|
205 |
-
def test_revert_pattern_sequence(self, n_q: int, timesteps: int):
|
206 |
-
bs = 2
|
207 |
-
card = 256
|
208 |
-
special_token = card
|
209 |
-
|
210 |
-
pattern_providers = self._get_pattern_providers(n_q)
|
211 |
-
for pattern_provider in pattern_providers:
|
212 |
-
pattern = pattern_provider.get_pattern(timesteps)
|
213 |
-
# this works assuming previous tests are successful
|
214 |
-
z = torch.randint(0, card, (bs, n_q, timesteps))
|
215 |
-
s = self.ref_build_pattern_sequence(z, pattern, special_token)
|
216 |
-
ref_out = self.ref_revert_pattern_sequence(s, pattern, special_token)
|
217 |
-
# ensure our reference script retrieve the original sequence
|
218 |
-
assert z.shape == ref_out.shape
|
219 |
-
assert (z == ref_out).float().mean() == 1.0
|
220 |
-
# now we can test the scatter version
|
221 |
-
out, indexes, mask = pattern.revert_pattern_sequence(s, special_token)
|
222 |
-
assert out.shape == ref_out.shape
|
223 |
-
assert (out == ref_out).float().mean() == 1.0
|
224 |
-
|
225 |
-
@pytest.mark.parametrize("n_q", [1, 4, 32])
|
226 |
-
@pytest.mark.parametrize("timesteps", [16, 72])
|
227 |
-
@pytest.mark.parametrize("card", [1, 2, 256, 1024])
|
228 |
-
def test_revert_pattern_logits(self, n_q: int, timesteps: int, card: int):
|
229 |
-
bs = 2
|
230 |
-
special_token = card
|
231 |
-
logits_special_token = float('nan')
|
232 |
-
|
233 |
-
pattern_providers = self._get_pattern_providers(n_q)
|
234 |
-
for pattern_provider in pattern_providers:
|
235 |
-
pattern = pattern_provider.get_pattern(timesteps)
|
236 |
-
# this works assuming previous tests are successful
|
237 |
-
z = torch.randint(0, card, (bs, n_q, timesteps))
|
238 |
-
s = self.ref_build_pattern_sequence(z, pattern, special_token)
|
239 |
-
logits = torch.randn((bs, card, n_q, s.shape[-1]))
|
240 |
-
ref_out = self.ref_revert_pattern_logits(logits, pattern, logits_special_token)
|
241 |
-
# ensure our reference script retrieve the original sequence
|
242 |
-
assert ref_out.shape == torch.Size([bs, card, n_q, timesteps])
|
243 |
-
# now we can test the scatter version
|
244 |
-
out, indexes, mask = pattern.revert_pattern_logits(logits, logits_special_token)
|
245 |
-
assert out.shape == ref_out.shape
|
246 |
-
assert (out == ref_out).float().mean() == 1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/modules/test_conv.py
DELETED
@@ -1,203 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
from itertools import product
|
8 |
-
import math
|
9 |
-
import random
|
10 |
-
|
11 |
-
import pytest
|
12 |
-
import torch
|
13 |
-
from torch import nn
|
14 |
-
|
15 |
-
from audiocraft.modules import (
|
16 |
-
NormConv1d,
|
17 |
-
NormConvTranspose1d,
|
18 |
-
StreamableConv1d,
|
19 |
-
StreamableConvTranspose1d,
|
20 |
-
pad1d,
|
21 |
-
unpad1d,
|
22 |
-
)
|
23 |
-
|
24 |
-
|
25 |
-
def test_get_extra_padding_for_conv1d():
|
26 |
-
# TODO: Implement me!
|
27 |
-
pass
|
28 |
-
|
29 |
-
|
30 |
-
def test_pad1d_zeros():
|
31 |
-
x = torch.randn(1, 1, 20)
|
32 |
-
|
33 |
-
xp1 = pad1d(x, (0, 5), mode='constant', value=0.)
|
34 |
-
assert xp1.shape[-1] == 25
|
35 |
-
xp2 = pad1d(x, (5, 5), mode='constant', value=0.)
|
36 |
-
assert xp2.shape[-1] == 30
|
37 |
-
xp3 = pad1d(x, (0, 0), mode='constant', value=0.)
|
38 |
-
assert xp3.shape[-1] == 20
|
39 |
-
xp4 = pad1d(x, (10, 30), mode='constant', value=0.)
|
40 |
-
assert xp4.shape[-1] == 60
|
41 |
-
|
42 |
-
with pytest.raises(AssertionError):
|
43 |
-
pad1d(x, (-1, 0), mode='constant', value=0.)
|
44 |
-
|
45 |
-
with pytest.raises(AssertionError):
|
46 |
-
pad1d(x, (0, -1), mode='constant', value=0.)
|
47 |
-
|
48 |
-
with pytest.raises(AssertionError):
|
49 |
-
pad1d(x, (-1, -1), mode='constant', value=0.)
|
50 |
-
|
51 |
-
|
52 |
-
def test_pad1d_reflect():
|
53 |
-
x = torch.randn(1, 1, 20)
|
54 |
-
|
55 |
-
xp1 = pad1d(x, (0, 5), mode='reflect', value=0.)
|
56 |
-
assert xp1.shape[-1] == 25
|
57 |
-
xp2 = pad1d(x, (5, 5), mode='reflect', value=0.)
|
58 |
-
assert xp2.shape[-1] == 30
|
59 |
-
xp3 = pad1d(x, (0, 0), mode='reflect', value=0.)
|
60 |
-
assert xp3.shape[-1] == 20
|
61 |
-
xp4 = pad1d(x, (10, 30), mode='reflect', value=0.)
|
62 |
-
assert xp4.shape[-1] == 60
|
63 |
-
|
64 |
-
with pytest.raises(AssertionError):
|
65 |
-
pad1d(x, (-1, 0), mode='reflect', value=0.)
|
66 |
-
|
67 |
-
with pytest.raises(AssertionError):
|
68 |
-
pad1d(x, (0, -1), mode='reflect', value=0.)
|
69 |
-
|
70 |
-
with pytest.raises(AssertionError):
|
71 |
-
pad1d(x, (-1, -1), mode='reflect', value=0.)
|
72 |
-
|
73 |
-
|
74 |
-
def test_unpad1d():
|
75 |
-
x = torch.randn(1, 1, 20)
|
76 |
-
|
77 |
-
u1 = unpad1d(x, (5, 5))
|
78 |
-
assert u1.shape[-1] == 10
|
79 |
-
u2 = unpad1d(x, (0, 5))
|
80 |
-
assert u2.shape[-1] == 15
|
81 |
-
u3 = unpad1d(x, (5, 0))
|
82 |
-
assert u3.shape[-1] == 15
|
83 |
-
u4 = unpad1d(x, (0, 0))
|
84 |
-
assert u4.shape[-1] == x.shape[-1]
|
85 |
-
|
86 |
-
with pytest.raises(AssertionError):
|
87 |
-
unpad1d(x, (-1, 0))
|
88 |
-
|
89 |
-
with pytest.raises(AssertionError):
|
90 |
-
unpad1d(x, (0, -1))
|
91 |
-
|
92 |
-
with pytest.raises(AssertionError):
|
93 |
-
unpad1d(x, (-1, -1))
|
94 |
-
|
95 |
-
|
96 |
-
class TestNormConv1d:
|
97 |
-
|
98 |
-
def test_norm_conv1d_modules(self):
|
99 |
-
N, C, T = 2, 2, random.randrange(1, 100_000)
|
100 |
-
t0 = torch.randn(N, C, T)
|
101 |
-
|
102 |
-
C_out, kernel_size, stride = 1, 4, 1
|
103 |
-
expected_out_length = int((T - kernel_size) / stride + 1)
|
104 |
-
wn_conv = NormConv1d(C, 1, kernel_size=4, norm='weight_norm')
|
105 |
-
gn_conv = NormConv1d(C, 1, kernel_size=4, norm='time_group_norm')
|
106 |
-
nn_conv = NormConv1d(C, 1, kernel_size=4, norm='none')
|
107 |
-
|
108 |
-
assert isinstance(wn_conv.norm, nn.Identity)
|
109 |
-
assert isinstance(wn_conv.conv, nn.Conv1d)
|
110 |
-
|
111 |
-
assert isinstance(gn_conv.norm, nn.GroupNorm)
|
112 |
-
assert isinstance(gn_conv.conv, nn.Conv1d)
|
113 |
-
|
114 |
-
assert isinstance(nn_conv.norm, nn.Identity)
|
115 |
-
assert isinstance(nn_conv.conv, nn.Conv1d)
|
116 |
-
|
117 |
-
for conv_layer in [wn_conv, gn_conv, nn_conv]:
|
118 |
-
out = conv_layer(t0)
|
119 |
-
assert isinstance(out, torch.Tensor)
|
120 |
-
assert list(out.shape) == [N, C_out, expected_out_length]
|
121 |
-
|
122 |
-
|
123 |
-
class TestNormConvTranspose1d:
|
124 |
-
|
125 |
-
def test_normalizations(self):
|
126 |
-
N, C, T = 2, 2, random.randrange(1, 100_000)
|
127 |
-
t0 = torch.randn(N, C, T)
|
128 |
-
|
129 |
-
C_out, kernel_size, stride = 1, 4, 1
|
130 |
-
expected_out_length = (T - 1) * stride + (kernel_size - 1) + 1
|
131 |
-
|
132 |
-
wn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='weight_norm')
|
133 |
-
gn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='time_group_norm')
|
134 |
-
nn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='none')
|
135 |
-
|
136 |
-
assert isinstance(wn_convtr.norm, nn.Identity)
|
137 |
-
assert isinstance(wn_convtr.convtr, nn.ConvTranspose1d)
|
138 |
-
|
139 |
-
assert isinstance(gn_convtr.norm, nn.GroupNorm)
|
140 |
-
assert isinstance(gn_convtr.convtr, nn.ConvTranspose1d)
|
141 |
-
|
142 |
-
assert isinstance(nn_convtr.norm, nn.Identity)
|
143 |
-
assert isinstance(nn_convtr.convtr, nn.ConvTranspose1d)
|
144 |
-
|
145 |
-
for convtr_layer in [wn_convtr, gn_convtr, nn_convtr]:
|
146 |
-
out = convtr_layer(t0)
|
147 |
-
assert isinstance(out, torch.Tensor)
|
148 |
-
assert list(out.shape) == [N, C_out, expected_out_length]
|
149 |
-
|
150 |
-
|
151 |
-
class TestStreamableConv1d:
|
152 |
-
|
153 |
-
def get_streamable_conv1d_output_length(self, length, kernel_size, stride, dilation):
|
154 |
-
# StreamableConv1d internally pads to make sure that the last window is full
|
155 |
-
padding_total = (kernel_size - 1) * dilation - (stride - 1)
|
156 |
-
n_frames = (length - kernel_size + padding_total) / stride + 1
|
157 |
-
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
158 |
-
return ideal_length // stride
|
159 |
-
|
160 |
-
def test_streamable_conv1d(self):
|
161 |
-
N, C, T = 2, 2, random.randrange(1, 100_000)
|
162 |
-
t0 = torch.randn(N, C, T)
|
163 |
-
C_out = 1
|
164 |
-
|
165 |
-
# conv params are [(kernel_size, stride, dilation)]
|
166 |
-
conv_params = [(4, 1, 1), (4, 2, 1), (3, 1, 3), (10, 5, 1), (3, 2, 3)]
|
167 |
-
for causal, (kernel_size, stride, dilation) in product([False, True], conv_params):
|
168 |
-
expected_out_length = self.get_streamable_conv1d_output_length(T, kernel_size, stride, dilation)
|
169 |
-
sconv = StreamableConv1d(C, C_out, kernel_size=kernel_size, stride=stride, dilation=dilation, causal=causal)
|
170 |
-
out = sconv(t0)
|
171 |
-
assert isinstance(out, torch.Tensor)
|
172 |
-
print(list(out.shape), [N, C_out, expected_out_length])
|
173 |
-
assert list(out.shape) == [N, C_out, expected_out_length]
|
174 |
-
|
175 |
-
|
176 |
-
class TestStreamableConvTranspose1d:
|
177 |
-
|
178 |
-
def get_streamable_convtr1d_output_length(self, length, kernel_size, stride):
|
179 |
-
padding_total = (kernel_size - stride)
|
180 |
-
return (length - 1) * stride - padding_total + (kernel_size - 1) + 1
|
181 |
-
|
182 |
-
def test_streamable_convtr1d(self):
|
183 |
-
N, C, T = 2, 2, random.randrange(1, 100_000)
|
184 |
-
t0 = torch.randn(N, C, T)
|
185 |
-
|
186 |
-
C_out = 1
|
187 |
-
|
188 |
-
with pytest.raises(AssertionError):
|
189 |
-
StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=False, trim_right_ratio=0.5)
|
190 |
-
StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=True, trim_right_ratio=-1.)
|
191 |
-
StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=True, trim_right_ratio=2)
|
192 |
-
|
193 |
-
# causal params are [(causal, trim_right)]
|
194 |
-
causal_params = [(False, 1.0), (True, 1.0), (True, 0.5), (True, 0.0)]
|
195 |
-
# conv params are [(kernel_size, stride)]
|
196 |
-
conv_params = [(4, 1), (4, 2), (3, 1), (10, 5)]
|
197 |
-
for ((causal, trim_right_ratio), (kernel_size, stride)) in product(causal_params, conv_params):
|
198 |
-
expected_out_length = self.get_streamable_convtr1d_output_length(T, kernel_size, stride)
|
199 |
-
sconvtr = StreamableConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride,
|
200 |
-
causal=causal, trim_right_ratio=trim_right_ratio)
|
201 |
-
out = sconvtr(t0)
|
202 |
-
assert isinstance(out, torch.Tensor)
|
203 |
-
assert list(out.shape) == [N, C_out, expected_out_length]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/modules/test_lstm.py
DELETED
@@ -1,32 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import random
|
8 |
-
import torch
|
9 |
-
|
10 |
-
from audiocraft.modules.lstm import StreamableLSTM
|
11 |
-
|
12 |
-
|
13 |
-
class TestStreamableLSTM:
|
14 |
-
|
15 |
-
def test_lstm(self):
|
16 |
-
B, C, T = 4, 2, random.randint(1, 100)
|
17 |
-
|
18 |
-
lstm = StreamableLSTM(C, 3, skip=False)
|
19 |
-
x = torch.randn(B, C, T)
|
20 |
-
y = lstm(x)
|
21 |
-
|
22 |
-
print(y.shape)
|
23 |
-
assert y.shape == torch.Size([B, C, T])
|
24 |
-
|
25 |
-
def test_lstm_skip(self):
|
26 |
-
B, C, T = 4, 2, random.randint(1, 100)
|
27 |
-
|
28 |
-
lstm = StreamableLSTM(C, 3, skip=True)
|
29 |
-
x = torch.randn(B, C, T)
|
30 |
-
y = lstm(x)
|
31 |
-
|
32 |
-
assert y.shape == torch.Size([B, C, T])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/modules/test_rope.py
DELETED
@@ -1,168 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import torch
|
8 |
-
|
9 |
-
from audiocraft.modules.rope import RotaryEmbedding
|
10 |
-
from audiocraft.modules.transformer import StreamingTransformer, set_efficient_attention_backend
|
11 |
-
|
12 |
-
|
13 |
-
def test_rope():
|
14 |
-
set_efficient_attention_backend('xformers')
|
15 |
-
B, T, H, C = 8, 75, 16, 128
|
16 |
-
|
17 |
-
rope = RotaryEmbedding(dim=C)
|
18 |
-
xq = torch.rand((B, T, H, C))
|
19 |
-
xk = torch.rand((B, T, H, C))
|
20 |
-
xq_out, xk_out = rope.rotate_qk(xq, xk, start=7)
|
21 |
-
|
22 |
-
assert list(xq_out.shape) == [B, T, H, C]
|
23 |
-
assert list(xk_out.shape) == [B, T, H, C]
|
24 |
-
|
25 |
-
|
26 |
-
def test_rope_io_dtypes():
|
27 |
-
set_efficient_attention_backend('xformers')
|
28 |
-
B, T, H, C = 8, 75, 16, 128
|
29 |
-
|
30 |
-
rope_32 = RotaryEmbedding(dim=C, dtype=torch.float32)
|
31 |
-
rope_64 = RotaryEmbedding(dim=C, dtype=torch.float64)
|
32 |
-
|
33 |
-
# Test bfloat16 inputs w/ both 32 and 64 precision rope.
|
34 |
-
xq_16 = torch.rand((B, T, H, C)).to(torch.bfloat16)
|
35 |
-
xk_16 = torch.rand((B, T, H, C)).to(torch.bfloat16)
|
36 |
-
xq_out, xk_out = rope_32.rotate_qk(xq_16, xk_16)
|
37 |
-
assert xq_out.dtype == torch.bfloat16
|
38 |
-
xq_out, xk_out = rope_64.rotate_qk(xq_16, xk_16)
|
39 |
-
assert xq_out.dtype == torch.bfloat16
|
40 |
-
|
41 |
-
# Test float32 inputs w/ both 32 and 64 precision rope.
|
42 |
-
xq_32 = torch.rand((B, T, H, C)).to(torch.float32)
|
43 |
-
xk_32 = torch.rand((B, T, H, C)).to(torch.float32)
|
44 |
-
xq_out, xk_out = rope_32.rotate_qk(xq_32, xk_32)
|
45 |
-
assert xq_out.dtype == torch.float32
|
46 |
-
xq_out, xk_out = rope_64.rotate_qk(xq_32, xk_32)
|
47 |
-
assert xq_out.dtype == torch.float32
|
48 |
-
|
49 |
-
|
50 |
-
def test_transformer_with_rope():
|
51 |
-
set_efficient_attention_backend('xformers')
|
52 |
-
torch.manual_seed(1234)
|
53 |
-
for pos in ['rope', 'sin_rope']:
|
54 |
-
tr = StreamingTransformer(
|
55 |
-
16, 4, 2, custom=True, dropout=0., layer_scale=0.1,
|
56 |
-
positional_embedding=pos)
|
57 |
-
tr.eval()
|
58 |
-
steps = 12
|
59 |
-
x = torch.randn(3, steps, 16)
|
60 |
-
|
61 |
-
out = tr(x)
|
62 |
-
assert list(out.shape) == list(x.shape)
|
63 |
-
|
64 |
-
|
65 |
-
@torch.no_grad()
|
66 |
-
def test_rope_streaming():
|
67 |
-
set_efficient_attention_backend('xformers')
|
68 |
-
torch.manual_seed(1234)
|
69 |
-
tr = StreamingTransformer(
|
70 |
-
16, 4, 2, causal=True, dropout=0.,
|
71 |
-
custom=True, positional_embedding='rope')
|
72 |
-
tr.eval()
|
73 |
-
steps = 12
|
74 |
-
x = torch.randn(3, steps, 16)
|
75 |
-
|
76 |
-
ref = tr(x)
|
77 |
-
|
78 |
-
with tr.streaming():
|
79 |
-
outs = []
|
80 |
-
frame_sizes = [1] * steps
|
81 |
-
|
82 |
-
for frame_size in frame_sizes:
|
83 |
-
frame = x[:, :frame_size]
|
84 |
-
x = x[:, frame_size:]
|
85 |
-
outs.append(tr(frame))
|
86 |
-
|
87 |
-
out = torch.cat(outs, dim=1)
|
88 |
-
assert list(out.shape) == [3, steps, 16]
|
89 |
-
delta = torch.norm(out - ref) / torch.norm(out)
|
90 |
-
assert delta < 1e-6, delta
|
91 |
-
|
92 |
-
|
93 |
-
@torch.no_grad()
|
94 |
-
def test_rope_streaming_past_context():
|
95 |
-
set_efficient_attention_backend('xformers')
|
96 |
-
torch.manual_seed(1234)
|
97 |
-
|
98 |
-
for context in [None, 10]:
|
99 |
-
tr = StreamingTransformer(
|
100 |
-
16, 4, 1 if context else 2,
|
101 |
-
causal=True, past_context=context, custom=True,
|
102 |
-
dropout=0., positional_embedding='rope')
|
103 |
-
tr.eval()
|
104 |
-
|
105 |
-
steps = 20
|
106 |
-
x = torch.randn(3, steps, 16)
|
107 |
-
ref = tr(x)
|
108 |
-
|
109 |
-
with tr.streaming():
|
110 |
-
outs = []
|
111 |
-
frame_sizes = [1] * steps
|
112 |
-
|
113 |
-
for frame_size in frame_sizes:
|
114 |
-
frame = x[:, :frame_size]
|
115 |
-
x = x[:, frame_size:]
|
116 |
-
outs.append(tr(frame))
|
117 |
-
|
118 |
-
out = torch.cat(outs, dim=1)
|
119 |
-
assert list(out.shape) == [3, steps, 16]
|
120 |
-
delta = torch.norm(out - ref) / torch.norm(out)
|
121 |
-
assert delta < 1e-6, delta
|
122 |
-
|
123 |
-
|
124 |
-
def test_rope_memory_efficient():
|
125 |
-
set_efficient_attention_backend('xformers')
|
126 |
-
torch.manual_seed(1234)
|
127 |
-
tr = StreamingTransformer(
|
128 |
-
16, 4, 2, custom=True, dropout=0., layer_scale=0.1,
|
129 |
-
positional_embedding='rope')
|
130 |
-
tr_mem_efficient = StreamingTransformer(
|
131 |
-
16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1,
|
132 |
-
positional_embedding='rope')
|
133 |
-
tr_mem_efficient.load_state_dict(tr.state_dict())
|
134 |
-
tr.eval()
|
135 |
-
steps = 12
|
136 |
-
x = torch.randn(3, steps, 16)
|
137 |
-
|
138 |
-
with torch.no_grad():
|
139 |
-
y = tr(x)
|
140 |
-
y2 = tr_mem_efficient(x)
|
141 |
-
# Check at float precision b/c this is the rope default.
|
142 |
-
assert torch.allclose(y, y2, atol=1e-7), (y - y2).norm()
|
143 |
-
|
144 |
-
|
145 |
-
def test_rope_with_xpos():
|
146 |
-
set_efficient_attention_backend('xformers')
|
147 |
-
B, T, H, C = 8, 75, 16, 128
|
148 |
-
|
149 |
-
rope = RotaryEmbedding(dim=C, xpos=True)
|
150 |
-
xq = torch.rand((B, T, H, C))
|
151 |
-
xk = torch.rand((B, T, H, C))
|
152 |
-
xq_out, xk_out = rope.rotate_qk(xq, xk, start=7)
|
153 |
-
|
154 |
-
assert list(xq_out.shape) == [B, T, H, C]
|
155 |
-
assert list(xk_out.shape) == [B, T, H, C]
|
156 |
-
|
157 |
-
|
158 |
-
def test_positional_scale():
|
159 |
-
set_efficient_attention_backend('xformers')
|
160 |
-
B, T, H, C = 8, 75, 16, 128
|
161 |
-
|
162 |
-
rope = RotaryEmbedding(dim=C, xpos=True, scale=0.0)
|
163 |
-
xq = torch.rand((B, T, H, C))
|
164 |
-
xk = torch.rand((B, T, H, C))
|
165 |
-
xq_out, xk_out = rope.rotate_qk(xq, xk, start=7)
|
166 |
-
|
167 |
-
assert torch.allclose(xq, xq_out)
|
168 |
-
assert torch.allclose(xk, xk_out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/modules/test_seanet.py
DELETED
@@ -1,115 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
from itertools import product
|
8 |
-
|
9 |
-
import pytest
|
10 |
-
import torch
|
11 |
-
|
12 |
-
from audiocraft.modules.seanet import SEANetEncoder, SEANetDecoder, SEANetResnetBlock
|
13 |
-
from audiocraft.modules import StreamableConv1d, StreamableConvTranspose1d
|
14 |
-
|
15 |
-
|
16 |
-
class TestSEANetModel:
|
17 |
-
|
18 |
-
def test_base(self):
|
19 |
-
encoder = SEANetEncoder()
|
20 |
-
decoder = SEANetDecoder()
|
21 |
-
|
22 |
-
x = torch.randn(1, 1, 24000)
|
23 |
-
z = encoder(x)
|
24 |
-
assert list(z.shape) == [1, 128, 75], z.shape
|
25 |
-
y = decoder(z)
|
26 |
-
assert y.shape == x.shape, (x.shape, y.shape)
|
27 |
-
|
28 |
-
def test_causal(self):
|
29 |
-
encoder = SEANetEncoder(causal=True)
|
30 |
-
decoder = SEANetDecoder(causal=True)
|
31 |
-
x = torch.randn(1, 1, 24000)
|
32 |
-
|
33 |
-
z = encoder(x)
|
34 |
-
assert list(z.shape) == [1, 128, 75], z.shape
|
35 |
-
y = decoder(z)
|
36 |
-
assert y.shape == x.shape, (x.shape, y.shape)
|
37 |
-
|
38 |
-
def test_conv_skip_connection(self):
|
39 |
-
encoder = SEANetEncoder(true_skip=False)
|
40 |
-
decoder = SEANetDecoder(true_skip=False)
|
41 |
-
|
42 |
-
x = torch.randn(1, 1, 24000)
|
43 |
-
z = encoder(x)
|
44 |
-
assert list(z.shape) == [1, 128, 75], z.shape
|
45 |
-
y = decoder(z)
|
46 |
-
assert y.shape == x.shape, (x.shape, y.shape)
|
47 |
-
|
48 |
-
def test_seanet_encoder_decoder_final_act(self):
|
49 |
-
encoder = SEANetEncoder(true_skip=False)
|
50 |
-
decoder = SEANetDecoder(true_skip=False, final_activation='Tanh')
|
51 |
-
|
52 |
-
x = torch.randn(1, 1, 24000)
|
53 |
-
z = encoder(x)
|
54 |
-
assert list(z.shape) == [1, 128, 75], z.shape
|
55 |
-
y = decoder(z)
|
56 |
-
assert y.shape == x.shape, (x.shape, y.shape)
|
57 |
-
|
58 |
-
def _check_encoder_blocks_norm(self, encoder: SEANetEncoder, n_disable_blocks: int, norm: str):
|
59 |
-
n_blocks = 0
|
60 |
-
for layer in encoder.model:
|
61 |
-
if isinstance(layer, StreamableConv1d):
|
62 |
-
n_blocks += 1
|
63 |
-
assert layer.conv.norm_type == 'none' if n_blocks <= n_disable_blocks else norm
|
64 |
-
elif isinstance(layer, SEANetResnetBlock):
|
65 |
-
for resnet_layer in layer.block:
|
66 |
-
if isinstance(resnet_layer, StreamableConv1d):
|
67 |
-
# here we add + 1 to n_blocks as we increment n_blocks just after the block
|
68 |
-
assert resnet_layer.conv.norm_type == 'none' if (n_blocks + 1) <= n_disable_blocks else norm
|
69 |
-
|
70 |
-
def test_encoder_disable_norm(self):
|
71 |
-
n_residuals = [0, 1, 3]
|
72 |
-
disable_blocks = [0, 1, 2, 3, 4, 5, 6]
|
73 |
-
norms = ['weight_norm', 'none']
|
74 |
-
for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms):
|
75 |
-
encoder = SEANetEncoder(n_residual_layers=n_res, norm=norm,
|
76 |
-
disable_norm_outer_blocks=disable_blocks)
|
77 |
-
self._check_encoder_blocks_norm(encoder, disable_blocks, norm)
|
78 |
-
|
79 |
-
def _check_decoder_blocks_norm(self, decoder: SEANetDecoder, n_disable_blocks: int, norm: str):
|
80 |
-
n_blocks = 0
|
81 |
-
for layer in decoder.model:
|
82 |
-
if isinstance(layer, StreamableConv1d):
|
83 |
-
n_blocks += 1
|
84 |
-
assert layer.conv.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
|
85 |
-
elif isinstance(layer, StreamableConvTranspose1d):
|
86 |
-
n_blocks += 1
|
87 |
-
assert layer.convtr.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
|
88 |
-
elif isinstance(layer, SEANetResnetBlock):
|
89 |
-
for resnet_layer in layer.block:
|
90 |
-
if isinstance(resnet_layer, StreamableConv1d):
|
91 |
-
assert resnet_layer.conv.norm_type == 'none' \
|
92 |
-
if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
|
93 |
-
|
94 |
-
def test_decoder_disable_norm(self):
|
95 |
-
n_residuals = [0, 1, 3]
|
96 |
-
disable_blocks = [0, 1, 2, 3, 4, 5, 6]
|
97 |
-
norms = ['weight_norm', 'none']
|
98 |
-
for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms):
|
99 |
-
decoder = SEANetDecoder(n_residual_layers=n_res, norm=norm,
|
100 |
-
disable_norm_outer_blocks=disable_blocks)
|
101 |
-
self._check_decoder_blocks_norm(decoder, disable_blocks, norm)
|
102 |
-
|
103 |
-
def test_disable_norm_raises_exception(self):
|
104 |
-
# Invalid disable_norm_outer_blocks values raise exceptions
|
105 |
-
with pytest.raises(AssertionError):
|
106 |
-
SEANetEncoder(disable_norm_outer_blocks=-1)
|
107 |
-
|
108 |
-
with pytest.raises(AssertionError):
|
109 |
-
SEANetEncoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7)
|
110 |
-
|
111 |
-
with pytest.raises(AssertionError):
|
112 |
-
SEANetDecoder(disable_norm_outer_blocks=-1)
|
113 |
-
|
114 |
-
with pytest.raises(AssertionError):
|
115 |
-
SEANetDecoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/modules/test_transformer.py
DELETED
@@ -1,253 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
from itertools import product
|
8 |
-
|
9 |
-
import pytest
|
10 |
-
import torch
|
11 |
-
|
12 |
-
from audiocraft.modules.transformer import (
|
13 |
-
StreamingMultiheadAttention, StreamingTransformer, set_efficient_attention_backend)
|
14 |
-
|
15 |
-
|
16 |
-
def test_transformer_causal_streaming():
|
17 |
-
torch.manual_seed(1234)
|
18 |
-
|
19 |
-
for context, custom in product([None, 10], [False, True]):
|
20 |
-
# Test that causality and receptive fields are properly handled.
|
21 |
-
# looking at the gradients
|
22 |
-
tr = StreamingTransformer(
|
23 |
-
16, 4, 1 if context else 2,
|
24 |
-
causal=True, past_context=context, custom=custom,
|
25 |
-
dropout=0.)
|
26 |
-
steps = 20
|
27 |
-
for k in [0, 10, 15, 19]:
|
28 |
-
x = torch.randn(4, steps, 16, requires_grad=True)
|
29 |
-
y = tr(x)
|
30 |
-
y[:, k].abs().sum().backward()
|
31 |
-
if k + 1 < steps:
|
32 |
-
assert torch.allclose(x.grad[:, k + 1:], torch.tensor(0.)), x.grad[:, k + 1:].norm()
|
33 |
-
assert not torch.allclose(x.grad[:, :k + 1], torch.tensor(0.)), x.grad[:, :k + 1].norm()
|
34 |
-
if context is not None and k > context:
|
35 |
-
limit = k - context - 1
|
36 |
-
assert torch.allclose(x.grad[:, :limit],
|
37 |
-
torch.tensor(0.)), x.grad[:, :limit].norm()
|
38 |
-
|
39 |
-
# Now check that streaming gives the same result at batch eval.
|
40 |
-
x = torch.randn(4, steps, 16)
|
41 |
-
y = tr(x)
|
42 |
-
ys = []
|
43 |
-
with tr.streaming():
|
44 |
-
for k in range(steps):
|
45 |
-
chunk = x[:, k:k + 1, :]
|
46 |
-
ys.append(tr(chunk))
|
47 |
-
y_stream = torch.cat(ys, dim=1)
|
48 |
-
delta = torch.norm(y_stream - y) / torch.norm(y)
|
49 |
-
assert delta < 1e-6, delta
|
50 |
-
|
51 |
-
|
52 |
-
def test_transformer_vs_pytorch():
|
53 |
-
torch.manual_seed(1234)
|
54 |
-
# Check that in the non causal setting, we get the same result as
|
55 |
-
# PyTorch Transformer encoder.
|
56 |
-
for custom in [False, True]:
|
57 |
-
tr = StreamingTransformer(
|
58 |
-
16, 4, 2,
|
59 |
-
causal=False, custom=custom, dropout=0., positional_scale=0.)
|
60 |
-
layer = torch.nn.TransformerEncoderLayer(16, 4, dropout=0., batch_first=True)
|
61 |
-
tr_ref = torch.nn.TransformerEncoder(layer, 2)
|
62 |
-
tr.load_state_dict(tr_ref.state_dict())
|
63 |
-
|
64 |
-
x = torch.randn(4, 20, 16)
|
65 |
-
y = tr(x)
|
66 |
-
y2 = tr_ref(x)
|
67 |
-
delta = torch.norm(y2 - y) / torch.norm(y)
|
68 |
-
assert delta < 1e-6, delta
|
69 |
-
|
70 |
-
|
71 |
-
def test_streaming_api():
|
72 |
-
tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0.)
|
73 |
-
tr.eval()
|
74 |
-
steps = 12
|
75 |
-
x = torch.randn(1, steps, 16)
|
76 |
-
|
77 |
-
with torch.no_grad():
|
78 |
-
with tr.streaming():
|
79 |
-
_ = tr(x[:, :1])
|
80 |
-
state = {k: v.clone() for k, v in tr.get_streaming_state().items()}
|
81 |
-
y = tr(x[:, 1:2])
|
82 |
-
tr.set_streaming_state(state)
|
83 |
-
y2 = tr(x[:, 1:2])
|
84 |
-
assert torch.allclose(y, y2), (y - y2).norm()
|
85 |
-
assert tr.flush() is None
|
86 |
-
|
87 |
-
|
88 |
-
def test_memory_efficient():
|
89 |
-
for backend in ['torch', 'xformers']:
|
90 |
-
torch.manual_seed(1234)
|
91 |
-
set_efficient_attention_backend(backend)
|
92 |
-
|
93 |
-
tr = StreamingTransformer(
|
94 |
-
16, 4, 2, custom=True, dropout=0., layer_scale=0.1)
|
95 |
-
tr_mem_efficient = StreamingTransformer(
|
96 |
-
16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1)
|
97 |
-
tr_mem_efficient.load_state_dict(tr.state_dict())
|
98 |
-
tr.eval()
|
99 |
-
steps = 12
|
100 |
-
x = torch.randn(3, steps, 16)
|
101 |
-
|
102 |
-
with torch.no_grad():
|
103 |
-
y = tr(x)
|
104 |
-
y2 = tr_mem_efficient(x)
|
105 |
-
assert torch.allclose(y, y2), ((y - y2).norm(), backend)
|
106 |
-
|
107 |
-
|
108 |
-
def test_attention_as_float32():
|
109 |
-
torch.manual_seed(1234)
|
110 |
-
cases = [
|
111 |
-
{'custom': True},
|
112 |
-
{'custom': False},
|
113 |
-
]
|
114 |
-
for case in cases:
|
115 |
-
tr = StreamingTransformer(16, 4, 2, dropout=0., dtype=torch.bfloat16, **case)
|
116 |
-
tr_float32 = StreamingTransformer(
|
117 |
-
16, 4, 2, dropout=0., attention_as_float32=True, dtype=torch.bfloat16, **case)
|
118 |
-
if not case['custom']:
|
119 |
-
# we are not using autocast here because it doesn't really
|
120 |
-
# work as expected on CPU, so we have to manually cast the weights of the MHA.
|
121 |
-
for layer in tr_float32.layers:
|
122 |
-
layer.self_attn.mha.to(torch.float32)
|
123 |
-
tr_float32.load_state_dict(tr.state_dict())
|
124 |
-
steps = 12
|
125 |
-
x = torch.randn(3, steps, 16, dtype=torch.bfloat16)
|
126 |
-
|
127 |
-
with torch.no_grad():
|
128 |
-
y = tr(x)
|
129 |
-
y2 = tr_float32(x)
|
130 |
-
assert not torch.allclose(y, y2), (y - y2).norm()
|
131 |
-
|
132 |
-
|
133 |
-
@torch.no_grad()
|
134 |
-
def test_streaming_memory_efficient():
|
135 |
-
for backend in ['torch', 'xformers']:
|
136 |
-
torch.manual_seed(1234)
|
137 |
-
set_efficient_attention_backend(backend)
|
138 |
-
tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0., custom=True)
|
139 |
-
tr_mem_efficient = StreamingTransformer(
|
140 |
-
16, 4, 2, dropout=0., memory_efficient=True, causal=True)
|
141 |
-
tr.load_state_dict(tr_mem_efficient.state_dict())
|
142 |
-
tr.eval()
|
143 |
-
tr_mem_efficient.eval()
|
144 |
-
steps = 12
|
145 |
-
x = torch.randn(3, steps, 16)
|
146 |
-
|
147 |
-
ref = tr(x)
|
148 |
-
|
149 |
-
with tr_mem_efficient.streaming():
|
150 |
-
outs = []
|
151 |
-
# frame_sizes = [2] + [1] * (steps - 2)
|
152 |
-
frame_sizes = [1] * steps
|
153 |
-
|
154 |
-
for frame_size in frame_sizes:
|
155 |
-
frame = x[:, :frame_size]
|
156 |
-
x = x[:, frame_size:]
|
157 |
-
outs.append(tr_mem_efficient(frame))
|
158 |
-
|
159 |
-
out = torch.cat(outs, dim=1)
|
160 |
-
delta = torch.norm(out - ref) / torch.norm(out)
|
161 |
-
assert delta < 1e-6, delta
|
162 |
-
|
163 |
-
|
164 |
-
def test_cross_attention():
|
165 |
-
torch.manual_seed(1234)
|
166 |
-
for norm_first in [True, False]:
|
167 |
-
m = StreamingTransformer(
|
168 |
-
16, 4, 2, cross_attention=False, norm_first=norm_first, dropout=0., custom=True)
|
169 |
-
m_cross = StreamingTransformer(
|
170 |
-
16, 4, 2, cross_attention=True, norm_first=norm_first, dropout=0., custom=True)
|
171 |
-
m_cross.load_state_dict(m.state_dict(), strict=False)
|
172 |
-
x = torch.randn(2, 5, 16)
|
173 |
-
cross_x = torch.randn(2, 3, 16)
|
174 |
-
y_ref = m(x)
|
175 |
-
y_cross_zero = m_cross(x, cross_attention_src=0 * cross_x)
|
176 |
-
# With norm_first, the two should be exactly yhe same,
|
177 |
-
# but with norm_first=False, we get 2 normalization in a row
|
178 |
-
# and the epsilon value leads to a tiny change.
|
179 |
-
atol = 0. if norm_first else 1e-6
|
180 |
-
print((y_ref - y_cross_zero).norm() / y_ref.norm())
|
181 |
-
assert torch.allclose(y_ref, y_cross_zero, atol=atol)
|
182 |
-
|
183 |
-
# We now expect a difference even with a generous atol of 1e-2.
|
184 |
-
y_cross = m_cross(x, cross_attention_src=cross_x)
|
185 |
-
assert not torch.allclose(y_cross, y_cross_zero, atol=1e-2)
|
186 |
-
|
187 |
-
with pytest.raises(AssertionError):
|
188 |
-
_ = m_cross(x)
|
189 |
-
_ = m(x, cross_attention_src=cross_x)
|
190 |
-
|
191 |
-
|
192 |
-
def test_cross_attention_compat():
|
193 |
-
torch.manual_seed(1234)
|
194 |
-
num_heads = 2
|
195 |
-
dim = num_heads * 64
|
196 |
-
with pytest.raises(AssertionError):
|
197 |
-
StreamingMultiheadAttention(dim, num_heads, causal=True, cross_attention=True)
|
198 |
-
|
199 |
-
cross_attn = StreamingMultiheadAttention(
|
200 |
-
dim, num_heads, dropout=0, cross_attention=True, custom=True)
|
201 |
-
ref_attn = torch.nn.MultiheadAttention(dim, num_heads, dropout=0, batch_first=True)
|
202 |
-
|
203 |
-
# We can load the regular attention state dict
|
204 |
-
# so we have compat when loading old checkpoints.
|
205 |
-
cross_attn.load_state_dict(ref_attn.state_dict())
|
206 |
-
|
207 |
-
queries = torch.randn(3, 7, dim)
|
208 |
-
keys = torch.randn(3, 9, dim)
|
209 |
-
values = torch.randn(3, 9, dim)
|
210 |
-
|
211 |
-
y = cross_attn(queries, keys, values)[0]
|
212 |
-
y_ref = ref_attn(queries, keys, values)[0]
|
213 |
-
assert torch.allclose(y, y_ref, atol=1e-7), (y - y_ref).norm() / y_ref.norm()
|
214 |
-
|
215 |
-
# Now let's check that streaming is working properly.
|
216 |
-
with cross_attn.streaming():
|
217 |
-
ys = []
|
218 |
-
for step in range(queries.shape[1]):
|
219 |
-
ys.append(cross_attn(queries[:, step: step + 1], keys, values)[0])
|
220 |
-
y_streaming = torch.cat(ys, dim=1)
|
221 |
-
assert torch.allclose(y_streaming, y, atol=1e-7)
|
222 |
-
|
223 |
-
|
224 |
-
def test_repeat_kv():
|
225 |
-
torch.manual_seed(1234)
|
226 |
-
num_heads = 8
|
227 |
-
kv_repeat = 4
|
228 |
-
dim = num_heads * 64
|
229 |
-
with pytest.raises(AssertionError):
|
230 |
-
mha = StreamingMultiheadAttention(
|
231 |
-
dim, num_heads, causal=True, kv_repeat=kv_repeat, cross_attention=True)
|
232 |
-
mha = StreamingMultiheadAttention(
|
233 |
-
dim, num_heads, causal=True, kv_repeat=kv_repeat)
|
234 |
-
mha = StreamingMultiheadAttention(
|
235 |
-
dim, num_heads, causal=True, kv_repeat=kv_repeat, custom=True)
|
236 |
-
x = torch.randn(4, 18, dim)
|
237 |
-
y = mha(x, x, x)[0]
|
238 |
-
assert x.shape == y.shape
|
239 |
-
|
240 |
-
|
241 |
-
def test_qk_layer_norm():
|
242 |
-
torch.manual_seed(1234)
|
243 |
-
tr = StreamingTransformer(
|
244 |
-
16, 4, 2, custom=True, dropout=0., qk_layer_norm=True, bias_attn=False)
|
245 |
-
steps = 12
|
246 |
-
x = torch.randn(3, steps, 16)
|
247 |
-
y = tr(x)
|
248 |
-
|
249 |
-
tr = StreamingTransformer(
|
250 |
-
16, 4, 2, custom=True, dropout=0., qk_layer_norm=True, cross_attention=True)
|
251 |
-
z = torch.randn(3, 21, 16)
|
252 |
-
y = tr(x, cross_attention_src=z)
|
253 |
-
assert y.shape == x.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/quantization/test_vq.py
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import torch
|
8 |
-
|
9 |
-
from audiocraft.quantization.vq import ResidualVectorQuantizer
|
10 |
-
|
11 |
-
|
12 |
-
class TestResidualVectorQuantizer:
|
13 |
-
|
14 |
-
def test_rvq(self):
|
15 |
-
x = torch.randn(1, 16, 2048)
|
16 |
-
vq = ResidualVectorQuantizer(n_q=8, dimension=16, bins=8)
|
17 |
-
res = vq(x, 1.)
|
18 |
-
assert res.x.shape == torch.Size([1, 16, 2048])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/utils/__init__.py
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
|
|