Surn commited on
Commit
c542417
·
1 Parent(s): ab70d24

Initial Upload HF

Browse files
CHANGELOG.md DELETED
@@ -1,18 +0,0 @@
1
- # Changelog
2
-
3
- All notable changes to this project will be documented in this file.
4
-
5
- The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
6
-
7
- ## [0.0.2a] - TBD
8
-
9
- Improved demo, fixed top p (thanks @jnordberg).
10
-
11
- Compressor tanh on output to avoid clipping with some style (especially piano).
12
- Now repeating the conditioning periodically if it is too short.
13
-
14
- More options when launching Gradio app locally (thanks @ashleykleynhans).
15
-
16
- ## [0.0.1] - 2023-06-09
17
-
18
- Initial release, with model evaluation only.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
CODE_OF_CONDUCT.md DELETED
@@ -1,80 +0,0 @@
1
- # Code of Conduct
2
-
3
- ## Our Pledge
4
-
5
- In the interest of fostering an open and welcoming environment, we as
6
- contributors and maintainers pledge to make participation in our project and
7
- our community a harassment-free experience for everyone, regardless of age, body
8
- size, disability, ethnicity, sex characteristics, gender identity and expression,
9
- level of experience, education, socio-economic status, nationality, personal
10
- appearance, race, religion, or sexual identity and orientation.
11
-
12
- ## Our Standards
13
-
14
- Examples of behavior that contributes to creating a positive environment
15
- include:
16
-
17
- * Using welcoming and inclusive language
18
- * Being respectful of differing viewpoints and experiences
19
- * Gracefully accepting constructive criticism
20
- * Focusing on what is best for the community
21
- * Showing empathy towards other community members
22
-
23
- Examples of unacceptable behavior by participants include:
24
-
25
- * The use of sexualized language or imagery and unwelcome sexual attention or
26
- advances
27
- * Trolling, insulting/derogatory comments, and personal or political attacks
28
- * Public or private harassment
29
- * Publishing others' private information, such as a physical or electronic
30
- address, without explicit permission
31
- * Other conduct which could reasonably be considered inappropriate in a
32
- professional setting
33
-
34
- ## Our Responsibilities
35
-
36
- Project maintainers are responsible for clarifying the standards of acceptable
37
- behavior and are expected to take appropriate and fair corrective action in
38
- response to any instances of unacceptable behavior.
39
-
40
- Project maintainers have the right and responsibility to remove, edit, or
41
- reject comments, commits, code, wiki edits, issues, and other contributions
42
- that are not aligned to this Code of Conduct, or to ban temporarily or
43
- permanently any contributor for other behaviors that they deem inappropriate,
44
- threatening, offensive, or harmful.
45
-
46
- ## Scope
47
-
48
- This Code of Conduct applies within all project spaces, and it also applies when
49
- an individual is representing the project or its community in public spaces.
50
- Examples of representing a project or community include using an official
51
- project e-mail address, posting via an official social media account, or acting
52
- as an appointed representative at an online or offline event. Representation of
53
- a project may be further defined and clarified by project maintainers.
54
-
55
- This Code of Conduct also applies outside the project spaces when there is a
56
- reasonable belief that an individual's behavior may have a negative impact on
57
- the project or its community.
58
-
59
- ## Enforcement
60
-
61
- Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
- reported by contacting the project team at <opensource-conduct@fb.com>. All
63
- complaints will be reviewed and investigated and will result in a response that
64
- is deemed necessary and appropriate to the circumstances. The project team is
65
- obligated to maintain confidentiality with regard to the reporter of an incident.
66
- Further details of specific enforcement policies may be posted separately.
67
-
68
- Project maintainers who do not follow or enforce the Code of Conduct in good
69
- faith may face temporary or permanent repercussions as determined by other
70
- members of the project's leadership.
71
-
72
- ## Attribution
73
-
74
- This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75
- available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76
-
77
- [homepage]: https://www.contributor-covenant.org
78
-
79
- For answers to common questions about this code of conduct, see
80
- https://www.contributor-covenant.org/faq
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
CONTRIBUTING.md DELETED
@@ -1,35 +0,0 @@
1
- # Contributing to Audiocraft
2
-
3
- We want to make contributing to this project as easy and transparent as
4
- possible.
5
-
6
- ## Pull Requests
7
-
8
- Audiocraft is the implementation of a research paper.
9
- Therefore, we do not plan on accepting many pull requests for new features.
10
- We certainly welcome them for bug fixes.
11
-
12
- 1. Fork the repo and create your branch from `main`.
13
- 2. If you've added code that should be tested, add tests.
14
- 3. If you've changed APIs, update the documentation.
15
- 4. Ensure the test suite passes.
16
- 5. Make sure your code lints.
17
- 6. If you haven't already, complete the Contributor License Agreement ("CLA").
18
-
19
- ## Contributor License Agreement ("CLA")
20
- In order to accept your pull request, we need you to submit a CLA. You only need
21
- to do this once to work on any of Meta's open source projects.
22
-
23
- Complete your CLA here: <https://code.facebook.com/cla>
24
-
25
- ## Issues
26
- We use GitHub issues to track public bugs. Please ensure your description is
27
- clear and has sufficient instructions to be able to reproduce the issue.
28
-
29
- Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
30
- disclosure of security bugs. In those cases, please go through the process
31
- outlined on that page and do not file a public issue.
32
-
33
- ## License
34
- By contributing to encodec, you agree that your contributions will be licensed
35
- under the LICENSE file in the root directory of this source tree.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,3 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Audiocraft
2
  ![docs badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_docs/badge.svg)
3
  ![linter badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_linter/badge.svg)
@@ -129,18 +148,3 @@ Yes. We will soon release the training code for MusicGen and EnCodec.
129
  ## License
130
  * The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE).
131
  * The weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights).
132
- ---
133
- title: UnlimitedMusicGen
134
- emoji: 🚀
135
- colorFrom: pink
136
- colorTo: red
137
- sdk: gradio
138
- sdk_version: 3.34.0
139
- app_file: app.py
140
- pinned: false
141
- license: creativeml-openrail-m
142
- ---
143
-
144
- [arxiv]: https://arxiv.org/abs/2306.05284
145
- [musicgen_samples]: https://ai.honu.io/papers/musicgen/
146
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: UnlimitedMusicGen
3
+ emoji: 🚀
4
+ colorFrom: pink
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 3.34.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: creativeml-openrail-m
11
+ ---
12
+
13
+ [arxiv]: https://arxiv.org/abs/2306.05284
14
+ [musicgen_samples]: https://ai.honu.io/papers/musicgen/
15
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
16
+
17
+ # UnlimitedMusicGen
18
+ This is my modification of the Audiocraft project to enable unlimited Audio generation. I have added a few features to the original project to enable this. I have also added a few features to the gradio interface to make it easier to use.
19
+
20
  # Audiocraft
21
  ![docs badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_docs/badge.svg)
22
  ![linter badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_linter/badge.svg)
 
148
  ## License
149
  * The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE).
150
  * The weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_batched.py DELETED
