Spaces:
Running
on
T4
Running
on
T4
Initial Upload HF
Browse files- CHANGELOG.md +0 -18
- CODE_OF_CONDUCT.md +0 -80
- CONTRIBUTING.md +0 -35
- README.md +19 -15
- app_batched.py +0 -221
- tests/__init__.py +0 -5
- tests/common_utils/__init__.py +0 -9
- tests/common_utils/temp_utils.py +0 -56
- tests/common_utils/wav_utils.py +0 -32
- tests/data/__init__.py +0 -5
- tests/data/test_audio.py +0 -239
- tests/data/test_audio_dataset.py +0 -352
- tests/data/test_audio_utils.py +0 -110
- tests/models/test_encodec_model.py +0 -60
- tests/models/test_musicgen.py +0 -50
- tests/modules/__init__.py +0 -5
- tests/modules/test_codebooks_patterns.py +0 -246
- tests/modules/test_conv.py +0 -203
- tests/modules/test_lstm.py +0 -32
- tests/modules/test_rope.py +0 -160
- tests/modules/test_seanet.py +0 -115
- tests/modules/test_transformer.py +0 -247
- tests/quantization/test_vq.py +0 -18
- tests/utils/__init__.py +0 -5
CHANGELOG.md
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
# Changelog
|
2 |
-
|
3 |
-
All notable changes to this project will be documented in this file.
|
4 |
-
|
5 |
-
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
|
6 |
-
|
7 |
-
## [0.0.2a] - TBD
|
8 |
-
|
9 |
-
Improved demo, fixed top p (thanks @jnordberg).
|
10 |
-
|
11 |
-
Compressor tanh on output to avoid clipping with some style (especially piano).
|
12 |
-
Now repeating the conditioning periodically if it is too short.
|
13 |
-
|
14 |
-
More options when launching Gradio app locally (thanks @ashleykleynhans).
|
15 |
-
|
16 |
-
## [0.0.1] - 2023-06-09
|
17 |
-
|
18 |
-
Initial release, with model evaluation only.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CODE_OF_CONDUCT.md
DELETED
@@ -1,80 +0,0 @@
|
|
1 |
-
# Code of Conduct
|
2 |
-
|
3 |
-
## Our Pledge
|
4 |
-
|
5 |
-
In the interest of fostering an open and welcoming environment, we as
|
6 |
-
contributors and maintainers pledge to make participation in our project and
|
7 |
-
our community a harassment-free experience for everyone, regardless of age, body
|
8 |
-
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
9 |
-
level of experience, education, socio-economic status, nationality, personal
|
10 |
-
appearance, race, religion, or sexual identity and orientation.
|
11 |
-
|
12 |
-
## Our Standards
|
13 |
-
|
14 |
-
Examples of behavior that contributes to creating a positive environment
|
15 |
-
include:
|
16 |
-
|
17 |
-
* Using welcoming and inclusive language
|
18 |
-
* Being respectful of differing viewpoints and experiences
|
19 |
-
* Gracefully accepting constructive criticism
|
20 |
-
* Focusing on what is best for the community
|
21 |
-
* Showing empathy towards other community members
|
22 |
-
|
23 |
-
Examples of unacceptable behavior by participants include:
|
24 |
-
|
25 |
-
* The use of sexualized language or imagery and unwelcome sexual attention or
|
26 |
-
advances
|
27 |
-
* Trolling, insulting/derogatory comments, and personal or political attacks
|
28 |
-
* Public or private harassment
|
29 |
-
* Publishing others' private information, such as a physical or electronic
|
30 |
-
address, without explicit permission
|
31 |
-
* Other conduct which could reasonably be considered inappropriate in a
|
32 |
-
professional setting
|
33 |
-
|
34 |
-
## Our Responsibilities
|
35 |
-
|
36 |
-
Project maintainers are responsible for clarifying the standards of acceptable
|
37 |
-
behavior and are expected to take appropriate and fair corrective action in
|
38 |
-
response to any instances of unacceptable behavior.
|
39 |
-
|
40 |
-
Project maintainers have the right and responsibility to remove, edit, or
|
41 |
-
reject comments, commits, code, wiki edits, issues, and other contributions
|
42 |
-
that are not aligned to this Code of Conduct, or to ban temporarily or
|
43 |
-
permanently any contributor for other behaviors that they deem inappropriate,
|
44 |
-
threatening, offensive, or harmful.
|
45 |
-
|
46 |
-
## Scope
|
47 |
-
|
48 |
-
This Code of Conduct applies within all project spaces, and it also applies when
|
49 |
-
an individual is representing the project or its community in public spaces.
|
50 |
-
Examples of representing a project or community include using an official
|
51 |
-
project e-mail address, posting via an official social media account, or acting
|
52 |
-
as an appointed representative at an online or offline event. Representation of
|
53 |
-
a project may be further defined and clarified by project maintainers.
|
54 |
-
|
55 |
-
This Code of Conduct also applies outside the project spaces when there is a
|
56 |
-
reasonable belief that an individual's behavior may have a negative impact on
|
57 |
-
the project or its community.
|
58 |
-
|
59 |
-
## Enforcement
|
60 |
-
|
61 |
-
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
62 |
-
reported by contacting the project team at <opensource-conduct@fb.com>. All
|
63 |
-
complaints will be reviewed and investigated and will result in a response that
|
64 |
-
is deemed necessary and appropriate to the circumstances. The project team is
|
65 |
-
obligated to maintain confidentiality with regard to the reporter of an incident.
|
66 |
-
Further details of specific enforcement policies may be posted separately.
|
67 |
-
|
68 |
-
Project maintainers who do not follow or enforce the Code of Conduct in good
|
69 |
-
faith may face temporary or permanent repercussions as determined by other
|
70 |
-
members of the project's leadership.
|
71 |
-
|
72 |
-
## Attribution
|
73 |
-
|
74 |
-
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
75 |
-
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
76 |
-
|
77 |
-
[homepage]: https://www.contributor-covenant.org
|
78 |
-
|
79 |
-
For answers to common questions about this code of conduct, see
|
80 |
-
https://www.contributor-covenant.org/faq
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CONTRIBUTING.md
DELETED
@@ -1,35 +0,0 @@
|
|
1 |
-
# Contributing to Audiocraft
|
2 |
-
|
3 |
-
We want to make contributing to this project as easy and transparent as
|
4 |
-
possible.
|
5 |
-
|
6 |
-
## Pull Requests
|
7 |
-
|
8 |
-
Audiocraft is the implementation of a research paper.
|
9 |
-
Therefore, we do not plan on accepting many pull requests for new features.
|
10 |
-
We certainly welcome them for bug fixes.
|
11 |
-
|
12 |
-
1. Fork the repo and create your branch from `main`.
|
13 |
-
2. If you've added code that should be tested, add tests.
|
14 |
-
3. If you've changed APIs, update the documentation.
|
15 |
-
4. Ensure the test suite passes.
|
16 |
-
5. Make sure your code lints.
|
17 |
-
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
18 |
-
|
19 |
-
## Contributor License Agreement ("CLA")
|
20 |
-
In order to accept your pull request, we need you to submit a CLA. You only need
|
21 |
-
to do this once to work on any of Meta's open source projects.
|
22 |
-
|
23 |
-
Complete your CLA here: <https://code.facebook.com/cla>
|
24 |
-
|
25 |
-
## Issues
|
26 |
-
We use GitHub issues to track public bugs. Please ensure your description is
|
27 |
-
clear and has sufficient instructions to be able to reproduce the issue.
|
28 |
-
|
29 |
-
Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
30 |
-
disclosure of security bugs. In those cases, please go through the process
|
31 |
-
outlined on that page and do not file a public issue.
|
32 |
-
|
33 |
-
## License
|
34 |
-
By contributing to encodec, you agree that your contributions will be licensed
|
35 |
-
under the LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
@@ -1,3 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# Audiocraft
|
2 |
![docs badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_docs/badge.svg)
|
3 |
![linter badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_linter/badge.svg)
|
@@ -129,18 +148,3 @@ Yes. We will soon release the training code for MusicGen and EnCodec.
|
|
129 |
## License
|
130 |
* The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE).
|
131 |
* The weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights).
|
132 |
-
---
|
133 |
-
title: UnlimitedMusicGen
|
134 |
-
emoji: 🚀
|
135 |
-
colorFrom: pink
|
136 |
-
colorTo: red
|
137 |
-
sdk: gradio
|
138 |
-
sdk_version: 3.34.0
|
139 |
-
app_file: app.py
|
140 |
-
pinned: false
|
141 |
-
license: creativeml-openrail-m
|
142 |
-
---
|
143 |
-
|
144 |
-
[arxiv]: https://arxiv.org/abs/2306.05284
|
145 |
-
[musicgen_samples]: https://ai.honu.io/papers/musicgen/
|
146 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
---
|
2 |
+
title: UnlimitedMusicGen
|
3 |
+
emoji: 🚀
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: red
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.34.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: creativeml-openrail-m
|
11 |
+
---
|
12 |
+
|
13 |
+
[arxiv]: https://arxiv.org/abs/2306.05284
|
14 |
+
[musicgen_samples]: https://ai.honu.io/papers/musicgen/
|
15 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
16 |
+
|
17 |
+
# UnlimitedMusicGen
|
18 |
+
This is my modification of the Audiocraft project to enable unlimited Audio generation. I have added a few features to the original project to enable this. I have also added a few features to the gradio interface to make it easier to use.
|
19 |
+
|
20 |
# Audiocraft
|
21 |
![docs badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_docs/badge.svg)
|
22 |
![linter badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_linter/badge.svg)
|
|
|
148 |
## License
|
149 |
* The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE).
|
150 |
* The weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app_batched.py
DELETED
@@ -1,221 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
-
All rights reserved.
|
4 |
-
|
5 |
-
This source code is licensed under the license found in the
|
6 |
-
LICENSE file in the root directory of this source tree.
|
7 |
-
"""
|
8 |
-
|
9 |
-
import argparse
|
10 |
-
from concurrent.futures import ProcessPoolExecutor
|
11 |
-
import subprocess as sp
|
12 |
-
from tempfile import NamedTemporaryFile
|
13 |
-
import time
|
14 |
-
import warnings
|
15 |
-
import torch
|
16 |
-
import gradio as gr
|
17 |
-
from audiocraft.data.audio_utils import convert_audio
|
18 |
-
from audiocraft.data.audio import audio_write
|
19 |
-
from audiocraft.models import MusicGen
|
20 |
-
|
21 |
-
|
22 |
-
MODEL = None
|
23 |
-
|
24 |
-
_old_call = sp.call
|
25 |
-
|
26 |
-
|
27 |
-
def _call_nostderr(*args, **kwargs):
|
28 |
-
# Avoid ffmpeg vomitting on the logs.
|
29 |
-
kwargs['stderr'] = sp.DEVNULL
|
30 |
-
kwargs['stdout'] = sp.DEVNULL
|
31 |
-
_old_call(*args, **kwargs)
|
32 |
-
|
33 |
-
|
34 |
-
sp.call = _call_nostderr
|
35 |
-
pool = ProcessPoolExecutor(3)
|
36 |
-
pool.__enter__()
|
37 |
-
|
38 |
-
|
39 |
-
def make_waveform(*args, **kwargs):
|
40 |
-
be = time.time()
|
41 |
-
with warnings.catch_warnings():
|
42 |
-
warnings.simplefilter('ignore')
|
43 |
-
out = gr.make_waveform(*args, **kwargs)
|
44 |
-
print("Make a video took", time.time() - be)
|
45 |
-
return out
|
46 |
-
|
47 |
-
|
48 |
-
def load_model():
|
49 |
-
print("Loading model")
|
50 |
-
return MusicGen.get_pretrained("melody")
|
51 |
-
|
52 |
-
|
53 |
-
def predict(texts, melodies):
|
54 |
-
global MODEL
|
55 |
-
if MODEL is None:
|
56 |
-
MODEL = load_model()
|
57 |
-
|
58 |
-
duration = 12
|
59 |
-
max_text_length = 512
|
60 |
-
texts = [text[:max_text_length] for text in texts]
|
61 |
-
MODEL.set_generation_params(duration=duration)
|
62 |
-
|
63 |
-
print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
|
64 |
-
be = time.time()
|
65 |
-
processed_melodies = []
|
66 |
-
target_sr = 32000
|
67 |
-
target_ac = 1
|
68 |
-
for melody in melodies:
|
69 |
-
if melody is None:
|
70 |
-
processed_melodies.append(None)
|
71 |
-
else:
|
72 |
-
sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
|
73 |
-
if melody.dim() == 1:
|
74 |
-
melody = melody[None]
|
75 |
-
melody = melody[..., :int(sr * duration)]
|
76 |
-
melody = convert_audio(melody, sr, target_sr, target_ac)
|
77 |
-
processed_melodies.append(melody)
|
78 |
-
|
79 |
-
outputs = MODEL.generate_with_chroma(
|
80 |
-
descriptions=texts,
|
81 |
-
melody_wavs=processed_melodies,
|
82 |
-
melody_sample_rate=target_sr,
|
83 |
-
progress=False
|
84 |
-
)
|
85 |
-
|
86 |
-
outputs = outputs.detach().cpu().float()
|
87 |
-
out_files = []
|
88 |
-
for output in outputs:
|
89 |
-
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
90 |
-
audio_write(
|
91 |
-
file.name, output, MODEL.sample_rate, strategy="loudness",
|
92 |
-
loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
|
93 |
-
out_files.append(pool.submit(make_waveform, file.name))
|
94 |
-
res = [[out_file.result() for out_file in out_files]]
|
95 |
-
print("batch finished", len(texts), time.time() - be)
|
96 |
-
return res
|
97 |
-
|
98 |
-
|
99 |
-
def ui(**kwargs):
|
100 |
-
with gr.Blocks() as demo:
|
101 |
-
gr.Markdown(
|
102 |
-
"""
|
103 |
-
# MusicGen
|
104 |
-
|
105 |
-
This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
|
106 |
-
presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
|
107 |
-
<br/>
|
108 |
-
<a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
|
109 |
-
<img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
|
110 |
-
for longer sequences, more control and no queue.</p>
|
111 |
-
"""
|
112 |
-
)
|
113 |
-
with gr.Row():
|
114 |
-
with gr.Column():
|
115 |
-
with gr.Row():
|
116 |
-
text = gr.Text(label="Describe your music", lines=2, interactive=True)
|
117 |
-
melody = gr.Audio(source="upload", type="numpy", label="Condition on a melody (optional)", interactive=True)
|
118 |
-
with gr.Row():
|
119 |
-
submit = gr.Button("Generate")
|
120 |
-
with gr.Column():
|
121 |
-
output = gr.Video(label="Generated Music")
|
122 |
-
submit.click(predict, inputs=[text, melody], outputs=[output], batch=True, max_batch_size=8)
|
123 |
-
gr.Examples(
|
124 |
-
fn=predict,
|
125 |
-
examples=[
|
126 |
-
[
|
127 |
-
"An 80s driving pop song with heavy drums and synth pads in the background",
|
128 |
-
"./assets/bach.mp3",
|
129 |
-
],
|
130 |
-
[
|
131 |
-
"A cheerful country song with acoustic guitars",
|
132 |
-
"./assets/bolero_ravel.mp3",
|
133 |
-
],
|
134 |
-
[
|
135 |
-
"90s rock song with electric guitar and heavy drums",
|
136 |
-
None,
|
137 |
-
],
|
138 |
-
[
|
139 |
-
"a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
|
140 |
-
"./assets/bach.mp3",
|
141 |
-
],
|
142 |
-
[
|
143 |
-
"lofi slow bpm electro chill with organic samples",
|
144 |
-
None,
|
145 |
-
],
|
146 |
-
],
|
147 |
-
inputs=[text, melody],
|
148 |
-
outputs=[output]
|
149 |
-
)
|
150 |
-
gr.Markdown("""
|
151 |
-
### More details
|
152 |
-
|
153 |
-
The model will generate 12 seconds of audio based on the description you provided.
|
154 |
-
You can optionaly provide a reference audio from which a broad melody will be extracted.
|
155 |
-
The model will then try to follow both the description and melody provided.
|
156 |
-
All samples are generated with the `melody` model.
|
157 |
-
|
158 |
-
You can also use your own GPU or a Google Colab by following the instructions on our repo.
|
159 |
-
|
160 |
-
See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
|
161 |
-
for more details.
|
162 |
-
""")
|
163 |
-
|
164 |
-
# Show the interface
|
165 |
-
launch_kwargs = {}
|
166 |
-
username = kwargs.get('username')
|
167 |
-
password = kwargs.get('password')
|
168 |
-
server_port = kwargs.get('server_port', 0)
|
169 |
-
inbrowser = kwargs.get('inbrowser', False)
|
170 |
-
share = kwargs.get('share', False)
|
171 |
-
server_name = kwargs.get('listen')
|
172 |
-
|
173 |
-
launch_kwargs['server_name'] = server_name
|
174 |
-
|
175 |
-
if username and password:
|
176 |
-
launch_kwargs['auth'] = (username, password)
|
177 |
-
if server_port > 0:
|
178 |
-
launch_kwargs['server_port'] = server_port
|
179 |
-
if inbrowser:
|
180 |
-
launch_kwargs['inbrowser'] = inbrowser
|
181 |
-
if share:
|
182 |
-
launch_kwargs['share'] = share
|
183 |
-
demo.queue(max_size=60).launch(**launch_kwargs)
|
184 |
-
|
185 |
-
if __name__ == "__main__":
|
186 |
-
parser = argparse.ArgumentParser()
|
187 |
-
parser.add_argument(
|
188 |
-
'--listen',
|
189 |
-
type=str,
|
190 |
-
default='127.0.0.1',
|
191 |
-
help='IP to listen on for connections to Gradio',
|
192 |
-
)
|
193 |
-
parser.add_argument(
|
194 |
-
'--username', type=str, default='', help='Username for authentication'
|
195 |
-
)
|
196 |
-
parser.add_argument(
|
197 |
-
'--password', type=str, default='', help='Password for authentication'
|
198 |
-
)
|
199 |
-
parser.add_argument(
|
200 |
-
'--server_port',
|
201 |
-
type=int,
|
202 |
-
default=0,
|
203 |
-
help='Port to run the server listener on',
|
204 |
-
)
|
205 |
-
parser.add_argument(
|
206 |
-
'--inbrowser', action='store_true', help='Open in browser'
|
207 |
-
)
|
208 |
-
parser.add_argument(
|
209 |
-
'--share', action='store_true', help='Share the gradio UI'
|
210 |
-
)
|
211 |
-
|
212 |
-
args = parser.parse_args()
|
213 |
-
|
214 |
-
ui(
|
215 |
-
username=args.username,
|
216 |
-
password=args.password,
|
217 |
-
inbrowser=args.inbrowser,
|
218 |
-
server_port=args.server_port,
|
219 |
-
share=args.share,
|
220 |
-
listen=args.listen
|
221 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/__init__.py
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
|
|
tests/common_utils/__init__.py
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
# flake8: noqa
|
8 |
-
from .temp_utils import TempDirMixin
|
9 |
-
from .wav_utils import get_batch_white_noise, get_white_noise, save_wav
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/common_utils/temp_utils.py
DELETED
@@ -1,56 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import os
|
8 |
-
import tempfile
|
9 |
-
|
10 |
-
|
11 |
-
class TempDirMixin:
|
12 |
-
"""Mixin to provide easy access to temp dir.
|
13 |
-
"""
|
14 |
-
|
15 |
-
temp_dir_ = None
|
16 |
-
|
17 |
-
@classmethod
|
18 |
-
def get_base_temp_dir(cls):
|
19 |
-
# If AUDIOCRAFT_TEST_DIR is set, use it instead of temporary directory.
|
20 |
-
# this is handy for debugging.
|
21 |
-
key = "AUDIOCRAFT_TEST_DIR"
|
22 |
-
if key in os.environ:
|
23 |
-
return os.environ[key]
|
24 |
-
if cls.temp_dir_ is None:
|
25 |
-
cls.temp_dir_ = tempfile.TemporaryDirectory()
|
26 |
-
return cls.temp_dir_.name
|
27 |
-
|
28 |
-
@classmethod
|
29 |
-
def tearDownClass(cls):
|
30 |
-
if cls.temp_dir_ is not None:
|
31 |
-
try:
|
32 |
-
cls.temp_dir_.cleanup()
|
33 |
-
cls.temp_dir_ = None
|
34 |
-
except PermissionError:
|
35 |
-
# On Windows there is a know issue with `shutil.rmtree`,
|
36 |
-
# which fails intermittenly.
|
37 |
-
# https://github.com/python/cpython/issues/74168
|
38 |
-
# Following the above thread, we ignore it.
|
39 |
-
pass
|
40 |
-
super().tearDownClass()
|
41 |
-
|
42 |
-
@property
|
43 |
-
def id(self):
|
44 |
-
return self.__class__.__name__
|
45 |
-
|
46 |
-
def get_temp_path(self, *paths):
|
47 |
-
temp_dir = os.path.join(self.get_base_temp_dir(), self.id)
|
48 |
-
path = os.path.join(temp_dir, *paths)
|
49 |
-
os.makedirs(os.path.dirname(path), exist_ok=True)
|
50 |
-
return path
|
51 |
-
|
52 |
-
def get_temp_dir(self, *paths):
|
53 |
-
temp_dir = os.path.join(self.get_base_temp_dir(), self.id)
|
54 |
-
path = os.path.join(temp_dir, *paths)
|
55 |
-
os.makedirs(path, exist_ok=True)
|
56 |
-
return path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/common_utils/wav_utils.py
DELETED
@@ -1,32 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
from pathlib import Path
|
8 |
-
import typing as tp
|
9 |
-
|
10 |
-
import torch
|
11 |
-
import torchaudio
|
12 |
-
|
13 |
-
|
14 |
-
def get_white_noise(chs: int = 1, num_frames: int = 1):
|
15 |
-
wav = torch.randn(chs, num_frames)
|
16 |
-
return wav
|
17 |
-
|
18 |
-
|
19 |
-
def get_batch_white_noise(bs: int = 1, chs: int = 1, num_frames: int = 1):
|
20 |
-
wav = torch.randn(bs, chs, num_frames)
|
21 |
-
return wav
|
22 |
-
|
23 |
-
|
24 |
-
def save_wav(path: str, wav: torch.Tensor, sample_rate: int):
|
25 |
-
fp = Path(path)
|
26 |
-
kwargs: tp.Dict[str, tp.Any] = {}
|
27 |
-
if fp.suffix == '.wav':
|
28 |
-
kwargs['encoding'] = 'PCM_S'
|
29 |
-
kwargs['bits_per_sample'] = 16
|
30 |
-
elif fp.suffix == '.mp3':
|
31 |
-
kwargs['compression'] = 320
|
32 |
-
torchaudio.save(str(fp), wav, sample_rate, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/data/__init__.py
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
|
|
tests/data/test_audio.py
DELETED
@@ -1,239 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
from itertools import product
|
8 |
-
import random
|
9 |
-
|
10 |
-
import numpy as np
|
11 |
-
import torch
|
12 |
-
import torchaudio
|
13 |
-
|
14 |
-
from audiocraft.data.audio import audio_info, audio_read, audio_write, _av_read
|
15 |
-
|
16 |
-
from ..common_utils import TempDirMixin, get_white_noise, save_wav
|
17 |
-
|
18 |
-
|
19 |
-
class TestInfo(TempDirMixin):
|
20 |
-
|
21 |
-
def test_info_mp3(self):
|
22 |
-
sample_rates = [8000, 16_000]
|
23 |
-
channels = [1, 2]
|
24 |
-
duration = 1.
|
25 |
-
for sample_rate, ch in product(sample_rates, channels):
|
26 |
-
wav = get_white_noise(ch, int(sample_rate * duration))
|
27 |
-
path = self.get_temp_path('sample_wav.mp3')
|
28 |
-
save_wav(path, wav, sample_rate)
|
29 |
-
info = audio_info(path)
|
30 |
-
assert info.sample_rate == sample_rate
|
31 |
-
assert info.channels == ch
|
32 |
-
# we cannot trust torchaudio for num_frames, so we don't check
|
33 |
-
|
34 |
-
def _test_info_format(self, ext: str):
|
35 |
-
sample_rates = [8000, 16_000]
|
36 |
-
channels = [1, 2]
|
37 |
-
duration = 1.
|
38 |
-
for sample_rate, ch in product(sample_rates, channels):
|
39 |
-
n_frames = int(sample_rate * duration)
|
40 |
-
wav = get_white_noise(ch, n_frames)
|
41 |
-
path = self.get_temp_path(f'sample_wav{ext}')
|
42 |
-
save_wav(path, wav, sample_rate)
|
43 |
-
info = audio_info(path)
|
44 |
-
assert info.sample_rate == sample_rate
|
45 |
-
assert info.channels == ch
|
46 |
-
assert np.isclose(info.duration, duration, atol=1e-5)
|
47 |
-
|
48 |
-
def test_info_wav(self):
|
49 |
-
self._test_info_format('.wav')
|
50 |
-
|
51 |
-
def test_info_flac(self):
|
52 |
-
self._test_info_format('.flac')
|
53 |
-
|
54 |
-
def test_info_ogg(self):
|
55 |
-
self._test_info_format('.ogg')
|
56 |
-
|
57 |
-
def test_info_m4a(self):
|
58 |
-
# TODO: generate m4a file programmatically
|
59 |
-
# self._test_info_format('.m4a')
|
60 |
-
pass
|
61 |
-
|
62 |
-
|
63 |
-
class TestRead(TempDirMixin):
|
64 |
-
|
65 |
-
def test_read_full_wav(self):
|
66 |
-
sample_rates = [8000, 16_000]
|
67 |
-
channels = [1, 2]
|
68 |
-
duration = 1.
|
69 |
-
for sample_rate, ch in product(sample_rates, channels):
|
70 |
-
n_frames = int(sample_rate * duration)
|
71 |
-
wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99)
|
72 |
-
path = self.get_temp_path('sample_wav.wav')
|
73 |
-
save_wav(path, wav, sample_rate)
|
74 |
-
read_wav, read_sr = audio_read(path)
|
75 |
-
assert read_sr == sample_rate
|
76 |
-
assert read_wav.shape[0] == wav.shape[0]
|
77 |
-
assert read_wav.shape[1] == wav.shape[1]
|
78 |
-
assert torch.allclose(read_wav, wav, rtol=1e-03, atol=1e-04)
|
79 |
-
|
80 |
-
def test_read_partial_wav(self):
|
81 |
-
sample_rates = [8000, 16_000]
|
82 |
-
channels = [1, 2]
|
83 |
-
duration = 1.
|
84 |
-
read_duration = torch.rand(1).item()
|
85 |
-
for sample_rate, ch in product(sample_rates, channels):
|
86 |
-
n_frames = int(sample_rate * duration)
|
87 |
-
read_frames = int(sample_rate * read_duration)
|
88 |
-
wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99)
|
89 |
-
path = self.get_temp_path('sample_wav.wav')
|
90 |
-
save_wav(path, wav, sample_rate)
|
91 |
-
read_wav, read_sr = audio_read(path, 0, read_duration)
|
92 |
-
assert read_sr == sample_rate
|
93 |
-
assert read_wav.shape[0] == wav.shape[0]
|
94 |
-
assert read_wav.shape[1] == read_frames
|
95 |
-
assert torch.allclose(read_wav[..., 0:read_frames], wav[..., 0:read_frames], rtol=1e-03, atol=1e-04)
|
96 |
-
|
97 |
-
def test_read_seek_time_wav(self):
|
98 |
-
sample_rates = [8000, 16_000]
|
99 |
-
channels = [1, 2]
|
100 |
-
duration = 1.
|
101 |
-
read_duration = 1.
|
102 |
-
for sample_rate, ch in product(sample_rates, channels):
|
103 |
-
n_frames = int(sample_rate * duration)
|
104 |
-
wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99)
|
105 |
-
path = self.get_temp_path('sample_wav.wav')
|
106 |
-
save_wav(path, wav, sample_rate)
|
107 |
-
seek_time = torch.rand(1).item()
|
108 |
-
read_wav, read_sr = audio_read(path, seek_time, read_duration)
|
109 |
-
seek_frames = int(sample_rate * seek_time)
|
110 |
-
expected_frames = n_frames - seek_frames
|
111 |
-
assert read_sr == sample_rate
|
112 |
-
assert read_wav.shape[0] == wav.shape[0]
|
113 |
-
assert read_wav.shape[1] == expected_frames
|
114 |
-
assert torch.allclose(read_wav, wav[..., seek_frames:], rtol=1e-03, atol=1e-04)
|
115 |
-
|
116 |
-
def test_read_seek_time_wav_padded(self):
|
117 |
-
sample_rates = [8000, 16_000]
|
118 |
-
channels = [1, 2]
|
119 |
-
duration = 1.
|
120 |
-
read_duration = 1.
|
121 |
-
for sample_rate, ch in product(sample_rates, channels):
|
122 |
-
n_frames = int(sample_rate * duration)
|
123 |
-
read_frames = int(sample_rate * read_duration)
|
124 |
-
wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99)
|
125 |
-
path = self.get_temp_path('sample_wav.wav')
|
126 |
-
save_wav(path, wav, sample_rate)
|
127 |
-
seek_time = torch.rand(1).item()
|
128 |
-
seek_frames = int(sample_rate * seek_time)
|
129 |
-
expected_frames = n_frames - seek_frames
|
130 |
-
read_wav, read_sr = audio_read(path, seek_time, read_duration, pad=True)
|
131 |
-
expected_pad_wav = torch.zeros(wav.shape[0], read_frames - expected_frames)
|
132 |
-
assert read_sr == sample_rate
|
133 |
-
assert read_wav.shape[0] == wav.shape[0]
|
134 |
-
assert read_wav.shape[1] == read_frames
|
135 |
-
assert torch.allclose(read_wav[..., :expected_frames], wav[..., seek_frames:], rtol=1e-03, atol=1e-04)
|
136 |
-
assert torch.allclose(read_wav[..., expected_frames:], expected_pad_wav)
|
137 |
-
|
138 |
-
|
139 |
-
class TestAvRead(TempDirMixin):
|
140 |
-
|
141 |
-
def test_avread_seek_base(self):
|
142 |
-
sample_rates = [8000, 16_000]
|
143 |
-
channels = [1, 2]
|
144 |
-
duration = 2.
|
145 |
-
for sample_rate, ch in product(sample_rates, channels):
|
146 |
-
n_frames = int(sample_rate * duration)
|
147 |
-
wav = get_white_noise(ch, n_frames)
|
148 |
-
path = self.get_temp_path(f'reference_a_{sample_rate}_{ch}.wav')
|
149 |
-
save_wav(path, wav, sample_rate)
|
150 |
-
for _ in range(100):
|
151 |
-
# seek will always load a full duration segment in the file
|
152 |
-
seek_time = random.uniform(0.0, 1.0)
|
153 |
-
seek_duration = random.uniform(0.001, 1.0)
|
154 |
-
read_wav, read_sr = _av_read(path, seek_time, seek_duration)
|
155 |
-
assert read_sr == sample_rate
|
156 |
-
assert read_wav.shape[0] == wav.shape[0]
|
157 |
-
assert read_wav.shape[-1] == int(seek_duration * sample_rate)
|
158 |
-
|
159 |
-
def test_avread_seek_partial(self):
|
160 |
-
sample_rates = [8000, 16_000]
|
161 |
-
channels = [1, 2]
|
162 |
-
duration = 1.
|
163 |
-
for sample_rate, ch in product(sample_rates, channels):
|
164 |
-
n_frames = int(sample_rate * duration)
|
165 |
-
wav = get_white_noise(ch, n_frames)
|
166 |
-
path = self.get_temp_path(f'reference_b_{sample_rate}_{ch}.wav')
|
167 |
-
save_wav(path, wav, sample_rate)
|
168 |
-
for _ in range(100):
|
169 |
-
# seek will always load a partial segment
|
170 |
-
seek_time = random.uniform(0.5, 1.)
|
171 |
-
seek_duration = 1.
|
172 |
-
expected_num_frames = n_frames - int(seek_time * sample_rate)
|
173 |
-
read_wav, read_sr = _av_read(path, seek_time, seek_duration)
|
174 |
-
assert read_sr == sample_rate
|
175 |
-
assert read_wav.shape[0] == wav.shape[0]
|
176 |
-
assert read_wav.shape[-1] == expected_num_frames
|
177 |
-
|
178 |
-
def test_avread_seek_outofbound(self):
|
179 |
-
sample_rates = [8000, 16_000]
|
180 |
-
channels = [1, 2]
|
181 |
-
duration = 1.
|
182 |
-
for sample_rate, ch in product(sample_rates, channels):
|
183 |
-
n_frames = int(sample_rate * duration)
|
184 |
-
wav = get_white_noise(ch, n_frames)
|
185 |
-
path = self.get_temp_path(f'reference_c_{sample_rate}_{ch}.wav')
|
186 |
-
save_wav(path, wav, sample_rate)
|
187 |
-
seek_time = 1.5
|
188 |
-
read_wav, read_sr = _av_read(path, seek_time, 1.)
|
189 |
-
assert read_sr == sample_rate
|
190 |
-
assert read_wav.shape[0] == wav.shape[0]
|
191 |
-
assert read_wav.shape[-1] == 0
|
192 |
-
|
193 |
-
def test_avread_seek_edge(self):
|
194 |
-
sample_rates = [8000, 16_000]
|
195 |
-
# some of these values will have
|
196 |
-
# int(((frames - 1) / sample_rate) * sample_rate) != (frames - 1)
|
197 |
-
n_frames = [1000, 1001, 1002]
|
198 |
-
channels = [1, 2]
|
199 |
-
for sample_rate, ch, frames in product(sample_rates, channels, n_frames):
|
200 |
-
duration = frames / sample_rate
|
201 |
-
wav = get_white_noise(ch, frames)
|
202 |
-
path = self.get_temp_path(f'reference_d_{sample_rate}_{ch}.wav')
|
203 |
-
save_wav(path, wav, sample_rate)
|
204 |
-
seek_time = (frames - 1) / sample_rate
|
205 |
-
seek_frames = int(seek_time * sample_rate)
|
206 |
-
read_wav, read_sr = _av_read(path, seek_time, duration)
|
207 |
-
assert read_sr == sample_rate
|
208 |
-
assert read_wav.shape[0] == wav.shape[0]
|
209 |
-
assert read_wav.shape[-1] == (frames - seek_frames)
|
210 |
-
|
211 |
-
|
212 |
-
class TestAudioWrite(TempDirMixin):
|
213 |
-
|
214 |
-
def test_audio_write_wav(self):
|
215 |
-
torch.manual_seed(1234)
|
216 |
-
sample_rates = [8000, 16_000]
|
217 |
-
n_frames = [1000, 1001, 1002]
|
218 |
-
channels = [1, 2]
|
219 |
-
strategies = ["peak", "clip", "rms"]
|
220 |
-
formats = ["wav", "mp3"]
|
221 |
-
for sample_rate, ch, frames in product(sample_rates, channels, n_frames):
|
222 |
-
for format_, strategy in product(formats, strategies):
|
223 |
-
wav = get_white_noise(ch, frames)
|
224 |
-
path = self.get_temp_path(f'pred_{sample_rate}_{ch}')
|
225 |
-
audio_write(path, wav, sample_rate, format_, strategy=strategy)
|
226 |
-
read_wav, read_sr = torchaudio.load(f'{path}.{format_}')
|
227 |
-
if format_ == "wav":
|
228 |
-
assert read_wav.shape == wav.shape
|
229 |
-
|
230 |
-
if format_ == "wav" and strategy in ["peak", "rms"]:
|
231 |
-
rescaled_read_wav = read_wav / read_wav.abs().max() * wav.abs().max()
|
232 |
-
# for a Gaussian, the typical max scale will be less than ~5x the std.
|
233 |
-
# The error when writing to disk will ~ 1/2**15, and when rescaling, 5x that.
|
234 |
-
# For RMS target, rescaling leaves more headroom by default, leading
|
235 |
-
# to a 20x rescaling typically
|
236 |
-
atol = (5 if strategy == "peak" else 20) / 2**15
|
237 |
-
delta = (rescaled_read_wav - wav).abs().max()
|
238 |
-
assert torch.allclose(wav, rescaled_read_wav, rtol=0, atol=atol), (delta, atol)
|
239 |
-
formats = ["wav"] # faster unit tests
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/data/test_audio_dataset.py
DELETED
@@ -1,352 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
from functools import partial
|
8 |
-
from itertools import product
|
9 |
-
import json
|
10 |
-
import math
|
11 |
-
import os
|
12 |
-
import random
|
13 |
-
import typing as tp
|
14 |
-
|
15 |
-
import pytest
|
16 |
-
import torch
|
17 |
-
from torch.utils.data import DataLoader
|
18 |
-
|
19 |
-
from audiocraft.data.audio_dataset import (
|
20 |
-
AudioDataset,
|
21 |
-
AudioMeta,
|
22 |
-
_get_audio_meta,
|
23 |
-
load_audio_meta,
|
24 |
-
save_audio_meta
|
25 |
-
)
|
26 |
-
from audiocraft.data.zip import PathInZip
|
27 |
-
|
28 |
-
from ..common_utils import TempDirMixin, get_white_noise, save_wav
|
29 |
-
|
30 |
-
|
31 |
-
class TestAudioMeta(TempDirMixin):
|
32 |
-
|
33 |
-
def test_get_audio_meta(self):
|
34 |
-
sample_rates = [8000, 16_000]
|
35 |
-
channels = [1, 2]
|
36 |
-
duration = 1.
|
37 |
-
for sample_rate, ch in product(sample_rates, channels):
|
38 |
-
n_frames = int(duration * sample_rate)
|
39 |
-
wav = get_white_noise(ch, n_frames)
|
40 |
-
path = self.get_temp_path('sample.wav')
|
41 |
-
save_wav(path, wav, sample_rate)
|
42 |
-
m = _get_audio_meta(path, minimal=True)
|
43 |
-
assert m.path == path, 'path does not match'
|
44 |
-
assert m.sample_rate == sample_rate, 'sample rate does not match'
|
45 |
-
assert m.duration == duration, 'duration does not match'
|
46 |
-
assert m.amplitude is None
|
47 |
-
assert m.info_path is None
|
48 |
-
|
49 |
-
def test_save_audio_meta(self):
|
50 |
-
audio_meta = [
|
51 |
-
AudioMeta("mypath1", 1., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file1.json')),
|
52 |
-
AudioMeta("mypath2", 2., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file2.json'))
|
53 |
-
]
|
54 |
-
empty_audio_meta = []
|
55 |
-
for idx, meta in enumerate([audio_meta, empty_audio_meta]):
|
56 |
-
path = self.get_temp_path(f'data_{idx}_save.jsonl')
|
57 |
-
save_audio_meta(path, meta)
|
58 |
-
with open(path, 'r') as f:
|
59 |
-
lines = f.readlines()
|
60 |
-
read_meta = [AudioMeta.from_dict(json.loads(line)) for line in lines]
|
61 |
-
assert len(read_meta) == len(meta)
|
62 |
-
for m, read_m in zip(meta, read_meta):
|
63 |
-
assert m == read_m
|
64 |
-
|
65 |
-
def test_load_audio_meta(self):
|
66 |
-
try:
|
67 |
-
import dora
|
68 |
-
except ImportError:
|
69 |
-
dora = None # type: ignore
|
70 |
-
|
71 |
-
audio_meta = [
|
72 |
-
AudioMeta("mypath1", 1., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file1.json')),
|
73 |
-
AudioMeta("mypath2", 2., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file2.json'))
|
74 |
-
]
|
75 |
-
empty_meta = []
|
76 |
-
for idx, meta in enumerate([audio_meta, empty_meta]):
|
77 |
-
path = self.get_temp_path(f'data_{idx}_load.jsonl')
|
78 |
-
with open(path, 'w') as f:
|
79 |
-
for m in meta:
|
80 |
-
json_str = json.dumps(m.to_dict()) + '\n'
|
81 |
-
f.write(json_str)
|
82 |
-
read_meta = load_audio_meta(path)
|
83 |
-
assert len(read_meta) == len(meta)
|
84 |
-
for m, read_m in zip(meta, read_meta):
|
85 |
-
if dora:
|
86 |
-
m.path = dora.git_save.to_absolute_path(m.path)
|
87 |
-
assert m == read_m, f'original={m}, read={read_m}'
|
88 |
-
|
89 |
-
|
90 |
-
class TestAudioDataset(TempDirMixin):
|
91 |
-
|
92 |
-
def _create_audio_files(self,
|
93 |
-
root_name: str,
|
94 |
-
num_examples: int,
|
95 |
-
durations: tp.Union[float, tp.Tuple[float, float]] = (0.1, 1.),
|
96 |
-
sample_rate: int = 16_000,
|
97 |
-
channels: int = 1):
|
98 |
-
root_dir = self.get_temp_dir(root_name)
|
99 |
-
for i in range(num_examples):
|
100 |
-
if isinstance(durations, float):
|
101 |
-
duration = durations
|
102 |
-
elif isinstance(durations, tuple) and len(durations) == 1:
|
103 |
-
duration = durations[0]
|
104 |
-
elif isinstance(durations, tuple) and len(durations) == 2:
|
105 |
-
duration = random.uniform(durations[0], durations[1])
|
106 |
-
else:
|
107 |
-
assert False
|
108 |
-
n_frames = int(duration * sample_rate)
|
109 |
-
wav = get_white_noise(channels, n_frames)
|
110 |
-
path = os.path.join(root_dir, f'example_{i}.wav')
|
111 |
-
save_wav(path, wav, sample_rate)
|
112 |
-
return root_dir
|
113 |
-
|
114 |
-
def _create_audio_dataset(self,
|
115 |
-
root_name: str,
|
116 |
-
total_num_examples: int,
|
117 |
-
durations: tp.Union[float, tp.Tuple[float, float]] = (0.1, 1.),
|
118 |
-
sample_rate: int = 16_000,
|
119 |
-
channels: int = 1,
|
120 |
-
segment_duration: tp.Optional[float] = None,
|
121 |
-
num_examples: int = 10,
|
122 |
-
shuffle: bool = True,
|
123 |
-
return_info: bool = False):
|
124 |
-
root_dir = self._create_audio_files(root_name, total_num_examples, durations, sample_rate, channels)
|
125 |
-
dataset = AudioDataset.from_path(root_dir,
|
126 |
-
minimal_meta=True,
|
127 |
-
segment_duration=segment_duration,
|
128 |
-
num_samples=num_examples,
|
129 |
-
sample_rate=sample_rate,
|
130 |
-
channels=channels,
|
131 |
-
shuffle=shuffle,
|
132 |
-
return_info=return_info)
|
133 |
-
return dataset
|
134 |
-
|
135 |
-
def test_dataset_full(self):
|
136 |
-
total_examples = 10
|
137 |
-
min_duration, max_duration = 1., 4.
|
138 |
-
sample_rate = 16_000
|
139 |
-
channels = 1
|
140 |
-
dataset = self._create_audio_dataset(
|
141 |
-
'dset', total_examples, durations=(min_duration, max_duration),
|
142 |
-
sample_rate=sample_rate, channels=channels, segment_duration=None)
|
143 |
-
assert len(dataset) == total_examples
|
144 |
-
assert dataset.sample_rate == sample_rate
|
145 |
-
assert dataset.channels == channels
|
146 |
-
for idx in range(len(dataset)):
|
147 |
-
sample = dataset[idx]
|
148 |
-
assert sample.shape[0] == channels
|
149 |
-
assert sample.shape[1] <= int(max_duration * sample_rate)
|
150 |
-
assert sample.shape[1] >= int(min_duration * sample_rate)
|
151 |
-
|
152 |
-
def test_dataset_segment(self):
|
153 |
-
total_examples = 10
|
154 |
-
num_samples = 20
|
155 |
-
min_duration, max_duration = 1., 4.
|
156 |
-
segment_duration = 1.
|
157 |
-
sample_rate = 16_000
|
158 |
-
channels = 1
|
159 |
-
dataset = self._create_audio_dataset(
|
160 |
-
'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
|
161 |
-
channels=channels, segment_duration=segment_duration, num_examples=num_samples)
|
162 |
-
assert len(dataset) == num_samples
|
163 |
-
assert dataset.sample_rate == sample_rate
|
164 |
-
assert dataset.channels == channels
|
165 |
-
for idx in range(len(dataset)):
|
166 |
-
sample = dataset[idx]
|
167 |
-
assert sample.shape[0] == channels
|
168 |
-
assert sample.shape[1] == int(segment_duration * sample_rate)
|
169 |
-
|
170 |
-
def test_dataset_equal_audio_and_segment_durations(self):
|
171 |
-
total_examples = 1
|
172 |
-
num_samples = 2
|
173 |
-
audio_duration = 1.
|
174 |
-
segment_duration = 1.
|
175 |
-
sample_rate = 16_000
|
176 |
-
channels = 1
|
177 |
-
dataset = self._create_audio_dataset(
|
178 |
-
'dset', total_examples, durations=audio_duration, sample_rate=sample_rate,
|
179 |
-
channels=channels, segment_duration=segment_duration, num_examples=num_samples)
|
180 |
-
assert len(dataset) == num_samples
|
181 |
-
assert dataset.sample_rate == sample_rate
|
182 |
-
assert dataset.channels == channels
|
183 |
-
for idx in range(len(dataset)):
|
184 |
-
sample = dataset[idx]
|
185 |
-
assert sample.shape[0] == channels
|
186 |
-
assert sample.shape[1] == int(segment_duration * sample_rate)
|
187 |
-
# the random seek_time adds variability on audio read
|
188 |
-
sample_1 = dataset[0]
|
189 |
-
sample_2 = dataset[1]
|
190 |
-
assert not torch.allclose(sample_1, sample_2)
|
191 |
-
|
192 |
-
def test_dataset_samples(self):
|
193 |
-
total_examples = 1
|
194 |
-
num_samples = 2
|
195 |
-
audio_duration = 1.
|
196 |
-
segment_duration = 1.
|
197 |
-
sample_rate = 16_000
|
198 |
-
channels = 1
|
199 |
-
|
200 |
-
create_dataset = partial(
|
201 |
-
self._create_audio_dataset,
|
202 |
-
'dset', total_examples, durations=audio_duration, sample_rate=sample_rate,
|
203 |
-
channels=channels, segment_duration=segment_duration, num_examples=num_samples,
|
204 |
-
)
|
205 |
-
|
206 |
-
dataset = create_dataset(shuffle=True)
|
207 |
-
# when shuffle = True, we have different inputs for the same index across epoch
|
208 |
-
sample_1 = dataset[0]
|
209 |
-
sample_2 = dataset[0]
|
210 |
-
assert not torch.allclose(sample_1, sample_2)
|
211 |
-
|
212 |
-
dataset_noshuffle = create_dataset(shuffle=False)
|
213 |
-
# when shuffle = False, we have same inputs for the same index across epoch
|
214 |
-
sample_1 = dataset_noshuffle[0]
|
215 |
-
sample_2 = dataset_noshuffle[0]
|
216 |
-
assert torch.allclose(sample_1, sample_2)
|
217 |
-
|
218 |
-
def test_dataset_return_info(self):
|
219 |
-
total_examples = 10
|
220 |
-
num_samples = 20
|
221 |
-
min_duration, max_duration = 1., 4.
|
222 |
-
segment_duration = 1.
|
223 |
-
sample_rate = 16_000
|
224 |
-
channels = 1
|
225 |
-
dataset = self._create_audio_dataset(
|
226 |
-
'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
|
227 |
-
channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=True)
|
228 |
-
assert len(dataset) == num_samples
|
229 |
-
assert dataset.sample_rate == sample_rate
|
230 |
-
assert dataset.channels == channels
|
231 |
-
for idx in range(len(dataset)):
|
232 |
-
sample, segment_info = dataset[idx]
|
233 |
-
assert sample.shape[0] == channels
|
234 |
-
assert sample.shape[1] == int(segment_duration * sample_rate)
|
235 |
-
assert segment_info.sample_rate == sample_rate
|
236 |
-
assert segment_info.total_frames == int(segment_duration * sample_rate)
|
237 |
-
assert segment_info.n_frames <= int(segment_duration * sample_rate)
|
238 |
-
assert segment_info.seek_time >= 0
|
239 |
-
|
240 |
-
def test_dataset_return_info_no_segment_duration(self):
|
241 |
-
total_examples = 10
|
242 |
-
num_samples = 20
|
243 |
-
min_duration, max_duration = 1., 4.
|
244 |
-
segment_duration = None
|
245 |
-
sample_rate = 16_000
|
246 |
-
channels = 1
|
247 |
-
dataset = self._create_audio_dataset(
|
248 |
-
'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
|
249 |
-
channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=True)
|
250 |
-
assert len(dataset) == total_examples
|
251 |
-
assert dataset.sample_rate == sample_rate
|
252 |
-
assert dataset.channels == channels
|
253 |
-
for idx in range(len(dataset)):
|
254 |
-
sample, segment_info = dataset[idx]
|
255 |
-
assert sample.shape[0] == channels
|
256 |
-
assert sample.shape[1] == segment_info.total_frames
|
257 |
-
assert segment_info.sample_rate == sample_rate
|
258 |
-
assert segment_info.n_frames <= segment_info.total_frames
|
259 |
-
|
260 |
-
def test_dataset_collate_fn(self):
|
261 |
-
total_examples = 10
|
262 |
-
num_samples = 20
|
263 |
-
min_duration, max_duration = 1., 4.
|
264 |
-
segment_duration = 1.
|
265 |
-
sample_rate = 16_000
|
266 |
-
channels = 1
|
267 |
-
dataset = self._create_audio_dataset(
|
268 |
-
'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
|
269 |
-
channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=False)
|
270 |
-
batch_size = 4
|
271 |
-
dataloader = DataLoader(
|
272 |
-
dataset,
|
273 |
-
batch_size=batch_size,
|
274 |
-
num_workers=0
|
275 |
-
)
|
276 |
-
for idx, batch in enumerate(dataloader):
|
277 |
-
assert batch.shape[0] == batch_size
|
278 |
-
|
279 |
-
@pytest.mark.parametrize("segment_duration", [1.0, None])
|
280 |
-
def test_dataset_with_meta_collate_fn(self, segment_duration):
|
281 |
-
total_examples = 10
|
282 |
-
num_samples = 20
|
283 |
-
min_duration, max_duration = 1., 4.
|
284 |
-
segment_duration = 1.
|
285 |
-
sample_rate = 16_000
|
286 |
-
channels = 1
|
287 |
-
dataset = self._create_audio_dataset(
|
288 |
-
'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
|
289 |
-
channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=True)
|
290 |
-
batch_size = 4
|
291 |
-
dataloader = DataLoader(
|
292 |
-
dataset,
|
293 |
-
batch_size=batch_size,
|
294 |
-
collate_fn=dataset.collater,
|
295 |
-
num_workers=0
|
296 |
-
)
|
297 |
-
for idx, batch in enumerate(dataloader):
|
298 |
-
wav, infos = batch
|
299 |
-
assert wav.shape[0] == batch_size
|
300 |
-
assert len(infos) == batch_size
|
301 |
-
|
302 |
-
@pytest.mark.parametrize("segment_duration,sample_on_weight,sample_on_duration,a_hist,b_hist,c_hist", [
|
303 |
-
[1, True, True, 0.5, 0.5, 0.0],
|
304 |
-
[1, False, True, 0.25, 0.5, 0.25],
|
305 |
-
[1, True, False, 0.666, 0.333, 0.0],
|
306 |
-
[1, False, False, 0.333, 0.333, 0.333],
|
307 |
-
[None, False, False, 0.333, 0.333, 0.333]])
|
308 |
-
def test_sample_with_weight(self, segment_duration, sample_on_weight, sample_on_duration, a_hist, b_hist, c_hist):
|
309 |
-
random.seed(1234)
|
310 |
-
rng = torch.Generator()
|
311 |
-
rng.manual_seed(1234)
|
312 |
-
|
313 |
-
def _get_histogram(dataset, repetitions=20_000):
|
314 |
-
counts = {file_meta.path: 0. for file_meta in meta}
|
315 |
-
for _ in range(repetitions):
|
316 |
-
file_meta = dataset.sample_file(rng)
|
317 |
-
counts[file_meta.path] += 1
|
318 |
-
return {name: count / repetitions for name, count in counts.items()}
|
319 |
-
|
320 |
-
meta = [
|
321 |
-
AudioMeta(path='a', duration=5, sample_rate=1, weight=2),
|
322 |
-
AudioMeta(path='b', duration=10, sample_rate=1, weight=None),
|
323 |
-
AudioMeta(path='c', duration=5, sample_rate=1, weight=0),
|
324 |
-
]
|
325 |
-
dataset = AudioDataset(
|
326 |
-
meta, segment_duration=segment_duration, sample_on_weight=sample_on_weight,
|
327 |
-
sample_on_duration=sample_on_duration)
|
328 |
-
hist = _get_histogram(dataset)
|
329 |
-
assert math.isclose(hist['a'], a_hist, abs_tol=0.01)
|
330 |
-
assert math.isclose(hist['b'], b_hist, abs_tol=0.01)
|
331 |
-
assert math.isclose(hist['c'], c_hist, abs_tol=0.01)
|
332 |
-
|
333 |
-
def test_meta_duration_filter_all(self):
|
334 |
-
meta = [
|
335 |
-
AudioMeta(path='a', duration=5, sample_rate=1, weight=2),
|
336 |
-
AudioMeta(path='b', duration=10, sample_rate=1, weight=None),
|
337 |
-
AudioMeta(path='c', duration=5, sample_rate=1, weight=0),
|
338 |
-
]
|
339 |
-
try:
|
340 |
-
AudioDataset(meta, segment_duration=11, min_segment_ratio=1)
|
341 |
-
assert False
|
342 |
-
except AssertionError:
|
343 |
-
assert True
|
344 |
-
|
345 |
-
def test_meta_duration_filter_long(self):
|
346 |
-
meta = [
|
347 |
-
AudioMeta(path='a', duration=5, sample_rate=1, weight=2),
|
348 |
-
AudioMeta(path='b', duration=10, sample_rate=1, weight=None),
|
349 |
-
AudioMeta(path='c', duration=5, sample_rate=1, weight=0),
|
350 |
-
]
|
351 |
-
dataset = AudioDataset(meta, segment_duration=None, min_segment_ratio=1, max_audio_duration=7)
|
352 |
-
assert len(dataset) == 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/data/test_audio_utils.py
DELETED
@@ -1,110 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import julius
|
8 |
-
import torch
|
9 |
-
import pytest
|
10 |
-
|
11 |
-
from audiocraft.data.audio_utils import (
|
12 |
-
_clip_wav,
|
13 |
-
convert_audio_channels,
|
14 |
-
convert_audio,
|
15 |
-
normalize_audio
|
16 |
-
)
|
17 |
-
from ..common_utils import get_batch_white_noise
|
18 |
-
|
19 |
-
|
20 |
-
class TestConvertAudioChannels:
|
21 |
-
|
22 |
-
def test_convert_audio_channels_downmix(self):
|
23 |
-
b, c, t = 2, 3, 100
|
24 |
-
audio = get_batch_white_noise(b, c, t)
|
25 |
-
mixed = convert_audio_channels(audio, channels=2)
|
26 |
-
assert list(mixed.shape) == [b, 2, t]
|
27 |
-
|
28 |
-
def test_convert_audio_channels_nochange(self):
|
29 |
-
b, c, t = 2, 3, 100
|
30 |
-
audio = get_batch_white_noise(b, c, t)
|
31 |
-
mixed = convert_audio_channels(audio, channels=c)
|
32 |
-
assert list(mixed.shape) == list(audio.shape)
|
33 |
-
|
34 |
-
def test_convert_audio_channels_upmix(self):
|
35 |
-
b, c, t = 2, 1, 100
|
36 |
-
audio = get_batch_white_noise(b, c, t)
|
37 |
-
mixed = convert_audio_channels(audio, channels=3)
|
38 |
-
assert list(mixed.shape) == [b, 3, t]
|
39 |
-
|
40 |
-
def test_convert_audio_channels_upmix_error(self):
|
41 |
-
b, c, t = 2, 2, 100
|
42 |
-
audio = get_batch_white_noise(b, c, t)
|
43 |
-
with pytest.raises(ValueError):
|
44 |
-
convert_audio_channels(audio, channels=3)
|
45 |
-
|
46 |
-
|
47 |
-
class TestConvertAudio:
|
48 |
-
|
49 |
-
def test_convert_audio_channels_downmix(self):
|
50 |
-
b, c, dur = 2, 3, 4.
|
51 |
-
sr = 128
|
52 |
-
audio = get_batch_white_noise(b, c, int(sr * dur))
|
53 |
-
out = convert_audio(audio, from_rate=sr, to_rate=sr, to_channels=2)
|
54 |
-
assert list(out.shape) == [audio.shape[0], 2, audio.shape[-1]]
|
55 |
-
|
56 |
-
def test_convert_audio_channels_upmix(self):
|
57 |
-
b, c, dur = 2, 1, 4.
|
58 |
-
sr = 128
|
59 |
-
audio = get_batch_white_noise(b, c, int(sr * dur))
|
60 |
-
out = convert_audio(audio, from_rate=sr, to_rate=sr, to_channels=3)
|
61 |
-
assert list(out.shape) == [audio.shape[0], 3, audio.shape[-1]]
|
62 |
-
|
63 |
-
def test_convert_audio_upsample(self):
|
64 |
-
b, c, dur = 2, 1, 4.
|
65 |
-
sr = 2
|
66 |
-
new_sr = 3
|
67 |
-
audio = get_batch_white_noise(b, c, int(sr * dur))
|
68 |
-
out = convert_audio(audio, from_rate=sr, to_rate=new_sr, to_channels=c)
|
69 |
-
out_j = julius.resample.resample_frac(audio, old_sr=sr, new_sr=new_sr)
|
70 |
-
assert torch.allclose(out, out_j)
|
71 |
-
|
72 |
-
def test_convert_audio_resample(self):
|
73 |
-
b, c, dur = 2, 1, 4.
|
74 |
-
sr = 3
|
75 |
-
new_sr = 2
|
76 |
-
audio = get_batch_white_noise(b, c, int(sr * dur))
|
77 |
-
out = convert_audio(audio, from_rate=sr, to_rate=new_sr, to_channels=c)
|
78 |
-
out_j = julius.resample.resample_frac(audio, old_sr=sr, new_sr=new_sr)
|
79 |
-
assert torch.allclose(out, out_j)
|
80 |
-
|
81 |
-
|
82 |
-
class TestNormalizeAudio:
|
83 |
-
|
84 |
-
def test_clip_wav(self):
|
85 |
-
b, c, dur = 2, 1, 4.
|
86 |
-
sr = 3
|
87 |
-
audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
|
88 |
-
_clip_wav(audio)
|
89 |
-
assert audio.abs().max() <= 1
|
90 |
-
|
91 |
-
def test_normalize_audio_clip(self):
|
92 |
-
b, c, dur = 2, 1, 4.
|
93 |
-
sr = 3
|
94 |
-
audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
|
95 |
-
norm_audio = normalize_audio(audio, strategy='clip')
|
96 |
-
assert norm_audio.abs().max() <= 1
|
97 |
-
|
98 |
-
def test_normalize_audio_rms(self):
|
99 |
-
b, c, dur = 2, 1, 4.
|
100 |
-
sr = 3
|
101 |
-
audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
|
102 |
-
norm_audio = normalize_audio(audio, strategy='rms')
|
103 |
-
assert norm_audio.abs().max() <= 1
|
104 |
-
|
105 |
-
def test_normalize_audio_peak(self):
|
106 |
-
b, c, dur = 2, 1, 4.
|
107 |
-
sr = 3
|
108 |
-
audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
|
109 |
-
norm_audio = normalize_audio(audio, strategy='peak')
|
110 |
-
assert norm_audio.abs().max() <= 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/models/test_encodec_model.py
DELETED
@@ -1,60 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import random
|
8 |
-
|
9 |
-
import numpy as np
|
10 |
-
import torch
|
11 |
-
|
12 |
-
from audiocraft.models import EncodecModel
|
13 |
-
from audiocraft.modules import SEANetEncoder, SEANetDecoder
|
14 |
-
from audiocraft.quantization import DummyQuantizer
|
15 |
-
|
16 |
-
|
17 |
-
class TestEncodecModel:
|
18 |
-
|
19 |
-
def _create_encodec_model(self,
|
20 |
-
sample_rate: int,
|
21 |
-
channels: int,
|
22 |
-
dim: int = 5,
|
23 |
-
n_filters: int = 3,
|
24 |
-
n_residual_layers: int = 1,
|
25 |
-
ratios: list = [5, 4, 3, 2],
|
26 |
-
**kwargs):
|
27 |
-
frame_rate = np.prod(ratios)
|
28 |
-
encoder = SEANetEncoder(channels=channels, dimension=dim, n_filters=n_filters,
|
29 |
-
n_residual_layers=n_residual_layers, ratios=ratios)
|
30 |
-
decoder = SEANetDecoder(channels=channels, dimension=dim, n_filters=n_filters,
|
31 |
-
n_residual_layers=n_residual_layers, ratios=ratios)
|
32 |
-
quantizer = DummyQuantizer()
|
33 |
-
model = EncodecModel(encoder, decoder, quantizer, frame_rate=frame_rate,
|
34 |
-
sample_rate=sample_rate, channels=channels, **kwargs)
|
35 |
-
return model
|
36 |
-
|
37 |
-
def test_model(self):
|
38 |
-
random.seed(1234)
|
39 |
-
sample_rate = 24_000
|
40 |
-
channels = 1
|
41 |
-
model = self._create_encodec_model(sample_rate, channels)
|
42 |
-
for _ in range(10):
|
43 |
-
length = random.randrange(1, 10_000)
|
44 |
-
x = torch.randn(2, channels, length)
|
45 |
-
res = model(x)
|
46 |
-
assert res.x.shape == x.shape
|
47 |
-
|
48 |
-
def test_model_renorm(self):
|
49 |
-
random.seed(1234)
|
50 |
-
sample_rate = 24_000
|
51 |
-
channels = 1
|
52 |
-
model_nonorm = self._create_encodec_model(sample_rate, channels, renormalize=False)
|
53 |
-
model_renorm = self._create_encodec_model(sample_rate, channels, renormalize=True)
|
54 |
-
|
55 |
-
for _ in range(10):
|
56 |
-
length = random.randrange(1, 10_000)
|
57 |
-
x = torch.randn(2, channels, length)
|
58 |
-
codes, scales = model_nonorm.encode(x)
|
59 |
-
codes, scales = model_renorm.encode(x)
|
60 |
-
assert scales is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/models/test_musicgen.py
DELETED
@@ -1,50 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import pytest
|
8 |
-
import torch
|
9 |
-
|
10 |
-
from audiocraft.models import MusicGen
|
11 |
-
|
12 |
-
|
13 |
-
class TestSEANetModel:
|
14 |
-
def get_musicgen(self):
|
15 |
-
mg = MusicGen.get_pretrained(name='debug', device='cpu')
|
16 |
-
mg.set_generation_params(duration=2.0)
|
17 |
-
return mg
|
18 |
-
|
19 |
-
def test_base(self):
|
20 |
-
mg = self.get_musicgen()
|
21 |
-
assert mg.frame_rate == 25
|
22 |
-
assert mg.sample_rate == 32000
|
23 |
-
assert mg.audio_channels == 1
|
24 |
-
|
25 |
-
def test_generate_unconditional(self):
|
26 |
-
mg = self.get_musicgen()
|
27 |
-
wav = mg.generate_unconditional(3)
|
28 |
-
assert list(wav.shape) == [3, 1, 64000]
|
29 |
-
|
30 |
-
def test_generate_continuation(self):
|
31 |
-
mg = self.get_musicgen()
|
32 |
-
prompt = torch.randn(3, 1, 32000)
|
33 |
-
wav = mg.generate_continuation(prompt, 32000)
|
34 |
-
assert list(wav.shape) == [3, 1, 64000]
|
35 |
-
|
36 |
-
prompt = torch.randn(2, 1, 32000)
|
37 |
-
wav = mg.generate_continuation(
|
38 |
-
prompt, 32000, ['youpi', 'lapin dort'])
|
39 |
-
assert list(wav.shape) == [2, 1, 64000]
|
40 |
-
|
41 |
-
prompt = torch.randn(2, 1, 32000)
|
42 |
-
with pytest.raises(AssertionError):
|
43 |
-
wav = mg.generate_continuation(
|
44 |
-
prompt, 32000, ['youpi', 'lapin dort', 'one too many'])
|
45 |
-
|
46 |
-
def test_generate(self):
|
47 |
-
mg = self.get_musicgen()
|
48 |
-
wav = mg.generate(
|
49 |
-
['youpi', 'lapin dort'])
|
50 |
-
assert list(wav.shape) == [2, 1, 64000]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/modules/__init__.py
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
|
|
tests/modules/test_codebooks_patterns.py
DELETED
@@ -1,246 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import pytest
|
8 |
-
import torch
|
9 |
-
|
10 |
-
from audiocraft.modules.codebooks_patterns import (
|
11 |
-
DelayedPatternProvider,
|
12 |
-
ParallelPatternProvider,
|
13 |
-
Pattern,
|
14 |
-
UnrolledPatternProvider,
|
15 |
-
)
|
16 |
-
|
17 |
-
|
18 |
-
class TestParallelPatternProvider:
|
19 |
-
|
20 |
-
@pytest.mark.parametrize("n_q", [1, 4, 32])
|
21 |
-
@pytest.mark.parametrize("timesteps", [0, 1, 16, 100])
|
22 |
-
def test_get_pattern(self, n_q: int, timesteps: int):
|
23 |
-
provider = ParallelPatternProvider(n_q)
|
24 |
-
pattern = provider.get_pattern(timesteps)
|
25 |
-
# + 1 to account for 1st step
|
26 |
-
assert len(pattern.layout) == timesteps + 1
|
27 |
-
|
28 |
-
@pytest.mark.parametrize("n_q", [1, 4, 32])
|
29 |
-
@pytest.mark.parametrize("timesteps", [8, 16, 100])
|
30 |
-
def test_pattern_content(self, n_q: int, timesteps: int):
|
31 |
-
provider = ParallelPatternProvider(n_q)
|
32 |
-
pattern = provider.get_pattern(timesteps)
|
33 |
-
for s, v in enumerate(pattern.layout):
|
34 |
-
for i, code in enumerate(v):
|
35 |
-
assert i == code.q
|
36 |
-
assert code.t == s - 1 # account for the 1st empty step
|
37 |
-
|
38 |
-
@pytest.mark.parametrize("n_q", [1, 4, 32])
|
39 |
-
@pytest.mark.parametrize("timesteps", [8, 16, 100])
|
40 |
-
def test_pattern_max_delay(self, n_q: int, timesteps: int):
|
41 |
-
provider = ParallelPatternProvider(n_q)
|
42 |
-
pattern = provider.get_pattern(timesteps)
|
43 |
-
assert pattern.max_delay == 0
|
44 |
-
assert len(pattern.valid_layout) == len(pattern.layout) - pattern.max_delay
|
45 |
-
|
46 |
-
|
47 |
-
class TestDelayedPatternProvider:
|
48 |
-
|
49 |
-
@pytest.mark.parametrize("n_q", [1, 4, 32])
|
50 |
-
@pytest.mark.parametrize("timesteps", [0, 1, 16, 100])
|
51 |
-
def test_get_pattern(self, n_q: int, timesteps: int):
|
52 |
-
delays = [
|
53 |
-
list(range(n_q)),
|
54 |
-
[0] + [1] * (n_q - 1),
|
55 |
-
[0] + [4] * (n_q - 1),
|
56 |
-
]
|
57 |
-
for delay in delays:
|
58 |
-
provider = DelayedPatternProvider(n_q, delay)
|
59 |
-
pattern = provider.get_pattern(timesteps)
|
60 |
-
# + 1 to account for 1st step
|
61 |
-
assert len(pattern.layout) == timesteps + max(delay) + 1
|
62 |
-
|
63 |
-
@pytest.mark.parametrize("n_q", [1, 4, 32])
|
64 |
-
@pytest.mark.parametrize("timesteps", [8, 16, 100])
|
65 |
-
def test_pattern_content(self, n_q: int, timesteps: int):
|
66 |
-
provider = DelayedPatternProvider(n_q)
|
67 |
-
pattern = provider.get_pattern(timesteps)
|
68 |
-
for s, v in enumerate(pattern.layout):
|
69 |
-
for i, code in enumerate(v):
|
70 |
-
assert i == code.q
|
71 |
-
assert code.t == max(0, s - code.q - 1)
|
72 |
-
|
73 |
-
@pytest.mark.parametrize("timesteps", [8, 16, 100])
|
74 |
-
@pytest.mark.parametrize("delay", [[0, 1, 2, 3], [0, 1, 1, 1], [0, 3, 3, 3], [0, 3]])
|
75 |
-
def test_pattern_max_delay(self, timesteps: int, delay: list):
|
76 |
-
provider = DelayedPatternProvider(len(delay), delay)
|
77 |
-
pattern = provider.get_pattern(timesteps)
|
78 |
-
assert pattern.max_delay == max(delay)
|
79 |
-
assert len(pattern.valid_layout) == len(pattern.layout) - pattern.max_delay
|
80 |
-
|
81 |
-
|
82 |
-
class TestUnrolledPatternProvider:
|
83 |
-
|
84 |
-
@pytest.mark.parametrize("timesteps", [0, 1, 16])
|
85 |
-
@pytest.mark.parametrize("flattening", [[0, 1, 2], [0, 1, 1]])
|
86 |
-
@pytest.mark.parametrize("delays", [[0, 0, 0], [0, 5, 5]])
|
87 |
-
def test_get_pattern(self, timesteps: int, flattening: list, delays: list):
|
88 |
-
n_q = len(flattening)
|
89 |
-
max_delay = max(delays)
|
90 |
-
provider = UnrolledPatternProvider(n_q, flattening, delays)
|
91 |
-
pattern = provider.get_pattern(timesteps)
|
92 |
-
assert len(pattern.layout) == provider.num_virtual_steps(timesteps) + max_delay
|
93 |
-
|
94 |
-
@pytest.mark.parametrize("timesteps", [0, 1, 16])
|
95 |
-
@pytest.mark.parametrize("flattening", [[0, 1, 2], [0, 1, 1]])
|
96 |
-
@pytest.mark.parametrize("delays", [[0, 0, 0], [0, 5, 5]])
|
97 |
-
def test_pattern_max_delay(self, timesteps: int, flattening: list, delays: list):
|
98 |
-
n_q = len(flattening)
|
99 |
-
max_delay = max(delays)
|
100 |
-
provider = UnrolledPatternProvider(n_q, flattening, delays)
|
101 |
-
pattern = provider.get_pattern(timesteps)
|
102 |
-
assert pattern.max_delay == max_delay
|
103 |
-
|
104 |
-
|
105 |
-
class TestPattern:
|
106 |
-
|
107 |
-
def ref_build_pattern_sequence(self, z: torch.Tensor, pattern: Pattern, special_token: int):
|
108 |
-
"""Reference method to build the sequence from the pattern without using fancy scatter."""
|
109 |
-
bs, n_q, T = z.shape
|
110 |
-
z = z.cpu().numpy()
|
111 |
-
assert n_q == pattern.n_q
|
112 |
-
assert T <= pattern.timesteps
|
113 |
-
inp = torch.full((bs, n_q, len(pattern.layout)), special_token, dtype=torch.long).numpy()
|
114 |
-
inp[:] = special_token
|
115 |
-
for s, v in enumerate(pattern.layout):
|
116 |
-
for (t, q) in v:
|
117 |
-
if t < T:
|
118 |
-
inp[:, q, s] = z[:, q, t]
|
119 |
-
return torch.from_numpy(inp)
|
120 |
-
|
121 |
-
def ref_revert_pattern_sequence(self, z: torch.Tensor, pattern: Pattern, special_token: int):
|
122 |
-
"""Reference method to revert the sequence from the pattern without using fancy scatter."""
|
123 |
-
z = z.cpu().numpy()
|
124 |
-
bs, n_q, S = z.shape
|
125 |
-
assert pattern.n_q == n_q
|
126 |
-
inp = torch.full((bs, pattern.n_q, pattern.timesteps), special_token, dtype=torch.long).numpy()
|
127 |
-
inp[:] = special_token
|
128 |
-
for s, v in enumerate(pattern.layout):
|
129 |
-
for (t, q) in v:
|
130 |
-
if t < pattern.timesteps:
|
131 |
-
inp[:, q, t] = z[:, q, s]
|
132 |
-
return torch.from_numpy(inp)
|
133 |
-
|
134 |
-
def ref_revert_pattern_logits(self, z: torch.Tensor, pattern: Pattern, special_token: float):
|
135 |
-
"""Reference method to revert the logits from the pattern without using fancy scatter."""
|
136 |
-
z = z.cpu().numpy()
|
137 |
-
bs, card, n_q, S = z.shape
|
138 |
-
assert pattern.n_q == n_q
|
139 |
-
ref_layout = pattern.layout
|
140 |
-
inp = torch.full((bs, card, pattern.n_q, pattern.timesteps), special_token, dtype=torch.float).numpy()
|
141 |
-
inp[:] = special_token
|
142 |
-
for s, v in enumerate(ref_layout[1:]):
|
143 |
-
if s < S:
|
144 |
-
for (t, q) in v:
|
145 |
-
if t < pattern.timesteps:
|
146 |
-
inp[:, :, q, t] = z[:, :, q, s]
|
147 |
-
return torch.from_numpy(inp)
|
148 |
-
|
149 |
-
def _get_pattern_providers(self, n_q: int):
|
150 |
-
pattern_provider_1 = ParallelPatternProvider(n_q)
|
151 |
-
pattern_provider_2 = DelayedPatternProvider(n_q, list(range(n_q)))
|
152 |
-
pattern_provider_3 = DelayedPatternProvider(n_q, [0] + [1] * (n_q - 1))
|
153 |
-
pattern_provider_4 = UnrolledPatternProvider(
|
154 |
-
n_q, flattening=list(range(n_q)), delays=[0] * n_q
|
155 |
-
)
|
156 |
-
pattern_provider_5 = UnrolledPatternProvider(
|
157 |
-
n_q, flattening=[0] + [1] * (n_q - 1), delays=[0] * n_q
|
158 |
-
)
|
159 |
-
pattern_provider_6 = UnrolledPatternProvider(
|
160 |
-
n_q, flattening=[0] + [1] * (n_q - 1), delays=[0] + [5] * (n_q - 1)
|
161 |
-
)
|
162 |
-
return [
|
163 |
-
pattern_provider_1,
|
164 |
-
pattern_provider_2,
|
165 |
-
pattern_provider_3,
|
166 |
-
pattern_provider_4,
|
167 |
-
pattern_provider_5,
|
168 |
-
pattern_provider_6,
|
169 |
-
]
|
170 |
-
|
171 |
-
@pytest.mark.parametrize("n_q", [1, 4, 32])
|
172 |
-
@pytest.mark.parametrize("timesteps", [16, 72])
|
173 |
-
def test_build_pattern_sequence(self, n_q: int, timesteps: int):
|
174 |
-
bs = 2
|
175 |
-
card = 256
|
176 |
-
special_token = card
|
177 |
-
|
178 |
-
pattern_providers = self._get_pattern_providers(n_q)
|
179 |
-
for pattern_provider in pattern_providers:
|
180 |
-
pattern = pattern_provider.get_pattern(timesteps)
|
181 |
-
# we can correctly build the sequence from the pattern
|
182 |
-
z = torch.randint(0, card, (bs, n_q, timesteps))
|
183 |
-
ref_res = self.ref_build_pattern_sequence(z, pattern, special_token)
|
184 |
-
res, indexes, mask = pattern.build_pattern_sequence(z, special_token)
|
185 |
-
assert (res == ref_res).float().mean() == 1.0
|
186 |
-
|
187 |
-
# expected assertion fails on the number of timesteps
|
188 |
-
invalid_timesteps = [timesteps + 1]
|
189 |
-
if pattern.num_sequence_steps != pattern.timesteps:
|
190 |
-
invalid_timesteps.append(pattern.num_sequence_steps)
|
191 |
-
for i_timesteps in invalid_timesteps:
|
192 |
-
z2 = torch.randint(0, card, (bs, n_q, i_timesteps))
|
193 |
-
with pytest.raises(AssertionError):
|
194 |
-
pattern.build_pattern_sequence(z2, special_token)
|
195 |
-
|
196 |
-
# expected assertion fails on the number of codebooks
|
197 |
-
invalid_qs = [0, n_q - 1, n_q + 1]
|
198 |
-
for i_q in invalid_qs:
|
199 |
-
z3 = torch.randint(0, card, (bs, i_q, timesteps))
|
200 |
-
with pytest.raises(AssertionError):
|
201 |
-
pattern.build_pattern_sequence(z3, special_token)
|
202 |
-
|
203 |
-
@pytest.mark.parametrize("n_q", [1, 4, 32])
|
204 |
-
@pytest.mark.parametrize("timesteps", [16, 72])
|
205 |
-
def test_revert_pattern_sequence(self, n_q: int, timesteps: int):
|
206 |
-
bs = 2
|
207 |
-
card = 256
|
208 |
-
special_token = card
|
209 |
-
|
210 |
-
pattern_providers = self._get_pattern_providers(n_q)
|
211 |
-
for pattern_provider in pattern_providers:
|
212 |
-
pattern = pattern_provider.get_pattern(timesteps)
|
213 |
-
# this works assuming previous tests are successful
|
214 |
-
z = torch.randint(0, card, (bs, n_q, timesteps))
|
215 |
-
s = self.ref_build_pattern_sequence(z, pattern, special_token)
|
216 |
-
ref_out = self.ref_revert_pattern_sequence(s, pattern, special_token)
|
217 |
-
# ensure our reference script retrieve the original sequence
|
218 |
-
assert z.shape == ref_out.shape
|
219 |
-
assert (z == ref_out).float().mean() == 1.0
|
220 |
-
# now we can test the scatter version
|
221 |
-
out, indexes, mask = pattern.revert_pattern_sequence(s, special_token)
|
222 |
-
assert out.shape == ref_out.shape
|
223 |
-
assert (out == ref_out).float().mean() == 1.0
|
224 |
-
|
225 |
-
@pytest.mark.parametrize("n_q", [1, 4, 32])
|
226 |
-
@pytest.mark.parametrize("timesteps", [16, 72])
|
227 |
-
@pytest.mark.parametrize("card", [1, 2, 256, 1024])
|
228 |
-
def test_revert_pattern_logits(self, n_q: int, timesteps: int, card: int):
|
229 |
-
bs = 2
|
230 |
-
special_token = card
|
231 |
-
logits_special_token = float('nan')
|
232 |
-
|
233 |
-
pattern_providers = self._get_pattern_providers(n_q)
|
234 |
-
for pattern_provider in pattern_providers:
|
235 |
-
pattern = pattern_provider.get_pattern(timesteps)
|
236 |
-
# this works assuming previous tests are successful
|
237 |
-
z = torch.randint(0, card, (bs, n_q, timesteps))
|
238 |
-
s = self.ref_build_pattern_sequence(z, pattern, special_token)
|
239 |
-
logits = torch.randn((bs, card, n_q, s.shape[-1]))
|
240 |
-
ref_out = self.ref_revert_pattern_logits(logits, pattern, logits_special_token)
|
241 |
-
# ensure our reference script retrieve the original sequence
|
242 |
-
assert ref_out.shape == torch.Size([bs, card, n_q, timesteps])
|
243 |
-
# now we can test the scatter version
|
244 |
-
out, indexes, mask = pattern.revert_pattern_logits(logits, logits_special_token)
|
245 |
-
assert out.shape == ref_out.shape
|
246 |
-
assert (out == ref_out).float().mean() == 1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/modules/test_conv.py
DELETED
@@ -1,203 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
from itertools import product
|
8 |
-
import math
|
9 |
-
import random
|
10 |
-
|
11 |
-
import pytest
|
12 |
-
import torch
|
13 |
-
from torch import nn
|
14 |
-
|
15 |
-
from audiocraft.modules import (
|
16 |
-
NormConv1d,
|
17 |
-
NormConvTranspose1d,
|
18 |
-
StreamableConv1d,
|
19 |
-
StreamableConvTranspose1d,
|
20 |
-
pad1d,
|
21 |
-
unpad1d,
|
22 |
-
)
|
23 |
-
|
24 |
-
|
25 |
-
def test_get_extra_padding_for_conv1d():
|
26 |
-
# TODO: Implement me!
|
27 |
-
pass
|
28 |
-
|
29 |
-
|
30 |
-
def test_pad1d_zeros():
|
31 |
-
x = torch.randn(1, 1, 20)
|
32 |
-
|
33 |
-
xp1 = pad1d(x, (0, 5), mode='constant', value=0.)
|
34 |
-
assert xp1.shape[-1] == 25
|
35 |
-
xp2 = pad1d(x, (5, 5), mode='constant', value=0.)
|
36 |
-
assert xp2.shape[-1] == 30
|
37 |
-
xp3 = pad1d(x, (0, 0), mode='constant', value=0.)
|
38 |
-
assert xp3.shape[-1] == 20
|
39 |
-
xp4 = pad1d(x, (10, 30), mode='constant', value=0.)
|
40 |
-
assert xp4.shape[-1] == 60
|
41 |
-
|
42 |
-
with pytest.raises(AssertionError):
|
43 |
-
pad1d(x, (-1, 0), mode='constant', value=0.)
|
44 |
-
|
45 |
-
with pytest.raises(AssertionError):
|
46 |
-
pad1d(x, (0, -1), mode='constant', value=0.)
|
47 |
-
|
48 |
-
with pytest.raises(AssertionError):
|
49 |
-
pad1d(x, (-1, -1), mode='constant', value=0.)
|
50 |
-
|
51 |
-
|
52 |
-
def test_pad1d_reflect():
|
53 |
-
x = torch.randn(1, 1, 20)
|
54 |
-
|
55 |
-
xp1 = pad1d(x, (0, 5), mode='reflect', value=0.)
|
56 |
-
assert xp1.shape[-1] == 25
|
57 |
-
xp2 = pad1d(x, (5, 5), mode='reflect', value=0.)
|
58 |
-
assert xp2.shape[-1] == 30
|
59 |
-
xp3 = pad1d(x, (0, 0), mode='reflect', value=0.)
|
60 |
-
assert xp3.shape[-1] == 20
|
61 |
-
xp4 = pad1d(x, (10, 30), mode='reflect', value=0.)
|
62 |
-
assert xp4.shape[-1] == 60
|
63 |
-
|
64 |
-
with pytest.raises(AssertionError):
|
65 |
-
pad1d(x, (-1, 0), mode='reflect', value=0.)
|
66 |
-
|
67 |
-
with pytest.raises(AssertionError):
|
68 |
-
pad1d(x, (0, -1), mode='reflect', value=0.)
|
69 |
-
|
70 |
-
with pytest.raises(AssertionError):
|
71 |
-
pad1d(x, (-1, -1), mode='reflect', value=0.)
|
72 |
-
|
73 |
-
|
74 |
-
def test_unpad1d():
|
75 |
-
x = torch.randn(1, 1, 20)
|
76 |
-
|
77 |
-
u1 = unpad1d(x, (5, 5))
|
78 |
-
assert u1.shape[-1] == 10
|
79 |
-
u2 = unpad1d(x, (0, 5))
|
80 |
-
assert u2.shape[-1] == 15
|
81 |
-
u3 = unpad1d(x, (5, 0))
|
82 |
-
assert u3.shape[-1] == 15
|
83 |
-
u4 = unpad1d(x, (0, 0))
|
84 |
-
assert u4.shape[-1] == x.shape[-1]
|
85 |
-
|
86 |
-
with pytest.raises(AssertionError):
|
87 |
-
unpad1d(x, (-1, 0))
|
88 |
-
|
89 |
-
with pytest.raises(AssertionError):
|
90 |
-
unpad1d(x, (0, -1))
|
91 |
-
|
92 |
-
with pytest.raises(AssertionError):
|
93 |
-
unpad1d(x, (-1, -1))
|
94 |
-
|
95 |
-
|
96 |
-
class TestNormConv1d:
|
97 |
-
|
98 |
-
def test_norm_conv1d_modules(self):
|
99 |
-
N, C, T = 2, 2, random.randrange(1, 100_000)
|
100 |
-
t0 = torch.randn(N, C, T)
|
101 |
-
|
102 |
-
C_out, kernel_size, stride = 1, 4, 1
|
103 |
-
expected_out_length = int((T - kernel_size) / stride + 1)
|
104 |
-
wn_conv = NormConv1d(C, 1, kernel_size=4, norm='weight_norm')
|
105 |
-
gn_conv = NormConv1d(C, 1, kernel_size=4, norm='time_group_norm')
|
106 |
-
nn_conv = NormConv1d(C, 1, kernel_size=4, norm='none')
|
107 |
-
|
108 |
-
assert isinstance(wn_conv.norm, nn.Identity)
|
109 |
-
assert isinstance(wn_conv.conv, nn.Conv1d)
|
110 |
-
|
111 |
-
assert isinstance(gn_conv.norm, nn.GroupNorm)
|
112 |
-
assert isinstance(gn_conv.conv, nn.Conv1d)
|
113 |
-
|
114 |
-
assert isinstance(nn_conv.norm, nn.Identity)
|
115 |
-
assert isinstance(nn_conv.conv, nn.Conv1d)
|
116 |
-
|
117 |
-
for conv_layer in [wn_conv, gn_conv, nn_conv]:
|
118 |
-
out = conv_layer(t0)
|
119 |
-
assert isinstance(out, torch.Tensor)
|
120 |
-
assert list(out.shape) == [N, C_out, expected_out_length]
|
121 |
-
|
122 |
-
|
123 |
-
class TestNormConvTranspose1d:
|
124 |
-
|
125 |
-
def test_normalizations(self):
|
126 |
-
N, C, T = 2, 2, random.randrange(1, 100_000)
|
127 |
-
t0 = torch.randn(N, C, T)
|
128 |
-
|
129 |
-
C_out, kernel_size, stride = 1, 4, 1
|
130 |
-
expected_out_length = (T - 1) * stride + (kernel_size - 1) + 1
|
131 |
-
|
132 |
-
wn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='weight_norm')
|
133 |
-
gn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='time_group_norm')
|
134 |
-
nn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='none')
|
135 |
-
|
136 |
-
assert isinstance(wn_convtr.norm, nn.Identity)
|
137 |
-
assert isinstance(wn_convtr.convtr, nn.ConvTranspose1d)
|
138 |
-
|
139 |
-
assert isinstance(gn_convtr.norm, nn.GroupNorm)
|
140 |
-
assert isinstance(gn_convtr.convtr, nn.ConvTranspose1d)
|
141 |
-
|
142 |
-
assert isinstance(nn_convtr.norm, nn.Identity)
|
143 |
-
assert isinstance(nn_convtr.convtr, nn.ConvTranspose1d)
|
144 |
-
|
145 |
-
for convtr_layer in [wn_convtr, gn_convtr, nn_convtr]:
|
146 |
-
out = convtr_layer(t0)
|
147 |
-
assert isinstance(out, torch.Tensor)
|
148 |
-
assert list(out.shape) == [N, C_out, expected_out_length]
|
149 |
-
|
150 |
-
|
151 |
-
class TestStreamableConv1d:
|
152 |
-
|
153 |
-
def get_streamable_conv1d_output_length(self, length, kernel_size, stride, dilation):
|
154 |
-
# StreamableConv1d internally pads to make sure that the last window is full
|
155 |
-
padding_total = (kernel_size - 1) * dilation - (stride - 1)
|
156 |
-
n_frames = (length - kernel_size + padding_total) / stride + 1
|
157 |
-
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
158 |
-
return ideal_length // stride
|
159 |
-
|
160 |
-
def test_streamable_conv1d(self):
|
161 |
-
N, C, T = 2, 2, random.randrange(1, 100_000)
|
162 |
-
t0 = torch.randn(N, C, T)
|
163 |
-
C_out = 1
|
164 |
-
|
165 |
-
# conv params are [(kernel_size, stride, dilation)]
|
166 |
-
conv_params = [(4, 1, 1), (4, 2, 1), (3, 1, 3), (10, 5, 1), (3, 2, 3)]
|
167 |
-
for causal, (kernel_size, stride, dilation) in product([False, True], conv_params):
|
168 |
-
expected_out_length = self.get_streamable_conv1d_output_length(T, kernel_size, stride, dilation)
|
169 |
-
sconv = StreamableConv1d(C, C_out, kernel_size=kernel_size, stride=stride, dilation=dilation, causal=causal)
|
170 |
-
out = sconv(t0)
|
171 |
-
assert isinstance(out, torch.Tensor)
|
172 |
-
print(list(out.shape), [N, C_out, expected_out_length])
|
173 |
-
assert list(out.shape) == [N, C_out, expected_out_length]
|
174 |
-
|
175 |
-
|
176 |
-
class TestStreamableConvTranspose1d:
|
177 |
-
|
178 |
-
def get_streamable_convtr1d_output_length(self, length, kernel_size, stride):
|
179 |
-
padding_total = (kernel_size - stride)
|
180 |
-
return (length - 1) * stride - padding_total + (kernel_size - 1) + 1
|
181 |
-
|
182 |
-
def test_streamable_convtr1d(self):
|
183 |
-
N, C, T = 2, 2, random.randrange(1, 100_000)
|
184 |
-
t0 = torch.randn(N, C, T)
|
185 |
-
|
186 |
-
C_out = 1
|
187 |
-
|
188 |
-
with pytest.raises(AssertionError):
|
189 |
-
StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=False, trim_right_ratio=0.5)
|
190 |
-
StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=True, trim_right_ratio=-1.)
|
191 |
-
StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=True, trim_right_ratio=2)
|
192 |
-
|
193 |
-
# causal params are [(causal, trim_right)]
|
194 |
-
causal_params = [(False, 1.0), (True, 1.0), (True, 0.5), (True, 0.0)]
|
195 |
-
# conv params are [(kernel_size, stride)]
|
196 |
-
conv_params = [(4, 1), (4, 2), (3, 1), (10, 5)]
|
197 |
-
for ((causal, trim_right_ratio), (kernel_size, stride)) in product(causal_params, conv_params):
|
198 |
-
expected_out_length = self.get_streamable_convtr1d_output_length(T, kernel_size, stride)
|
199 |
-
sconvtr = StreamableConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride,
|
200 |
-
causal=causal, trim_right_ratio=trim_right_ratio)
|
201 |
-
out = sconvtr(t0)
|
202 |
-
assert isinstance(out, torch.Tensor)
|
203 |
-
assert list(out.shape) == [N, C_out, expected_out_length]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/modules/test_lstm.py
DELETED
@@ -1,32 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import random
|
8 |
-
import torch
|
9 |
-
|
10 |
-
from audiocraft.modules.lstm import StreamableLSTM
|
11 |
-
|
12 |
-
|
13 |
-
class TestStreamableLSTM:
|
14 |
-
|
15 |
-
def test_lstm(self):
|
16 |
-
B, C, T = 4, 2, random.randint(1, 100)
|
17 |
-
|
18 |
-
lstm = StreamableLSTM(C, 3, skip=False)
|
19 |
-
x = torch.randn(B, C, T)
|
20 |
-
y = lstm(x)
|
21 |
-
|
22 |
-
print(y.shape)
|
23 |
-
assert y.shape == torch.Size([B, C, T])
|
24 |
-
|
25 |
-
def test_lstm_skip(self):
|
26 |
-
B, C, T = 4, 2, random.randint(1, 100)
|
27 |
-
|
28 |
-
lstm = StreamableLSTM(C, 3, skip=True)
|
29 |
-
x = torch.randn(B, C, T)
|
30 |
-
y = lstm(x)
|
31 |
-
|
32 |
-
assert y.shape == torch.Size([B, C, T])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/modules/test_rope.py
DELETED
@@ -1,160 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import torch
|
8 |
-
|
9 |
-
from audiocraft.modules.rope import RotaryEmbedding
|
10 |
-
from audiocraft.modules.transformer import StreamingTransformer
|
11 |
-
|
12 |
-
|
13 |
-
def test_rope():
|
14 |
-
B, T, H, C = 8, 75, 16, 128
|
15 |
-
|
16 |
-
rope = RotaryEmbedding(dim=C)
|
17 |
-
xq = torch.rand((B, T, H, C))
|
18 |
-
xk = torch.rand((B, T, H, C))
|
19 |
-
xq_out, xk_out = rope.rotate_qk(xq, xk, start=7)
|
20 |
-
|
21 |
-
assert list(xq_out.shape) == [B, T, H, C]
|
22 |
-
assert list(xk_out.shape) == [B, T, H, C]
|
23 |
-
|
24 |
-
|
25 |
-
def test_rope_io_dtypes():
|
26 |
-
B, T, H, C = 8, 75, 16, 128
|
27 |
-
|
28 |
-
rope_32 = RotaryEmbedding(dim=C, dtype=torch.float32)
|
29 |
-
rope_64 = RotaryEmbedding(dim=C, dtype=torch.float64)
|
30 |
-
|
31 |
-
# Test bfloat16 inputs w/ both 32 and 64 precision rope.
|
32 |
-
xq_16 = torch.rand((B, T, H, C)).to(torch.bfloat16)
|
33 |
-
xk_16 = torch.rand((B, T, H, C)).to(torch.bfloat16)
|
34 |
-
xq_out, xk_out = rope_32.rotate_qk(xq_16, xk_16)
|
35 |
-
assert xq_out.dtype == torch.bfloat16
|
36 |
-
xq_out, xk_out = rope_64.rotate_qk(xq_16, xk_16)
|
37 |
-
assert xq_out.dtype == torch.bfloat16
|
38 |
-
|
39 |
-
# Test float32 inputs w/ both 32 and 64 precision rope.
|
40 |
-
xq_32 = torch.rand((B, T, H, C)).to(torch.float32)
|
41 |
-
xk_32 = torch.rand((B, T, H, C)).to(torch.float32)
|
42 |
-
xq_out, xk_out = rope_32.rotate_qk(xq_32, xk_32)
|
43 |
-
assert xq_out.dtype == torch.float32
|
44 |
-
xq_out, xk_out = rope_64.rotate_qk(xq_32, xk_32)
|
45 |
-
assert xq_out.dtype == torch.float32
|
46 |
-
|
47 |
-
|
48 |
-
def test_transformer_with_rope():
|
49 |
-
torch.manual_seed(1234)
|
50 |
-
for pos in ['rope', 'sin_rope']:
|
51 |
-
tr = StreamingTransformer(
|
52 |
-
16, 4, 2, custom=True, dropout=0., layer_scale=0.1,
|
53 |
-
positional_embedding=pos)
|
54 |
-
tr.eval()
|
55 |
-
steps = 12
|
56 |
-
x = torch.randn(3, steps, 16)
|
57 |
-
|
58 |
-
out = tr(x)
|
59 |
-
assert list(out.shape) == list(x.shape)
|
60 |
-
|
61 |
-
|
62 |
-
@torch.no_grad()
|
63 |
-
def test_rope_streaming():
|
64 |
-
torch.manual_seed(1234)
|
65 |
-
tr = StreamingTransformer(
|
66 |
-
16, 4, 2, causal=True, dropout=0.,
|
67 |
-
custom=True, positional_embedding='rope')
|
68 |
-
tr.eval()
|
69 |
-
steps = 12
|
70 |
-
x = torch.randn(3, steps, 16)
|
71 |
-
|
72 |
-
ref = tr(x)
|
73 |
-
|
74 |
-
with tr.streaming():
|
75 |
-
outs = []
|
76 |
-
frame_sizes = [1] * steps
|
77 |
-
|
78 |
-
for frame_size in frame_sizes:
|
79 |
-
frame = x[:, :frame_size]
|
80 |
-
x = x[:, frame_size:]
|
81 |
-
outs.append(tr(frame))
|
82 |
-
|
83 |
-
out = torch.cat(outs, dim=1)
|
84 |
-
assert list(out.shape) == [3, steps, 16]
|
85 |
-
delta = torch.norm(out - ref) / torch.norm(out)
|
86 |
-
assert delta < 1e-6, delta
|
87 |
-
|
88 |
-
|
89 |
-
@torch.no_grad()
|
90 |
-
def test_rope_streaming_past_context():
|
91 |
-
torch.manual_seed(1234)
|
92 |
-
|
93 |
-
for context in [None, 10]:
|
94 |
-
tr = StreamingTransformer(
|
95 |
-
16, 4, 1 if context else 2,
|
96 |
-
causal=True, past_context=context, custom=True,
|
97 |
-
dropout=0., positional_embedding='rope')
|
98 |
-
tr.eval()
|
99 |
-
|
100 |
-
steps = 20
|
101 |
-
x = torch.randn(3, steps, 16)
|
102 |
-
ref = tr(x)
|
103 |
-
|
104 |
-
with tr.streaming():
|
105 |
-
outs = []
|
106 |
-
frame_sizes = [1] * steps
|
107 |
-
|
108 |
-
for frame_size in frame_sizes:
|
109 |
-
frame = x[:, :frame_size]
|
110 |
-
x = x[:, frame_size:]
|
111 |
-
outs.append(tr(frame))
|
112 |
-
|
113 |
-
out = torch.cat(outs, dim=1)
|
114 |
-
assert list(out.shape) == [3, steps, 16]
|
115 |
-
delta = torch.norm(out - ref) / torch.norm(out)
|
116 |
-
assert delta < 1e-6, delta
|
117 |
-
|
118 |
-
|
119 |
-
def test_rope_memory_efficient():
|
120 |
-
torch.manual_seed(1234)
|
121 |
-
tr = StreamingTransformer(
|
122 |
-
16, 4, 2, custom=True, dropout=0., layer_scale=0.1,
|
123 |
-
positional_embedding='rope')
|
124 |
-
tr_mem_efficient = StreamingTransformer(
|
125 |
-
16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1,
|
126 |
-
positional_embedding='rope')
|
127 |
-
tr_mem_efficient.load_state_dict(tr.state_dict())
|
128 |
-
tr.eval()
|
129 |
-
steps = 12
|
130 |
-
x = torch.randn(3, steps, 16)
|
131 |
-
|
132 |
-
with torch.no_grad():
|
133 |
-
y = tr(x)
|
134 |
-
y2 = tr_mem_efficient(x)
|
135 |
-
# Check at float precision b/c this is the rope default.
|
136 |
-
assert torch.allclose(y, y2, atol=1e-7), (y - y2).norm()
|
137 |
-
|
138 |
-
|
139 |
-
def test_rope_with_xpos():
|
140 |
-
B, T, H, C = 8, 75, 16, 128
|
141 |
-
|
142 |
-
rope = RotaryEmbedding(dim=C, xpos=True)
|
143 |
-
xq = torch.rand((B, T, H, C))
|
144 |
-
xk = torch.rand((B, T, H, C))
|
145 |
-
xq_out, xk_out = rope.rotate_qk(xq, xk, start=7)
|
146 |
-
|
147 |
-
assert list(xq_out.shape) == [B, T, H, C]
|
148 |
-
assert list(xk_out.shape) == [B, T, H, C]
|
149 |
-
|
150 |
-
|
151 |
-
def test_positional_scale():
|
152 |
-
B, T, H, C = 8, 75, 16, 128
|
153 |
-
|
154 |
-
rope = RotaryEmbedding(dim=C, xpos=True, scale=0.0)
|
155 |
-
xq = torch.rand((B, T, H, C))
|
156 |
-
xk = torch.rand((B, T, H, C))
|
157 |
-
xq_out, xk_out = rope.rotate_qk(xq, xk, start=7)
|
158 |
-
|
159 |
-
assert torch.allclose(xq, xq_out)
|
160 |
-
assert torch.allclose(xk, xk_out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/modules/test_seanet.py
DELETED
@@ -1,115 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
from itertools import product
|
8 |
-
|
9 |
-
import pytest
|
10 |
-
import torch
|
11 |
-
|
12 |
-
from audiocraft.modules.seanet import SEANetEncoder, SEANetDecoder, SEANetResnetBlock
|
13 |
-
from audiocraft.modules import StreamableConv1d, StreamableConvTranspose1d
|
14 |
-
|
15 |
-
|
16 |
-
class TestSEANetModel:
|
17 |
-
|
18 |
-
def test_base(self):
|
19 |
-
encoder = SEANetEncoder()
|
20 |
-
decoder = SEANetDecoder()
|
21 |
-
|
22 |
-
x = torch.randn(1, 1, 24000)
|
23 |
-
z = encoder(x)
|
24 |
-
assert list(z.shape) == [1, 128, 75], z.shape
|
25 |
-
y = decoder(z)
|
26 |
-
assert y.shape == x.shape, (x.shape, y.shape)
|
27 |
-
|
28 |
-
def test_causal(self):
|
29 |
-
encoder = SEANetEncoder(causal=True)
|
30 |
-
decoder = SEANetDecoder(causal=True)
|
31 |
-
x = torch.randn(1, 1, 24000)
|
32 |
-
|
33 |
-
z = encoder(x)
|
34 |
-
assert list(z.shape) == [1, 128, 75], z.shape
|
35 |
-
y = decoder(z)
|
36 |
-
assert y.shape == x.shape, (x.shape, y.shape)
|
37 |
-
|
38 |
-
def test_conv_skip_connection(self):
|
39 |
-
encoder = SEANetEncoder(true_skip=False)
|
40 |
-
decoder = SEANetDecoder(true_skip=False)
|
41 |
-
|
42 |
-
x = torch.randn(1, 1, 24000)
|
43 |
-
z = encoder(x)
|
44 |
-
assert list(z.shape) == [1, 128, 75], z.shape
|
45 |
-
y = decoder(z)
|
46 |
-
assert y.shape == x.shape, (x.shape, y.shape)
|
47 |
-
|
48 |
-
def test_seanet_encoder_decoder_final_act(self):
|
49 |
-
encoder = SEANetEncoder(true_skip=False)
|
50 |
-
decoder = SEANetDecoder(true_skip=False, final_activation='Tanh')
|
51 |
-
|
52 |
-
x = torch.randn(1, 1, 24000)
|
53 |
-
z = encoder(x)
|
54 |
-
assert list(z.shape) == [1, 128, 75], z.shape
|
55 |
-
y = decoder(z)
|
56 |
-
assert y.shape == x.shape, (x.shape, y.shape)
|
57 |
-
|
58 |
-
def _check_encoder_blocks_norm(self, encoder: SEANetEncoder, n_disable_blocks: int, norm: str):
|
59 |
-
n_blocks = 0
|
60 |
-
for layer in encoder.model:
|
61 |
-
if isinstance(layer, StreamableConv1d):
|
62 |
-
n_blocks += 1
|
63 |
-
assert layer.conv.norm_type == 'none' if n_blocks <= n_disable_blocks else norm
|
64 |
-
elif isinstance(layer, SEANetResnetBlock):
|
65 |
-
for resnet_layer in layer.block:
|
66 |
-
if isinstance(resnet_layer, StreamableConv1d):
|
67 |
-
# here we add + 1 to n_blocks as we increment n_blocks just after the block
|
68 |
-
assert resnet_layer.conv.norm_type == 'none' if (n_blocks + 1) <= n_disable_blocks else norm
|
69 |
-
|
70 |
-
def test_encoder_disable_norm(self):
|
71 |
-
n_residuals = [0, 1, 3]
|
72 |
-
disable_blocks = [0, 1, 2, 3, 4, 5, 6]
|
73 |
-
norms = ['weight_norm', 'none']
|
74 |
-
for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms):
|
75 |
-
encoder = SEANetEncoder(n_residual_layers=n_res, norm=norm,
|
76 |
-
disable_norm_outer_blocks=disable_blocks)
|
77 |
-
self._check_encoder_blocks_norm(encoder, disable_blocks, norm)
|
78 |
-
|
79 |
-
def _check_decoder_blocks_norm(self, decoder: SEANetDecoder, n_disable_blocks: int, norm: str):
|
80 |
-
n_blocks = 0
|
81 |
-
for layer in decoder.model:
|
82 |
-
if isinstance(layer, StreamableConv1d):
|
83 |
-
n_blocks += 1
|
84 |
-
assert layer.conv.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
|
85 |
-
elif isinstance(layer, StreamableConvTranspose1d):
|
86 |
-
n_blocks += 1
|
87 |
-
assert layer.convtr.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
|
88 |
-
elif isinstance(layer, SEANetResnetBlock):
|
89 |
-
for resnet_layer in layer.block:
|
90 |
-
if isinstance(resnet_layer, StreamableConv1d):
|
91 |
-
assert resnet_layer.conv.norm_type == 'none' \
|
92 |
-
if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
|
93 |
-
|
94 |
-
def test_decoder_disable_norm(self):
|
95 |
-
n_residuals = [0, 1, 3]
|
96 |
-
disable_blocks = [0, 1, 2, 3, 4, 5, 6]
|
97 |
-
norms = ['weight_norm', 'none']
|
98 |
-
for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms):
|
99 |
-
decoder = SEANetDecoder(n_residual_layers=n_res, norm=norm,
|
100 |
-
disable_norm_outer_blocks=disable_blocks)
|
101 |
-
self._check_decoder_blocks_norm(decoder, disable_blocks, norm)
|
102 |
-
|
103 |
-
def test_disable_norm_raises_exception(self):
|
104 |
-
# Invalid disable_norm_outer_blocks values raise exceptions
|
105 |
-
with pytest.raises(AssertionError):
|
106 |
-
SEANetEncoder(disable_norm_outer_blocks=-1)
|
107 |
-
|
108 |
-
with pytest.raises(AssertionError):
|
109 |
-
SEANetEncoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7)
|
110 |
-
|
111 |
-
with pytest.raises(AssertionError):
|
112 |
-
SEANetDecoder(disable_norm_outer_blocks=-1)
|
113 |
-
|
114 |
-
with pytest.raises(AssertionError):
|
115 |
-
SEANetDecoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/modules/test_transformer.py
DELETED
@@ -1,247 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
from itertools import product
|
8 |
-
|
9 |
-
import pytest
|
10 |
-
import torch
|
11 |
-
|
12 |
-
from audiocraft.modules.transformer import StreamingMultiheadAttention, StreamingTransformer
|
13 |
-
|
14 |
-
|
15 |
-
def test_transformer_causal_streaming():
|
16 |
-
torch.manual_seed(1234)
|
17 |
-
|
18 |
-
for context, custom in product([None, 10], [False, True]):
|
19 |
-
# Test that causality and receptive fields are properly handled.
|
20 |
-
# looking at the gradients
|
21 |
-
tr = StreamingTransformer(
|
22 |
-
16, 4, 1 if context else 2,
|
23 |
-
causal=True, past_context=context, custom=custom,
|
24 |
-
dropout=0.)
|
25 |
-
steps = 20
|
26 |
-
for k in [0, 10, 15, 19]:
|
27 |
-
x = torch.randn(4, steps, 16, requires_grad=True)
|
28 |
-
y = tr(x)
|
29 |
-
y[:, k].abs().sum().backward()
|
30 |
-
if k + 1 < steps:
|
31 |
-
assert torch.allclose(x.grad[:, k + 1:], torch.tensor(0.)), x.grad[:, k + 1:].norm()
|
32 |
-
assert not torch.allclose(x.grad[:, :k + 1], torch.tensor(0.)), x.grad[:, :k + 1].norm()
|
33 |
-
if context is not None and k > context:
|
34 |
-
limit = k - context - 1
|
35 |
-
assert torch.allclose(x.grad[:, :limit],
|
36 |
-
torch.tensor(0.)), x.grad[:, :limit].norm()
|
37 |
-
|
38 |
-
# Now check that streaming gives the same result at batch eval.
|
39 |
-
x = torch.randn(4, steps, 16)
|
40 |
-
y = tr(x)
|
41 |
-
ys = []
|
42 |
-
with tr.streaming():
|
43 |
-
for k in range(steps):
|
44 |
-
chunk = x[:, k:k + 1, :]
|
45 |
-
ys.append(tr(chunk))
|
46 |
-
y_stream = torch.cat(ys, dim=1)
|
47 |
-
delta = torch.norm(y_stream - y) / torch.norm(y)
|
48 |
-
assert delta < 1e-6, delta
|
49 |
-
|
50 |
-
|
51 |
-
def test_transformer_vs_pytorch():
|
52 |
-
torch.manual_seed(1234)
|
53 |
-
# Check that in the non causal setting, we get the same result as
|
54 |
-
# PyTorch Transformer encoder.
|
55 |
-
for custom in [False, True]:
|
56 |
-
tr = StreamingTransformer(
|
57 |
-
16, 4, 2,
|
58 |
-
causal=False, custom=custom, dropout=0., positional_scale=0.)
|
59 |
-
layer = torch.nn.TransformerEncoderLayer(16, 4, dropout=0., batch_first=True)
|
60 |
-
tr_ref = torch.nn.TransformerEncoder(layer, 2)
|
61 |
-
tr.load_state_dict(tr_ref.state_dict())
|
62 |
-
|
63 |
-
x = torch.randn(4, 20, 16)
|
64 |
-
y = tr(x)
|
65 |
-
y2 = tr_ref(x)
|
66 |
-
delta = torch.norm(y2 - y) / torch.norm(y)
|
67 |
-
assert delta < 1e-6, delta
|
68 |
-
|
69 |
-
|
70 |
-
def test_streaming_api():
|
71 |
-
tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0.)
|
72 |
-
tr.eval()
|
73 |
-
steps = 12
|
74 |
-
x = torch.randn(1, steps, 16)
|
75 |
-
|
76 |
-
with torch.no_grad():
|
77 |
-
with tr.streaming():
|
78 |
-
_ = tr(x[:, :1])
|
79 |
-
state = {k: v.clone() for k, v in tr.get_streaming_state().items()}
|
80 |
-
y = tr(x[:, 1:2])
|
81 |
-
tr.set_streaming_state(state)
|
82 |
-
y2 = tr(x[:, 1:2])
|
83 |
-
assert torch.allclose(y, y2), (y - y2).norm()
|
84 |
-
assert tr.flush() is None
|
85 |
-
|
86 |
-
|
87 |
-
def test_memory_efficient():
|
88 |
-
torch.manual_seed(1234)
|
89 |
-
tr = StreamingTransformer(
|
90 |
-
16, 4, 2, custom=True, dropout=0., layer_scale=0.1)
|
91 |
-
tr_mem_efficient = StreamingTransformer(
|
92 |
-
16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1)
|
93 |
-
tr_mem_efficient.load_state_dict(tr.state_dict())
|
94 |
-
tr.eval()
|
95 |
-
steps = 12
|
96 |
-
x = torch.randn(3, steps, 16)
|
97 |
-
|
98 |
-
with torch.no_grad():
|
99 |
-
y = tr(x)
|
100 |
-
y2 = tr_mem_efficient(x)
|
101 |
-
assert torch.allclose(y, y2), (y - y2).norm()
|
102 |
-
|
103 |
-
|
104 |
-
def test_attention_as_float32():
|
105 |
-
torch.manual_seed(1234)
|
106 |
-
cases = [
|
107 |
-
{'custom': True},
|
108 |
-
{'custom': False},
|
109 |
-
]
|
110 |
-
for case in cases:
|
111 |
-
tr = StreamingTransformer(16, 4, 2, dropout=0., dtype=torch.bfloat16, **case)
|
112 |
-
tr_float32 = StreamingTransformer(
|
113 |
-
16, 4, 2, dropout=0., attention_as_float32=True, dtype=torch.bfloat16, **case)
|
114 |
-
if not case['custom']:
|
115 |
-
# we are not using autocast here because it doesn't really
|
116 |
-
# work as expected on CPU, so we have to manually cast the weights of the MHA.
|
117 |
-
for layer in tr_float32.layers:
|
118 |
-
layer.self_attn.mha.to(torch.float32)
|
119 |
-
tr_float32.load_state_dict(tr.state_dict())
|
120 |
-
steps = 12
|
121 |
-
x = torch.randn(3, steps, 16, dtype=torch.bfloat16)
|
122 |
-
|
123 |
-
with torch.no_grad():
|
124 |
-
y = tr(x)
|
125 |
-
y2 = tr_float32(x)
|
126 |
-
assert not torch.allclose(y, y2), (y - y2).norm()
|
127 |
-
|
128 |
-
|
129 |
-
@torch.no_grad()
|
130 |
-
def test_streaming_memory_efficient():
|
131 |
-
torch.manual_seed(1234)
|
132 |
-
tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0., custom=True)
|
133 |
-
tr_mem_efficient = StreamingTransformer(
|
134 |
-
16, 4, 2, dropout=0., memory_efficient=True, causal=True)
|
135 |
-
tr.load_state_dict(tr_mem_efficient.state_dict())
|
136 |
-
tr.eval()
|
137 |
-
tr_mem_efficient.eval()
|
138 |
-
steps = 12
|
139 |
-
x = torch.randn(3, steps, 16)
|
140 |
-
|
141 |
-
ref = tr(x)
|
142 |
-
|
143 |
-
with tr_mem_efficient.streaming():
|
144 |
-
outs = []
|
145 |
-
# frame_sizes = [2] + [1] * (steps - 2)
|
146 |
-
frame_sizes = [1] * steps
|
147 |
-
|
148 |
-
for frame_size in frame_sizes:
|
149 |
-
frame = x[:, :frame_size]
|
150 |
-
x = x[:, frame_size:]
|
151 |
-
outs.append(tr_mem_efficient(frame))
|
152 |
-
|
153 |
-
out = torch.cat(outs, dim=1)
|
154 |
-
delta = torch.norm(out - ref) / torch.norm(out)
|
155 |
-
assert delta < 1e-6, delta
|
156 |
-
|
157 |
-
|
158 |
-
def test_cross_attention():
|
159 |
-
torch.manual_seed(1234)
|
160 |
-
for norm_first in [True, False]:
|
161 |
-
m = StreamingTransformer(
|
162 |
-
16, 4, 2, cross_attention=False, norm_first=norm_first, dropout=0., custom=True)
|
163 |
-
m_cross = StreamingTransformer(
|
164 |
-
16, 4, 2, cross_attention=True, norm_first=norm_first, dropout=0., custom=True)
|
165 |
-
m_cross.load_state_dict(m.state_dict(), strict=False)
|
166 |
-
x = torch.randn(2, 5, 16)
|
167 |
-
cross_x = torch.randn(2, 3, 16)
|
168 |
-
y_ref = m(x)
|
169 |
-
y_cross_zero = m_cross(x, cross_attention_src=0 * cross_x)
|
170 |
-
# With norm_first, the two should be exactly yhe same,
|
171 |
-
# but with norm_first=False, we get 2 normalization in a row
|
172 |
-
# and the epsilon value leads to a tiny change.
|
173 |
-
atol = 0. if norm_first else 1e-6
|
174 |
-
print((y_ref - y_cross_zero).norm() / y_ref.norm())
|
175 |
-
assert torch.allclose(y_ref, y_cross_zero, atol=atol)
|
176 |
-
|
177 |
-
# We now expect a difference even with a generous atol of 1e-2.
|
178 |
-
y_cross = m_cross(x, cross_attention_src=cross_x)
|
179 |
-
assert not torch.allclose(y_cross, y_cross_zero, atol=1e-2)
|
180 |
-
|
181 |
-
with pytest.raises(AssertionError):
|
182 |
-
_ = m_cross(x)
|
183 |
-
_ = m(x, cross_attention_src=cross_x)
|
184 |
-
|
185 |
-
|
186 |
-
def test_cross_attention_compat():
|
187 |
-
torch.manual_seed(1234)
|
188 |
-
num_heads = 2
|
189 |
-
dim = num_heads * 64
|
190 |
-
with pytest.raises(AssertionError):
|
191 |
-
StreamingMultiheadAttention(dim, num_heads, causal=True, cross_attention=True)
|
192 |
-
|
193 |
-
cross_attn = StreamingMultiheadAttention(
|
194 |
-
dim, num_heads, dropout=0, cross_attention=True, custom=True)
|
195 |
-
ref_attn = torch.nn.MultiheadAttention(dim, num_heads, dropout=0, batch_first=True)
|
196 |
-
|
197 |
-
# We can load the regular attention state dict
|
198 |
-
# so we have compat when loading old checkpoints.
|
199 |
-
cross_attn.load_state_dict(ref_attn.state_dict())
|
200 |
-
|
201 |
-
queries = torch.randn(3, 7, dim)
|
202 |
-
keys = torch.randn(3, 9, dim)
|
203 |
-
values = torch.randn(3, 9, dim)
|
204 |
-
|
205 |
-
y = cross_attn(queries, keys, values)[0]
|
206 |
-
y_ref = ref_attn(queries, keys, values)[0]
|
207 |
-
assert torch.allclose(y, y_ref, atol=1e-7)
|
208 |
-
|
209 |
-
# Now let's check that streaming is working properly.
|
210 |
-
with cross_attn.streaming():
|
211 |
-
ys = []
|
212 |
-
for step in range(queries.shape[1]):
|
213 |
-
ys.append(cross_attn(queries[:, step: step + 1], keys, values)[0])
|
214 |
-
y_streaming = torch.cat(ys, dim=1)
|
215 |
-
assert torch.allclose(y_streaming, y, atol=1e-7)
|
216 |
-
|
217 |
-
|
218 |
-
def test_repeat_kv():
|
219 |
-
torch.manual_seed(1234)
|
220 |
-
num_heads = 8
|
221 |
-
kv_repeat = 4
|
222 |
-
dim = num_heads * 64
|
223 |
-
with pytest.raises(AssertionError):
|
224 |
-
mha = StreamingMultiheadAttention(
|
225 |
-
dim, num_heads, causal=True, kv_repeat=kv_repeat, cross_attention=True)
|
226 |
-
mha = StreamingMultiheadAttention(
|
227 |
-
dim, num_heads, causal=True, kv_repeat=kv_repeat)
|
228 |
-
mha = StreamingMultiheadAttention(
|
229 |
-
dim, num_heads, causal=True, kv_repeat=kv_repeat, custom=True)
|
230 |
-
x = torch.randn(4, 18, dim)
|
231 |
-
y = mha(x, x, x)[0]
|
232 |
-
assert x.shape == y.shape
|
233 |
-
|
234 |
-
|
235 |
-
def test_qk_layer_norm():
|
236 |
-
torch.manual_seed(1234)
|
237 |
-
tr = StreamingTransformer(
|
238 |
-
16, 4, 2, custom=True, dropout=0., qk_layer_norm=True, bias_attn=False)
|
239 |
-
steps = 12
|
240 |
-
x = torch.randn(3, steps, 16)
|
241 |
-
y = tr(x)
|
242 |
-
|
243 |
-
tr = StreamingTransformer(
|
244 |
-
16, 4, 2, custom=True, dropout=0., qk_layer_norm=True, cross_attention=True)
|
245 |
-
z = torch.randn(3, 21, 16)
|
246 |
-
y = tr(x, cross_attention_src=z)
|
247 |
-
assert y.shape == x.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/quantization/test_vq.py
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import torch
|
8 |
-
|
9 |
-
from audiocraft.quantization.vq import ResidualVectorQuantizer
|
10 |
-
|
11 |
-
|
12 |
-
class TestResidualVectorQuantizer:
|
13 |
-
|
14 |
-
def test_rvq(self):
|
15 |
-
x = torch.randn(1, 16, 2048)
|
16 |
-
vq = ResidualVectorQuantizer(n_q=8, dimension=16, bins=8)
|
17 |
-
res = vq(x, 1.)
|
18 |
-
assert res.x.shape == torch.Size([1, 16, 2048])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/utils/__init__.py
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
|
|