Spaces:
Alexxggs
/
Configuration error

Alexxggs commited on
Commit
ac08452
·
1 Parent(s): 60f415d

Delete tests

Browse files
tests/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
 
 
 
 
 
 
tests/common_utils/__init__.py DELETED
@@ -1,9 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- # flake8: noqa
8
- from .temp_utils import TempDirMixin
9
- from .wav_utils import get_batch_white_noise, get_white_noise, save_wav
 
 
 
 
 
 
 
 
 
 
tests/common_utils/temp_utils.py DELETED
@@ -1,56 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import os
8
- import tempfile
9
-
10
-
11
- class TempDirMixin:
12
- """Mixin to provide easy access to temp dir.
13
- """
14
-
15
- temp_dir_ = None
16
-
17
- @classmethod
18
- def get_base_temp_dir(cls):
19
- # If AUDIOCRAFT_TEST_DIR is set, use it instead of temporary directory.
20
- # this is handy for debugging.
21
- key = "AUDIOCRAFT_TEST_DIR"
22
- if key in os.environ:
23
- return os.environ[key]
24
- if cls.temp_dir_ is None:
25
- cls.temp_dir_ = tempfile.TemporaryDirectory()
26
- return cls.temp_dir_.name
27
-
28
- @classmethod
29
- def tearDownClass(cls):
30
- if cls.temp_dir_ is not None:
31
- try:
32
- cls.temp_dir_.cleanup()
33
- cls.temp_dir_ = None
34
- except PermissionError:
35
- # On Windows there is a know issue with `shutil.rmtree`,
36
- # which fails intermittenly.
37
- # https://github.com/python/cpython/issues/74168
38
- # Following the above thread, we ignore it.
39
- pass
40
- super().tearDownClass()
41
-
42
- @property
43
- def id(self):
44
- return self.__class__.__name__
45
-
46
- def get_temp_path(self, *paths):
47
- temp_dir = os.path.join(self.get_base_temp_dir(), self.id)
48
- path = os.path.join(temp_dir, *paths)
49
- os.makedirs(os.path.dirname(path), exist_ok=True)
50
- return path
51
-
52
- def get_temp_dir(self, *paths):
53
- temp_dir = os.path.join(self.get_base_temp_dir(), self.id)
54
- path = os.path.join(temp_dir, *paths)
55
- os.makedirs(path, exist_ok=True)
56
- return path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/common_utils/wav_utils.py DELETED
@@ -1,32 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- from pathlib import Path
8
- import typing as tp
9
-
10
- import torch
11
- import torchaudio
12
-
13
-
14
- def get_white_noise(chs: int = 1, num_frames: int = 1):
15
- wav = torch.randn(chs, num_frames)
16
- return wav
17
-
18
-
19
- def get_batch_white_noise(bs: int = 1, chs: int = 1, num_frames: int = 1):
20
- wav = torch.randn(bs, chs, num_frames)
21
- return wav
22
-
23
-
24
- def save_wav(path: str, wav: torch.Tensor, sample_rate: int):
25
- fp = Path(path)
26
- kwargs: tp.Dict[str, tp.Any] = {}
27
- if fp.suffix == '.wav':
28
- kwargs['encoding'] = 'PCM_S'
29
- kwargs['bits_per_sample'] = 16
30
- elif fp.suffix == '.mp3':
31
- kwargs['compression'] = 320
32
- torchaudio.save(str(fp), wav, sample_rate, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/data/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
 
 
 
 
 
 
tests/data/test_audio.py DELETED
@@ -1,239 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- from itertools import product
8
- import random
9
-
10
- import numpy as np
11
- import torch
12
- import torchaudio
13
-
14
- from audiocraft.data.audio import audio_info, audio_read, audio_write, _av_read
15
-
16
- from ..common_utils import TempDirMixin, get_white_noise, save_wav
17
-
18
-
19
- class TestInfo(TempDirMixin):
20
-
21
- def test_info_mp3(self):
22
- sample_rates = [8000, 16_000]
23
- channels = [1, 2]
24
- duration = 1.
25
- for sample_rate, ch in product(sample_rates, channels):
26
- wav = get_white_noise(ch, int(sample_rate * duration))
27
- path = self.get_temp_path('sample_wav.mp3')
28
- save_wav(path, wav, sample_rate)
29
- info = audio_info(path)
30
- assert info.sample_rate == sample_rate
31
- assert info.channels == ch
32
- # we cannot trust torchaudio for num_frames, so we don't check
33
-
34
- def _test_info_format(self, ext: str):
35
- sample_rates = [8000, 16_000]
36
- channels = [1, 2]
37
- duration = 1.
38
- for sample_rate, ch in product(sample_rates, channels):
39
- n_frames = int(sample_rate * duration)
40
- wav = get_white_noise(ch, n_frames)
41
- path = self.get_temp_path(f'sample_wav{ext}')
42
- save_wav(path, wav, sample_rate)
43
- info = audio_info(path)
44
- assert info.sample_rate == sample_rate
45
- assert info.channels == ch
46
- assert np.isclose(info.duration, duration, atol=1e-5)
47
-
48
- def test_info_wav(self):
49
- self._test_info_format('.wav')
50
-
51
- def test_info_flac(self):
52
- self._test_info_format('.flac')
53
-
54
- def test_info_ogg(self):
55
- self._test_info_format('.ogg')
56
-
57
- def test_info_m4a(self):
58
- # TODO: generate m4a file programmatically
59
- # self._test_info_format('.m4a')
60
- pass
61
-
62
-
63
- class TestRead(TempDirMixin):
64
-
65
- def test_read_full_wav(self):
66
- sample_rates = [8000, 16_000]
67
- channels = [1, 2]
68
- duration = 1.
69
- for sample_rate, ch in product(sample_rates, channels):
70
- n_frames = int(sample_rate * duration)
71
- wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99)
72
- path = self.get_temp_path('sample_wav.wav')
73
- save_wav(path, wav, sample_rate)
74
- read_wav, read_sr = audio_read(path)
75
- assert read_sr == sample_rate
76
- assert read_wav.shape[0] == wav.shape[0]
77
- assert read_wav.shape[1] == wav.shape[1]
78
- assert torch.allclose(read_wav, wav, rtol=1e-03, atol=1e-04)
79
-
80
- def test_read_partial_wav(self):
81
- sample_rates = [8000, 16_000]
82
- channels = [1, 2]
83
- duration = 1.
84
- read_duration = torch.rand(1).item()
85
- for sample_rate, ch in product(sample_rates, channels):
86
- n_frames = int(sample_rate * duration)
87
- read_frames = int(sample_rate * read_duration)
88
- wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99)
89
- path = self.get_temp_path('sample_wav.wav')
90
- save_wav(path, wav, sample_rate)
91
- read_wav, read_sr = audio_read(path, 0, read_duration)
92
- assert read_sr == sample_rate
93
- assert read_wav.shape[0] == wav.shape[0]
94
- assert read_wav.shape[1] == read_frames
95
- assert torch.allclose(read_wav[..., 0:read_frames], wav[..., 0:read_frames], rtol=1e-03, atol=1e-04)
96
-
97
- def test_read_seek_time_wav(self):
98
- sample_rates = [8000, 16_000]
99
- channels = [1, 2]
100
- duration = 1.
101
- read_duration = 1.
102
- for sample_rate, ch in product(sample_rates, channels):
103
- n_frames = int(sample_rate * duration)
104
- wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99)
105
- path = self.get_temp_path('sample_wav.wav')
106
- save_wav(path, wav, sample_rate)
107
- seek_time = torch.rand(1).item()
108
- read_wav, read_sr = audio_read(path, seek_time, read_duration)
109
- seek_frames = int(sample_rate * seek_time)
110
- expected_frames = n_frames - seek_frames
111
- assert read_sr == sample_rate
112
- assert read_wav.shape[0] == wav.shape[0]
113
- assert read_wav.shape[1] == expected_frames
114
- assert torch.allclose(read_wav, wav[..., seek_frames:], rtol=1e-03, atol=1e-04)
115
-
116
- def test_read_seek_time_wav_padded(self):
117
- sample_rates = [8000, 16_000]
118
- channels = [1, 2]
119
- duration = 1.
120
- read_duration = 1.
121
- for sample_rate, ch in product(sample_rates, channels):
122
- n_frames = int(sample_rate * duration)
123
- read_frames = int(sample_rate * read_duration)
124
- wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99)
125
- path = self.get_temp_path('sample_wav.wav')
126
- save_wav(path, wav, sample_rate)
127
- seek_time = torch.rand(1).item()
128
- seek_frames = int(sample_rate * seek_time)
129
- expected_frames = n_frames - seek_frames
130
- read_wav, read_sr = audio_read(path, seek_time, read_duration, pad=True)
131
- expected_pad_wav = torch.zeros(wav.shape[0], read_frames - expected_frames)
132
- assert read_sr == sample_rate
133
- assert read_wav.shape[0] == wav.shape[0]
134
- assert read_wav.shape[1] == read_frames
135
- assert torch.allclose(read_wav[..., :expected_frames], wav[..., seek_frames:], rtol=1e-03, atol=1e-04)
136
- assert torch.allclose(read_wav[..., expected_frames:], expected_pad_wav)
137
-
138
-
139
- class TestAvRead(TempDirMixin):
140
-
141
- def test_avread_seek_base(self):
142
- sample_rates = [8000, 16_000]
143
- channels = [1, 2]
144
- duration = 2.
145
- for sample_rate, ch in product(sample_rates, channels):
146
- n_frames = int(sample_rate * duration)
147
- wav = get_white_noise(ch, n_frames)
148
- path = self.get_temp_path(f'reference_a_{sample_rate}_{ch}.wav')
149
- save_wav(path, wav, sample_rate)
150
- for _ in range(100):
151
- # seek will always load a full duration segment in the file
152
- seek_time = random.uniform(0.0, 1.0)
153
- seek_duration = random.uniform(0.001, 1.0)
154
- read_wav, read_sr = _av_read(path, seek_time, seek_duration)
155
- assert read_sr == sample_rate
156
- assert read_wav.shape[0] == wav.shape[0]
157
- assert read_wav.shape[-1] == int(seek_duration * sample_rate)
158
-
159
- def test_avread_seek_partial(self):
160
- sample_rates = [8000, 16_000]
161
- channels = [1, 2]
162
- duration = 1.
163
- for sample_rate, ch in product(sample_rates, channels):
164
- n_frames = int(sample_rate * duration)
165
- wav = get_white_noise(ch, n_frames)
166
- path = self.get_temp_path(f'reference_b_{sample_rate}_{ch}.wav')
167
- save_wav(path, wav, sample_rate)
168
- for _ in range(100):
169
- # seek will always load a partial segment
170
- seek_time = random.uniform(0.5, 1.)
171
- seek_duration = 1.
172
- expected_num_frames = n_frames - int(seek_time * sample_rate)
173
- read_wav, read_sr = _av_read(path, seek_time, seek_duration)
174
- assert read_sr == sample_rate
175
- assert read_wav.shape[0] == wav.shape[0]
176
- assert read_wav.shape[-1] == expected_num_frames
177
-
178
- def test_avread_seek_outofbound(self):
179
- sample_rates = [8000, 16_000]
180
- channels = [1, 2]
181
- duration = 1.
182
- for sample_rate, ch in product(sample_rates, channels):
183
- n_frames = int(sample_rate * duration)
184
- wav = get_white_noise(ch, n_frames)
185
- path = self.get_temp_path(f'reference_c_{sample_rate}_{ch}.wav')
186
- save_wav(path, wav, sample_rate)
187
- seek_time = 1.5
188
- read_wav, read_sr = _av_read(path, seek_time, 1.)
189
- assert read_sr == sample_rate
190
- assert read_wav.shape[0] == wav.shape[0]
191
- assert read_wav.shape[-1] == 0
192
-
193
- def test_avread_seek_edge(self):
194
- sample_rates = [8000, 16_000]
195
- # some of these values will have
196
- # int(((frames - 1) / sample_rate) * sample_rate) != (frames - 1)
197
- n_frames = [1000, 1001, 1002]
198
- channels = [1, 2]
199
- for sample_rate, ch, frames in product(sample_rates, channels, n_frames):
200
- duration = frames / sample_rate
201
- wav = get_white_noise(ch, frames)
202
- path = self.get_temp_path(f'reference_d_{sample_rate}_{ch}.wav')
203
- save_wav(path, wav, sample_rate)
204
- seek_time = (frames - 1) / sample_rate
205
- seek_frames = int(seek_time * sample_rate)
206
- read_wav, read_sr = _av_read(path, seek_time, duration)
207
- assert read_sr == sample_rate
208
- assert read_wav.shape[0] == wav.shape[0]
209
- assert read_wav.shape[-1] == (frames - seek_frames)
210
-
211
-
212
- class TestAudioWrite(TempDirMixin):
213
-
214
- def test_audio_write_wav(self):
215
- torch.manual_seed(1234)
216
- sample_rates = [8000, 16_000]
217
- n_frames = [1000, 1001, 1002]
218
- channels = [1, 2]
219
- strategies = ["peak", "clip", "rms"]
220
- formats = ["wav", "mp3"]
221
- for sample_rate, ch, frames in product(sample_rates, channels, n_frames):
222
- for format_, strategy in product(formats, strategies):
223
- wav = get_white_noise(ch, frames)
224
- path = self.get_temp_path(f'pred_{sample_rate}_{ch}')
225
- audio_write(path, wav, sample_rate, format_, strategy=strategy)
226
- read_wav, read_sr = torchaudio.load(f'{path}.{format_}')
227
- if format_ == "wav":
228
- assert read_wav.shape == wav.shape
229
-
230
- if format_ == "wav" and strategy in ["peak", "rms"]:
231
- rescaled_read_wav = read_wav / read_wav.abs().max() * wav.abs().max()
232
- # for a Gaussian, the typical max scale will be less than ~5x the std.
233
- # The error when writing to disk will ~ 1/2**15, and when rescaling, 5x that.
234
- # For RMS target, rescaling leaves more headroom by default, leading
235
- # to a 20x rescaling typically
236
- atol = (5 if strategy == "peak" else 20) / 2**15
237
- delta = (rescaled_read_wav - wav).abs().max()
238
- assert torch.allclose(wav, rescaled_read_wav, rtol=0, atol=atol), (delta, atol)
239
- formats = ["wav"] # faster unit tests
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/data/test_audio_dataset.py DELETED
@@ -1,352 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- from functools import partial
8
- from itertools import product
9
- import json
10
- import math
11
- import os
12
- import random
13
- import typing as tp
14
-
15
- import pytest
16
- import torch
17
- from torch.utils.data import DataLoader
18
-
19
- from audiocraft.data.audio_dataset import (
20
- AudioDataset,
21
- AudioMeta,
22
- _get_audio_meta,
23
- load_audio_meta,
24
- save_audio_meta
25
- )
26
- from audiocraft.data.zip import PathInZip
27
-
28
- from ..common_utils import TempDirMixin, get_white_noise, save_wav
29
-
30
-
31
- class TestAudioMeta(TempDirMixin):
32
-
33
- def test_get_audio_meta(self):
34
- sample_rates = [8000, 16_000]
35
- channels = [1, 2]
36
- duration = 1.
37
- for sample_rate, ch in product(sample_rates, channels):
38
- n_frames = int(duration * sample_rate)
39
- wav = get_white_noise(ch, n_frames)
40
- path = self.get_temp_path('sample.wav')
41
- save_wav(path, wav, sample_rate)
42
- m = _get_audio_meta(path, minimal=True)
43
- assert m.path == path, 'path does not match'
44
- assert m.sample_rate == sample_rate, 'sample rate does not match'
45
- assert m.duration == duration, 'duration does not match'
46
- assert m.amplitude is None
47
- assert m.info_path is None
48
-
49
- def test_save_audio_meta(self):
50
- audio_meta = [
51
- AudioMeta("mypath1", 1., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file1.json')),
52
- AudioMeta("mypath2", 2., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file2.json'))
53
- ]
54
- empty_audio_meta = []
55
- for idx, meta in enumerate([audio_meta, empty_audio_meta]):
56
- path = self.get_temp_path(f'data_{idx}_save.jsonl')
57
- save_audio_meta(path, meta)
58
- with open(path, 'r') as f:
59
- lines = f.readlines()
60
- read_meta = [AudioMeta.from_dict(json.loads(line)) for line in lines]
61
- assert len(read_meta) == len(meta)
62
- for m, read_m in zip(meta, read_meta):
63
- assert m == read_m
64
-
65
- def test_load_audio_meta(self):
66
- try:
67
- import dora
68
- except ImportError:
69
- dora = None # type: ignore
70
-
71
- audio_meta = [
72
- AudioMeta("mypath1", 1., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file1.json')),
73
- AudioMeta("mypath2", 2., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file2.json'))
74
- ]
75
- empty_meta = []
76
- for idx, meta in enumerate([audio_meta, empty_meta]):
77
- path = self.get_temp_path(f'data_{idx}_load.jsonl')
78
- with open(path, 'w') as f:
79
- for m in meta:
80
- json_str = json.dumps(m.to_dict()) + '\n'
81
- f.write(json_str)
82
- read_meta = load_audio_meta(path)
83
- assert len(read_meta) == len(meta)
84
- for m, read_m in zip(meta, read_meta):
85
- if dora:
86
- m.path = dora.git_save.to_absolute_path(m.path)
87
- assert m == read_m, f'original={m}, read={read_m}'
88
-
89
-
90
- class TestAudioDataset(TempDirMixin):
91
-
92
- def _create_audio_files(self,
93
- root_name: str,
94
- num_examples: int,
95
- durations: tp.Union[float, tp.Tuple[float, float]] = (0.1, 1.),
96
- sample_rate: int = 16_000,
97
- channels: int = 1):
98
- root_dir = self.get_temp_dir(root_name)
99
- for i in range(num_examples):
100
- if isinstance(durations, float):
101
- duration = durations
102
- elif isinstance(durations, tuple) and len(durations) == 1:
103
- duration = durations[0]
104
- elif isinstance(durations, tuple) and len(durations) == 2:
105
- duration = random.uniform(durations[0], durations[1])
106
- else:
107
- assert False
108
- n_frames = int(duration * sample_rate)
109
- wav = get_white_noise(channels, n_frames)
110
- path = os.path.join(root_dir, f'example_{i}.wav')
111
- save_wav(path, wav, sample_rate)
112
- return root_dir
113
-
114
- def _create_audio_dataset(self,
115
- root_name: str,
116
- total_num_examples: int,
117
- durations: tp.Union[float, tp.Tuple[float, float]] = (0.1, 1.),
118
- sample_rate: int = 16_000,
119
- channels: int = 1,
120
- segment_duration: tp.Optional[float] = None,
121
- num_examples: int = 10,
122
- shuffle: bool = True,
123
- return_info: bool = False):
124
- root_dir = self._create_audio_files(root_name, total_num_examples, durations, sample_rate, channels)
125
- dataset = AudioDataset.from_path(root_dir,
126
- minimal_meta=True,
127
- segment_duration=segment_duration,
128
- num_samples=num_examples,
129
- sample_rate=sample_rate,
130
- channels=channels,
131
- shuffle=shuffle,
132
- return_info=return_info)
133
- return dataset
134
-
135
- def test_dataset_full(self):
136
- total_examples = 10
137
- min_duration, max_duration = 1., 4.
138
- sample_rate = 16_000
139
- channels = 1
140
- dataset = self._create_audio_dataset(
141
- 'dset', total_examples, durations=(min_duration, max_duration),
142
- sample_rate=sample_rate, channels=channels, segment_duration=None)
143
- assert len(dataset) == total_examples
144
- assert dataset.sample_rate == sample_rate
145
- assert dataset.channels == channels
146
- for idx in range(len(dataset)):
147
- sample = dataset[idx]
148
- assert sample.shape[0] == channels
149
- assert sample.shape[1] <= int(max_duration * sample_rate)
150
- assert sample.shape[1] >= int(min_duration * sample_rate)
151
-
152
- def test_dataset_segment(self):
153
- total_examples = 10
154
- num_samples = 20
155
- min_duration, max_duration = 1., 4.
156
- segment_duration = 1.
157
- sample_rate = 16_000
158
- channels = 1
159
- dataset = self._create_audio_dataset(
160
- 'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
161
- channels=channels, segment_duration=segment_duration, num_examples=num_samples)
162
- assert len(dataset) == num_samples
163
- assert dataset.sample_rate == sample_rate
164
- assert dataset.channels == channels
165
- for idx in range(len(dataset)):
166
- sample = dataset[idx]
167
- assert sample.shape[0] == channels
168
- assert sample.shape[1] == int(segment_duration * sample_rate)
169
-
170
- def test_dataset_equal_audio_and_segment_durations(self):
171
- total_examples = 1
172
- num_samples = 2
173
- audio_duration = 1.
174
- segment_duration = 1.
175
- sample_rate = 16_000
176
- channels = 1
177
- dataset = self._create_audio_dataset(
178
- 'dset', total_examples, durations=audio_duration, sample_rate=sample_rate,
179
- channels=channels, segment_duration=segment_duration, num_examples=num_samples)
180
- assert len(dataset) == num_samples
181
- assert dataset.sample_rate == sample_rate
182
- assert dataset.channels == channels
183
- for idx in range(len(dataset)):
184
- sample = dataset[idx]
185
- assert sample.shape[0] == channels
186
- assert sample.shape[1] == int(segment_duration * sample_rate)
187
- # the random seek_time adds variability on audio read
188
- sample_1 = dataset[0]
189
- sample_2 = dataset[1]
190
- assert not torch.allclose(sample_1, sample_2)
191
-
192
- def test_dataset_samples(self):
193
- total_examples = 1
194
- num_samples = 2
195
- audio_duration = 1.
196
- segment_duration = 1.
197
- sample_rate = 16_000
198
- channels = 1
199
-
200
- create_dataset = partial(
201
- self._create_audio_dataset,
202
- 'dset', total_examples, durations=audio_duration, sample_rate=sample_rate,
203
- channels=channels, segment_duration=segment_duration, num_examples=num_samples,
204
- )
205
-
206
- dataset = create_dataset(shuffle=True)
207
- # when shuffle = True, we have different inputs for the same index across epoch
208
- sample_1 = dataset[0]
209
- sample_2 = dataset[0]
210
- assert not torch.allclose(sample_1, sample_2)
211
-
212
- dataset_noshuffle = create_dataset(shuffle=False)
213
- # when shuffle = False, we have same inputs for the same index across epoch
214
- sample_1 = dataset_noshuffle[0]
215
- sample_2 = dataset_noshuffle[0]
216
- assert torch.allclose(sample_1, sample_2)
217
-
218
- def test_dataset_return_info(self):
219
- total_examples = 10
220
- num_samples = 20
221
- min_duration, max_duration = 1., 4.
222
- segment_duration = 1.
223
- sample_rate = 16_000
224
- channels = 1
225
- dataset = self._create_audio_dataset(
226
- 'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
227
- channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=True)
228
- assert len(dataset) == num_samples
229
- assert dataset.sample_rate == sample_rate
230
- assert dataset.channels == channels
231
- for idx in range(len(dataset)):
232
- sample, segment_info = dataset[idx]
233
- assert sample.shape[0] == channels
234
- assert sample.shape[1] == int(segment_duration * sample_rate)
235
- assert segment_info.sample_rate == sample_rate
236
- assert segment_info.total_frames == int(segment_duration * sample_rate)
237
- assert segment_info.n_frames <= int(segment_duration * sample_rate)
238
- assert segment_info.seek_time >= 0
239
-
240
- def test_dataset_return_info_no_segment_duration(self):
241
- total_examples = 10
242
- num_samples = 20
243
- min_duration, max_duration = 1., 4.
244
- segment_duration = None
245
- sample_rate = 16_000
246
- channels = 1
247
- dataset = self._create_audio_dataset(
248
- 'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
249
- channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=True)
250
- assert len(dataset) == total_examples
251
- assert dataset.sample_rate == sample_rate
252
- assert dataset.channels == channels
253
- for idx in range(len(dataset)):
254
- sample, segment_info = dataset[idx]
255
- assert sample.shape[0] == channels
256
- assert sample.shape[1] == segment_info.total_frames
257
- assert segment_info.sample_rate == sample_rate
258
- assert segment_info.n_frames <= segment_info.total_frames
259
-
260
- def test_dataset_collate_fn(self):
261
- total_examples = 10
262
- num_samples = 20
263
- min_duration, max_duration = 1., 4.
264
- segment_duration = 1.
265
- sample_rate = 16_000
266
- channels = 1
267
- dataset = self._create_audio_dataset(
268
- 'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
269
- channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=False)
270
- batch_size = 4
271
- dataloader = DataLoader(
272
- dataset,
273
- batch_size=batch_size,
274
- num_workers=0
275
- )
276
- for idx, batch in enumerate(dataloader):
277
- assert batch.shape[0] == batch_size
278
-
279
- @pytest.mark.parametrize("segment_duration", [1.0, None])
280
- def test_dataset_with_meta_collate_fn(self, segment_duration):
281
- total_examples = 10
282
- num_samples = 20
283
- min_duration, max_duration = 1., 4.
284
- segment_duration = 1.
285
- sample_rate = 16_000
286
- channels = 1
287
- dataset = self._create_audio_dataset(
288
- 'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
289
- channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=True)
290
- batch_size = 4
291
- dataloader = DataLoader(
292
- dataset,
293
- batch_size=batch_size,
294
- collate_fn=dataset.collater,
295
- num_workers=0
296
- )
297
- for idx, batch in enumerate(dataloader):
298
- wav, infos = batch
299
- assert wav.shape[0] == batch_size
300
- assert len(infos) == batch_size
301
-
302
- @pytest.mark.parametrize("segment_duration,sample_on_weight,sample_on_duration,a_hist,b_hist,c_hist", [
303
- [1, True, True, 0.5, 0.5, 0.0],
304
- [1, False, True, 0.25, 0.5, 0.25],
305
- [1, True, False, 0.666, 0.333, 0.0],
306
- [1, False, False, 0.333, 0.333, 0.333],
307
- [None, False, False, 0.333, 0.333, 0.333]])
308
- def test_sample_with_weight(self, segment_duration, sample_on_weight, sample_on_duration, a_hist, b_hist, c_hist):
309
- random.seed(1234)
310
- rng = torch.Generator()
311
- rng.manual_seed(1234)
312
-
313
- def _get_histogram(dataset, repetitions=20_000):
314
- counts = {file_meta.path: 0. for file_meta in meta}
315
- for _ in range(repetitions):
316
- file_meta = dataset.sample_file(rng)
317
- counts[file_meta.path] += 1
318
- return {name: count / repetitions for name, count in counts.items()}
319
-
320
- meta = [
321
- AudioMeta(path='a', duration=5, sample_rate=1, weight=2),
322
- AudioMeta(path='b', duration=10, sample_rate=1, weight=None),
323
- AudioMeta(path='c', duration=5, sample_rate=1, weight=0),
324
- ]
325
- dataset = AudioDataset(
326
- meta, segment_duration=segment_duration, sample_on_weight=sample_on_weight,
327
- sample_on_duration=sample_on_duration)
328
- hist = _get_histogram(dataset)
329
- assert math.isclose(hist['a'], a_hist, abs_tol=0.01)
330
- assert math.isclose(hist['b'], b_hist, abs_tol=0.01)
331
- assert math.isclose(hist['c'], c_hist, abs_tol=0.01)
332
-
333
- def test_meta_duration_filter_all(self):
334
- meta = [
335
- AudioMeta(path='a', duration=5, sample_rate=1, weight=2),
336
- AudioMeta(path='b', duration=10, sample_rate=1, weight=None),
337
- AudioMeta(path='c', duration=5, sample_rate=1, weight=0),
338
- ]
339
- try:
340
- AudioDataset(meta, segment_duration=11, min_segment_ratio=1)
341
- assert False
342
- except AssertionError:
343
- assert True
344
-
345
- def test_meta_duration_filter_long(self):
346
- meta = [
347
- AudioMeta(path='a', duration=5, sample_rate=1, weight=2),
348
- AudioMeta(path='b', duration=10, sample_rate=1, weight=None),
349
- AudioMeta(path='c', duration=5, sample_rate=1, weight=0),
350
- ]
351
- dataset = AudioDataset(meta, segment_duration=None, min_segment_ratio=1, max_audio_duration=7)
352
- assert len(dataset) == 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/data/test_audio_utils.py DELETED
@@ -1,110 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import julius
8
- import torch
9
- import pytest
10
-
11
- from audiocraft.data.audio_utils import (
12
- _clip_wav,
13
- convert_audio_channels,
14
- convert_audio,
15
- normalize_audio
16
- )
17
- from ..common_utils import get_batch_white_noise
18
-
19
-
20
- class TestConvertAudioChannels:
21
-
22
- def test_convert_audio_channels_downmix(self):
23
- b, c, t = 2, 3, 100
24
- audio = get_batch_white_noise(b, c, t)
25
- mixed = convert_audio_channels(audio, channels=2)
26
- assert list(mixed.shape) == [b, 2, t]
27
-
28
- def test_convert_audio_channels_nochange(self):
29
- b, c, t = 2, 3, 100
30
- audio = get_batch_white_noise(b, c, t)
31
- mixed = convert_audio_channels(audio, channels=c)
32
- assert list(mixed.shape) == list(audio.shape)
33
-
34
- def test_convert_audio_channels_upmix(self):
35
- b, c, t = 2, 1, 100
36
- audio = get_batch_white_noise(b, c, t)
37
- mixed = convert_audio_channels(audio, channels=3)
38
- assert list(mixed.shape) == [b, 3, t]
39
-
40
- def test_convert_audio_channels_upmix_error(self):
41
- b, c, t = 2, 2, 100
42
- audio = get_batch_white_noise(b, c, t)
43
- with pytest.raises(ValueError):
44
- convert_audio_channels(audio, channels=3)
45
-
46
-
47
- class TestConvertAudio:
48
-
49
- def test_convert_audio_channels_downmix(self):
50
- b, c, dur = 2, 3, 4.
51
- sr = 128
52
- audio = get_batch_white_noise(b, c, int(sr * dur))
53
- out = convert_audio(audio, from_rate=sr, to_rate=sr, to_channels=2)
54
- assert list(out.shape) == [audio.shape[0], 2, audio.shape[-1]]
55
-
56
- def test_convert_audio_channels_upmix(self):
57
- b, c, dur = 2, 1, 4.
58
- sr = 128
59
- audio = get_batch_white_noise(b, c, int(sr * dur))
60
- out = convert_audio(audio, from_rate=sr, to_rate=sr, to_channels=3)
61
- assert list(out.shape) == [audio.shape[0], 3, audio.shape[-1]]
62
-
63
- def test_convert_audio_upsample(self):
64
- b, c, dur = 2, 1, 4.
65
- sr = 2
66
- new_sr = 3
67
- audio = get_batch_white_noise(b, c, int(sr * dur))
68
- out = convert_audio(audio, from_rate=sr, to_rate=new_sr, to_channels=c)
69
- out_j = julius.resample.resample_frac(audio, old_sr=sr, new_sr=new_sr)
70
- assert torch.allclose(out, out_j)
71
-
72
- def test_convert_audio_resample(self):
73
- b, c, dur = 2, 1, 4.
74
- sr = 3
75
- new_sr = 2
76
- audio = get_batch_white_noise(b, c, int(sr * dur))
77
- out = convert_audio(audio, from_rate=sr, to_rate=new_sr, to_channels=c)
78
- out_j = julius.resample.resample_frac(audio, old_sr=sr, new_sr=new_sr)
79
- assert torch.allclose(out, out_j)
80
-
81
-
82
- class TestNormalizeAudio:
83
-
84
- def test_clip_wav(self):
85
- b, c, dur = 2, 1, 4.
86
- sr = 3
87
- audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
88
- _clip_wav(audio)
89
- assert audio.abs().max() <= 1
90
-
91
- def test_normalize_audio_clip(self):
92
- b, c, dur = 2, 1, 4.
93
- sr = 3
94
- audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
95
- norm_audio = normalize_audio(audio, strategy='clip')
96
- assert norm_audio.abs().max() <= 1
97
-
98
- def test_normalize_audio_rms(self):
99
- b, c, dur = 2, 1, 4.
100
- sr = 3
101
- audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
102
- norm_audio = normalize_audio(audio, strategy='rms')
103
- assert norm_audio.abs().max() <= 1
104
-
105
- def test_normalize_audio_peak(self):
106
- b, c, dur = 2, 1, 4.
107
- sr = 3
108
- audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
109
- norm_audio = normalize_audio(audio, strategy='peak')
110
- assert norm_audio.abs().max() <= 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/models/test_encodec_model.py DELETED
@@ -1,60 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import random
8
-
9
- import numpy as np
10
- import torch
11
-
12
- from audiocraft.models import EncodecModel
13
- from audiocraft.modules import SEANetEncoder, SEANetDecoder
14
- from audiocraft.quantization import DummyQuantizer
15
-
16
-
17
- class TestEncodecModel:
18
-
19
- def _create_encodec_model(self,
20
- sample_rate: int,
21
- channels: int,
22
- dim: int = 5,
23
- n_filters: int = 3,
24
- n_residual_layers: int = 1,
25
- ratios: list = [5, 4, 3, 2],
26
- **kwargs):
27
- frame_rate = np.prod(ratios)
28
- encoder = SEANetEncoder(channels=channels, dimension=dim, n_filters=n_filters,
29
- n_residual_layers=n_residual_layers, ratios=ratios)
30
- decoder = SEANetDecoder(channels=channels, dimension=dim, n_filters=n_filters,
31
- n_residual_layers=n_residual_layers, ratios=ratios)
32
- quantizer = DummyQuantizer()
33
- model = EncodecModel(encoder, decoder, quantizer, frame_rate=frame_rate,
34
- sample_rate=sample_rate, channels=channels, **kwargs)
35
- return model
36
-
37
- def test_model(self):
38
- random.seed(1234)
39
- sample_rate = 24_000
40
- channels = 1
41
- model = self._create_encodec_model(sample_rate, channels)
42
- for _ in range(10):
43
- length = random.randrange(1, 10_000)
44
- x = torch.randn(2, channels, length)
45
- res = model(x)
46
- assert res.x.shape == x.shape
47
-
48
- def test_model_renorm(self):
49
- random.seed(1234)
50
- sample_rate = 24_000
51
- channels = 1
52
- model_nonorm = self._create_encodec_model(sample_rate, channels, renormalize=False)
53
- model_renorm = self._create_encodec_model(sample_rate, channels, renormalize=True)
54
-
55
- for _ in range(10):
56
- length = random.randrange(1, 10_000)
57
- x = torch.randn(2, channels, length)
58
- codes, scales = model_nonorm.encode(x)
59
- codes, scales = model_renorm.encode(x)
60
- assert scales is not None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/models/test_musicgen.py DELETED
@@ -1,58 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import pytest
8
- import torch
9
-
10
- from audiocraft.models import MusicGen
11
-
12
-
13
- class TestSEANetModel:
14
- def get_musicgen(self):
15
- mg = MusicGen.get_pretrained(name='debug', device='cpu')
16
- mg.set_generation_params(duration=2.0, extend_stride=2.)
17
- return mg
18
-
19
- def test_base(self):
20
- mg = self.get_musicgen()
21
- assert mg.frame_rate == 25
22
- assert mg.sample_rate == 32000
23
- assert mg.audio_channels == 1
24
-
25
- def test_generate_unconditional(self):
26
- mg = self.get_musicgen()
27
- wav = mg.generate_unconditional(3)
28
- assert list(wav.shape) == [3, 1, 64000]
29
-
30
- def test_generate_continuation(self):
31
- mg = self.get_musicgen()
32
- prompt = torch.randn(3, 1, 32000)
33
- wav = mg.generate_continuation(prompt, 32000)
34
- assert list(wav.shape) == [3, 1, 64000]
35
-
36
- prompt = torch.randn(2, 1, 32000)
37
- wav = mg.generate_continuation(
38
- prompt, 32000, ['youpi', 'lapin dort'])
39
- assert list(wav.shape) == [2, 1, 64000]
40
-
41
- prompt = torch.randn(2, 1, 32000)
42
- with pytest.raises(AssertionError):
43
- wav = mg.generate_continuation(
44
- prompt, 32000, ['youpi', 'lapin dort', 'one too many'])
45
-
46
- def test_generate(self):
47
- mg = self.get_musicgen()
48
- wav = mg.generate(
49
- ['youpi', 'lapin dort'])
50
- assert list(wav.shape) == [2, 1, 64000]
51
-
52
- def test_generate_long(self):
53
- mg = self.get_musicgen()
54
- mg.max_duration = 3.
55
- mg.set_generation_params(duration=4., extend_stride=2.)
56
- wav = mg.generate(
57
- ['youpi', 'lapin dort'])
58
- assert list(wav.shape) == [2, 1, 32000 * 4]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/modules/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
 
 
 
 
 
 
tests/modules/test_codebooks_patterns.py DELETED
@@ -1,246 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import pytest
8
- import torch
9
-
10
- from audiocraft.modules.codebooks_patterns import (
11
- DelayedPatternProvider,
12
- ParallelPatternProvider,
13
- Pattern,
14
- UnrolledPatternProvider,
15
- )
16
-
17
-
18
- class TestParallelPatternProvider:
19
-
20
- @pytest.mark.parametrize("n_q", [1, 4, 32])
21
- @pytest.mark.parametrize("timesteps", [0, 1, 16, 100])
22
- def test_get_pattern(self, n_q: int, timesteps: int):
23
- provider = ParallelPatternProvider(n_q)
24
- pattern = provider.get_pattern(timesteps)
25
- # + 1 to account for 1st step
26
- assert len(pattern.layout) == timesteps + 1
27
-
28
- @pytest.mark.parametrize("n_q", [1, 4, 32])
29
- @pytest.mark.parametrize("timesteps", [8, 16, 100])
30
- def test_pattern_content(self, n_q: int, timesteps: int):
31
- provider = ParallelPatternProvider(n_q)
32
- pattern = provider.get_pattern(timesteps)
33
- for s, v in enumerate(pattern.layout):
34
- for i, code in enumerate(v):
35
- assert i == code.q
36
- assert code.t == s - 1 # account for the 1st empty step
37
-
38
- @pytest.mark.parametrize("n_q", [1, 4, 32])
39
- @pytest.mark.parametrize("timesteps", [8, 16, 100])
40
- def test_pattern_max_delay(self, n_q: int, timesteps: int):
41
- provider = ParallelPatternProvider(n_q)
42
- pattern = provider.get_pattern(timesteps)
43
- assert pattern.max_delay == 0
44
- assert len(pattern.valid_layout) == len(pattern.layout) - pattern.max_delay
45
-
46
-
47
- class TestDelayedPatternProvider:
48
-
49
- @pytest.mark.parametrize("n_q", [1, 4, 32])
50
- @pytest.mark.parametrize("timesteps", [0, 1, 16, 100])
51
- def test_get_pattern(self, n_q: int, timesteps: int):
52
- delays = [
53
- list(range(n_q)),
54
- [0] + [1] * (n_q - 1),
55
- [0] + [4] * (n_q - 1),
56
- ]
57
- for delay in delays:
58
- provider = DelayedPatternProvider(n_q, delay)
59
- pattern = provider.get_pattern(timesteps)
60
- # + 1 to account for 1st step
61
- assert len(pattern.layout) == timesteps + max(delay) + 1
62
-
63
- @pytest.mark.parametrize("n_q", [1, 4, 32])
64
- @pytest.mark.parametrize("timesteps", [8, 16, 100])
65
- def test_pattern_content(self, n_q: int, timesteps: int):
66
- provider = DelayedPatternProvider(n_q)
67
- pattern = provider.get_pattern(timesteps)
68
- for s, v in enumerate(pattern.layout):
69
- for i, code in enumerate(v):
70
- assert i == code.q
71
- assert code.t == max(0, s - code.q - 1)
72
-
73
- @pytest.mark.parametrize("timesteps", [8, 16, 100])
74
- @pytest.mark.parametrize("delay", [[0, 1, 2, 3], [0, 1, 1, 1], [0, 3, 3, 3], [0, 3]])
75
- def test_pattern_max_delay(self, timesteps: int, delay: list):
76
- provider = DelayedPatternProvider(len(delay), delay)
77
- pattern = provider.get_pattern(timesteps)
78
- assert pattern.max_delay == max(delay)
79
- assert len(pattern.valid_layout) == len(pattern.layout) - pattern.max_delay
80
-
81
-
82
- class TestUnrolledPatternProvider:
83
-
84
- @pytest.mark.parametrize("timesteps", [0, 1, 16])
85
- @pytest.mark.parametrize("flattening", [[0, 1, 2], [0, 1, 1]])
86
- @pytest.mark.parametrize("delays", [[0, 0, 0], [0, 5, 5]])
87
- def test_get_pattern(self, timesteps: int, flattening: list, delays: list):
88
- n_q = len(flattening)
89
- max_delay = max(delays)
90
- provider = UnrolledPatternProvider(n_q, flattening, delays)
91
- pattern = provider.get_pattern(timesteps)
92
- assert len(pattern.layout) == provider.num_virtual_steps(timesteps) + max_delay
93
-
94
- @pytest.mark.parametrize("timesteps", [0, 1, 16])
95
- @pytest.mark.parametrize("flattening", [[0, 1, 2], [0, 1, 1]])
96
- @pytest.mark.parametrize("delays", [[0, 0, 0], [0, 5, 5]])
97
- def test_pattern_max_delay(self, timesteps: int, flattening: list, delays: list):
98
- n_q = len(flattening)
99
- max_delay = max(delays)
100
- provider = UnrolledPatternProvider(n_q, flattening, delays)
101
- pattern = provider.get_pattern(timesteps)
102
- assert pattern.max_delay == max_delay
103
-
104
-
105
- class TestPattern:
106
-
107
- def ref_build_pattern_sequence(self, z: torch.Tensor, pattern: Pattern, special_token: int):
108
- """Reference method to build the sequence from the pattern without using fancy scatter."""
109
- bs, n_q, T = z.shape
110
- z = z.cpu().numpy()
111
- assert n_q == pattern.n_q
112
- assert T <= pattern.timesteps
113
- inp = torch.full((bs, n_q, len(pattern.layout)), special_token, dtype=torch.long).numpy()
114
- inp[:] = special_token
115
- for s, v in enumerate(pattern.layout):
116
- for (t, q) in v:
117
- if t < T:
118
- inp[:, q, s] = z[:, q, t]
119
- return torch.from_numpy(inp)
120
-
121
- def ref_revert_pattern_sequence(self, z: torch.Tensor, pattern: Pattern, special_token: int):
122
- """Reference method to revert the sequence from the pattern without using fancy scatter."""
123
- z = z.cpu().numpy()
124
- bs, n_q, S = z.shape
125
- assert pattern.n_q == n_q
126
- inp = torch.full((bs, pattern.n_q, pattern.timesteps), special_token, dtype=torch.long).numpy()
127
- inp[:] = special_token
128
- for s, v in enumerate(pattern.layout):
129
- for (t, q) in v:
130
- if t < pattern.timesteps:
131
- inp[:, q, t] = z[:, q, s]
132
- return torch.from_numpy(inp)
133
-
134
- def ref_revert_pattern_logits(self, z: torch.Tensor, pattern: Pattern, special_token: float):
135
- """Reference method to revert the logits from the pattern without using fancy scatter."""
136
- z = z.cpu().numpy()
137
- bs, card, n_q, S = z.shape
138
- assert pattern.n_q == n_q
139
- ref_layout = pattern.layout
140
- inp = torch.full((bs, card, pattern.n_q, pattern.timesteps), special_token, dtype=torch.float).numpy()
141
- inp[:] = special_token
142
- for s, v in enumerate(ref_layout[1:]):
143
- if s < S:
144
- for (t, q) in v:
145
- if t < pattern.timesteps:
146
- inp[:, :, q, t] = z[:, :, q, s]
147
- return torch.from_numpy(inp)
148
-
149
- def _get_pattern_providers(self, n_q: int):
150
- pattern_provider_1 = ParallelPatternProvider(n_q)
151
- pattern_provider_2 = DelayedPatternProvider(n_q, list(range(n_q)))
152
- pattern_provider_3 = DelayedPatternProvider(n_q, [0] + [1] * (n_q - 1))
153
- pattern_provider_4 = UnrolledPatternProvider(
154
- n_q, flattening=list(range(n_q)), delays=[0] * n_q
155
- )
156
- pattern_provider_5 = UnrolledPatternProvider(
157
- n_q, flattening=[0] + [1] * (n_q - 1), delays=[0] * n_q
158
- )
159
- pattern_provider_6 = UnrolledPatternProvider(
160
- n_q, flattening=[0] + [1] * (n_q - 1), delays=[0] + [5] * (n_q - 1)
161
- )
162
- return [
163
- pattern_provider_1,
164
- pattern_provider_2,
165
- pattern_provider_3,
166
- pattern_provider_4,
167
- pattern_provider_5,
168
- pattern_provider_6,
169
- ]
170
-
171
- @pytest.mark.parametrize("n_q", [1, 4, 32])
172
- @pytest.mark.parametrize("timesteps", [16, 72])
173
- def test_build_pattern_sequence(self, n_q: int, timesteps: int):
174
- bs = 2
175
- card = 256
176
- special_token = card
177
-
178
- pattern_providers = self._get_pattern_providers(n_q)
179
- for pattern_provider in pattern_providers:
180
- pattern = pattern_provider.get_pattern(timesteps)
181
- # we can correctly build the sequence from the pattern
182
- z = torch.randint(0, card, (bs, n_q, timesteps))
183
- ref_res = self.ref_build_pattern_sequence(z, pattern, special_token)
184
- res, indexes, mask = pattern.build_pattern_sequence(z, special_token)
185
- assert (res == ref_res).float().mean() == 1.0
186
-
187
- # expected assertion fails on the number of timesteps
188
- invalid_timesteps = [timesteps + 1]
189
- if pattern.num_sequence_steps != pattern.timesteps:
190
- invalid_timesteps.append(pattern.num_sequence_steps)
191
- for i_timesteps in invalid_timesteps:
192
- z2 = torch.randint(0, card, (bs, n_q, i_timesteps))
193
- with pytest.raises(AssertionError):
194
- pattern.build_pattern_sequence(z2, special_token)
195
-
196
- # expected assertion fails on the number of codebooks
197
- invalid_qs = [0, n_q - 1, n_q + 1]
198
- for i_q in invalid_qs:
199
- z3 = torch.randint(0, card, (bs, i_q, timesteps))
200
- with pytest.raises(AssertionError):
201
- pattern.build_pattern_sequence(z3, special_token)
202
-
203
- @pytest.mark.parametrize("n_q", [1, 4, 32])
204
- @pytest.mark.parametrize("timesteps", [16, 72])
205
- def test_revert_pattern_sequence(self, n_q: int, timesteps: int):
206
- bs = 2
207
- card = 256
208
- special_token = card
209
-
210
- pattern_providers = self._get_pattern_providers(n_q)
211
- for pattern_provider in pattern_providers:
212
- pattern = pattern_provider.get_pattern(timesteps)
213
- # this works assuming previous tests are successful
214
- z = torch.randint(0, card, (bs, n_q, timesteps))
215
- s = self.ref_build_pattern_sequence(z, pattern, special_token)
216
- ref_out = self.ref_revert_pattern_sequence(s, pattern, special_token)
217
- # ensure our reference script retrieve the original sequence
218
- assert z.shape == ref_out.shape
219
- assert (z == ref_out).float().mean() == 1.0
220
- # now we can test the scatter version
221
- out, indexes, mask = pattern.revert_pattern_sequence(s, special_token)
222
- assert out.shape == ref_out.shape
223
- assert (out == ref_out).float().mean() == 1.0
224
-
225
- @pytest.mark.parametrize("n_q", [1, 4, 32])
226
- @pytest.mark.parametrize("timesteps", [16, 72])
227
- @pytest.mark.parametrize("card", [1, 2, 256, 1024])
228
- def test_revert_pattern_logits(self, n_q: int, timesteps: int, card: int):
229
- bs = 2
230
- special_token = card
231
- logits_special_token = float('nan')
232
-
233
- pattern_providers = self._get_pattern_providers(n_q)
234
- for pattern_provider in pattern_providers:
235
- pattern = pattern_provider.get_pattern(timesteps)
236
- # this works assuming previous tests are successful
237
- z = torch.randint(0, card, (bs, n_q, timesteps))
238
- s = self.ref_build_pattern_sequence(z, pattern, special_token)
239
- logits = torch.randn((bs, card, n_q, s.shape[-1]))
240
- ref_out = self.ref_revert_pattern_logits(logits, pattern, logits_special_token)
241
- # ensure our reference script retrieve the original sequence
242
- assert ref_out.shape == torch.Size([bs, card, n_q, timesteps])
243
- # now we can test the scatter version
244
- out, indexes, mask = pattern.revert_pattern_logits(logits, logits_special_token)
245
- assert out.shape == ref_out.shape
246
- assert (out == ref_out).float().mean() == 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/modules/test_conv.py DELETED
@@ -1,203 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- from itertools import product
8
- import math
9
- import random
10
-
11
- import pytest
12
- import torch
13
- from torch import nn
14
-
15
- from audiocraft.modules import (
16
- NormConv1d,
17
- NormConvTranspose1d,
18
- StreamableConv1d,
19
- StreamableConvTranspose1d,
20
- pad1d,
21
- unpad1d,
22
- )
23
-
24
-
25
- def test_get_extra_padding_for_conv1d():
26
- # TODO: Implement me!
27
- pass
28
-
29
-
30
- def test_pad1d_zeros():
31
- x = torch.randn(1, 1, 20)
32
-
33
- xp1 = pad1d(x, (0, 5), mode='constant', value=0.)
34
- assert xp1.shape[-1] == 25
35
- xp2 = pad1d(x, (5, 5), mode='constant', value=0.)
36
- assert xp2.shape[-1] == 30
37
- xp3 = pad1d(x, (0, 0), mode='constant', value=0.)
38
- assert xp3.shape[-1] == 20
39
- xp4 = pad1d(x, (10, 30), mode='constant', value=0.)
40
- assert xp4.shape[-1] == 60
41
-
42
- with pytest.raises(AssertionError):
43
- pad1d(x, (-1, 0), mode='constant', value=0.)
44
-
45
- with pytest.raises(AssertionError):
46
- pad1d(x, (0, -1), mode='constant', value=0.)
47
-
48
- with pytest.raises(AssertionError):
49
- pad1d(x, (-1, -1), mode='constant', value=0.)
50
-
51
-
52
- def test_pad1d_reflect():
53
- x = torch.randn(1, 1, 20)
54
-
55
- xp1 = pad1d(x, (0, 5), mode='reflect', value=0.)
56
- assert xp1.shape[-1] == 25
57
- xp2 = pad1d(x, (5, 5), mode='reflect', value=0.)
58
- assert xp2.shape[-1] == 30
59
- xp3 = pad1d(x, (0, 0), mode='reflect', value=0.)
60
- assert xp3.shape[-1] == 20
61
- xp4 = pad1d(x, (10, 30), mode='reflect', value=0.)
62
- assert xp4.shape[-1] == 60
63
-
64
- with pytest.raises(AssertionError):
65
- pad1d(x, (-1, 0), mode='reflect', value=0.)
66
-
67
- with pytest.raises(AssertionError):
68
- pad1d(x, (0, -1), mode='reflect', value=0.)
69
-
70
- with pytest.raises(AssertionError):
71
- pad1d(x, (-1, -1), mode='reflect', value=0.)
72
-
73
-
74
- def test_unpad1d():
75
- x = torch.randn(1, 1, 20)
76
-
77
- u1 = unpad1d(x, (5, 5))
78
- assert u1.shape[-1] == 10
79
- u2 = unpad1d(x, (0, 5))
80
- assert u2.shape[-1] == 15
81
- u3 = unpad1d(x, (5, 0))
82
- assert u3.shape[-1] == 15
83
- u4 = unpad1d(x, (0, 0))
84
- assert u4.shape[-1] == x.shape[-1]
85
-
86
- with pytest.raises(AssertionError):
87
- unpad1d(x, (-1, 0))
88
-
89
- with pytest.raises(AssertionError):
90
- unpad1d(x, (0, -1))
91
-
92
- with pytest.raises(AssertionError):
93
- unpad1d(x, (-1, -1))
94
-
95
-
96
- class TestNormConv1d:
97
-
98
- def test_norm_conv1d_modules(self):
99
- N, C, T = 2, 2, random.randrange(1, 100_000)
100
- t0 = torch.randn(N, C, T)
101
-
102
- C_out, kernel_size, stride = 1, 4, 1
103
- expected_out_length = int((T - kernel_size) / stride + 1)
104
- wn_conv = NormConv1d(C, 1, kernel_size=4, norm='weight_norm')
105
- gn_conv = NormConv1d(C, 1, kernel_size=4, norm='time_group_norm')
106
- nn_conv = NormConv1d(C, 1, kernel_size=4, norm='none')
107
-
108
- assert isinstance(wn_conv.norm, nn.Identity)
109
- assert isinstance(wn_conv.conv, nn.Conv1d)
110
-
111
- assert isinstance(gn_conv.norm, nn.GroupNorm)
112
- assert isinstance(gn_conv.conv, nn.Conv1d)
113
-
114
- assert isinstance(nn_conv.norm, nn.Identity)
115
- assert isinstance(nn_conv.conv, nn.Conv1d)
116
-
117
- for conv_layer in [wn_conv, gn_conv, nn_conv]:
118
- out = conv_layer(t0)
119
- assert isinstance(out, torch.Tensor)
120
- assert list(out.shape) == [N, C_out, expected_out_length]
121
-
122
-
123
- class TestNormConvTranspose1d:
124
-
125
- def test_normalizations(self):
126
- N, C, T = 2, 2, random.randrange(1, 100_000)
127
- t0 = torch.randn(N, C, T)
128
-
129
- C_out, kernel_size, stride = 1, 4, 1
130
- expected_out_length = (T - 1) * stride + (kernel_size - 1) + 1
131
-
132
- wn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='weight_norm')
133
- gn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='time_group_norm')
134
- nn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='none')
135
-
136
- assert isinstance(wn_convtr.norm, nn.Identity)
137
- assert isinstance(wn_convtr.convtr, nn.ConvTranspose1d)
138
-
139
- assert isinstance(gn_convtr.norm, nn.GroupNorm)
140
- assert isinstance(gn_convtr.convtr, nn.ConvTranspose1d)
141
-
142
- assert isinstance(nn_convtr.norm, nn.Identity)
143
- assert isinstance(nn_convtr.convtr, nn.ConvTranspose1d)
144
-
145
- for convtr_layer in [wn_convtr, gn_convtr, nn_convtr]:
146
- out = convtr_layer(t0)
147
- assert isinstance(out, torch.Tensor)
148
- assert list(out.shape) == [N, C_out, expected_out_length]
149
-
150
-
151
- class TestStreamableConv1d:
152
-
153
- def get_streamable_conv1d_output_length(self, length, kernel_size, stride, dilation):
154
- # StreamableConv1d internally pads to make sure that the last window is full
155
- padding_total = (kernel_size - 1) * dilation - (stride - 1)
156
- n_frames = (length - kernel_size + padding_total) / stride + 1
157
- ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
158
- return ideal_length // stride
159
-
160
- def test_streamable_conv1d(self):
161
- N, C, T = 2, 2, random.randrange(1, 100_000)
162
- t0 = torch.randn(N, C, T)
163
- C_out = 1
164
-
165
- # conv params are [(kernel_size, stride, dilation)]
166
- conv_params = [(4, 1, 1), (4, 2, 1), (3, 1, 3), (10, 5, 1), (3, 2, 3)]
167
- for causal, (kernel_size, stride, dilation) in product([False, True], conv_params):
168
- expected_out_length = self.get_streamable_conv1d_output_length(T, kernel_size, stride, dilation)
169
- sconv = StreamableConv1d(C, C_out, kernel_size=kernel_size, stride=stride, dilation=dilation, causal=causal)
170
- out = sconv(t0)
171
- assert isinstance(out, torch.Tensor)
172
- print(list(out.shape), [N, C_out, expected_out_length])
173
- assert list(out.shape) == [N, C_out, expected_out_length]
174
-
175
-
176
- class TestStreamableConvTranspose1d:
177
-
178
- def get_streamable_convtr1d_output_length(self, length, kernel_size, stride):
179
- padding_total = (kernel_size - stride)
180
- return (length - 1) * stride - padding_total + (kernel_size - 1) + 1
181
-
182
- def test_streamable_convtr1d(self):
183
- N, C, T = 2, 2, random.randrange(1, 100_000)
184
- t0 = torch.randn(N, C, T)
185
-
186
- C_out = 1
187
-
188
- with pytest.raises(AssertionError):
189
- StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=False, trim_right_ratio=0.5)
190
- StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=True, trim_right_ratio=-1.)
191
- StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=True, trim_right_ratio=2)
192
-
193
- # causal params are [(causal, trim_right)]
194
- causal_params = [(False, 1.0), (True, 1.0), (True, 0.5), (True, 0.0)]
195
- # conv params are [(kernel_size, stride)]
196
- conv_params = [(4, 1), (4, 2), (3, 1), (10, 5)]
197
- for ((causal, trim_right_ratio), (kernel_size, stride)) in product(causal_params, conv_params):
198
- expected_out_length = self.get_streamable_convtr1d_output_length(T, kernel_size, stride)
199
- sconvtr = StreamableConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride,
200
- causal=causal, trim_right_ratio=trim_right_ratio)
201
- out = sconvtr(t0)
202
- assert isinstance(out, torch.Tensor)
203
- assert list(out.shape) == [N, C_out, expected_out_length]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/modules/test_lstm.py DELETED
@@ -1,32 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import random
8
- import torch
9
-
10
- from audiocraft.modules.lstm import StreamableLSTM
11
-
12
-
13
- class TestStreamableLSTM:
14
-
15
- def test_lstm(self):
16
- B, C, T = 4, 2, random.randint(1, 100)
17
-
18
- lstm = StreamableLSTM(C, 3, skip=False)
19
- x = torch.randn(B, C, T)
20
- y = lstm(x)
21
-
22
- print(y.shape)
23
- assert y.shape == torch.Size([B, C, T])
24
-
25
- def test_lstm_skip(self):
26
- B, C, T = 4, 2, random.randint(1, 100)
27
-
28
- lstm = StreamableLSTM(C, 3, skip=True)
29
- x = torch.randn(B, C, T)
30
- y = lstm(x)
31
-
32
- assert y.shape == torch.Size([B, C, T])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/modules/test_rope.py DELETED
@@ -1,168 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import torch
8
-
9
- from audiocraft.modules.rope import RotaryEmbedding
10
- from audiocraft.modules.transformer import StreamingTransformer, set_efficient_attention_backend
11
-
12
-
13
- def test_rope():
14
- set_efficient_attention_backend('xformers')
15
- B, T, H, C = 8, 75, 16, 128
16
-
17
- rope = RotaryEmbedding(dim=C)
18
- xq = torch.rand((B, T, H, C))
19
- xk = torch.rand((B, T, H, C))
20
- xq_out, xk_out = rope.rotate_qk(xq, xk, start=7)
21
-
22
- assert list(xq_out.shape) == [B, T, H, C]
23
- assert list(xk_out.shape) == [B, T, H, C]
24
-
25
-
26
- def test_rope_io_dtypes():
27
- set_efficient_attention_backend('xformers')
28
- B, T, H, C = 8, 75, 16, 128
29
-
30
- rope_32 = RotaryEmbedding(dim=C, dtype=torch.float32)
31
- rope_64 = RotaryEmbedding(dim=C, dtype=torch.float64)
32
-
33
- # Test bfloat16 inputs w/ both 32 and 64 precision rope.
34
- xq_16 = torch.rand((B, T, H, C)).to(torch.bfloat16)
35
- xk_16 = torch.rand((B, T, H, C)).to(torch.bfloat16)
36
- xq_out, xk_out = rope_32.rotate_qk(xq_16, xk_16)
37
- assert xq_out.dtype == torch.bfloat16
38
- xq_out, xk_out = rope_64.rotate_qk(xq_16, xk_16)
39
- assert xq_out.dtype == torch.bfloat16
40
-
41
- # Test float32 inputs w/ both 32 and 64 precision rope.
42
- xq_32 = torch.rand((B, T, H, C)).to(torch.float32)
43
- xk_32 = torch.rand((B, T, H, C)).to(torch.float32)
44
- xq_out, xk_out = rope_32.rotate_qk(xq_32, xk_32)
45
- assert xq_out.dtype == torch.float32
46
- xq_out, xk_out = rope_64.rotate_qk(xq_32, xk_32)
47
- assert xq_out.dtype == torch.float32
48
-
49
-
50
- def test_transformer_with_rope():
51
- set_efficient_attention_backend('xformers')
52
- torch.manual_seed(1234)
53
- for pos in ['rope', 'sin_rope']:
54
- tr = StreamingTransformer(
55
- 16, 4, 2, custom=True, dropout=0., layer_scale=0.1,
56
- positional_embedding=pos)
57
- tr.eval()
58
- steps = 12
59
- x = torch.randn(3, steps, 16)
60
-
61
- out = tr(x)
62
- assert list(out.shape) == list(x.shape)
63
-
64
-
65
- @torch.no_grad()
66
- def test_rope_streaming():
67
- set_efficient_attention_backend('xformers')
68
- torch.manual_seed(1234)
69
- tr = StreamingTransformer(
70
- 16, 4, 2, causal=True, dropout=0.,
71
- custom=True, positional_embedding='rope')
72
- tr.eval()
73
- steps = 12
74
- x = torch.randn(3, steps, 16)
75
-
76
- ref = tr(x)
77
-
78
- with tr.streaming():
79
- outs = []
80
- frame_sizes = [1] * steps
81
-
82
- for frame_size in frame_sizes:
83
- frame = x[:, :frame_size]
84
- x = x[:, frame_size:]
85
- outs.append(tr(frame))
86
-
87
- out = torch.cat(outs, dim=1)
88
- assert list(out.shape) == [3, steps, 16]
89
- delta = torch.norm(out - ref) / torch.norm(out)
90
- assert delta < 1e-6, delta
91
-
92
-
93
- @torch.no_grad()
94
- def test_rope_streaming_past_context():
95
- set_efficient_attention_backend('xformers')
96
- torch.manual_seed(1234)
97
-
98
- for context in [None, 10]:
99
- tr = StreamingTransformer(
100
- 16, 4, 1 if context else 2,
101
- causal=True, past_context=context, custom=True,
102
- dropout=0., positional_embedding='rope')
103
- tr.eval()
104
-
105
- steps = 20
106
- x = torch.randn(3, steps, 16)
107
- ref = tr(x)
108
-
109
- with tr.streaming():
110
- outs = []
111
- frame_sizes = [1] * steps
112
-
113
- for frame_size in frame_sizes:
114
- frame = x[:, :frame_size]
115
- x = x[:, frame_size:]
116
- outs.append(tr(frame))
117
-
118
- out = torch.cat(outs, dim=1)
119
- assert list(out.shape) == [3, steps, 16]
120
- delta = torch.norm(out - ref) / torch.norm(out)
121
- assert delta < 1e-6, delta
122
-
123
-
124
- def test_rope_memory_efficient():
125
- set_efficient_attention_backend('xformers')
126
- torch.manual_seed(1234)
127
- tr = StreamingTransformer(
128
- 16, 4, 2, custom=True, dropout=0., layer_scale=0.1,
129
- positional_embedding='rope')
130
- tr_mem_efficient = StreamingTransformer(
131
- 16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1,
132
- positional_embedding='rope')
133
- tr_mem_efficient.load_state_dict(tr.state_dict())
134
- tr.eval()
135
- steps = 12
136
- x = torch.randn(3, steps, 16)
137
-
138
- with torch.no_grad():
139
- y = tr(x)
140
- y2 = tr_mem_efficient(x)
141
- # Check at float precision b/c this is the rope default.
142
- assert torch.allclose(y, y2, atol=1e-7), (y - y2).norm()
143
-
144
-
145
- def test_rope_with_xpos():
146
- set_efficient_attention_backend('xformers')
147
- B, T, H, C = 8, 75, 16, 128
148
-
149
- rope = RotaryEmbedding(dim=C, xpos=True)
150
- xq = torch.rand((B, T, H, C))
151
- xk = torch.rand((B, T, H, C))
152
- xq_out, xk_out = rope.rotate_qk(xq, xk, start=7)
153
-
154
- assert list(xq_out.shape) == [B, T, H, C]
155
- assert list(xk_out.shape) == [B, T, H, C]
156
-
157
-
158
- def test_positional_scale():
159
- set_efficient_attention_backend('xformers')
160
- B, T, H, C = 8, 75, 16, 128
161
-
162
- rope = RotaryEmbedding(dim=C, xpos=True, scale=0.0)
163
- xq = torch.rand((B, T, H, C))
164
- xk = torch.rand((B, T, H, C))
165
- xq_out, xk_out = rope.rotate_qk(xq, xk, start=7)
166
-
167
- assert torch.allclose(xq, xq_out)
168
- assert torch.allclose(xk, xk_out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/modules/test_seanet.py DELETED
@@ -1,115 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- from itertools import product
8
-
9
- import pytest
10
- import torch
11
-
12
- from audiocraft.modules.seanet import SEANetEncoder, SEANetDecoder, SEANetResnetBlock
13
- from audiocraft.modules import StreamableConv1d, StreamableConvTranspose1d
14
-
15
-
16
- class TestSEANetModel:
17
-
18
- def test_base(self):
19
- encoder = SEANetEncoder()
20
- decoder = SEANetDecoder()
21
-
22
- x = torch.randn(1, 1, 24000)
23
- z = encoder(x)
24
- assert list(z.shape) == [1, 128, 75], z.shape
25
- y = decoder(z)
26
- assert y.shape == x.shape, (x.shape, y.shape)
27
-
28
- def test_causal(self):
29
- encoder = SEANetEncoder(causal=True)
30
- decoder = SEANetDecoder(causal=True)
31
- x = torch.randn(1, 1, 24000)
32
-
33
- z = encoder(x)
34
- assert list(z.shape) == [1, 128, 75], z.shape
35
- y = decoder(z)
36
- assert y.shape == x.shape, (x.shape, y.shape)
37
-
38
- def test_conv_skip_connection(self):
39
- encoder = SEANetEncoder(true_skip=False)
40
- decoder = SEANetDecoder(true_skip=False)
41
-
42
- x = torch.randn(1, 1, 24000)
43
- z = encoder(x)
44
- assert list(z.shape) == [1, 128, 75], z.shape
45
- y = decoder(z)
46
- assert y.shape == x.shape, (x.shape, y.shape)
47
-
48
- def test_seanet_encoder_decoder_final_act(self):
49
- encoder = SEANetEncoder(true_skip=False)
50
- decoder = SEANetDecoder(true_skip=False, final_activation='Tanh')
51
-
52
- x = torch.randn(1, 1, 24000)
53
- z = encoder(x)
54
- assert list(z.shape) == [1, 128, 75], z.shape
55
- y = decoder(z)
56
- assert y.shape == x.shape, (x.shape, y.shape)
57
-
58
- def _check_encoder_blocks_norm(self, encoder: SEANetEncoder, n_disable_blocks: int, norm: str):
59
- n_blocks = 0
60
- for layer in encoder.model:
61
- if isinstance(layer, StreamableConv1d):
62
- n_blocks += 1
63
- assert layer.conv.norm_type == 'none' if n_blocks <= n_disable_blocks else norm
64
- elif isinstance(layer, SEANetResnetBlock):
65
- for resnet_layer in layer.block:
66
- if isinstance(resnet_layer, StreamableConv1d):
67
- # here we add + 1 to n_blocks as we increment n_blocks just after the block
68
- assert resnet_layer.conv.norm_type == 'none' if (n_blocks + 1) <= n_disable_blocks else norm
69
-
70
- def test_encoder_disable_norm(self):
71
- n_residuals = [0, 1, 3]
72
- disable_blocks = [0, 1, 2, 3, 4, 5, 6]
73
- norms = ['weight_norm', 'none']
74
- for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms):
75
- encoder = SEANetEncoder(n_residual_layers=n_res, norm=norm,
76
- disable_norm_outer_blocks=disable_blocks)
77
- self._check_encoder_blocks_norm(encoder, disable_blocks, norm)
78
-
79
- def _check_decoder_blocks_norm(self, decoder: SEANetDecoder, n_disable_blocks: int, norm: str):
80
- n_blocks = 0
81
- for layer in decoder.model:
82
- if isinstance(layer, StreamableConv1d):
83
- n_blocks += 1
84
- assert layer.conv.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
85
- elif isinstance(layer, StreamableConvTranspose1d):
86
- n_blocks += 1
87
- assert layer.convtr.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
88
- elif isinstance(layer, SEANetResnetBlock):
89
- for resnet_layer in layer.block:
90
- if isinstance(resnet_layer, StreamableConv1d):
91
- assert resnet_layer.conv.norm_type == 'none' \
92
- if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
93
-
94
- def test_decoder_disable_norm(self):
95
- n_residuals = [0, 1, 3]
96
- disable_blocks = [0, 1, 2, 3, 4, 5, 6]
97
- norms = ['weight_norm', 'none']
98
- for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms):
99
- decoder = SEANetDecoder(n_residual_layers=n_res, norm=norm,
100
- disable_norm_outer_blocks=disable_blocks)
101
- self._check_decoder_blocks_norm(decoder, disable_blocks, norm)
102
-
103
- def test_disable_norm_raises_exception(self):
104
- # Invalid disable_norm_outer_blocks values raise exceptions
105
- with pytest.raises(AssertionError):
106
- SEANetEncoder(disable_norm_outer_blocks=-1)
107
-
108
- with pytest.raises(AssertionError):
109
- SEANetEncoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7)
110
-
111
- with pytest.raises(AssertionError):
112
- SEANetDecoder(disable_norm_outer_blocks=-1)
113
-
114
- with pytest.raises(AssertionError):
115
- SEANetDecoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/modules/test_transformer.py DELETED
@@ -1,253 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- from itertools import product
8
-
9
- import pytest
10
- import torch
11
-
12
- from audiocraft.modules.transformer import (
13
- StreamingMultiheadAttention, StreamingTransformer, set_efficient_attention_backend)
14
-
15
-
16
- def test_transformer_causal_streaming():
17
- torch.manual_seed(1234)
18
-
19
- for context, custom in product([None, 10], [False, True]):
20
- # Test that causality and receptive fields are properly handled.
21
- # looking at the gradients
22
- tr = StreamingTransformer(
23
- 16, 4, 1 if context else 2,
24
- causal=True, past_context=context, custom=custom,
25
- dropout=0.)
26
- steps = 20
27
- for k in [0, 10, 15, 19]:
28
- x = torch.randn(4, steps, 16, requires_grad=True)
29
- y = tr(x)
30
- y[:, k].abs().sum().backward()
31
- if k + 1 < steps:
32
- assert torch.allclose(x.grad[:, k + 1:], torch.tensor(0.)), x.grad[:, k + 1:].norm()
33
- assert not torch.allclose(x.grad[:, :k + 1], torch.tensor(0.)), x.grad[:, :k + 1].norm()
34
- if context is not None and k > context:
35
- limit = k - context - 1
36
- assert torch.allclose(x.grad[:, :limit],
37
- torch.tensor(0.)), x.grad[:, :limit].norm()
38
-
39
- # Now check that streaming gives the same result at batch eval.
40
- x = torch.randn(4, steps, 16)
41
- y = tr(x)
42
- ys = []
43
- with tr.streaming():
44
- for k in range(steps):
45
- chunk = x[:, k:k + 1, :]
46
- ys.append(tr(chunk))
47
- y_stream = torch.cat(ys, dim=1)
48
- delta = torch.norm(y_stream - y) / torch.norm(y)
49
- assert delta < 1e-6, delta
50
-
51
-
52
- def test_transformer_vs_pytorch():
53
- torch.manual_seed(1234)
54
- # Check that in the non causal setting, we get the same result as
55
- # PyTorch Transformer encoder.
56
- for custom in [False, True]:
57
- tr = StreamingTransformer(
58
- 16, 4, 2,
59
- causal=False, custom=custom, dropout=0., positional_scale=0.)
60
- layer = torch.nn.TransformerEncoderLayer(16, 4, dropout=0., batch_first=True)
61
- tr_ref = torch.nn.TransformerEncoder(layer, 2)
62
- tr.load_state_dict(tr_ref.state_dict())
63
-
64
- x = torch.randn(4, 20, 16)
65
- y = tr(x)
66
- y2 = tr_ref(x)
67
- delta = torch.norm(y2 - y) / torch.norm(y)
68
- assert delta < 1e-6, delta
69
-
70
-
71
- def test_streaming_api():
72
- tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0.)
73
- tr.eval()
74
- steps = 12
75
- x = torch.randn(1, steps, 16)
76
-
77
- with torch.no_grad():
78
- with tr.streaming():
79
- _ = tr(x[:, :1])
80
- state = {k: v.clone() for k, v in tr.get_streaming_state().items()}
81
- y = tr(x[:, 1:2])
82
- tr.set_streaming_state(state)
83
- y2 = tr(x[:, 1:2])
84
- assert torch.allclose(y, y2), (y - y2).norm()
85
- assert tr.flush() is None
86
-
87
-
88
- def test_memory_efficient():
89
- for backend in ['torch', 'xformers']:
90
- torch.manual_seed(1234)
91
- set_efficient_attention_backend(backend)
92
-
93
- tr = StreamingTransformer(
94
- 16, 4, 2, custom=True, dropout=0., layer_scale=0.1)
95
- tr_mem_efficient = StreamingTransformer(
96
- 16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1)
97
- tr_mem_efficient.load_state_dict(tr.state_dict())
98
- tr.eval()
99
- steps = 12
100
- x = torch.randn(3, steps, 16)
101
-
102
- with torch.no_grad():
103
- y = tr(x)
104
- y2 = tr_mem_efficient(x)
105
- assert torch.allclose(y, y2), ((y - y2).norm(), backend)
106
-
107
-
108
- def test_attention_as_float32():
109
- torch.manual_seed(1234)
110
- cases = [
111
- {'custom': True},
112
- {'custom': False},
113
- ]
114
- for case in cases:
115
- tr = StreamingTransformer(16, 4, 2, dropout=0., dtype=torch.bfloat16, **case)
116
- tr_float32 = StreamingTransformer(
117
- 16, 4, 2, dropout=0., attention_as_float32=True, dtype=torch.bfloat16, **case)
118
- if not case['custom']:
119
- # we are not using autocast here because it doesn't really
120
- # work as expected on CPU, so we have to manually cast the weights of the MHA.
121
- for layer in tr_float32.layers:
122
- layer.self_attn.mha.to(torch.float32)
123
- tr_float32.load_state_dict(tr.state_dict())
124
- steps = 12
125
- x = torch.randn(3, steps, 16, dtype=torch.bfloat16)
126
-
127
- with torch.no_grad():
128
- y = tr(x)
129
- y2 = tr_float32(x)
130
- assert not torch.allclose(y, y2), (y - y2).norm()
131
-
132
-
133
- @torch.no_grad()
134
- def test_streaming_memory_efficient():
135
- for backend in ['torch', 'xformers']:
136
- torch.manual_seed(1234)
137
- set_efficient_attention_backend(backend)
138
- tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0., custom=True)
139
- tr_mem_efficient = StreamingTransformer(
140
- 16, 4, 2, dropout=0., memory_efficient=True, causal=True)
141
- tr.load_state_dict(tr_mem_efficient.state_dict())
142
- tr.eval()
143
- tr_mem_efficient.eval()
144
- steps = 12
145
- x = torch.randn(3, steps, 16)
146
-
147
- ref = tr(x)
148
-
149
- with tr_mem_efficient.streaming():
150
- outs = []
151
- # frame_sizes = [2] + [1] * (steps - 2)
152
- frame_sizes = [1] * steps
153
-
154
- for frame_size in frame_sizes:
155
- frame = x[:, :frame_size]
156
- x = x[:, frame_size:]
157
- outs.append(tr_mem_efficient(frame))
158
-
159
- out = torch.cat(outs, dim=1)
160
- delta = torch.norm(out - ref) / torch.norm(out)
161
- assert delta < 1e-6, delta
162
-
163
-
164
- def test_cross_attention():
165
- torch.manual_seed(1234)
166
- for norm_first in [True, False]:
167
- m = StreamingTransformer(
168
- 16, 4, 2, cross_attention=False, norm_first=norm_first, dropout=0., custom=True)
169
- m_cross = StreamingTransformer(
170
- 16, 4, 2, cross_attention=True, norm_first=norm_first, dropout=0., custom=True)
171
- m_cross.load_state_dict(m.state_dict(), strict=False)
172
- x = torch.randn(2, 5, 16)
173
- cross_x = torch.randn(2, 3, 16)
174
- y_ref = m(x)
175
- y_cross_zero = m_cross(x, cross_attention_src=0 * cross_x)
176
- # With norm_first, the two should be exactly yhe same,
177
- # but with norm_first=False, we get 2 normalization in a row
178
- # and the epsilon value leads to a tiny change.
179
- atol = 0. if norm_first else 1e-6
180
- print((y_ref - y_cross_zero).norm() / y_ref.norm())
181
- assert torch.allclose(y_ref, y_cross_zero, atol=atol)
182
-
183
- # We now expect a difference even with a generous atol of 1e-2.
184
- y_cross = m_cross(x, cross_attention_src=cross_x)
185
- assert not torch.allclose(y_cross, y_cross_zero, atol=1e-2)
186
-
187
- with pytest.raises(AssertionError):
188
- _ = m_cross(x)
189
- _ = m(x, cross_attention_src=cross_x)
190
-
191
-
192
- def test_cross_attention_compat():
193
- torch.manual_seed(1234)
194
- num_heads = 2
195
- dim = num_heads * 64
196
- with pytest.raises(AssertionError):
197
- StreamingMultiheadAttention(dim, num_heads, causal=True, cross_attention=True)
198
-
199
- cross_attn = StreamingMultiheadAttention(
200
- dim, num_heads, dropout=0, cross_attention=True, custom=True)
201
- ref_attn = torch.nn.MultiheadAttention(dim, num_heads, dropout=0, batch_first=True)
202
-
203
- # We can load the regular attention state dict
204
- # so we have compat when loading old checkpoints.
205
- cross_attn.load_state_dict(ref_attn.state_dict())
206
-
207
- queries = torch.randn(3, 7, dim)
208
- keys = torch.randn(3, 9, dim)
209
- values = torch.randn(3, 9, dim)
210
-
211
- y = cross_attn(queries, keys, values)[0]
212
- y_ref = ref_attn(queries, keys, values)[0]
213
- assert torch.allclose(y, y_ref, atol=1e-7), (y - y_ref).norm() / y_ref.norm()
214
-
215
- # Now let's check that streaming is working properly.
216
- with cross_attn.streaming():
217
- ys = []
218
- for step in range(queries.shape[1]):
219
- ys.append(cross_attn(queries[:, step: step + 1], keys, values)[0])
220
- y_streaming = torch.cat(ys, dim=1)
221
- assert torch.allclose(y_streaming, y, atol=1e-7)
222
-
223
-
224
- def test_repeat_kv():
225
- torch.manual_seed(1234)
226
- num_heads = 8
227
- kv_repeat = 4
228
- dim = num_heads * 64
229
- with pytest.raises(AssertionError):
230
- mha = StreamingMultiheadAttention(
231
- dim, num_heads, causal=True, kv_repeat=kv_repeat, cross_attention=True)
232
- mha = StreamingMultiheadAttention(
233
- dim, num_heads, causal=True, kv_repeat=kv_repeat)
234
- mha = StreamingMultiheadAttention(
235
- dim, num_heads, causal=True, kv_repeat=kv_repeat, custom=True)
236
- x = torch.randn(4, 18, dim)
237
- y = mha(x, x, x)[0]
238
- assert x.shape == y.shape
239
-
240
-
241
- def test_qk_layer_norm():
242
- torch.manual_seed(1234)
243
- tr = StreamingTransformer(
244
- 16, 4, 2, custom=True, dropout=0., qk_layer_norm=True, bias_attn=False)
245
- steps = 12
246
- x = torch.randn(3, steps, 16)
247
- y = tr(x)
248
-
249
- tr = StreamingTransformer(
250
- 16, 4, 2, custom=True, dropout=0., qk_layer_norm=True, cross_attention=True)
251
- z = torch.randn(3, 21, 16)
252
- y = tr(x, cross_attention_src=z)
253
- assert y.shape == x.shape
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/quantization/test_vq.py DELETED
@@ -1,18 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import torch
8
-
9
- from audiocraft.quantization.vq import ResidualVectorQuantizer
10
-
11
-
12
- class TestResidualVectorQuantizer:
13
-
14
- def test_rvq(self):
15
- x = torch.randn(1, 16, 2048)
16
- vq = ResidualVectorQuantizer(n_q=8, dimension=16, bins=8)
17
- res = vq(x, 1.)
18
- assert res.x.shape == torch.Size([1, 16, 2048])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/utils/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.