@@ -1,221 +0,0 @@
1
- """
2
- Copyright (c) Meta Platforms, Inc. and affiliates.
3
- All rights reserved.
4
-
5
- This source code is licensed under the license found in the
6
- LICENSE file in the root directory of this source tree.
7
- """
8
-
9
- import argparse
10
- from concurrent.futures import ProcessPoolExecutor
11
- import subprocess as sp
12
- from tempfile import NamedTemporaryFile
13
- import time
14
- import warnings
15
- import torch
16
- import gradio as gr
17
- from audiocraft.data.audio_utils import convert_audio
18
- from audiocraft.data.audio import audio_write
19
- from audiocraft.models import MusicGen
20
-
21
-
22
- MODEL = None
23
-
24
- _old_call = sp.call
25
-
26
-
27
- def _call_nostderr(*args, **kwargs):
28
- # Avoid ffmpeg vomitting on the logs.
29
- kwargs['stderr'] = sp.DEVNULL
30
- kwargs['stdout'] = sp.DEVNULL
31
- _old_call(*args, **kwargs)
32
-
33
-
34
- sp.call = _call_nostderr
35
- pool = ProcessPoolExecutor(3)
36
- pool.__enter__()
37
-
38
-
39
- def make_waveform(*args, **kwargs):
40
- be = time.time()
41
- with warnings.catch_warnings():
42
- warnings.simplefilter('ignore')
43
- out = gr.make_waveform(*args, **kwargs)
44
- print("Make a video took", time.time() - be)
45
- return out
46
-
47
-
48
- def load_model():
49
- print("Loading model")
50
- return MusicGen.get_pretrained("melody")
51
-
52
-
53
- def predict(texts, melodies):
54
- global MODEL
55
- if MODEL is None:
56
- MODEL = load_model()
57
-
58
- duration = 12
59
- max_text_length = 512
60
- texts = [text[:max_text_length] for text in texts]
61
- MODEL.set_generation_params(duration=duration)
62
-
63
- print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
64
- be = time.time()
65
- processed_melodies = []
66
- target_sr = 32000
67
- target_ac = 1
68
- for melody in melodies:
69
- if melody is None:
70
- processed_melodies.append(None)
71
- else:
72
- sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
73
- if melody.dim() == 1:
74
- melody = melody[None]
75
- melody = melody[..., :int(sr * duration)]
76
- melody = convert_audio(melody, sr, target_sr, target_ac)
77
- processed_melodies.append(melody)
78
-
79
- outputs = MODEL.generate_with_chroma(
80
- descriptions=texts,
81
- melody_wavs=processed_melodies,
82
- melody_sample_rate=target_sr,
83
- progress=False
84
- )
85
-
86
- outputs = outputs.detach().cpu().float()
87
- out_files = []
88
- for output in outputs:
89
- with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
90
- audio_write(
91
- file.name, output, MODEL.sample_rate, strategy="loudness",
92
- loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
93
- out_files.append(pool.submit(make_waveform, file.name))
94
- res = [[out_file.result() for out_file in out_files]]
95
- print("batch finished", len(texts), time.time() - be)
96
- return res
97
-
98
-
99
- def ui(**kwargs):
100
- with gr.Blocks() as demo:
101
- gr.Markdown(
102
- """
103
- # MusicGen
104
-
105
- This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
106
- presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
107
- <br/>
108
- <a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
109
- <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
110
- for longer sequences, more control and no queue.</p>
111
- """
112
- )
113
- with gr.Row():
114
- with gr.Column():
115
- with gr.Row():
116
- text = gr.Text(label="Describe your music", lines=2, interactive=True)
117
- melody = gr.Audio(source="upload", type="numpy", label="Condition on a melody (optional)", interactive=True)
118
- with gr.Row():
119
- submit = gr.Button("Generate")
120
- with gr.Column():
121
- output = gr.Video(label="Generated Music")
122
- submit.click(predict, inputs=[text, melody], outputs=[output], batch=True, max_batch_size=8)
123
- gr.Examples(
124
- fn=predict,
125
- examples=[
126
- [
127
- "An 80s driving pop song with heavy drums and synth pads in the background",
128
- "./assets/bach.mp3",
129
- ],
130
- [
131
- "A cheerful country song with acoustic guitars",
132
- "./assets/bolero_ravel.mp3",
133
- ],
134
- [
135
- "90s rock song with electric guitar and heavy drums",
136
- None,
137
- ],
138
- [
139
- "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
140
- "./assets/bach.mp3",
141
- ],
142
- [
143
- "lofi slow bpm electro chill with organic samples",
144
- None,
145
- ],
146
- ],
147
- inputs=[text, melody],
148
- outputs=[output]
149
- )
150
- gr.Markdown("""
151
- ### More details
152
-
153
- The model will generate 12 seconds of audio based on the description you provided.
154
- You can optionaly provide a reference audio from which a broad melody will be extracted.
155
- The model will then try to follow both the description and melody provided.
156
- All samples are generated with the `melody` model.
157
-
158
- You can also use your own GPU or a Google Colab by following the instructions on our repo.
159
-
160
- See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
161
- for more details.
162
- """)
163
-
164
- # Show the interface
165
- launch_kwargs = {}
166
- username = kwargs.get('username')
167
- password = kwargs.get('password')
168
- server_port = kwargs.get('server_port', 0)
169
- inbrowser = kwargs.get('inbrowser', False)
170
- share = kwargs.get('share', False)
171
- server_name = kwargs.get('listen')
172
-
173
- launch_kwargs['server_name'] = server_name
174
-
175
- if username and password:
176
- launch_kwargs['auth'] = (username, password)
177
- if server_port > 0:
178
- launch_kwargs['server_port'] = server_port
179
- if inbrowser:
180
- launch_kwargs['inbrowser'] = inbrowser
181
- if share:
182
- launch_kwargs['share'] = share
183
- demo.queue(max_size=60).launch(**launch_kwargs)
184
-
185
- if __name__ == "__main__":
186
- parser = argparse.ArgumentParser()
187
- parser.add_argument(
188
- '--listen',
189
- type=str,
190
- default='127.0.0.1',
191
- help='IP to listen on for connections to Gradio',
192
- )
193
- parser.add_argument(
194
- '--username', type=str, default='', help='Username for authentication'
195
- )
196
- parser.add_argument(
197
- '--password', type=str, default='', help='Password for authentication'
198
- )
199
- parser.add_argument(
200
- '--server_port',
201
- type=int,
202
- default=0,
203
- help='Port to run the server listener on',
204
- )
205
- parser.add_argument(
206
- '--inbrowser', action='store_true', help='Open in browser'
207
- )
208
- parser.add_argument(
209
- '--share', action='store_true', help='Share the gradio UI'
210
- )
211
-
212
- args = parser.parse_args()
213
-
214
- ui(
215
- username=args.username,
216
- password=args.password,
217
- inbrowser=args.inbrowser,
218
- server_port=args.server_port,
219
- share=args.share,
220
- listen=args.listen
221
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
 
 
 
 
 
 
tests/common_utils/__init__.py DELETED
@@ -1,9 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- # flake8: noqa
8
- from .temp_utils import TempDirMixin
9
- from .wav_utils import get_batch_white_noise, get_white_noise, save_wav
 
 
 
 
 
 
 
 
 
 
tests/common_utils/temp_utils.py DELETED
@@ -1,56 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import os
8
- import tempfile
9
-
10
-
11
- class TempDirMixin:
12
- """Mixin to provide easy access to temp dir.
13
- """
14
-
15
- temp_dir_ = None
16
-
17
- @classmethod
18
- def get_base_temp_dir(cls):
19
- # If AUDIOCRAFT_TEST_DIR is set, use it instead of temporary directory.
20
- # this is handy for debugging.
21
- key = "AUDIOCRAFT_TEST_DIR"
22
- if key in os.environ:
23
- return os.environ[key]
24
- if cls.temp_dir_ is None:
25
- cls.temp_dir_ = tempfile.TemporaryDirectory()
26
- return cls.temp_dir_.name
27
-
28
- @classmethod
29
- def tearDownClass(cls):
30
- if cls.temp_dir_ is not None:
31
- try:
32
- cls.temp_dir_.cleanup()
33
- cls.temp_dir_ = None
34
- except PermissionError:
35
- # On Windows there is a know issue with `shutil.rmtree`,
36
- # which fails intermittenly.
37
- # https://github.com/python/cpython/issues/74168
38
- # Following the above thread, we ignore it.
39
- pass
40
- super().tearDownClass()
41
-
42
- @property
43
- def id(self):
44
- return self.__class__.__name__
45
-
46
- def get_temp_path(self, *paths):
47
- temp_dir = os.path.join(self.get_base_temp_dir(), self.id)
48
- path = os.path.join(temp_dir, *paths)
49
- os.makedirs(os.path.dirname(path), exist_ok=True)
50
- return path
51
-
52
- def get_temp_dir(self, *paths):
53
- temp_dir = os.path.join(self.get_base_temp_dir(), self.id)
54
- path = os.path.join(temp_dir, *paths)
55
- os.makedirs(path, exist_ok=True)
56
- return path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/common_utils/wav_utils.py DELETED
@@ -1,32 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- from pathlib import Path
8
- import typing as tp
9
-
10
- import torch
11
- import torchaudio
12
-
13
-
14
- def get_white_noise(chs: int = 1, num_frames: int = 1):
15
- wav = torch.randn(chs, num_frames)
16
- return wav
17
-
18
-
19
- def get_batch_white_noise(bs: int = 1, chs: int = 1, num_frames: int = 1):
20
- wav = torch.randn(bs, chs, num_frames)
21
- return wav
22
-
23
-
24
- def save_wav(path: str, wav: torch.Tensor, sample_rate: int):
25
- fp = Path(path)
26
- kwargs: tp.Dict[str, tp.Any] = {}
27
- if fp.suffix == '.wav':
28
- kwargs['encoding'] = 'PCM_S'
29
- kwargs['bits_per_sample'] = 16
30
- elif fp.suffix == '.mp3':
31
- kwargs['compression'] = 320
32
- torchaudio.save(str(fp), wav, sample_rate, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/data/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
 
 
 
 
 
 
tests/data/test_audio.py DELETED
@@ -1,239 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- from itertools import product
8
- import random
9
-
10
- import numpy as np
11
- import torch
12
- import torchaudio
13
-
14
- from audiocraft.data.audio import audio_info, audio_read, audio_write, _av_read
15
-
16
- from ..common_utils import TempDirMixin, get_white_noise, save_wav
17
-
18
-
19
- class TestInfo(TempDirMixin):
20
-
21
- def test_info_mp3(self):
22
- sample_rates = [8000, 16_000]
23
- channels = [1, 2]
24
- duration = 1.
25
- for sample_rate, ch in product(sample_rates, channels):
26
- wav = get_white_noise(ch, int(sample_rate * duration))
27
- path = self.get_temp_path('sample_wav.mp3')
28
- save_wav(path, wav, sample_rate)
29
- info = audio_info(path)
30
- assert info.sample_rate == sample_rate
31
- assert info.channels == ch
32
- # we cannot trust torchaudio for num_frames, so we don't check
33
-
34
- def _test_info_format(self, ext: str):
35
- sample_rates = [8000, 16_000]
36
- channels = [1, 2]
37
- duration = 1.
38
- for sample_rate, ch in product(sample_rates, channels):
39
- n_frames = int(sample_rate * duration)
40
- wav = get_white_noise(ch, n_frames)
41
- path = self.get_temp_path(f'sample_wav{ext}')
42
- save_wav(path, wav, sample_rate)
43
- info = audio_info(path)
44
- assert info.sample_rate == sample_rate
45
- assert info.channels == ch
46
- assert np.isclose(info.duration, duration, atol=1e-5)
47
-
48
- def test_info_wav(self):
49
- self._test_info_format('.wav')
50
-
51
- def test_info_flac(self):
52
- self._test_info_format('.flac')
53
-
54
- def test_info_ogg(self):
55
- self._test_info_format('.ogg')
56
-
57
- def test_info_m4a(self):
58
- # TODO: generate m4a file programmatically
59
- # self._test_info_format('.m4a')
60
- pass
61
-
62
-
63
- class TestRead(TempDirMixin):
64
-
65
- def test_read_full_wav(self):
66
- sample_rates = [8000, 16_000]
67
- channels = [1, 2]
68
- duration = 1.
69
- for sample_rate, ch in product(sample_rates, channels):
70
- n_frames = int(sample_rate * duration)
71
- wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99)
72
- path = self.get_temp_path('sample_wav.wav')
73
- save_wav(path, wav, sample_rate)
74
- read_wav, read_sr = audio_read(path)
75
- assert read_sr == sample_rate
76
- assert read_wav.shape[0] == wav.shape[0]
77
- assert read_wav.shape[1] == wav.shape[1]
78
- assert torch.allclose(read_wav, wav, rtol=1e-03, atol=1e-04)
79
-
80
- def test_read_partial_wav(self):
81
- sample_rates = [8000, 16_000]
82
- channels = [1, 2]
83
- duration = 1.
84
- read_duration = torch.rand(1).item()
85
- for sample_rate, ch in product(sample_rates, channels):
86
- n_frames = int(sample_rate * duration)
87
- read_frames = int(sample_rate * read_duration)
88
- wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99)
89
- path = self.get_temp_path('sample_wav.wav')
90
- save_wav(path, wav, sample_rate)
91
- read_wav, read_sr = audio_read(path, 0, read_duration)
92
- assert read_sr == sample_rate
93
- assert read_wav.shape[0] == wav.shape[0]
94
- assert read_wav.shape[1] == read_frames
95
- assert torch.allclose(read_wav[..., 0:read_frames], wav[..., 0:read_frames], rtol=1e-03, atol=1e-04)
96
-
97
- def test_read_seek_time_wav(self):
98
- sample_rates = [8000, 16_000]
99
- channels = [1, 2]
100
- duration = 1.
101
- read_duration = 1.
102
- for sample_rate, ch in product(sample_rates, channels):
103
- n_frames = int(sample_rate * duration)
104
- wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99)
105
- path = self.get_temp_path('sample_wav.wav')
106
- save_wav(path, wav, sample_rate)
107
- seek_time = torch.rand(1).item()
108
- read_wav, read_sr = audio_read(path, seek_time, read_duration)
109
- seek_frames = int(sample_rate * seek_time)
110
- expected_frames = n_frames - seek_frames
111
- assert read_sr == sample_rate
112
- assert read_wav.shape[0] == wav.shape[0]
113
- assert read_wav.shape[1] == expected_frames
114
- assert torch.allclose(read_wav, wav[..., seek_frames:], rtol=1e-03, atol=1e-04)
115
-
116
- def test_read_seek_time_wav_padded(self):
117
- sample_rates = [8000, 16_000]
118
- channels = [1, 2]
119
- duration = 1.
120
- read_duration = 1.
121
- for sample_rate, ch in product(sample_rates, channels):
122
- n_frames = int(sample_rate * duration)
123
- read_frames = int(sample_rate * read_duration)
124
- wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99)
125
- path = self.get_temp_path('sample_wav.wav')
126
- save_wav(path, wav, sample_rate)
127
- seek_time = torch.rand(1).item()
128
- seek_frames = int(sample_rate * seek_time)
129
- expected_frames = n_frames - seek_frames
130
- read_wav, read_sr = audio_read(path, seek_time, read_duration, pad=True)
131
- expected_pad_wav = torch.zeros(wav.shape[0], read_frames - expected_frames)
132
- assert read_sr == sample_rate
133
- assert read_wav.shape[0] == wav.shape[0]
134
- assert read_wav.shape[1] == read_frames
135
- assert torch.allclose(read_wav[..., :expected_frames], wav[..., seek_frames:], rtol=1e-03, atol=1e-04)
136
- assert torch.allclose(read_wav[..., expected_frames:], expected_pad_wav)
137
-
138
-
139
- class TestAvRead(TempDirMixin):
140
-
141
- def test_avread_seek_base(self):
142
- sample_rates = [8000, 16_000]
143
- channels = [1, 2]
144
- duration = 2.
145
- for sample_rate, ch in product(sample_rates, channels):
146
- n_frames = int(sample_rate * duration)
147
- wav = get_white_noise(ch, n_frames)
148
- path = self.get_temp_path(f'reference_a_{sample_rate}_{ch}.wav')
149
- save_wav(path, wav, sample_rate)
150
- for _ in range(100):
151
- # seek will always load a full duration segment in the file
152
- seek_time = random.uniform(0.0, 1.0)
153
- seek_duration = random.uniform(0.001, 1.0)
154
- read_wav, read_sr = _av_read(path, seek_time, seek_duration)
155
- assert read_sr == sample_rate
156
- assert read_wav.shape[0] == wav.shape[0]
157
- assert read_wav.shape[-1] == int(seek_duration * sample_rate)
158
-
159
- def test_avread_seek_partial(self):
160
- sample_rates = [8000, 16_000]
161
- channels = [1, 2]
162
- duration = 1.
163
- for sample_rate, ch in product(sample_rates, channels):
164
- n_frames = int(sample_rate * duration)
165
- wav = get_white_noise(ch, n_frames)
166
- path = self.get_temp_path(f'reference_b_{sample_rate}_{ch}.wav')
167
- save_wav(path, wav, sample_rate)
168
- for _ in range(100):
169
- # seek will always load a partial segment
170
- seek_time = random.uniform(0.5, 1.)
171
- seek_duration = 1.
172
- expected_num_frames = n_frames - int(seek_time * sample_rate)
173
- read_wav, read_sr = _av_read(path, seek_time, seek_duration)
174
- assert read_sr == sample_rate
175
- assert read_wav.shape[0] == wav.shape[0]
176
- assert read_wav.shape[-1] == expected_num_frames
177
-
178
- def test_avread_seek_outofbound(self):
179
- sample_rates = [8000, 16_000]
180
- channels = [1, 2]
181
- duration = 1.
182
- for sample_rate, ch in product(sample_rates, channels):
183
- n_frames = int(sample_rate * duration)
184
- wav = get_white_noise(ch, n_frames)
185
- path = self.get_temp_path(f'reference_c_{sample_rate}_{ch}.wav')
186
- save_wav(path, wav, sample_rate)
187
- seek_time = 1.5
188
- read_wav, read_sr = _av_read(path, seek_time, 1.)
189
- assert read_sr == sample_rate
190
- assert read_wav.shape[0] == wav.shape[0]
191
- assert read_wav.shape[-1] == 0
192
-
193
- def test_avread_seek_edge(self):
194
- sample_rates = [8000, 16_000]
195
- # some of these values will have
196
- # int(((frames - 1) / sample_rate) * sample_rate) != (frames - 1)
197
- n_frames = [1000, 1001, 1002]
198
- channels = [1, 2]
199
- for sample_rate, ch, frames in product(sample_rates, channels, n_frames):
200
- duration = frames / sample_rate
201
- wav = get_white_noise(ch, frames)
202
- path = self.get_temp_path(f'reference_d_{sample_rate}_{ch}.wav')
203
- save_wav(path, wav, sample_rate)
204
- seek_time = (frames - 1) / sample_rate
205
- seek_frames = int(seek_time * sample_rate)
206
- read_wav, read_sr = _av_read(path, seek_time, duration)
207
- assert read_sr == sample_rate
208
- assert read_wav.shape[0] == wav.shape[0]
209
- assert read_wav.shape[-1] == (frames - seek_frames)
210
-
211
-
212
- class TestAudioWrite(TempDirMixin):
213
-
214
- def test_audio_write_wav(self):
215
- torch.manual_seed(1234)
216
- sample_rates = [8000, 16_000]
217
- n_frames = [1000, 1001, 1002]
218
- channels = [1, 2]
219
- strategies = ["peak", "clip", "rms"]
220
- formats = ["wav", "mp3"]
221
- for sample_rate, ch, frames in product(sample_rates, channels, n_frames):
222
- for format_, strategy in product(formats, strategies):
223
- wav = get_white_noise(ch, frames)
224
- path = self.get_temp_path(f'pred_{sample_rate}_{ch}')
225
- audio_write(path, wav, sample_rate, format_, strategy=strategy)
226
- read_wav, read_sr = torchaudio.load(f'{path}.{format_}')
227
- if format_ == "wav":
228
- assert read_wav.shape == wav.shape
229
-
230
- if format_ == "wav" and strategy in ["peak", "rms"]:
231
- rescaled_read_wav = read_wav / read_wav.abs().max() * wav.abs().max()
232
- # for a Gaussian, the typical max scale will be less than ~5x the std.
233
- # The error when writing to disk will ~ 1/2**15, and when rescaling, 5x that.
234
- # For RMS target, rescaling leaves more headroom by default, leading
235
- # to a 20x rescaling typically
236
- atol = (5 if strategy == "peak" else 20) / 2**15
237
- delta = (rescaled_read_wav - wav).abs().max()
238
- assert torch.allclose(wav, rescaled_read_wav, rtol=0, atol=atol), (delta, atol)
239
- formats = ["wav"] # faster unit tests
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/data/test_audio_dataset.py DELETED
@@ -1,352 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- from functools import partial
8
- from itertools import product
9
- import json
10
- import math
11
- import os
12
- import random
13
- import typing as tp
14
-
15
- import pytest
16
- import torch
17
- from torch.utils.data import DataLoader
18
-
19
- from audiocraft.data.audio_dataset import (
20
- AudioDataset,
21
- AudioMeta,
22
- _get_audio_meta,
23
- load_audio_meta,
24
- save_audio_meta
25
- )
26
- from audiocraft.data.zip import PathInZip
27
-
28
- from ..common_utils import TempDirMixin, get_white_noise, save_wav
29
-
30
-
31
- class TestAudioMeta(TempDirMixin):
32
-
33
- def test_get_audio_meta(self):
34
- sample_rates = [8000, 16_000]
35
- channels = [1, 2]
36
- duration = 1.
37
- for sample_rate, ch in product(sample_rates, channels):
38
- n_frames = int(duration * sample_rate)
39
- wav = get_white_noise(ch, n_frames)
40
- path = self.get_temp_path('sample.wav')
41
- save_wav(path, wav, sample_rate)
42
- m = _get_audio_meta(path, minimal=True)
43
- assert m.path == path, 'path does not match'
44
- assert m.sample_rate == sample_rate, 'sample rate does not match'
45
- assert m.duration == duration, 'duration does not match'
46
- assert m.amplitude is None
47
- assert m.info_path is None
48
-
49
- def test_save_audio_meta(self):
50
- audio_meta = [
51
- AudioMeta("mypath1", 1., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file1.json')),
52
- AudioMeta("mypath2", 2., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file2.json'))
53
- ]
54
- empty_audio_meta = []
55
- for idx, meta in enumerate([audio_meta, empty_audio_meta]):
56
- path = self.get_temp_path(f'data_{idx}_save.jsonl')
57
- save_audio_meta(path, meta)
58
- with open(path, 'r') as f:
59
- lines = f.readlines()
60
- read_meta = [AudioMeta.from_dict(json.loads(line)) for line in lines]
61
- assert len(read_meta) == len(meta)
62
- for m, read_m in zip(meta, read_meta):
63
- assert m == read_m
64
-
65
- def test_load_audio_meta(self):
66
- try:
67
- import dora
68
- except ImportError:
69
- dora = None # type: ignore
70
-
71
- audio_meta = [
72
- AudioMeta("mypath1", 1., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file1.json')),
73
- AudioMeta("mypath2", 2., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file2.json'))
74
- ]
75
- empty_meta = []
76
- for idx, meta in enumerate([audio_meta, empty_meta]):
77
- path = self.get_temp_path(f'data_{idx}_load.jsonl')
78
- with open(path, 'w') as f:
79
- for m in meta:
80
- json_str = json.dumps(m.to_dict()) + '\n'
81
- f.write(json_str)
82
- read_meta = load_audio_meta(path)
83
- assert len(read_meta) == len(meta)
84
- for m, read_m in zip(meta, read_meta):
85
- if dora:
86
- m.path = dora.git_save.to_absolute_path(m.path)
87
- assert m == read_m, f'original={m}, read={read_m}'
88
-
89
-
90
- class TestAudioDataset(TempDirMixin):
91
-
92
- def _create_audio_files(self,
93
- root_name: str,
94
- num_examples: int,
95
- durations: tp.Union[float, tp.Tuple[float, float]] = (0.1, 1.),
96
- sample_rate: int = 16_000,
97
- channels: int = 1):
98
- root_dir = self.get_temp_dir(root_name)
99
- for i in range(num_examples):
100
- if isinstance(durations, float):
101
- duration = durations
102
- elif isinstance(durations, tuple) and len(durations) == 1:
103
- duration = durations[0]
104
- elif isinstance(durations, tuple) and len(durations) == 2:
105
- duration = random.uniform(durations[0], durations[1])
106
- else:
107
- assert False
108
- n_frames = int(duration * sample_rate)
109
- wav = get_white_noise(channels, n_frames)
110
- path = os.path.join(root_dir, f'example_{i}.wav')
111
- save_wav(path, wav, sample_rate)
112
- return root_dir
113
-
114
- def _create_audio_dataset(self,
115
- root_name: str,
116
- total_num_examples: int,
117
- durations: tp.Union[float, tp.Tuple[float, float]] = (0.1, 1.),
118
- sample_rate: int = 16_000,
119
- channels: int = 1,
120
- segment_duration: tp.Optional[float] = None,
121
- num_examples: int = 10,
122
- shuffle: bool = True,
123
- return_info: bool = False):
124
- root_dir = self._create_audio_files(root_name, total_num_examples, durations, sample_rate, channels)
125
- dataset = AudioDataset.from_path(root_dir,
126
- minimal_meta=True,
127
- segment_duration=segment_duration,
128
- num_samples=num_examples,
129
- sample_rate=sample_rate,
130
- channels=channels,
131
- shuffle=shuffle,
132
- return_info=return_info)
133
- return dataset
134
-
135
- def test_dataset_full(self):
136
- total_examples = 10
137
- min_duration, max_duration = 1., 4.
138
- sample_rate = 16_000
139
- channels = 1
140
- dataset = self._create_audio_dataset(
141
- 'dset', total_examples, durations=(min_duration, max_duration),
142
- sample_rate=sample_rate, channels=channels, segment_duration=None)
143
- assert len(dataset) == total_examples
144
- assert dataset.sample_rate == sample_rate
145
- assert dataset.channels == channels
146
- for idx in range(len(dataset)):
147
- sample = dataset[idx]
148
- assert sample.shape[0] == channels
149
- assert sample.shape[1] <= int(max_duration * sample_rate)
150
- assert sample.shape[1] >= int(min_duration * sample_rate)
151
-
152
- def test_dataset_segment(self):
153
- total_examples = 10
154
- num_samples = 20
155
- min_duration, max_duration = 1., 4.
156
- segment_duration = 1.
157
- sample_rate = 16_000
158
- channels = 1
159
- dataset = self._create_audio_dataset(
160
- 'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
161
- channels=channels, segment_duration=segment_duration, num_examples=num_samples)
162
- assert len(dataset) == num_samples
163
- assert dataset.sample_rate == sample_rate
164
- assert dataset.channels == channels
165
- for idx in range(len(dataset)):
166
- sample = dataset[idx]
167
- assert sample.shape[0] == channels
168
- assert sample.shape[1] == int(segment_duration * sample_rate)
169
-
170
- def test_dataset_equal_audio_and_segment_durations(self):
171
- total_examples = 1
172
- num_samples = 2
173
- audio_duration = 1.
174
- segment_duration = 1.
175
- sample_rate = 16_000
176
- channels = 1
177
- dataset = self._create_audio_dataset(
178
- 'dset', total_examples, durations=audio_duration, sample_rate=sample_rate,
179
- channels=channels, segment_duration=segment_duration, num_examples=num_samples)
180
- assert len(dataset) == num_samples
181
- assert dataset.sample_rate == sample_rate
182
- assert dataset.channels == channels
183
- for idx in range(len(dataset)):
184
- sample = dataset[idx]
185
- assert sample.shape[0] == channels
186
- assert sample.shape[1] == int(segment_duration * sample_rate)
187
- # the random seek_time adds variability on audio read
188
- sample_1 = dataset[0]
189
- sample_2 = dataset[1]
190
- assert not torch.allclose(sample_1, sample_2)
191
-
192
- def test_dataset_samples(self):
193
- total_examples = 1
194
- num_samples = 2
195
- audio_duration = 1.
196
- segment_duration = 1.
197
- sample_rate = 16_000
198
- channels = 1
199
-
200
- create_dataset = partial(
201
- self._create_audio_dataset,
202
- 'dset', total_examples, durations=audio_duration, sample_rate=sample_rate,
203
- channels=channels, segment_duration=segment_duration, num_examples=num_samples,
204
- )
205
-
206
- dataset = create_dataset(shuffle=True)
207
- # when shuffle = True, we have different inputs for the same index across epoch
208
- sample_1 = dataset[0]
209
- sample_2 = dataset[0]
210
- assert not torch.allclose(sample_1, sample_2)
211
-
212
- dataset_noshuffle = create_dataset(shuffle=False)
213
- # when shuffle = False, we have same inputs for the same index across epoch
214
- sample_1 = dataset_noshuffle[0]
215
- sample_2 = dataset_noshuffle[0]
216
- assert torch.allclose(sample_1, sample_2)
217
-
218
- def test_dataset_return_info(self):
219
- total_examples = 10
220
- num_samples = 20
221
- min_duration, max_duration = 1., 4.
222
- segment_duration = 1.
223
- sample_rate = 16_000
224
- channels = 1
225
- dataset = self._create_audio_dataset(
226
- 'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
227
- channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=True)
228
- assert len(dataset) == num_samples
229
- assert dataset.sample_rate == sample_rate
230
- assert dataset.channels == channels
231
- for idx in range(len(dataset)):
232
- sample, segment_info = dataset[idx]
233
- assert sample.shape[0] == channels
234
- assert sample.shape[1] == int(segment_duration * sample_rate)
235
- assert segment_info.sample_rate == sample_rate
236
- assert segment_info.total_frames == int(segment_duration * sample_rate)
237
- assert segment_info.n_frames <= int(segment_duration * sample_rate)
238
- assert segment_info.seek_time >= 0
239
-
240
- def test_dataset_return_info_no_segment_duration(self):
241
- total_examples = 10
242
- num_samples = 20
243
- min_duration, max_duration = 1., 4.
244
- segment_duration = None
245
- sample_rate = 16_000
246
- channels = 1
247
- dataset = self._create_audio_dataset(
248
- 'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
249
- channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=True)
250
- assert len(dataset) == total_examples
251
- assert dataset.sample_rate == sample_rate
252
- assert dataset.channels == channels
253
- for idx in range(len(dataset)):
254
- sample, segment_info = dataset[idx]
255
- assert sample.shape[0] == channels
256
- assert sample.shape[1] == segment_info.total_frames
257
- assert segment_info.sample_rate == sample_rate
258
- assert segment_info.n_frames <= segment_info.total_frames
259
-
260
- def test_dataset_collate_fn(self):
261
- total_examples = 10
262
- num_samples = 20
263
- min_duration, max_duration = 1., 4.
264
- segment_duration = 1.
265
- sample_rate = 16_000
266
- channels = 1
267
- dataset = self._create_audio_dataset(
268
- 'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
269
- channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=False)
270
- batch_size = 4
271
- dataloader = DataLoader(
272
- dataset,
273
- batch_size=batch_size,
274
- num_workers=0
275
- )
276
- for idx, batch in enumerate(dataloader):
277
- assert batch.shape[0] == batch_size
278
-
279
- @pytest.mark.parametrize("segment_duration", [1.0, None])
280
- def test_dataset_with_meta_collate_fn(self, segment_duration):
281
- total_examples = 10
282
- num_samples = 20
283
- min_duration, max_duration = 1., 4.
284
- segment_duration = 1.
285
- sample_rate = 16_000
286
- channels = 1
287
- dataset = self._create_audio_dataset(
288
- 'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
289
- channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=True)
290
- batch_size = 4
291
- dataloader = DataLoader(
292
- dataset,
293
- batch_size=batch_size,
294
- collate_fn=dataset.collater,
295
- num_workers=0
296
- )
297
- for idx, batch in enumerate(dataloader):
298
- wav, infos = batch
299
- assert wav.shape[0] == batch_size
300
- assert len(infos) == batch_size
301
-
302
- @pytest.mark.parametrize("segment_duration,sample_on_weight,sample_on_duration,a_hist,b_hist,c_hist", [
303
- [1, True, True, 0.5, 0.5, 0.0],
304
- [1, False, True, 0.25, 0.5, 0.25],
305
- [1, True, False, 0.666, 0.333, 0.0],
306
- [1, False, False, 0.333, 0.333, 0.333],
307
- [None, False, False, 0.333, 0.333, 0.333]])
308
- def test_sample_with_weight(self, segment_duration, sample_on_weight, sample_on_duration, a_hist, b_hist, c_hist):
309
- random.seed(1234)
310
- rng = torch.Generator()
311
- rng.manual_seed(1234)
312
-
313
- def _get_histogram(dataset, repetitions=20_000):
314
- counts = {file_meta.path: 0. for file_meta in meta}
315
- for _ in range(repetitions):
316
- file_meta = dataset.sample_file(rng)
317
- counts[file_meta.path] += 1
318
- return {name: count / repetitions for name, count in counts.items()}
319
-
320
- meta = [
321
- AudioMeta(path='a', duration=5, sample_rate=1, weight=2),
322
- AudioMeta(path='b', duration=10, sample_rate=1, weight=None),
323
- AudioMeta(path='c', duration=5, sample_rate=1, weight=0),
324
- ]
325
- dataset = AudioDataset(
326
- meta, segment_duration=segment_duration, sample_on_weight=sample_on_weight,
327
- sample_on_duration=sample_on_duration)
328
- hist = _get_histogram(dataset)
329
- assert math.isclose(hist['a'], a_hist, abs_tol=0.01)
330
- assert math.isclose(hist['b'], b_hist, abs_tol=0.01)
331
- assert math.isclose(hist['c'], c_hist, abs_tol=0.01)
332
-
333
- def test_meta_duration_filter_all(self):
334
- meta = [
335
- AudioMeta(path='a', duration=5, sample_rate=1, weight=2),
336
- AudioMeta(path='b', duration=10, sample_rate=1, weight=None),
337
- AudioMeta(path='c', duration=5, sample_rate=1, weight=0),
338
- ]
339
- try:
340
- AudioDataset(meta, segment_duration=11, min_segment_ratio=1)
341
- assert False
342
- except AssertionError:
343
- assert True
344
-
345
- def test_meta_duration_filter_long(self):
346
- meta = [
347
- AudioMeta(path='a', duration=5, sample_rate=1, weight=2),
348
- AudioMeta(path='b', duration=10, sample_rate=1, weight=None),
349
- AudioMeta(path='c', duration=5, sample_rate=1, weight=0),
350
- ]
351
- dataset = AudioDataset(meta, segment_duration=None, min_segment_ratio=1, max_audio_duration=7)
352
- assert len(dataset) == 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/data/test_audio_utils.py DELETED
@@ -1,110 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import julius
8
- import torch
9
- import pytest
10
-
11
- from audiocraft.data.audio_utils import (
12
- _clip_wav,
13
- convert_audio_channels,
14
- convert_audio,
15
- normalize_audio
16
- )
17
- from ..common_utils import get_batch_white_noise
18
-
19
-
20
- class TestConvertAudioChannels:
21
-
22
- def test_convert_audio_channels_downmix(self):
23
- b, c, t = 2, 3, 100
24
- audio = get_batch_white_noise(b, c, t)
25
- mixed = convert_audio_channels(audio, channels=2)
26
- assert list(mixed.shape) == [b, 2, t]
27
-
28
- def test_convert_audio_channels_nochange(self):
29
- b, c, t = 2, 3, 100
30
- audio = get_batch_white_noise(b, c, t)
31
- mixed = convert_audio_channels(audio, channels=c)
32
- assert list(mixed.shape) == list(audio.shape)
33
-
34
- def test_convert_audio_channels_upmix(self):
35
- b, c, t = 2, 1, 100
36
- audio = get_batch_white_noise(b, c, t)
37
- mixed = convert_audio_channels(audio, channels=3)
38
- assert list(mixed.shape) == [b, 3, t]
39
-
40
- def test_convert_audio_channels_upmix_error(self):
41
- b, c, t = 2, 2, 100
42
- audio = get_batch_white_noise(b, c, t)
43
- with pytest.raises(ValueError):
44
- convert_audio_channels(audio, channels=3)
45
-
46
-
47
- class TestConvertAudio:
48
-
49
- def test_convert_audio_channels_downmix(self):
50
- b, c, dur = 2, 3, 4.
51
- sr = 128
52
- audio = get_batch_white_noise(b, c, int(sr * dur))
53
- out = convert_audio(audio, from_rate=sr, to_rate=sr, to_channels=2)
54
- assert list(out.shape) == [audio.shape[0], 2, audio.shape[-1]]
55
-
56
- def test_convert_audio_channels_upmix(self):
57
- b, c, dur = 2, 1, 4.
58
- sr = 128
59
- audio = get_batch_white_noise(b, c, int(sr * dur))
60
- out = convert_audio(audio, from_rate=sr, to_rate=sr, to_channels=3)
61
- assert list(out.shape) == [audio.shape[0], 3, audio.shape[-1]]
62
-
63
- def test_convert_audio_upsample(self):
64
- b, c, dur = 2, 1, 4.
65
- sr = 2
66
- new_sr = 3
67
- audio = get_batch_white_noise(b, c, int(sr * dur))
68
- out = convert_audio(audio, from_rate=sr, to_rate=new_sr, to_channels=c)
69
- out_j = julius.resample.resample_frac(audio, old_sr=sr, new_sr=new_sr)
70
- assert torch.allclose(out, out_j)
71
-
72
- def test_convert_audio_resample(self):
73
- b, c, dur = 2, 1, 4.
74
- sr = 3
75
- new_sr = 2
76
- audio = get_batch_white_noise(b, c, int(sr * dur))
77
- out = convert_audio(audio, from_rate=sr, to_rate=new_sr, to_channels=c)
78
- out_j = julius.resample.resample_frac(audio, old_sr=sr, new_sr=new_sr)
79
- assert torch.allclose(out, out_j)
80
-
81
-
82
- class TestNormalizeAudio:
83
-
84
- def test_clip_wav(self):
85
- b, c, dur = 2, 1, 4.
86
- sr = 3
87
- audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
88
- _clip_wav(audio)
89
- assert audio.abs().max() <= 1
90
-
91
- def test_normalize_audio_clip(self):
92
- b, c, dur = 2, 1, 4.
93
- sr = 3
94
- audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
95
- norm_audio = normalize_audio(audio, strategy='clip')
96
- assert norm_audio.abs().max() <= 1
97
-
98
- def test_normalize_audio_rms(self):
99
- b, c, dur = 2, 1, 4.
100
- sr = 3
101
- audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
102
- norm_audio = normalize_audio(audio, strategy='rms')
103
- assert norm_audio.abs().max() <= 1
104
-
105
- def test_normalize_audio_peak(self):
106
- b, c, dur = 2, 1, 4.
107
- sr = 3
108
- audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
109
- norm_audio = normalize_audio(audio, strategy='peak')
110
- assert norm_audio.abs().max() <= 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/models/test_encodec_model.py DELETED
@@ -1,60 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import random
8
-
9
- import numpy as np
10
- import torch
11
-
12
- from audiocraft.models import EncodecModel
13
- from audiocraft.modules import SEANetEncoder, SEANetDecoder
14
- from audiocraft.quantization import DummyQuantizer
15
-
16
-
17
- class TestEncodecModel:
18
-
19
- def _create_encodec_model(self,
20
- sample_rate: int,
21
- channels: int,
22
- dim: int = 5,
23
- n_filters: int = 3,
24
- n_residual_layers: int = 1,
25
- ratios: list = [5, 4, 3, 2],
26
- **kwargs):
27
- frame_rate = np.prod(ratios)
28
- encoder = SEANetEncoder(channels=channels, dimension=dim, n_filters=n_filters,
29
- n_residual_layers=n_residual_layers, ratios=ratios)
30
- decoder = SEANetDecoder(channels=channels, dimension=dim, n_filters=n_filters,
31
- n_residual_layers=n_residual_layers, ratios=ratios)
32
- quantizer = DummyQuantizer()
33
- model = EncodecModel(encoder, decoder, quantizer, frame_rate=frame_rate,
34
- sample_rate=sample_rate, channels=channels, **kwargs)
35
- return model
36
-
37
- def test_model(self):
38
- random.seed(1234)
39
- sample_rate = 24_000
40
- channels = 1
41
- model = self._create_encodec_model(sample_rate, channels)
42
- for _ in range(10):
43
- length = random.randrange(1, 10_000)
44
- x = torch.randn(2, channels, length)
45
- res = model(x)
46
- assert res.x.shape == x.shape
47
-
48
- def test_model_renorm(self):
49
- random.seed(1234)
50
- sample_rate = 24_000
51
- channels = 1
52
- model_nonorm = self._create_encodec_model(sample_rate, channels, renormalize=False)
53
- model_renorm = self._create_encodec_model(sample_rate, channels, renormalize=True)
54
-
55
- for _ in range(10):
56
- length = random.randrange(1, 10_000)
57
- x = torch.randn(2, channels, length)
58
- codes, scales = model_nonorm.encode(x)
59
- codes, scales = model_renorm.encode(x)
60
- assert scales is not None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/models/test_musicgen.py DELETED
@@ -1,50 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import pytest
8
- import torch
9
-
10
- from audiocraft.models import MusicGen
11
-
12
-
13
- class TestSEANetModel:
14
- def get_musicgen(self):
15
- mg = MusicGen.get_pretrained(name='debug', device='cpu')
16
- mg.set_generation_params(duration=2.0)
17
- return mg
18
-
19
- def test_base(self):
20
- mg = self.get_musicgen()
21
- assert mg.frame_rate == 25
22
- assert mg.sample_rate == 32000
23
- assert mg.audio_channels == 1
24
-
25
- def test_generate_unconditional(self):
26
- mg = self.get_musicgen()
27
- wav = mg.generate_unconditional(3)
28
- assert list(wav.shape) == [3, 1, 64000]
29
-
30
- def test_generate_continuation(self):
31
- mg = self.get_musicgen()
32
- prompt = torch.randn(3, 1, 32000)
33
- wav = mg.generate_continuation(prompt, 32000)
34
- assert list(wav.shape) == [3, 1, 64000]
35
-
36
- prompt = torch.randn(2, 1, 32000)
37
- wav = mg.generate_continuation(
38
- prompt, 32000, ['youpi', 'lapin dort'])
39
- assert list(wav.shape) == [2, 1, 64000]
40
-
41
- prompt = torch.randn(2, 1, 32000)
42
- with pytest.raises(AssertionError):
43
- wav = mg.generate_continuation(
44
- prompt, 32000, ['youpi', 'lapin dort', 'one too many'])
45
-
46
- def test_generate(self):
47
- mg = self.get_musicgen()
48
- wav = mg.generate(
49
- ['youpi', 'lapin dort'])
50
- assert list(wav.shape) == [2, 1, 64000]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/modules/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
 
 
 
 
 
 
tests/modules/test_codebooks_patterns.py DELETED
@@ -1,246 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import pytest
8
- import torch
9
-
10
- from audiocraft.modules.codebooks_patterns import (
11
- DelayedPatternProvider,
12
- ParallelPatternProvider,
13
- Pattern,
14
- UnrolledPatternProvider,
15
- )
16
-
17
-
18
- class TestParallelPatternProvider:
19
-
20
- @pytest.mark.parametrize("n_q", [1, 4, 32])
21
- @pytest.mark.parametrize("timesteps", [0, 1, 16, 100])
22
- def test_get_pattern(self, n_q: int, timesteps: int):
23
- provider = ParallelPatternProvider(n_q)
24
- pattern = provider.get_pattern(timesteps)
25
- # + 1 to account for 1st step
26
- assert len(pattern.layout) == timesteps + 1
27
-
28
- @pytest.mark.parametrize("n_q", [1, 4, 32])
29
- @pytest.mark.parametrize("timesteps", [8, 16, 100])
30
- def test_pattern_content(self, n_q: int, timesteps: int):
31
- provider = ParallelPatternProvider(n_q)
32
- pattern = provider.get_pattern(timesteps)
33
- for s, v in enumerate(pattern.layout):
34
- for i, code in enumerate(v):
35
- assert i == code.q
36
- assert code.t == s - 1 # account for the 1st empty step
37
-
38
- @pytest.mark.parametrize("n_q", [1, 4, 32])
39
- @pytest.mark.parametrize("timesteps", [8, 16, 100])
40
- def test_pattern_max_delay(self, n_q: int, timesteps: int):
41
- provider = ParallelPatternProvider(n_q)
42
- pattern = provider.get_pattern(timesteps)
43
- assert pattern.max_delay == 0
44
- assert len(pattern.valid_layout) == len(pattern.layout) - pattern.max_delay
45
-
46
-
47
- class TestDelayedPatternProvider:
48
-
49
- @pytest.mark.parametrize("n_q", [1, 4, 32])
50
- @pytest.mark.parametrize("timesteps", [0, 1, 16, 100])
51
- def test_get_pattern(self, n_q: int, timesteps: int):
52
- delays = [
53
- list(range(n_q)),
54
- [0] + [1] * (n_q - 1),
55
- [0] + [4] * (n_q - 1),
56
- ]
57
- for delay in delays:
58
- provider = DelayedPatternProvider(n_q, delay)
59
- pattern = provider.get_pattern(timesteps)
60
- # + 1 to account for 1st step
61
- assert len(pattern.layout) == timesteps + max(delay) + 1
62
-
63
- @pytest.mark.parametrize("n_q", [1, 4, 32])
64
- @pytest.mark.parametrize("timesteps", [8, 16, 100])
65
- def test_pattern_content(self, n_q: int, timesteps: int):
66
- provider = DelayedPatternProvider(n_q)
67
- pattern = provider.get_pattern(timesteps)
68
- for s, v in enumerate(pattern.layout):
69
- for i, code in enumerate(v):
70
- assert i == code.q
71
- assert code.t == max(0, s - code.q - 1)
72
-
73
- @pytest.mark.parametrize("timesteps", [8, 16, 100])
74
- @pytest.mark.parametrize("delay", [[0, 1, 2, 3], [0, 1, 1, 1], [0, 3, 3, 3], [0, 3]])
75
- def test_pattern_max_delay(self, timesteps: int, delay: list):
76
- provider = DelayedPatternProvider(len(delay), delay)
77
- pattern = provider.get_pattern(timesteps)
78
- assert pattern.max_delay == max(delay)
79
- assert len(pattern.valid_layout) == len(pattern.layout) - pattern.max_delay
80
-
81
-
82
- class TestUnrolledPatternProvider:
83
-
84
- @pytest.mark.parametrize("timesteps", [0, 1, 16])
85
- @pytest.mark.parametrize("flattening", [[0, 1, 2], [0, 1, 1]])
86
- @pytest.mark.parametrize("delays", [[0, 0, 0], [0, 5, 5]])
87
- def test_get_pattern(self, timesteps: int, flattening: list, delays: list):
88
- n_q = len(flattening)
89
- max_delay = max(delays)
90
- provider = UnrolledPatternProvider(n_q, flattening, delays)
91
- pattern = provider.get_pattern(timesteps)
92
- assert len(pattern.layout) == provider.num_virtual_steps(timesteps) + max_delay
93
-
94
- @pytest.mark.parametrize("timesteps", [0, 1, 16])
95
- @pytest.mark.parametrize("flattening", [[0, 1, 2], [0, 1, 1]])
96
- @pytest.mark.parametrize("delays", [[0, 0, 0], [0, 5, 5]])
97
- def test_pattern_max_delay(self, timesteps: int, flattening: list, delays: list):
98
- n_q = len(flattening)
99
- max_delay = max(delays)
100
- provider = UnrolledPatternProvider(n_q, flattening, delays)
101
- pattern = provider.get_pattern(timesteps)
102
- assert pattern.max_delay == max_delay
103
-
104
-
105
- class TestPattern:
106
-
107
- def ref_build_pattern_sequence(self, z: torch.Tensor, pattern: Pattern, special_token: int):
108
- """Reference method to build the sequence from the pattern without using fancy scatter."""
109
- bs, n_q, T = z.shape
110
- z = z.cpu().numpy()
111
- assert n_q == pattern.n_q
112
- assert T <= pattern.timesteps
113
- inp = torch.full((bs, n_q, len(pattern.layout)), special_token, dtype=torch.long).numpy()
114
- inp[:] = special_token
115
- for s, v in enumerate(pattern.layout):
116
- for (t, q) in v:
117
- if t < T:
118
- inp[:, q, s] = z[:, q, t]
119
- return torch.from_numpy(inp)
120
-
121
- def ref_revert_pattern_sequence(self, z: torch.Tensor, pattern: Pattern, special_token: int):
122
- """Reference method to revert the sequence from the pattern without using fancy scatter."""
123
- z = z.cpu().numpy()
124
- bs, n_q, S = z.shape
125
- assert pattern.n_q == n_q
126
- inp = torch.full((bs, pattern.n_q, pattern.timesteps), special_token, dtype=torch.long).numpy()
127
- inp[:] = special_token
128
- for s, v in enumerate(pattern.layout):
129
- for (t, q) in v:
130
- if t < pattern.timesteps:
131
- inp[:, q, t] = z[:, q, s]
132
- return torch.from_numpy(inp)
133
-
134
- def ref_revert_pattern_logits(self, z: torch.Tensor, pattern: Pattern, special_token: float):
135
- """Reference method to revert the logits from the pattern without using fancy scatter."""
136
- z = z.cpu().numpy()
137
- bs, card, n_q, S = z.shape
138
- assert pattern.n_q == n_q
139
- ref_layout = pattern.layout
140
- inp = torch.full((bs, card, pattern.n_q, pattern.timesteps), special_token, dtype=torch.float).numpy()
141
- inp[:] = special_token
142
- for s, v in enumerate(ref_layout[1:]):
143
- if s < S:
144
- for (t, q) in v:
145
- if t < pattern.timesteps:
146
- inp[:, :, q, t] = z[:, :, q, s]
147
- return torch.from_numpy(inp)
148
-
149
- def _get_pattern_providers(self, n_q: int):
150
- pattern_provider_1 = ParallelPatternProvider(n_q)
151
- pattern_provider_2 = DelayedPatternProvider(n_q, list(range(n_q)))
152
- pattern_provider_3 = DelayedPatternProvider(n_q, [0] + [1] * (n_q - 1))
153
- pattern_provider_4 = UnrolledPatternProvider(
154
- n_q, flattening=list(range(n_q)), delays=[0] * n_q
155
- )
156
- pattern_provider_5 = UnrolledPatternProvider(
157
- n_q, flattening=[0] + [1] * (n_q - 1), delays=[0] * n_q
158
- )
159
- pattern_provider_6 = UnrolledPatternProvider(
160
- n_q, flattening=[0] + [1] * (n_q - 1), delays=[0] + [5] * (n_q - 1)
161
- )
162
- return [
163
- pattern_provider_1,
164
- pattern_provider_2,
165
- pattern_provider_3,
166
- pattern_provider_4,
167
- pattern_provider_5,
168
- pattern_provider_6,
169
- ]
170
-
171
- @pytest.mark.parametrize("n_q", [1, 4, 32])
172
- @pytest.mark.parametrize("timesteps", [16, 72])
173
- def test_build_pattern_sequence(self, n_q: int, timesteps: int):
174
- bs = 2
175
- card = 256
176
- special_token = card
177
-
178
- pattern_providers = self._get_pattern_providers(n_q)
179
- for pattern_provider in pattern_providers:
180
- pattern = pattern_provider.get_pattern(timesteps)
181
- # we can correctly build the sequence from the pattern
182
- z = torch.randint(0, card, (bs, n_q, timesteps))
183
- ref_res = self.ref_build_pattern_sequence(z, pattern, special_token)
184
- res, indexes, mask = pattern.build_pattern_sequence(z, special_token)
185
- assert (res == ref_res).float().mean() == 1.0
186
-
187
- # expected assertion fails on the number of timesteps
188
- invalid_timesteps = [timesteps + 1]
189
- if pattern.num_sequence_steps != pattern.timesteps:
190
- invalid_timesteps.append(pattern.num_sequence_steps)
191
- for i_timesteps in invalid_timesteps:
192
- z2 = torch.randint(0, card, (bs, n_q, i_timesteps))
193
- with pytest.raises(AssertionError):
194
- pattern.build_pattern_sequence(z2, special_token)
195
-
196
- # expected assertion fails on the number of codebooks
197
- invalid_qs = [0, n_q - 1, n_q + 1]
198
- for i_q in invalid_qs:
199
- z3 = torch.randint(0, card, (bs, i_q, timesteps))
200
- with pytest.raises(AssertionError):
201
- pattern.build_pattern_sequence(z3, special_token)
202
-
203
- @pytest.mark.parametrize("n_q", [1, 4, 32])
204
- @pytest.mark.parametrize("timesteps", [16, 72])
205
- def test_revert_pattern_sequence(self, n_q: int, timesteps: int):
206
- bs = 2
207
- card = 256
208
- special_token = card
209
-
210
- pattern_providers = self._get_pattern_providers(n_q)
211
- for pattern_provider in pattern_providers:
212
- pattern = pattern_provider.get_pattern(timesteps)
213
- # this works assuming previous tests are successful
214
- z = torch.randint(0, card, (bs, n_q, timesteps))
215
- s = self.ref_build_pattern_sequence(z, pattern, special_token)
216
- ref_out = self.ref_revert_pattern_sequence(s, pattern, special_token)
217
- # ensure our reference script retrieve the original sequence
218
- assert z.shape == ref_out.shape
219
- assert (z == ref_out).float().mean() == 1.0
220
- # now we can test the scatter version
221
- out, indexes, mask = pattern.revert_pattern_sequence(s, special_token)
222
- assert out.shape == ref_out.shape
223
- assert (out == ref_out).float().mean() == 1.0
224
-
225
- @pytest.mark.parametrize("n_q", [1, 4, 32])
226
- @pytest.mark.parametrize("timesteps", [16, 72])
227
- @pytest.mark.parametrize("card", [1, 2, 256, 1024])
228
- def test_revert_pattern_logits(self, n_q: int, timesteps: int, card: int):
229
- bs = 2
230
- special_token = card
231
- logits_special_token = float('nan')
232
-
233
- pattern_providers = self._get_pattern_providers(n_q)
234
- for pattern_provider in pattern_providers:
235
- pattern = pattern_provider.get_pattern(timesteps)
236
- # this works assuming previous tests are successful
237
- z = torch.randint(0, card, (bs, n_q, timesteps))
238
- s = self.ref_build_pattern_sequence(z, pattern, special_token)
239
- logits = torch.randn((bs, card, n_q, s.shape[-1]))
240
- ref_out = self.ref_revert_pattern_logits(logits, pattern, logits_special_token)
241
- # ensure our reference script retrieve the original sequence
242
- assert ref_out.shape == torch.Size([bs, card, n_q, timesteps])
243
- # now we can test the scatter version
244
- out, indexes, mask = pattern.revert_pattern_logits(logits, logits_special_token)
245
- assert out.shape == ref_out.shape
246
- assert (out == ref_out).float().mean() == 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/modules/test_conv.py DELETED
@@ -1,203 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- from itertools import product
8
- import math
9
- import random
10
-
11
- import pytest
12
- import torch
13
- from torch import nn
14
-
15
- from audiocraft.modules import (
16
- NormConv1d,
17
- NormConvTranspose1d,
18
- StreamableConv1d,
19
- StreamableConvTranspose1d,
20
- pad1d,
21
- unpad1d,
22
- )
23
-
24
-
25
- def test_get_extra_padding_for_conv1d():
26
- # TODO: Implement me!
27
- pass
28
-
29
-
30
- def test_pad1d_zeros():
31
- x = torch.randn(1, 1, 20)
32
-
33
- xp1 = pad1d(x, (0, 5), mode='constant', value=0.)
34
- assert xp1.shape[-1] == 25
35
- xp2 = pad1d(x, (5, 5), mode='constant', value=0.)
36
- assert xp2.shape[-1] == 30
37
- xp3 = pad1d(x, (0, 0), mode='constant', value=0.)
38
- assert xp3.shape[-1] == 20
39
- xp4 = pad1d(x, (10, 30), mode='constant', value=0.)
40
- assert xp4.shape[-1] == 60
41
-
42
- with pytest.raises(AssertionError):
43
- pad1d(x, (-1, 0), mode='constant', value=0.)
44
-
45
- with pytest.raises(AssertionError):
46
- pad1d(x, (0, -1), mode='constant', value=0.)
47
-
48
- with pytest.raises(AssertionError):
49
- pad1d(x, (-1, -1), mode='constant', value=0.)
50
-
51
-
52
- def test_pad1d_reflect():
53
- x = torch.randn(1, 1, 20)
54
-
55
- xp1 = pad1d(x, (0, 5), mode='reflect', value=0.)
56
- assert xp1.shape[-1] == 25
57
- xp2 = pad1d(x, (5, 5), mode='reflect', value=0.)
58
- assert xp2.shape[-1] == 30
59
- xp3 = pad1d(x, (0, 0), mode='reflect', value=0.)
60
- assert xp3.shape[-1] == 20
61
- xp4 = pad1d(x, (10, 30), mode='reflect', value=0.)
62
- assert xp4.shape[-1] == 60
63
-
64
- with pytest.raises(AssertionError):
65
- pad1d(x, (-1, 0), mode='reflect', value=0.)
66
-
67
- with pytest.raises(AssertionError):
68
- pad1d(x, (0, -1), mode='reflect', value=0.)
69
-
70
- with pytest.raises(AssertionError):
71
- pad1d(x, (-1, -1), mode='reflect', value=0.)
72
-
73
-
74
- def test_unpad1d():
75
- x = torch.randn(1, 1, 20)
76
-
77
- u1 = unpad1d(x, (5, 5))
78
- assert u1.shape[-1] == 10
79
- u2 = unpad1d(x, (0, 5))
80
- assert u2.shape[-1] == 15
81
- u3 = unpad1d(x, (5, 0))
82
- assert u3.shape[-1] == 15
83
- u4 = unpad1d(x, (0, 0))
84
- assert u4.shape[-1] == x.shape[-1]
85
-
86
- with pytest.raises(AssertionError):
87
- unpad1d(x, (-1, 0))
88
-
89
- with pytest.raises(AssertionError):
90
- unpad1d(x, (0, -1))
91
-
92
- with pytest.raises(AssertionError):
93
- unpad1d(x, (-1, -1))
94
-
95
-
96
- class TestNormConv1d:
97
-
98
- def test_norm_conv1d_modules(self):
99
- N, C, T = 2, 2, random.randrange(1, 100_000)
100
- t0 = torch.randn(N, C, T)
101
-
102
- C_out, kernel_size, stride = 1, 4, 1
103
- expected_out_length = int((T - kernel_size) / stride + 1)
104
- wn_conv = NormConv1d(C, 1, kernel_size=4, norm='weight_norm')
105
- gn_conv = NormConv1d(C, 1, kernel_size=4, norm='time_group_norm')
106
- nn_conv = NormConv1d(C, 1, kernel_size=4, norm='none')
107
-
108
- assert isinstance(wn_conv.norm, nn.Identity)
109
- assert isinstance(wn_conv.conv, nn.Conv1d)
110
-
111
- assert isinstance(gn_conv.norm, nn.GroupNorm)
112
- assert isinstance(gn_conv.conv, nn.Conv1d)
113
-
114
- assert isinstance(nn_conv.norm, nn.Identity)
115
- assert isinstance(nn_conv.conv, nn.Conv1d)
116
-
117
- for conv_layer in [wn_conv, gn_conv, nn_conv]:
118
- out = conv_layer(t0)
119
- assert isinstance(out, torch.Tensor)
120
- assert list(out.shape) == [N, C_out, expected_out_length]
121
-
122
-
123
- class TestNormConvTranspose1d:
124
-
125
- def test_normalizations(self):
126
- N, C, T = 2, 2, random.randrange(1, 100_000)
127
- t0 = torch.randn(N, C, T)
128
-
129
- C_out, kernel_size, stride = 1, 4, 1
130
- expected_out_length = (T - 1) * stride + (kernel_size - 1) + 1
131
-
132
- wn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='weight_norm')
133
- gn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='time_group_norm')
134
- nn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='none')
135
-
136
- assert isinstance(wn_convtr.norm, nn.Identity)
137
- assert isinstance(wn_convtr.convtr, nn.ConvTranspose1d)
138
-
139
- assert isinstance(gn_convtr.norm, nn.GroupNorm)
140
- assert isinstance(gn_convtr.convtr, nn.ConvTranspose1d)
141
-
142
- assert isinstance(nn_convtr.norm, nn.Identity)
143
- assert isinstance(nn_convtr.convtr, nn.ConvTranspose1d)
144
-
145
- for convtr_layer in [wn_convtr, gn_convtr, nn_convtr]:
146
- out = convtr_layer(t0)
147
- assert isinstance(out, torch.Tensor)
148
- assert list(out.shape) == [N, C_out, expected_out_length]
149
-
150
-
151
- class TestStreamableConv1d:
152
-
153
- def get_streamable_conv1d_output_length(self, length, kernel_size, stride, dilation):
154
- # StreamableConv1d internally pads to make sure that the last window is full
155
- padding_total = (kernel_size - 1) * dilation - (stride - 1)
156
- n_frames = (length - kernel_size + padding_total) / stride + 1
157
- ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
158
- return ideal_length // stride
159
-
160
- def test_streamable_conv1d(self):
161
- N, C, T = 2, 2, random.randrange(1, 100_000)
162
- t0 = torch.randn(N, C, T)
163
- C_out = 1
164
-
165
- # conv params are [(kernel_size, stride, dilation)]
166
- conv_params = [(4, 1, 1), (4, 2, 1), (3, 1, 3), (10, 5, 1), (3, 2, 3)]
167
- for causal, (kernel_size, stride, dilation) in product([False, True], conv_params):
168
- expected_out_length = self.get_streamable_conv1d_output_length(T, kernel_size, stride, dilation)
169
- sconv = StreamableConv1d(C, C_out, kernel_size=kernel_size, stride=stride, dilation=dilation, causal=causal)
170
- out = sconv(t0)
171
- assert isinstance(out, torch.Tensor)
172
- print(list(out.shape), [N, C_out, expected_out_length])
173
- assert list(out.shape) == [N, C_out, expected_out_length]
174
-
175
-
176
- class TestStreamableConvTranspose1d:
177
-
178
- def get_streamable_convtr1d_output_length(self, length, kernel_size, stride):
179
- padding_total = (kernel_size - stride)
180
- return (length - 1) * stride - padding_total + (kernel_size - 1) + 1
181
-
182
- def test_streamable_convtr1d(self):
183
- N, C, T = 2, 2, random.randrange(1, 100_000)
184
- t0 = torch.randn(N, C, T)
185
-
186
- C_out = 1
187
-
188
- with pytest.raises(AssertionError):
189
- StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=False, trim_right_ratio=0.5)
190
- StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=True, trim_right_ratio=-1.)
191
- StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=True, trim_right_ratio=2)
192
-
193
- # causal params are [(causal, trim_right)]
194
- causal_params = [(False, 1.0), (True, 1.0), (True, 0.5), (True, 0.0)]
195
- # conv params are [(kernel_size, stride)]
196
- conv_params = [(4, 1), (4, 2), (3, 1), (10, 5)]
197
- for ((causal, trim_right_ratio), (kernel_size, stride)) in product(causal_params, conv_params):
198
- expected_out_length = self.get_streamable_convtr1d_output_length(T, kernel_size, stride)
199
- sconvtr = StreamableConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride,
200
- causal=causal, trim_right_ratio=trim_right_ratio)
201
- out = sconvtr(t0)
202
- assert isinstance(out, torch.Tensor)
203
- assert list(out.shape) == [N, C_out, expected_out_length]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/modules/test_lstm.py DELETED
@@ -1,32 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import random
8
- import torch
9
-
10
- from audiocraft.modules.lstm import StreamableLSTM
11
-
12
-
13
- class TestStreamableLSTM:
14
-
15
- def test_lstm(self):
16
- B, C, T = 4, 2, random.randint(1, 100)
17
-
18
- lstm = StreamableLSTM(C, 3, skip=False)
19
- x = torch.randn(B, C, T)
20
- y = lstm(x)
21
-
22
- print(y.shape)
23
- assert y.shape == torch.Size([B, C, T])
24
-
25
- def test_lstm_skip(self):
26
- B, C, T = 4, 2, random.randint(1, 100)
27
-
28
- lstm = StreamableLSTM(C, 3, skip=True)
29
- x = torch.randn(B, C, T)
30
- y = lstm(x)
31
-
32
- assert y.shape == torch.Size([B, C, T])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/modules/test_rope.py DELETED
@@ -1,160 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import torch
8
-
9
- from audiocraft.modules.rope import RotaryEmbedding
10
- from audiocraft.modules.transformer import StreamingTransformer
11
-
12
-
13
- def test_rope():
14
- B, T, H, C = 8, 75, 16, 128
15
-
16
- rope = RotaryEmbedding(dim=C)
17
- xq = torch.rand((B, T, H, C))
18
- xk = torch.rand((B, T, H, C))
19
- xq_out, xk_out = rope.rotate_qk(xq, xk, start=7)
20
-
21
- assert list(xq_out.shape) == [B, T, H, C]
22
- assert list(xk_out.shape) == [B, T, H, C]
23
-
24
-
25
- def test_rope_io_dtypes():
26
- B, T, H, C = 8, 75, 16, 128
27
-
28
- rope_32 = RotaryEmbedding(dim=C, dtype=torch.float32)
29
- rope_64 = RotaryEmbedding(dim=C, dtype=torch.float64)
30
-
31
- # Test bfloat16 inputs w/ both 32 and 64 precision rope.
32
- xq_16 = torch.rand((B, T, H, C)).to(torch.bfloat16)
33
- xk_16 = torch.rand((B, T, H, C)).to(torch.bfloat16)
34
- xq_out, xk_out = rope_32.rotate_qk(xq_16, xk_16)
35
- assert xq_out.dtype == torch.bfloat16
36
- xq_out, xk_out = rope_64.rotate_qk(xq_16, xk_16)
37
- assert xq_out.dtype == torch.bfloat16
38
-
39
- # Test float32 inputs w/ both 32 and 64 precision rope.
40
- xq_32 = torch.rand((B, T, H, C)).to(torch.float32)
41
- xk_32 = torch.rand((B, T, H, C)).to(torch.float32)
42
- xq_out, xk_out = rope_32.rotate_qk(xq_32, xk_32)
43
- assert xq_out.dtype == torch.float32
44
- xq_out, xk_out = rope_64.rotate_qk(xq_32, xk_32)
45
- assert xq_out.dtype == torch.float32
46
-
47
-
48
- def test_transformer_with_rope():
49
- torch.manual_seed(1234)
50
- for pos in ['rope', 'sin_rope']:
51
- tr = StreamingTransformer(
52
- 16, 4, 2, custom=True, dropout=0., layer_scale=0.1,
53
- positional_embedding=pos)
54
- tr.eval()
55
- steps = 12
56
- x = torch.randn(3, steps, 16)
57
-
58
- out = tr(x)
59
- assert list(out.shape) == list(x.shape)
60
-
61
-
62
- @torch.no_grad()
63
- def test_rope_streaming():
64
- torch.manual_seed(1234)
65
- tr = StreamingTransformer(
66
- 16, 4, 2, causal=True, dropout=0.,
67
- custom=True, positional_embedding='rope')
68
- tr.eval()
69
- steps = 12
70
- x = torch.randn(3, steps, 16)
71
-
72
- ref = tr(x)
73
-
74
- with tr.streaming():
75
- outs = []
76
- frame_sizes = [1] * steps
77
-
78
- for frame_size in frame_sizes:
79
- frame = x[:, :frame_size]
80
- x = x[:, frame_size:]
81
- outs.append(tr(frame))
82
-
83
- out = torch.cat(outs, dim=1)
84
- assert list(out.shape) == [3, steps, 16]
85
- delta = torch.norm(out - ref) / torch.norm(out)
86
- assert delta < 1e-6, delta
87
-
88
-
89
- @torch.no_grad()
90
- def test_rope_streaming_past_context():
91
- torch.manual_seed(1234)
92
-
93
- for context in [None, 10]:
94
- tr = StreamingTransformer(
95
- 16, 4, 1 if context else 2,
96
- causal=True, past_context=context, custom=True,
97
- dropout=0., positional_embedding='rope')
98
- tr.eval()
99
-
100
- steps = 20
101
- x = torch.randn(3, steps, 16)
102
- ref = tr(x)
103
-
104
- with tr.streaming():
105
- outs = []
106
- frame_sizes = [1] * steps
107
-
108
- for frame_size in frame_sizes:
109
- frame = x[:, :frame_size]
110
- x = x[:, frame_size:]
111
- outs.append(tr(frame))
112
-
113
- out = torch.cat(outs, dim=1)
114
- assert list(out.shape) == [3, steps, 16]
115
- delta = torch.norm(out - ref) / torch.norm(out)
116
- assert delta < 1e-6, delta
117
-
118
-
119
- def test_rope_memory_efficient():
120
- torch.manual_seed(1234)
121
- tr = StreamingTransformer(
122
- 16, 4, 2, custom=True, dropout=0., layer_scale=0.1,
123
- positional_embedding='rope')
124
- tr_mem_efficient = StreamingTransformer(
125
- 16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1,
126
- positional_embedding='rope')
127
- tr_mem_efficient.load_state_dict(tr.state_dict())
128
- tr.eval()
129
- steps = 12
130
- x = torch.randn(3, steps, 16)
131
-
132
- with torch.no_grad():
133
- y = tr(x)
134
- y2 = tr_mem_efficient(x)
135
- # Check at float precision b/c this is the rope default.
136
- assert torch.allclose(y, y2, atol=1e-7), (y - y2).norm()
137
-
138
-
139
- def test_rope_with_xpos():
140
- B, T, H, C = 8, 75, 16, 128
141
-
142
- rope = RotaryEmbedding(dim=C, xpos=True)
143
- xq = torch.rand((B, T, H, C))
144
- xk = torch.rand((B, T, H, C))
145
- xq_out, xk_out = rope.rotate_qk(xq, xk, start=7)
146
-
147
- assert list(xq_out.shape) == [B, T, H, C]
148
- assert list(xk_out.shape) == [B, T, H, C]
149
-
150
-
151
- def test_positional_scale():
152
- B, T, H, C = 8, 75, 16, 128
153
-
154
- rope = RotaryEmbedding(dim=C, xpos=True, scale=0.0)
155
- xq = torch.rand((B, T, H, C))
156
- xk = torch.rand((B, T, H, C))
157
- xq_out, xk_out = rope.rotate_qk(xq, xk, start=7)
158
-
159
- assert torch.allclose(xq, xq_out)
160
- assert torch.allclose(xk, xk_out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/modules/test_seanet.py DELETED
@@ -1,115 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- from itertools import product
8
-
9
- import pytest
10
- import torch
11
-
12
- from audiocraft.modules.seanet import SEANetEncoder, SEANetDecoder, SEANetResnetBlock
13
- from audiocraft.modules import StreamableConv1d, StreamableConvTranspose1d
14
-
15
-
16
- class TestSEANetModel:
17
-
18
- def test_base(self):
19
- encoder = SEANetEncoder()
20
- decoder = SEANetDecoder()
21
-
22
- x = torch.randn(1, 1, 24000)
23
- z = encoder(x)
24
- assert list(z.shape) == [1, 128, 75], z.shape
25
- y = decoder(z)
26
- assert y.shape == x.shape, (x.shape, y.shape)
27
-
28
- def test_causal(self):
29
- encoder = SEANetEncoder(causal=True)
30
- decoder = SEANetDecoder(causal=True)
31
- x = torch.randn(1, 1, 24000)
32
-
33
- z = encoder(x)
34
- assert list(z.shape) == [1, 128, 75], z.shape
35
- y = decoder(z)
36
- assert y.shape == x.shape, (x.shape, y.shape)
37
-
38
- def test_conv_skip_connection(self):
39
- encoder = SEANetEncoder(true_skip=False)
40
- decoder = SEANetDecoder(true_skip=False)
41
-
42
- x = torch.randn(1, 1, 24000)
43
- z = encoder(x)
44
- assert list(z.shape) == [1, 128, 75], z.shape
45
- y = decoder(z)
46
- assert y.shape == x.shape, (x.shape, y.shape)
47
-
48
- def test_seanet_encoder_decoder_final_act(self):
49
- encoder = SEANetEncoder(true_skip=False)
50
- decoder = SEANetDecoder(true_skip=False, final_activation='Tanh')
51
-
52
- x = torch.randn(1, 1, 24000)
53
- z = encoder(x)
54
- assert list(z.shape) == [1, 128, 75], z.shape
55
- y = decoder(z)
56
- assert y.shape == x.shape, (x.shape, y.shape)
57
-
58
- def _check_encoder_blocks_norm(self, encoder: SEANetEncoder, n_disable_blocks: int, norm: str):
59
- n_blocks = 0
60
- for layer in encoder.model:
61
- if isinstance(layer, StreamableConv1d):
62
- n_blocks += 1
63
- assert layer.conv.norm_type == 'none' if n_blocks <= n_disable_blocks else norm
64
- elif isinstance(layer, SEANetResnetBlock):
65
- for resnet_layer in layer.block:
66
- if isinstance(resnet_layer, StreamableConv1d):
67
- # here we add + 1 to n_blocks as we increment n_blocks just after the block
68
- assert resnet_layer.conv.norm_type == 'none' if (n_blocks + 1) <= n_disable_blocks else norm
69
-
70
- def test_encoder_disable_norm(self):
71
- n_residuals = [0, 1, 3]
72
- disable_blocks = [0, 1, 2, 3, 4, 5, 6]
73
- norms = ['weight_norm', 'none']
74
- for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms):
75
- encoder = SEANetEncoder(n_residual_layers=n_res, norm=norm,
76
- disable_norm_outer_blocks=disable_blocks)
77
- self._check_encoder_blocks_norm(encoder, disable_blocks, norm)
78
-
79
- def _check_decoder_blocks_norm(self, decoder: SEANetDecoder, n_disable_blocks: int, norm: str):
80
- n_blocks = 0
81
- for layer in decoder.model:
82
- if isinstance(layer, StreamableConv1d):
83
- n_blocks += 1
84
- assert layer.conv.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
85
- elif isinstance(layer, StreamableConvTranspose1d):
86
- n_blocks += 1
87
- assert layer.convtr.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
88
- elif isinstance(layer, SEANetResnetBlock):
89
- for resnet_layer in layer.block:
90
- if isinstance(resnet_layer, StreamableConv1d):
91
- assert resnet_layer.conv.norm_type == 'none' \
92
- if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
93
-
94
- def test_decoder_disable_norm(self):
95
- n_residuals = [0, 1, 3]
96
- disable_blocks = [0, 1, 2, 3, 4, 5, 6]
97
- norms = ['weight_norm', 'none']
98
- for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms):
99
- decoder = SEANetDecoder(n_residual_layers=n_res, norm=norm,
100
- disable_norm_outer_blocks=disable_blocks)
101
- self._check_decoder_blocks_norm(decoder, disable_blocks, norm)
102
-
103
- def test_disable_norm_raises_exception(self):
104
- # Invalid disable_norm_outer_blocks values raise exceptions
105
- with pytest.raises(AssertionError):
106
- SEANetEncoder(disable_norm_outer_blocks=-1)
107
-
108
- with pytest.raises(AssertionError):
109
- SEANetEncoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7)
110
-
111
- with pytest.raises(AssertionError):
112
- SEANetDecoder(disable_norm_outer_blocks=-1)
113
-
114
- with pytest.raises(AssertionError):
115
- SEANetDecoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/modules/test_transformer.py DELETED
@@ -1,247 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- from itertools import product
8
-
9
- import pytest
10
- import torch
11
-
12
- from audiocraft.modules.transformer import StreamingMultiheadAttention, StreamingTransformer
13
-
14
-
15
- def test_transformer_causal_streaming():
16
- torch.manual_seed(1234)
17
-
18
- for context, custom in product([None, 10], [False, True]):
19
- # Test that causality and receptive fields are properly handled.
20
- # looking at the gradients
21
- tr = StreamingTransformer(
22
- 16, 4, 1 if context else 2,
23
- causal=True, past_context=context, custom=custom,
24
- dropout=0.)
25
- steps = 20
26
- for k in [0, 10, 15, 19]:
27
- x = torch.randn(4, steps, 16, requires_grad=True)
28
- y = tr(x)
29
- y[:, k].abs().sum().backward()
30
- if k + 1 < steps:
31
- assert torch.allclose(x.grad[:, k + 1:], torch.tensor(0.)), x.grad[:, k + 1:].norm()
32
- assert not torch.allclose(x.grad[:, :k + 1], torch.tensor(0.)), x.grad[:, :k + 1].norm()
33
- if context is not None and k > context:
34
- limit = k - context - 1
35
- assert torch.allclose(x.grad[:, :limit],
36
- torch.tensor(0.)), x.grad[:, :limit].norm()
37
-
38
- # Now check that streaming gives the same result at batch eval.
39
- x = torch.randn(4, steps, 16)
40
- y = tr(x)
41
- ys = []
42
- with tr.streaming():
43
- for k in range(steps):
44
- chunk = x[:, k:k + 1, :]
45
- ys.append(tr(chunk))
46
- y_stream = torch.cat(ys, dim=1)
47
- delta = torch.norm(y_stream - y) / torch.norm(y)
48
- assert delta < 1e-6, delta
49
-
50
-
51
- def test_transformer_vs_pytorch():
52
- torch.manual_seed(1234)
53
- # Check that in the non causal setting, we get the same result as
54
- # PyTorch Transformer encoder.
55
- for custom in [False, True]:
56
- tr = StreamingTransformer(
57
- 16, 4, 2,
58
- causal=False, custom=custom, dropout=0., positional_scale=0.)
59
- layer = torch.nn.TransformerEncoderLayer(16, 4, dropout=0., batch_first=True)
60
- tr_ref = torch.nn.TransformerEncoder(layer, 2)
61
- tr.load_state_dict(tr_ref.state_dict())
62
-
63
- x = torch.randn(4, 20, 16)
64
- y = tr(x)
65
- y2 = tr_ref(x)
66
- delta = torch.norm(y2 - y) / torch.norm(y)
67
- assert delta < 1e-6, delta
68
-
69
-
70
- def test_streaming_api():
71
- tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0.)
72
- tr.eval()
73
- steps = 12
74
- x = torch.randn(1, steps, 16)
75
-
76
- with torch.no_grad():
77
- with tr.streaming():
78
- _ = tr(x[:, :1])
79
- state = {k: v.clone() for k, v in tr.get_streaming_state().items()}
80
- y = tr(x[:, 1:2])
81
- tr.set_streaming_state(state)
82
- y2 = tr(x[:, 1:2])
83
- assert torch.allclose(y, y2), (y - y2).norm()
84
- assert tr.flush() is None
85
-
86
-
87
- def test_memory_efficient():
88
- torch.manual_seed(1234)
89
- tr = StreamingTransformer(
90
- 16, 4, 2, custom=True, dropout=0., layer_scale=0.1)
91
- tr_mem_efficient = StreamingTransformer(
92
- 16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1)
93
- tr_mem_efficient.load_state_dict(tr.state_dict())
94
- tr.eval()
95
- steps = 12
96
- x = torch.randn(3, steps, 16)
97
-
98
- with torch.no_grad():
99
- y = tr(x)
100
- y2 = tr_mem_efficient(x)
101
- assert torch.allclose(y, y2), (y - y2).norm()
102
-
103
-
104
- def test_attention_as_float32():
105
- torch.manual_seed(1234)
106
- cases = [
107
- {'custom': True},
108
- {'custom': False},
109
- ]
110
- for case in cases:
111
- tr = StreamingTransformer(16, 4, 2, dropout=0., dtype=torch.bfloat16, **case)
112
- tr_float32 = StreamingTransformer(
113
- 16, 4, 2, dropout=0., attention_as_float32=True, dtype=torch.bfloat16, **case)
114
- if not case['custom']:
115
- # we are not using autocast here because it doesn't really
116
- # work as expected on CPU, so we have to manually cast the weights of the MHA.
117
- for layer in tr_float32.layers:
118
- layer.self_attn.mha.to(torch.float32)
119
- tr_float32.load_state_dict(tr.state_dict())
120
- steps = 12
121
- x = torch.randn(3, steps, 16, dtype=torch.bfloat16)
122
-
123
- with torch.no_grad():
124
- y = tr(x)
125
- y2 = tr_float32(x)
126
- assert not torch.allclose(y, y2), (y - y2).norm()
127
-
128
-
129
- @torch.no_grad()
130
- def test_streaming_memory_efficient():
131
- torch.manual_seed(1234)
132
- tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0., custom=True)
133
- tr_mem_efficient = StreamingTransformer(
134
- 16, 4, 2, dropout=0., memory_efficient=True, causal=True)
135
- tr.load_state_dict(tr_mem_efficient.state_dict())
136
- tr.eval()
137
- tr_mem_efficient.eval()
138
- steps = 12
139
- x = torch.randn(3, steps, 16)
140
-
141
- ref = tr(x)
142
-
143
- with tr_mem_efficient.streaming():
144
- outs = []
145
- # frame_sizes = [2] + [1] * (steps - 2)
146
- frame_sizes = [1] * steps
147
-
148
- for frame_size in frame_sizes:
149
- frame = x[:, :frame_size]
150
- x = x[:, frame_size:]
151
- outs.append(tr_mem_efficient(frame))
152
-
153
- out = torch.cat(outs, dim=1)
154
- delta = torch.norm(out - ref) / torch.norm(out)
155
- assert delta < 1e-6, delta
156
-
157
-
158
- def test_cross_attention():
159
- torch.manual_seed(1234)
160
- for norm_first in [True, False]:
161
- m = StreamingTransformer(
162
- 16, 4, 2, cross_attention=False, norm_first=norm_first, dropout=0., custom=True)
163
- m_cross = StreamingTransformer(
164
- 16, 4, 2, cross_attention=True, norm_first=norm_first, dropout=0., custom=True)
165
- m_cross.load_state_dict(m.state_dict(), strict=False)
166
- x = torch.randn(2, 5, 16)
167
- cross_x = torch.randn(2, 3, 16)
168
- y_ref = m(x)
169
- y_cross_zero = m_cross(x, cross_attention_src=0 * cross_x)
170
- # With norm_first, the two should be exactly yhe same,
171
- # but with norm_first=False, we get 2 normalization in a row
172
- # and the epsilon value leads to a tiny change.
173
- atol = 0. if norm_first else 1e-6
174
- print((y_ref - y_cross_zero).norm() / y_ref.norm())
175
- assert torch.allclose(y_ref, y_cross_zero, atol=atol)
176
-
177
- # We now expect a difference even with a generous atol of 1e-2.
178
- y_cross = m_cross(x, cross_attention_src=cross_x)
179
- assert not torch.allclose(y_cross, y_cross_zero, atol=1e-2)
180
-
181
- with pytest.raises(AssertionError):
182
- _ = m_cross(x)
183
- _ = m(x, cross_attention_src=cross_x)
184
-
185
-
186
- def test_cross_attention_compat():
187
- torch.manual_seed(1234)
188
- num_heads = 2
189
- dim = num_heads * 64
190
- with pytest.raises(AssertionError):
191
- StreamingMultiheadAttention(dim, num_heads, causal=True, cross_attention=True)
192
-
193
- cross_attn = StreamingMultiheadAttention(
194
- dim, num_heads, dropout=0, cross_attention=True, custom=True)
195
- ref_attn = torch.nn.MultiheadAttention(dim, num_heads, dropout=0, batch_first=True)
196
-
197
- # We can load the regular attention state dict
198
- # so we have compat when loading old checkpoints.
199
- cross_attn.load_state_dict(ref_attn.state_dict())
200
-
201
- queries = torch.randn(3, 7, dim)
202
- keys = torch.randn(3, 9, dim)
203
- values = torch.randn(3, 9, dim)
204
-
205
- y = cross_attn(queries, keys, values)[0]
206
- y_ref = ref_attn(queries, keys, values)[0]
207
- assert torch.allclose(y, y_ref, atol=1e-7)
208
-
209
- # Now let's check that streaming is working properly.
210
- with cross_attn.streaming():
211
- ys = []
212
- for step in range(queries.shape[1]):
213
- ys.append(cross_attn(queries[:, step: step + 1], keys, values)[0])
214
- y_streaming = torch.cat(ys, dim=1)
215
- assert torch.allclose(y_streaming, y, atol=1e-7)
216
-
217
-
218
- def test_repeat_kv():
219
- torch.manual_seed(1234)
220
- num_heads = 8
221
- kv_repeat = 4
222
- dim = num_heads * 64
223
- with pytest.raises(AssertionError):
224
- mha = StreamingMultiheadAttention(
225
- dim, num_heads, causal=True, kv_repeat=kv_repeat, cross_attention=True)
226
- mha = StreamingMultiheadAttention(
227
- dim, num_heads, causal=True, kv_repeat=kv_repeat)
228
- mha = StreamingMultiheadAttention(
229
- dim, num_heads, causal=True, kv_repeat=kv_repeat, custom=True)
230
- x = torch.randn(4, 18, dim)
231
- y = mha(x, x, x)[0]
232
- assert x.shape == y.shape
233
-
234
-
235
- def test_qk_layer_norm():
236
- torch.manual_seed(1234)
237
- tr = StreamingTransformer(
238
- 16, 4, 2, custom=True, dropout=0., qk_layer_norm=True, bias_attn=False)
239
- steps = 12
240
- x = torch.randn(3, steps, 16)
241
- y = tr(x)
242
-
243
- tr = StreamingTransformer(
244
- 16, 4, 2, custom=True, dropout=0., qk_layer_norm=True, cross_attention=True)
245
- z = torch.randn(3, 21, 16)
246
- y = tr(x, cross_attention_src=z)
247
- assert y.shape == x.shape
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/quantization/test_vq.py DELETED
@@ -1,18 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import torch
8
-
9
- from audiocraft.quantization.vq import ResidualVectorQuantizer
10
-
11
-
12
- class TestResidualVectorQuantizer:
13
-
14
- def test_rvq(self):
15
- x = torch.randn(1, 16, 2048)
16
- vq = ResidualVectorQuantizer(n_q=8, dimension=16, bins=8)
17
- res = vq(x, 1.)
18
- assert res.x.shape == torch.Size([1, 16, 2048])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/utils/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.