Spaces:
Runtime error
Runtime error
GrandaddyShmax
commited on
Commit
•
00f2826
1
Parent(s):
166f7f5
first push
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +63 -0
- CHANGELOG.md +28 -0
- CODE_OF_CONDUCT.md +80 -0
- CONTRIBUTING.md +35 -0
- Dockerfile +26 -0
- LICENSE +21 -0
- LICENSE_weights +399 -0
- MANIFEST.in +9 -0
- Makefile +40 -0
- README.md +4 -4
- app.py +1839 -0
- assets/a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3 +0 -0
- assets/bach.mp3 +0 -0
- assets/bolero_ravel.mp3 +0 -0
- assets/sirens_and_a_humming_engine_approach_and_pass.mp3 +0 -0
- audiocraft/__init__.py +26 -0
- audiocraft/adversarial/__init__.py +22 -0
- audiocraft/adversarial/discriminators/__init__.py +10 -0
- audiocraft/adversarial/discriminators/base.py +34 -0
- audiocraft/adversarial/discriminators/mpd.py +106 -0
- audiocraft/adversarial/discriminators/msd.py +126 -0
- audiocraft/adversarial/discriminators/msstftd.py +134 -0
- audiocraft/adversarial/losses.py +228 -0
- audiocraft/data/__init__.py +10 -0
- audiocraft/data/audio.py +216 -0
- audiocraft/data/audio_dataset.py +587 -0
- audiocraft/data/audio_utils.py +177 -0
- audiocraft/data/info_audio_dataset.py +110 -0
- audiocraft/data/music_dataset.py +270 -0
- audiocraft/data/sound_dataset.py +330 -0
- audiocraft/data/zip.py +76 -0
- audiocraft/environment.py +176 -0
- audiocraft/grids/__init__.py +6 -0
- audiocraft/grids/_base_explorers.py +80 -0
- audiocraft/grids/audiogen/__init__.py +6 -0
- audiocraft/grids/audiogen/audiogen_base_16khz.py +23 -0
- audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py +68 -0
- audiocraft/grids/compression/__init__.py +6 -0
- audiocraft/grids/compression/_explorers.py +55 -0
- audiocraft/grids/compression/debug.py +31 -0
- audiocraft/grids/compression/encodec_audiogen_16khz.py +29 -0
- audiocraft/grids/compression/encodec_base_24khz.py +28 -0
- audiocraft/grids/compression/encodec_musicgen_32khz.py +34 -0
- audiocraft/grids/diffusion/4_bands_base_32khz.py +27 -0
- audiocraft/grids/diffusion/__init__.py +6 -0
- audiocraft/grids/diffusion/_explorers.py +66 -0
- audiocraft/grids/musicgen/__init__.py +6 -0
- audiocraft/grids/musicgen/_explorers.py +93 -0
- audiocraft/grids/musicgen/musicgen_base_32khz.py +43 -0
- audiocraft/grids/musicgen/musicgen_base_cached_32khz.py +67 -0
.gitignore
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# macOS dir files
|
10 |
+
.DS_Store
|
11 |
+
|
12 |
+
# Distribution / packaging
|
13 |
+
.Python
|
14 |
+
env/
|
15 |
+
build/
|
16 |
+
develop-eggs/
|
17 |
+
dist/
|
18 |
+
downloads/
|
19 |
+
eggs/
|
20 |
+
.eggs/
|
21 |
+
lib/
|
22 |
+
lib64/
|
23 |
+
parts/
|
24 |
+
sdist/
|
25 |
+
var/
|
26 |
+
wheels/
|
27 |
+
*.egg-info/
|
28 |
+
.installed.cfg
|
29 |
+
*.egg
|
30 |
+
.ipynb_checkpoints
|
31 |
+
|
32 |
+
# Tests and linter
|
33 |
+
.pytest_cache/
|
34 |
+
.mypy_cache/
|
35 |
+
.coverage
|
36 |
+
|
37 |
+
# docs
|
38 |
+
/api_docs
|
39 |
+
|
40 |
+
# dotenv
|
41 |
+
.env
|
42 |
+
.envrc
|
43 |
+
|
44 |
+
# virtualenv
|
45 |
+
.venv
|
46 |
+
venv/
|
47 |
+
ENV/
|
48 |
+
|
49 |
+
# egs with manifest files
|
50 |
+
egs/*
|
51 |
+
!egs/example
|
52 |
+
# local datasets
|
53 |
+
dataset/*
|
54 |
+
!dataset/example
|
55 |
+
|
56 |
+
# personal notebooks & scripts
|
57 |
+
*/local_scripts
|
58 |
+
*/notes
|
59 |
+
.vscode/
|
60 |
+
/notebooks
|
61 |
+
/local_scripts
|
62 |
+
/notes
|
63 |
+
/cache
|
CHANGELOG.md
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Changelog
|
2 |
+
|
3 |
+
All notable changes to this project will be documented in this file.
|
4 |
+
|
5 |
+
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
|
6 |
+
|
7 |
+
## [1.0.0] - 2023-08-02
|
8 |
+
|
9 |
+
Major revision, added training code for EnCodec, AudioGen, MusicGen, and MultiBandDiffusion.
|
10 |
+
Added pretrained model for AudioGen and MultiBandDiffusion.
|
11 |
+
|
12 |
+
## [0.0.2] - 2023-08-01
|
13 |
+
|
14 |
+
Improved demo, fixed top p (thanks @jnordberg).
|
15 |
+
|
16 |
+
Compressor tanh on output to avoid clipping with some style (especially piano).
|
17 |
+
Now repeating the conditioning periodically if it is too short.
|
18 |
+
|
19 |
+
More options when launching Gradio app locally (thanks @ashleykleynhans).
|
20 |
+
|
21 |
+
Testing out PyTorch 2.0 memory efficient attention.
|
22 |
+
|
23 |
+
Added extended generation (infinite length) by slowly moving the windows.
|
24 |
+
Note that other implementations exist: https://github.com/camenduru/MusicGen-colab.
|
25 |
+
|
26 |
+
## [0.0.1] - 2023-06-09
|
27 |
+
|
28 |
+
Initial release, with model evaluation only.
|
CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code of Conduct
|
2 |
+
|
3 |
+
## Our Pledge
|
4 |
+
|
5 |
+
In the interest of fostering an open and welcoming environment, we as
|
6 |
+
contributors and maintainers pledge to make participation in our project and
|
7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
9 |
+
level of experience, education, socio-economic status, nationality, personal
|
10 |
+
appearance, race, religion, or sexual identity and orientation.
|
11 |
+
|
12 |
+
## Our Standards
|
13 |
+
|
14 |
+
Examples of behavior that contributes to creating a positive environment
|
15 |
+
include:
|
16 |
+
|
17 |
+
* Using welcoming and inclusive language
|
18 |
+
* Being respectful of differing viewpoints and experiences
|
19 |
+
* Gracefully accepting constructive criticism
|
20 |
+
* Focusing on what is best for the community
|
21 |
+
* Showing empathy towards other community members
|
22 |
+
|
23 |
+
Examples of unacceptable behavior by participants include:
|
24 |
+
|
25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
26 |
+
advances
|
27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
28 |
+
* Public or private harassment
|
29 |
+
* Publishing others' private information, such as a physical or electronic
|
30 |
+
address, without explicit permission
|
31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
32 |
+
professional setting
|
33 |
+
|
34 |
+
## Our Responsibilities
|
35 |
+
|
36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
38 |
+
response to any instances of unacceptable behavior.
|
39 |
+
|
40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
44 |
+
threatening, offensive, or harmful.
|
45 |
+
|
46 |
+
## Scope
|
47 |
+
|
48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
49 |
+
an individual is representing the project or its community in public spaces.
|
50 |
+
Examples of representing a project or community include using an official
|
51 |
+
project e-mail address, posting via an official social media account, or acting
|
52 |
+
as an appointed representative at an online or offline event. Representation of
|
53 |
+
a project may be further defined and clarified by project maintainers.
|
54 |
+
|
55 |
+
This Code of Conduct also applies outside the project spaces when there is a
|
56 |
+
reasonable belief that an individual's behavior may have a negative impact on
|
57 |
+
the project or its community.
|
58 |
+
|
59 |
+
## Enforcement
|
60 |
+
|
61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
62 |
+
reported by contacting the project team at <opensource-conduct@fb.com>. All
|
63 |
+
complaints will be reviewed and investigated and will result in a response that
|
64 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
65 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
66 |
+
Further details of specific enforcement policies may be posted separately.
|
67 |
+
|
68 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
69 |
+
faith may face temporary or permanent repercussions as determined by other
|
70 |
+
members of the project's leadership.
|
71 |
+
|
72 |
+
## Attribution
|
73 |
+
|
74 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
75 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
76 |
+
|
77 |
+
[homepage]: https://www.contributor-covenant.org
|
78 |
+
|
79 |
+
For answers to common questions about this code of conduct, see
|
80 |
+
https://www.contributor-covenant.org/faq
|
CONTRIBUTING.md
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributing to AudioCraft
|
2 |
+
|
3 |
+
We want to make contributing to this project as easy and transparent as
|
4 |
+
possible.
|
5 |
+
|
6 |
+
## Pull Requests
|
7 |
+
|
8 |
+
AudioCraft is the implementation of a research paper.
|
9 |
+
Therefore, we do not plan on accepting many pull requests for new features.
|
10 |
+
We certainly welcome them for bug fixes.
|
11 |
+
|
12 |
+
1. Fork the repo and create your branch from `main`.
|
13 |
+
2. If you've added code that should be tested, add tests.
|
14 |
+
3. If you've changed APIs, update the documentation.
|
15 |
+
4. Ensure the test suite passes.
|
16 |
+
5. Make sure your code lints.
|
17 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
18 |
+
|
19 |
+
## Contributor License Agreement ("CLA")
|
20 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
21 |
+
to do this once to work on any of Meta's open source projects.
|
22 |
+
|
23 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
24 |
+
|
25 |
+
## Issues
|
26 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
27 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
28 |
+
|
29 |
+
Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
30 |
+
disclosure of security bugs. In those cases, please go through the process
|
31 |
+
outlined on that page and do not file a public issue.
|
32 |
+
|
33 |
+
## License
|
34 |
+
By contributing to encodec, you agree that your contributions will be licensed
|
35 |
+
under the LICENSE file in the root directory of this source tree.
|
Dockerfile
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:11.8.0-base-ubuntu22.04
|
2 |
+
|
3 |
+
ENV DEBIAN_FRONTEND=noninteractive \
|
4 |
+
PYTHONUNBUFFERED=1 \
|
5 |
+
PYTHONIOENCODING=UTF-8
|
6 |
+
RUN --mount=type=cache,target=/var/cache/apt --mount=type=cache,target=/var/lib/apt apt update &&\
|
7 |
+
apt install -y \
|
8 |
+
wget \
|
9 |
+
git \
|
10 |
+
pkg-config \
|
11 |
+
python3 \
|
12 |
+
python3-pip \
|
13 |
+
python-is-python3 \
|
14 |
+
ffmpeg \
|
15 |
+
libnvrtc11.2 \
|
16 |
+
libtcmalloc-minimal4
|
17 |
+
|
18 |
+
RUN useradd -m -u 1000 ac
|
19 |
+
RUN --mount=type=cache,target=/root/.cache python -m pip install --upgrade pip wheel
|
20 |
+
ENV TORCH_COMMAND="pip install torch==2.0.1+cu118 torchaudio --extra-index-url https://download.pytorch.org/whl/cu118"
|
21 |
+
RUN --mount=type=cache,target=/root/.cache python -m $TORCH_COMMAND
|
22 |
+
RUN ln -s /usr/lib/x86_64-linux-gnu/libnvrtc.so.11.2 /usr/lib/x86_64-linux-gnu/libnvrtc.so
|
23 |
+
USER 1000
|
24 |
+
RUN mkdir ~/.cache
|
25 |
+
RUN --mount=type=cache,target=/home/ac/.cache --mount=source=.,target=/home/ac/audiocraft python -m pip install -r /home/ac/audiocraft/requirements.txt
|
26 |
+
WORKDIR /home/ac/audiocraft
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
LICENSE_weights
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Attribution-NonCommercial 4.0 International
|
2 |
+
|
3 |
+
=======================================================================
|
4 |
+
|
5 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
6 |
+
does not provide legal services or legal advice. Distribution of
|
7 |
+
Creative Commons public licenses does not create a lawyer-client or
|
8 |
+
other relationship. Creative Commons makes its licenses and related
|
9 |
+
information available on an "as-is" basis. Creative Commons gives no
|
10 |
+
warranties regarding its licenses, any material licensed under their
|
11 |
+
terms and conditions, or any related information. Creative Commons
|
12 |
+
disclaims all liability for damages resulting from their use to the
|
13 |
+
fullest extent possible.
|
14 |
+
|
15 |
+
Using Creative Commons Public Licenses
|
16 |
+
|
17 |
+
Creative Commons public licenses provide a standard set of terms and
|
18 |
+
conditions that creators and other rights holders may use to share
|
19 |
+
original works of authorship and other material subject to copyright
|
20 |
+
and certain other rights specified in the public license below. The
|
21 |
+
following considerations are for informational purposes only, are not
|
22 |
+
exhaustive, and do not form part of our licenses.
|
23 |
+
|
24 |
+
Considerations for licensors: Our public licenses are
|
25 |
+
intended for use by those authorized to give the public
|
26 |
+
permission to use material in ways otherwise restricted by
|
27 |
+
copyright and certain other rights. Our licenses are
|
28 |
+
irrevocable. Licensors should read and understand the terms
|
29 |
+
and conditions of the license they choose before applying it.
|
30 |
+
Licensors should also secure all rights necessary before
|
31 |
+
applying our licenses so that the public can reuse the
|
32 |
+
material as expected. Licensors should clearly mark any
|
33 |
+
material not subject to the license. This includes other CC-
|
34 |
+
licensed material, or material used under an exception or
|
35 |
+
limitation to copyright. More considerations for licensors:
|
36 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
37 |
+
|
38 |
+
Considerations for the public: By using one of our public
|
39 |
+
licenses, a licensor grants the public permission to use the
|
40 |
+
licensed material under specified terms and conditions. If
|
41 |
+
the licensor's permission is not necessary for any reason--for
|
42 |
+
example, because of any applicable exception or limitation to
|
43 |
+
copyright--then that use is not regulated by the license. Our
|
44 |
+
licenses grant only permissions under copyright and certain
|
45 |
+
other rights that a licensor has authority to grant. Use of
|
46 |
+
the licensed material may still be restricted for other
|
47 |
+
reasons, including because others have copyright or other
|
48 |
+
rights in the material. A licensor may make special requests,
|
49 |
+
such as asking that all changes be marked or described.
|
50 |
+
Although not required by our licenses, you are encouraged to
|
51 |
+
respect those requests where reasonable. More_considerations
|
52 |
+
for the public:
|
53 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
54 |
+
|
55 |
+
=======================================================================
|
56 |
+
|
57 |
+
Creative Commons Attribution-NonCommercial 4.0 International Public
|
58 |
+
License
|
59 |
+
|
60 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
61 |
+
to be bound by the terms and conditions of this Creative Commons
|
62 |
+
Attribution-NonCommercial 4.0 International Public License ("Public
|
63 |
+
License"). To the extent this Public License may be interpreted as a
|
64 |
+
contract, You are granted the Licensed Rights in consideration of Your
|
65 |
+
acceptance of these terms and conditions, and the Licensor grants You
|
66 |
+
such rights in consideration of benefits the Licensor receives from
|
67 |
+
making the Licensed Material available under these terms and
|
68 |
+
conditions.
|
69 |
+
|
70 |
+
Section 1 -- Definitions.
|
71 |
+
|
72 |
+
a. Adapted Material means material subject to Copyright and Similar
|
73 |
+
Rights that is derived from or based upon the Licensed Material
|
74 |
+
and in which the Licensed Material is translated, altered,
|
75 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
76 |
+
permission under the Copyright and Similar Rights held by the
|
77 |
+
Licensor. For purposes of this Public License, where the Licensed
|
78 |
+
Material is a musical work, performance, or sound recording,
|
79 |
+
Adapted Material is always produced where the Licensed Material is
|
80 |
+
synched in timed relation with a moving image.
|
81 |
+
|
82 |
+
b. Adapter's License means the license You apply to Your Copyright
|
83 |
+
and Similar Rights in Your contributions to Adapted Material in
|
84 |
+
accordance with the terms and conditions of this Public License.
|
85 |
+
|
86 |
+
c. Copyright and Similar Rights means copyright and/or similar rights
|
87 |
+
closely related to copyright including, without limitation,
|
88 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
89 |
+
Rights, without regard to how the rights are labeled or
|
90 |
+
categorized. For purposes of this Public License, the rights
|
91 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
92 |
+
Rights.
|
93 |
+
d. Effective Technological Measures means those measures that, in the
|
94 |
+
absence of proper authority, may not be circumvented under laws
|
95 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
96 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
97 |
+
agreements.
|
98 |
+
|
99 |
+
e. Exceptions and Limitations means fair use, fair dealing, and/or
|
100 |
+
any other exception or limitation to Copyright and Similar Rights
|
101 |
+
that applies to Your use of the Licensed Material.
|
102 |
+
|
103 |
+
f. Licensed Material means the artistic or literary work, database,
|
104 |
+
or other material to which the Licensor applied this Public
|
105 |
+
License.
|
106 |
+
|
107 |
+
g. Licensed Rights means the rights granted to You subject to the
|
108 |
+
terms and conditions of this Public License, which are limited to
|
109 |
+
all Copyright and Similar Rights that apply to Your use of the
|
110 |
+
Licensed Material and that the Licensor has authority to license.
|
111 |
+
|
112 |
+
h. Licensor means the individual(s) or entity(ies) granting rights
|
113 |
+
under this Public License.
|
114 |
+
|
115 |
+
i. NonCommercial means not primarily intended for or directed towards
|
116 |
+
commercial advantage or monetary compensation. For purposes of
|
117 |
+
this Public License, the exchange of the Licensed Material for
|
118 |
+
other material subject to Copyright and Similar Rights by digital
|
119 |
+
file-sharing or similar means is NonCommercial provided there is
|
120 |
+
no payment of monetary compensation in connection with the
|
121 |
+
exchange.
|
122 |
+
|
123 |
+
j. Share means to provide material to the public by any means or
|
124 |
+
process that requires permission under the Licensed Rights, such
|
125 |
+
as reproduction, public display, public performance, distribution,
|
126 |
+
dissemination, communication, or importation, and to make material
|
127 |
+
available to the public including in ways that members of the
|
128 |
+
public may access the material from a place and at a time
|
129 |
+
individually chosen by them.
|
130 |
+
|
131 |
+
k. Sui Generis Database Rights means rights other than copyright
|
132 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
133 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
134 |
+
as amended and/or succeeded, as well as other essentially
|
135 |
+
equivalent rights anywhere in the world.
|
136 |
+
|
137 |
+
l. You means the individual or entity exercising the Licensed Rights
|
138 |
+
under this Public License. Your has a corresponding meaning.
|
139 |
+
|
140 |
+
Section 2 -- Scope.
|
141 |
+
|
142 |
+
a. License grant.
|
143 |
+
|
144 |
+
1. Subject to the terms and conditions of this Public License,
|
145 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
146 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
147 |
+
exercise the Licensed Rights in the Licensed Material to:
|
148 |
+
|
149 |
+
a. reproduce and Share the Licensed Material, in whole or
|
150 |
+
in part, for NonCommercial purposes only; and
|
151 |
+
|
152 |
+
b. produce, reproduce, and Share Adapted Material for
|
153 |
+
NonCommercial purposes only.
|
154 |
+
|
155 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
156 |
+
Exceptions and Limitations apply to Your use, this Public
|
157 |
+
License does not apply, and You do not need to comply with
|
158 |
+
its terms and conditions.
|
159 |
+
|
160 |
+
3. Term. The term of this Public License is specified in Section
|
161 |
+
6(a).
|
162 |
+
|
163 |
+
4. Media and formats; technical modifications allowed. The
|
164 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
165 |
+
all media and formats whether now known or hereafter created,
|
166 |
+
and to make technical modifications necessary to do so. The
|
167 |
+
Licensor waives and/or agrees not to assert any right or
|
168 |
+
authority to forbid You from making technical modifications
|
169 |
+
necessary to exercise the Licensed Rights, including
|
170 |
+
technical modifications necessary to circumvent Effective
|
171 |
+
Technological Measures. For purposes of this Public License,
|
172 |
+
simply making modifications authorized by this Section 2(a)
|
173 |
+
(4) never produces Adapted Material.
|
174 |
+
|
175 |
+
5. Downstream recipients.
|
176 |
+
|
177 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
178 |
+
recipient of the Licensed Material automatically
|
179 |
+
receives an offer from the Licensor to exercise the
|
180 |
+
Licensed Rights under the terms and conditions of this
|
181 |
+
Public License.
|
182 |
+
|
183 |
+
b. No downstream restrictions. You may not offer or impose
|
184 |
+
any additional or different terms or conditions on, or
|
185 |
+
apply any Effective Technological Measures to, the
|
186 |
+
Licensed Material if doing so restricts exercise of the
|
187 |
+
Licensed Rights by any recipient of the Licensed
|
188 |
+
Material.
|
189 |
+
|
190 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
191 |
+
may be construed as permission to assert or imply that You
|
192 |
+
are, or that Your use of the Licensed Material is, connected
|
193 |
+
with, or sponsored, endorsed, or granted official status by,
|
194 |
+
the Licensor or others designated to receive attribution as
|
195 |
+
provided in Section 3(a)(1)(A)(i).
|
196 |
+
|
197 |
+
b. Other rights.
|
198 |
+
|
199 |
+
1. Moral rights, such as the right of integrity, are not
|
200 |
+
licensed under this Public License, nor are publicity,
|
201 |
+
privacy, and/or other similar personality rights; however, to
|
202 |
+
the extent possible, the Licensor waives and/or agrees not to
|
203 |
+
assert any such rights held by the Licensor to the limited
|
204 |
+
extent necessary to allow You to exercise the Licensed
|
205 |
+
Rights, but not otherwise.
|
206 |
+
|
207 |
+
2. Patent and trademark rights are not licensed under this
|
208 |
+
Public License.
|
209 |
+
|
210 |
+
3. To the extent possible, the Licensor waives any right to
|
211 |
+
collect royalties from You for the exercise of the Licensed
|
212 |
+
Rights, whether directly or through a collecting society
|
213 |
+
under any voluntary or waivable statutory or compulsory
|
214 |
+
licensing scheme. In all other cases the Licensor expressly
|
215 |
+
reserves any right to collect such royalties, including when
|
216 |
+
the Licensed Material is used other than for NonCommercial
|
217 |
+
purposes.
|
218 |
+
|
219 |
+
Section 3 -- License Conditions.
|
220 |
+
|
221 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
222 |
+
following conditions.
|
223 |
+
|
224 |
+
a. Attribution.
|
225 |
+
|
226 |
+
1. If You Share the Licensed Material (including in modified
|
227 |
+
form), You must:
|
228 |
+
|
229 |
+
a. retain the following if it is supplied by the Licensor
|
230 |
+
with the Licensed Material:
|
231 |
+
|
232 |
+
i. identification of the creator(s) of the Licensed
|
233 |
+
Material and any others designated to receive
|
234 |
+
attribution, in any reasonable manner requested by
|
235 |
+
the Licensor (including by pseudonym if
|
236 |
+
designated);
|
237 |
+
|
238 |
+
ii. a copyright notice;
|
239 |
+
|
240 |
+
iii. a notice that refers to this Public License;
|
241 |
+
|
242 |
+
iv. a notice that refers to the disclaimer of
|
243 |
+
warranties;
|
244 |
+
|
245 |
+
v. a URI or hyperlink to the Licensed Material to the
|
246 |
+
extent reasonably practicable;
|
247 |
+
|
248 |
+
b. indicate if You modified the Licensed Material and
|
249 |
+
retain an indication of any previous modifications; and
|
250 |
+
|
251 |
+
c. indicate the Licensed Material is licensed under this
|
252 |
+
Public License, and include the text of, or the URI or
|
253 |
+
hyperlink to, this Public License.
|
254 |
+
|
255 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
256 |
+
reasonable manner based on the medium, means, and context in
|
257 |
+
which You Share the Licensed Material. For example, it may be
|
258 |
+
reasonable to satisfy the conditions by providing a URI or
|
259 |
+
hyperlink to a resource that includes the required
|
260 |
+
information.
|
261 |
+
|
262 |
+
3. If requested by the Licensor, You must remove any of the
|
263 |
+
information required by Section 3(a)(1)(A) to the extent
|
264 |
+
reasonably practicable.
|
265 |
+
|
266 |
+
4. If You Share Adapted Material You produce, the Adapter's
|
267 |
+
License You apply must not prevent recipients of the Adapted
|
268 |
+
Material from complying with this Public License.
|
269 |
+
|
270 |
+
Section 4 -- Sui Generis Database Rights.
|
271 |
+
|
272 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
273 |
+
apply to Your use of the Licensed Material:
|
274 |
+
|
275 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
276 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
277 |
+
portion of the contents of the database for NonCommercial purposes
|
278 |
+
only;
|
279 |
+
|
280 |
+
b. if You include all or a substantial portion of the database
|
281 |
+
contents in a database in which You have Sui Generis Database
|
282 |
+
Rights, then the database in which You have Sui Generis Database
|
283 |
+
Rights (but not its individual contents) is Adapted Material; and
|
284 |
+
|
285 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
286 |
+
all or a substantial portion of the contents of the database.
|
287 |
+
|
288 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
289 |
+
replace Your obligations under this Public License where the Licensed
|
290 |
+
Rights include other Copyright and Similar Rights.
|
291 |
+
|
292 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
293 |
+
|
294 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
295 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
296 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
297 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
298 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
299 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
300 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
301 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
302 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
303 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
304 |
+
|
305 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
306 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
307 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
308 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
309 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
310 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
311 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
312 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
313 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
314 |
+
|
315 |
+
c. The disclaimer of warranties and limitation of liability provided
|
316 |
+
above shall be interpreted in a manner that, to the extent
|
317 |
+
possible, most closely approximates an absolute disclaimer and
|
318 |
+
waiver of all liability.
|
319 |
+
|
320 |
+
Section 6 -- Term and Termination.
|
321 |
+
|
322 |
+
a. This Public License applies for the term of the Copyright and
|
323 |
+
Similar Rights licensed here. However, if You fail to comply with
|
324 |
+
this Public License, then Your rights under this Public License
|
325 |
+
terminate automatically.
|
326 |
+
|
327 |
+
b. Where Your right to use the Licensed Material has terminated under
|
328 |
+
Section 6(a), it reinstates:
|
329 |
+
|
330 |
+
1. automatically as of the date the violation is cured, provided
|
331 |
+
it is cured within 30 days of Your discovery of the
|
332 |
+
violation; or
|
333 |
+
|
334 |
+
2. upon express reinstatement by the Licensor.
|
335 |
+
|
336 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
337 |
+
right the Licensor may have to seek remedies for Your violations
|
338 |
+
of this Public License.
|
339 |
+
|
340 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
341 |
+
Licensed Material under separate terms or conditions or stop
|
342 |
+
distributing the Licensed Material at any time; however, doing so
|
343 |
+
will not terminate this Public License.
|
344 |
+
|
345 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
346 |
+
License.
|
347 |
+
|
348 |
+
Section 7 -- Other Terms and Conditions.
|
349 |
+
|
350 |
+
a. The Licensor shall not be bound by any additional or different
|
351 |
+
terms or conditions communicated by You unless expressly agreed.
|
352 |
+
|
353 |
+
b. Any arrangements, understandings, or agreements regarding the
|
354 |
+
Licensed Material not stated herein are separate from and
|
355 |
+
independent of the terms and conditions of this Public License.
|
356 |
+
|
357 |
+
Section 8 -- Interpretation.
|
358 |
+
|
359 |
+
a. For the avoidance of doubt, this Public License does not, and
|
360 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
361 |
+
conditions on any use of the Licensed Material that could lawfully
|
362 |
+
be made without permission under this Public License.
|
363 |
+
|
364 |
+
b. To the extent possible, if any provision of this Public License is
|
365 |
+
deemed unenforceable, it shall be automatically reformed to the
|
366 |
+
minimum extent necessary to make it enforceable. If the provision
|
367 |
+
cannot be reformed, it shall be severed from this Public License
|
368 |
+
without affecting the enforceability of the remaining terms and
|
369 |
+
conditions.
|
370 |
+
|
371 |
+
c. No term or condition of this Public License will be waived and no
|
372 |
+
failure to comply consented to unless expressly agreed to by the
|
373 |
+
Licensor.
|
374 |
+
|
375 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
376 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
377 |
+
that apply to the Licensor or You, including from the legal
|
378 |
+
processes of any jurisdiction or authority.
|
379 |
+
|
380 |
+
=======================================================================
|
381 |
+
|
382 |
+
Creative Commons is not a party to its public
|
383 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
384 |
+
its public licenses to material it publishes and in those instances
|
385 |
+
will be considered the “Licensor.” The text of the Creative Commons
|
386 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
387 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
388 |
+
material is shared under a Creative Commons public license or as
|
389 |
+
otherwise permitted by the Creative Commons policies published at
|
390 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
391 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
392 |
+
of Creative Commons without its prior written consent including,
|
393 |
+
without limitation, in connection with any unauthorized modifications
|
394 |
+
to any of its public licenses or any other arrangements,
|
395 |
+
understandings, or agreements concerning use of licensed material. For
|
396 |
+
the avoidance of doubt, this paragraph does not form part of the
|
397 |
+
public licenses.
|
398 |
+
|
399 |
+
Creative Commons may be contacted at creativecommons.org.
|
MANIFEST.in
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
include Makefile
|
2 |
+
include LICENSE
|
3 |
+
include LICENSE_weights
|
4 |
+
include *.md
|
5 |
+
include *.ini
|
6 |
+
include requirements.txt
|
7 |
+
include audiocraft/py.typed
|
8 |
+
include assets/*.mp3
|
9 |
+
recursive-include conf *.yaml
|
Makefile
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
INTEG=AUDIOCRAFT_DORA_DIR="/tmp/magma_$(USER)" python3 -m dora -v run --clear device=cpu dataset.num_workers=0 optim.epochs=1 \
|
2 |
+
dataset.train.num_samples=10 dataset.valid.num_samples=10 \
|
3 |
+
dataset.evaluate.num_samples=10 dataset.generate.num_samples=2 sample_rate=16000 \
|
4 |
+
logging.level=DEBUG
|
5 |
+
INTEG_COMPRESSION = $(INTEG) solver=compression/debug rvq.n_q=2 rvq.bins=48 checkpoint.save_last=true # SIG is 616d7b3c
|
6 |
+
INTEG_MUSICGEN = $(INTEG) solver=musicgen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \
|
7 |
+
transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 616d7b3c
|
8 |
+
INTEG_AUDIOGEN = $(INTEG) solver=audiogen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \
|
9 |
+
transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 616d7b3c
|
10 |
+
INTEG_MBD = $(INTEG) solver=diffusion/debug dset=audio/example \
|
11 |
+
checkpoint.save_last=false # Using compression model from 616d7b3c
|
12 |
+
|
13 |
+
default: linter tests
|
14 |
+
|
15 |
+
install:
|
16 |
+
pip install -U pip
|
17 |
+
pip install -U -e '.[dev]'
|
18 |
+
|
19 |
+
linter:
|
20 |
+
flake8 audiocraft && mypy audiocraft
|
21 |
+
flake8 tests && mypy tests
|
22 |
+
|
23 |
+
tests:
|
24 |
+
coverage run -m pytest tests
|
25 |
+
coverage report
|
26 |
+
|
27 |
+
tests_integ:
|
28 |
+
$(INTEG_COMPRESSION)
|
29 |
+
$(INTEG_MBD)
|
30 |
+
$(INTEG_MUSICGEN)
|
31 |
+
$(INTEG_AUDIOGEN)
|
32 |
+
|
33 |
+
|
34 |
+
api_docs:
|
35 |
+
pdoc3 --html -o api_docs -f audiocraft
|
36 |
+
|
37 |
+
dist:
|
38 |
+
python setup.py sdist
|
39 |
+
|
40 |
+
.PHONY: linter tests api_docs dist
|
README.md
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
---
|
2 |
-
title: AudioCraft Plus
|
3 |
-
emoji:
|
4 |
colorFrom: yellow
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.39.0
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
license: mit
|
11 |
---
|
12 |
|
|
|
1 |
---
|
2 |
+
title: AudioCraft Plus v2.0.0a
|
3 |
+
emoji: 🎶
|
4 |
colorFrom: yellow
|
5 |
+
colorTo: green
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.39.0
|
8 |
app_file: app.py
|
9 |
+
pinned: yes
|
10 |
license: mit
|
11 |
---
|
12 |
|
app.py
ADDED
@@ -0,0 +1,1839 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Updated to account for UI changes from https://github.com/rkfg/audiocraft/blob/long/app.py
|
8 |
+
# also released under the MIT license.
|
9 |
+
|
10 |
+
import argparse
|
11 |
+
from concurrent.futures import ProcessPoolExecutor
|
12 |
+
import os
|
13 |
+
from pathlib import Path
|
14 |
+
import subprocess as sp
|
15 |
+
from tempfile import NamedTemporaryFile
|
16 |
+
import time
|
17 |
+
import warnings
|
18 |
+
import glob
|
19 |
+
import re
|
20 |
+
from PIL import Image
|
21 |
+
from pydub import AudioSegment
|
22 |
+
from datetime import datetime
|
23 |
+
|
24 |
+
import json
|
25 |
+
import shutil
|
26 |
+
import taglib
|
27 |
+
import torch
|
28 |
+
import torchaudio
|
29 |
+
import gradio as gr
|
30 |
+
import numpy as np
|
31 |
+
import typing as tp
|
32 |
+
|
33 |
+
from audiocraft.data.audio_utils import convert_audio
|
34 |
+
from audiocraft.data.audio import audio_write
|
35 |
+
from audiocraft.models import AudioGen, MusicGen, MultiBandDiffusion
|
36 |
+
from audiocraft.utils import ui
|
37 |
+
import random, string
|
38 |
+
|
39 |
+
version = "2.0.0a"
|
40 |
+
|
41 |
+
theme = gr.themes.Base(
|
42 |
+
primary_hue="lime",
|
43 |
+
secondary_hue="lime",
|
44 |
+
neutral_hue="neutral",
|
45 |
+
).set(
|
46 |
+
button_primary_background_fill_hover='*primary_500',
|
47 |
+
button_primary_background_fill_hover_dark='*primary_500',
|
48 |
+
button_secondary_background_fill_hover='*primary_500',
|
49 |
+
button_secondary_background_fill_hover_dark='*primary_500'
|
50 |
+
)
|
51 |
+
|
52 |
+
MODEL = None # Last used model
|
53 |
+
MODELS = None
|
54 |
+
UNLOAD_MODEL = False
|
55 |
+
MOVE_TO_CPU = False
|
56 |
+
IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '')
|
57 |
+
print(IS_BATCHED)
|
58 |
+
MAX_BATCH_SIZE = 12
|
59 |
+
BATCHED_DURATION = 15
|
60 |
+
INTERRUPTING = False
|
61 |
+
MBD = None
|
62 |
+
# We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
|
63 |
+
_old_call = sp.call
|
64 |
+
|
65 |
+
|
66 |
+
def generate_random_string(length):
|
67 |
+
characters = string.ascii_letters + string.digits
|
68 |
+
return ''.join(random.choice(characters) for _ in range(length))
|
69 |
+
|
70 |
+
|
71 |
+
def resize_video(input_path, output_path, target_width, target_height):
|
72 |
+
ffmpeg_cmd = [
|
73 |
+
'ffmpeg',
|
74 |
+
'-y',
|
75 |
+
'-i', input_path,
|
76 |
+
'-vf', f'scale={target_width}:{target_height}',
|
77 |
+
'-c:a', 'copy',
|
78 |
+
output_path
|
79 |
+
]
|
80 |
+
sp.run(ffmpeg_cmd)
|
81 |
+
|
82 |
+
|
83 |
+
def _call_nostderr(*args, **kwargs):
|
84 |
+
# Avoid ffmpeg vomiting on the logs.
|
85 |
+
kwargs['stderr'] = sp.DEVNULL
|
86 |
+
kwargs['stdout'] = sp.DEVNULL
|
87 |
+
_old_call(*args, **kwargs)
|
88 |
+
|
89 |
+
|
90 |
+
sp.call = _call_nostderr
|
91 |
+
# Preallocating the pool of processes.
|
92 |
+
pool = ProcessPoolExecutor(4)
|
93 |
+
pool.__enter__()
|
94 |
+
|
95 |
+
|
96 |
+
def interrupt():
|
97 |
+
global INTERRUPTING
|
98 |
+
INTERRUPTING = True
|
99 |
+
|
100 |
+
|
101 |
+
class FileCleaner:
|
102 |
+
def __init__(self, file_lifetime: float = 3600):
|
103 |
+
self.file_lifetime = file_lifetime
|
104 |
+
self.files = []
|
105 |
+
|
106 |
+
def add(self, path: tp.Union[str, Path]):
|
107 |
+
self._cleanup()
|
108 |
+
self.files.append((time.time(), Path(path)))
|
109 |
+
|
110 |
+
def _cleanup(self):
|
111 |
+
now = time.time()
|
112 |
+
for time_added, path in list(self.files):
|
113 |
+
if now - time_added > self.file_lifetime:
|
114 |
+
if path.exists():
|
115 |
+
path.unlink()
|
116 |
+
self.files.pop(0)
|
117 |
+
else:
|
118 |
+
break
|
119 |
+
|
120 |
+
|
121 |
+
file_cleaner = FileCleaner()
|
122 |
+
|
123 |
+
|
124 |
+
def make_waveform(*args, **kwargs):
|
125 |
+
# Further remove some warnings.
|
126 |
+
be = time.time()
|
127 |
+
with warnings.catch_warnings():
|
128 |
+
warnings.simplefilter('ignore')
|
129 |
+
height = kwargs.pop('height')
|
130 |
+
width = kwargs.pop('width')
|
131 |
+
if height < 256:
|
132 |
+
height = 256
|
133 |
+
if width < 256:
|
134 |
+
width = 256
|
135 |
+
waveform_video = gr.make_waveform(*args, **kwargs)
|
136 |
+
out = f"{generate_random_string(12)}.mp4"
|
137 |
+
image = kwargs.get('bg_image', None)
|
138 |
+
if image is None:
|
139 |
+
resize_video(waveform_video, out, 900, 300)
|
140 |
+
else:
|
141 |
+
resize_video(waveform_video, out, width, height)
|
142 |
+
print("Make a video took", time.time() - be)
|
143 |
+
return out
|
144 |
+
|
145 |
+
|
146 |
+
def load_model(version='GrandaddyShmax/musicgen-melody', custom_model=None, base_model='GrandaddyShmax/musicgen-medium', gen_type="music"):
|
147 |
+
global MODEL, MODELS
|
148 |
+
print("Loading model", version)
|
149 |
+
if MODELS is None:
|
150 |
+
if version == 'GrandaddyShmax/musicgen-custom':
|
151 |
+
MODEL = MusicGen.get_pretrained(base_model)
|
152 |
+
file_path = os.path.abspath("models/" + str(custom_model) + ".pt")
|
153 |
+
MODEL.lm.load_state_dict(torch.load(file_path))
|
154 |
+
else:
|
155 |
+
if gen_type == "music":
|
156 |
+
MODEL = MusicGen.get_pretrained(version)
|
157 |
+
elif gen_type == "audio":
|
158 |
+
MODEL = AudioGen.get_pretrained(version)
|
159 |
+
|
160 |
+
return
|
161 |
+
|
162 |
+
else:
|
163 |
+
t1 = time.monotonic()
|
164 |
+
if MODEL is not None:
|
165 |
+
MODEL.to('cpu') # move to cache
|
166 |
+
print("Previous model moved to CPU in %.2fs" % (time.monotonic() - t1))
|
167 |
+
t1 = time.monotonic()
|
168 |
+
if version != 'GrandaddyShmax/musicgen-custom' and MODELS.get(version) is None:
|
169 |
+
print("Loading model %s from disk" % version)
|
170 |
+
if gen_type == "music":
|
171 |
+
result = MusicGen.get_pretrained(version)
|
172 |
+
elif gen_type == "audio":
|
173 |
+
result = AudioGen.get_pretrained(version)
|
174 |
+
MODELS[version] = result
|
175 |
+
print("Model loaded in %.2fs" % (time.monotonic() - t1))
|
176 |
+
MODEL = result
|
177 |
+
return
|
178 |
+
result = MODELS[version].to('cuda')
|
179 |
+
print("Cached model loaded in %.2fs" % (time.monotonic() - t1))
|
180 |
+
MODEL = result
|
181 |
+
|
182 |
+
def get_audio_info(audio_path):
|
183 |
+
if audio_path is not None:
|
184 |
+
if audio_path.name.endswith(".wav") or audio_path.name.endswith(".mp4") or audio_path.name.endswith(".json"):
|
185 |
+
if not audio_path.name.endswith(".json"):
|
186 |
+
with taglib.File(audio_path.name, save_on_exit=False) as song:
|
187 |
+
if 'COMMENT' not in song.tags:
|
188 |
+
return "No tags found. Either the file is not generated by MusicGen+ V1.2.7 and higher or the tags are corrupted. (Discord removes metadata from mp4 and wav files, so you can't use them)"
|
189 |
+
json_string = song.tags['COMMENT'][0]
|
190 |
+
data = json.loads(json_string)
|
191 |
+
global_prompt = str("\nGlobal Prompt: " + (data['global_prompt'] if data['global_prompt'] != "" else "none")) if 'global_prompt' in data else ""
|
192 |
+
bpm = str("\nBPM: " + data['bpm']) if 'bpm' in data else ""
|
193 |
+
key = str("\nKey: " + data['key']) if 'key' in data else ""
|
194 |
+
scale = str("\nScale: " + data['scale']) if 'scale' in data else ""
|
195 |
+
prompts = str("\nPrompts: " + (data['texts'] if data['texts'] != "['']" else "none")) if 'texts' in data else ""
|
196 |
+
duration = str("\nDuration: " + data['duration']) if 'duration' in data else ""
|
197 |
+
overlap = str("\nOverlap: " + data['overlap']) if 'overlap' in data else ""
|
198 |
+
seed = str("\nSeed: " + data['seed']) if 'seed' in data else ""
|
199 |
+
audio_mode = str("\nAudio Mode: " + data['audio_mode']) if 'audio_mode' in data else ""
|
200 |
+
input_length = str("\nInput Length: " + data['input_length']) if 'input_length' in data else ""
|
201 |
+
channel = str("\nChannel: " + data['channel']) if 'channel' in data else ""
|
202 |
+
sr_select = str("\nSample Rate: " + data['sr_select']) if 'sr_select' in data else ""
|
203 |
+
gen_type = str(data['generator'] + "gen-") if 'generator' in data else ""
|
204 |
+
model = str("\nModel: " + gen_type + data['model']) if 'model' in data else ""
|
205 |
+
custom_model = str("\nCustom Model: " + data['custom_model']) if 'custom_model' in data else ""
|
206 |
+
base_model = str("\nBase Model: " + data['base_model']) if 'base_model' in data else ""
|
207 |
+
decoder = str("\nDecoder: " + data['decoder']) if 'decoder' in data else ""
|
208 |
+
topk = str("\nTopk: " + data['topk']) if 'topk' in data else ""
|
209 |
+
topp = str("\nTopp: " + data['topp']) if 'topp' in data else ""
|
210 |
+
temperature = str("\nTemperature: " + data['temperature']) if 'temperature' in data else ""
|
211 |
+
cfg_coef = str("\nClassifier Free Guidance: " + data['cfg_coef']) if 'cfg_coef' in data else ""
|
212 |
+
version = str("Version: " + data['version']) if 'version' in data else "Version: Unknown"
|
213 |
+
info = str(version + global_prompt + bpm + key + scale + prompts + duration + overlap + seed + audio_mode + input_length + channel + sr_select + model + custom_model + base_model + decoder + topk + topp + temperature + cfg_coef)
|
214 |
+
if info == "":
|
215 |
+
return "No tags found. Either the file is not generated by MusicGen+ V1.2.7 and higher or the tags are corrupted. (Discord removes metadata from mp4 and wav files, so you can't use them)"
|
216 |
+
return info
|
217 |
+
else:
|
218 |
+
with open(audio_path.name) as json_file:
|
219 |
+
data = json.load(json_file)
|
220 |
+
#if 'global_prompt' not in data:
|
221 |
+
#return "No tags found. Either the file is not generated by MusicGen+ V1.2.8a and higher or the tags are corrupted."
|
222 |
+
global_prompt = str("\nGlobal Prompt: " + (data['global_prompt'] if data['global_prompt'] != "" else "none")) if 'global_prompt' in data else ""
|
223 |
+
bpm = str("\nBPM: " + data['bpm']) if 'bpm' in data else ""
|
224 |
+
key = str("\nKey: " + data['key']) if 'key' in data else ""
|
225 |
+
scale = str("\nScale: " + data['scale']) if 'scale' in data else ""
|
226 |
+
prompts = str("\nPrompts: " + (data['texts'] if data['texts'] != "['']" else "none")) if 'texts' in data else ""
|
227 |
+
duration = str("\nDuration: " + data['duration']) if 'duration' in data else ""
|
228 |
+
overlap = str("\nOverlap: " + data['overlap']) if 'overlap' in data else ""
|
229 |
+
seed = str("\nSeed: " + data['seed']) if 'seed' in data else ""
|
230 |
+
audio_mode = str("\nAudio Mode: " + data['audio_mode']) if 'audio_mode' in data else ""
|
231 |
+
input_length = str("\nInput Length: " + data['input_length']) if 'input_length' in data else ""
|
232 |
+
channel = str("\nChannel: " + data['channel']) if 'channel' in data else ""
|
233 |
+
sr_select = str("\nSample Rate: " + data['sr_select']) if 'sr_select' in data else ""
|
234 |
+
gen_type = str(data['generator'] + "gen-") if 'generator' in data else ""
|
235 |
+
model = str("\nModel: " + gen_type + data['model']) if 'model' in data else ""
|
236 |
+
custom_model = str("\nCustom Model: " + data['custom_model']) if 'custom_model' in data else ""
|
237 |
+
base_model = str("\nBase Model: " + data['base_model']) if 'base_model' in data else ""
|
238 |
+
decoder = str("\nDecoder: " + data['decoder']) if 'decoder' in data else ""
|
239 |
+
topk = str("\nTopk: " + data['topk']) if 'topk' in data else ""
|
240 |
+
topp = str("\nTopp: " + data['topp']) if 'topp' in data else ""
|
241 |
+
temperature = str("\nTemperature: " + data['temperature']) if 'temperature' in data else ""
|
242 |
+
cfg_coef = str("\nClassifier Free Guidance: " + data['cfg_coef']) if 'cfg_coef' in data else ""
|
243 |
+
version = str("Version: " + data['version']) if 'version' in data else "Version: Unknown"
|
244 |
+
info = str(version + global_prompt + bpm + key + scale + prompts + duration + overlap + seed + audio_mode + input_length + channel + sr_select + model + custom_model + base_model + decoder + topk + topp + temperature + cfg_coef)
|
245 |
+
if info == "":
|
246 |
+
return "No tags found. Either the file is not generated by MusicGen+ V1.2.7 and higher or the tags are corrupted."
|
247 |
+
return info
|
248 |
+
else:
|
249 |
+
return "Only .wav ,.mp4 and .json files are supported"
|
250 |
+
else:
|
251 |
+
return None
|
252 |
+
|
253 |
+
|
254 |
+
def info_to_params(audio_path):
|
255 |
+
if audio_path is not None:
|
256 |
+
if audio_path.name.endswith(".wav") or audio_path.name.endswith(".mp4") or audio_path.name.endswith(".json"):
|
257 |
+
if not audio_path.name.endswith(".json"):
|
258 |
+
with taglib.File(audio_path.name, save_on_exit=False) as song:
|
259 |
+
if 'COMMENT' not in song.tags:
|
260 |
+
return "Default", False, "", 120, "C", "Major", "large", None, "medium", 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, "sample", 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"
|
261 |
+
json_string = song.tags['COMMENT'][0]
|
262 |
+
data = json.loads(json_string)
|
263 |
+
struc_prompt = (False if data['bpm'] == "none" else True) if 'bpm' in data else False
|
264 |
+
global_prompt = data['global_prompt'] if 'global_prompt' in data else ""
|
265 |
+
bpm = (120 if data['bpm'] == "none" else int(data['bpm'])) if 'bpm' in data else 120
|
266 |
+
key = ("C" if data['key'] == "none" else data['key']) if 'key' in data else "C"
|
267 |
+
scale = ("Major" if data['scale'] == "none" else data['scale']) if 'scale' in data else "Major"
|
268 |
+
model = data['model'] if 'model' in data else "large"
|
269 |
+
custom_model = (data['custom_model'] if data['custom_model'] in get_available_models() else None) if 'custom_model' in data else None
|
270 |
+
base_model = data['base_model'] if 'base_model' in data else "medium"
|
271 |
+
decoder = data['decoder'] if 'decoder' in data else "Default"
|
272 |
+
if 'texts' not in data:
|
273 |
+
unique_prompts = 1
|
274 |
+
text = ["", "", "", "", "", "", "", "", "", ""]
|
275 |
+
repeat = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
|
276 |
+
else:
|
277 |
+
s = data['texts']
|
278 |
+
s = re.findall(r"'(.*?)'", s)
|
279 |
+
text = []
|
280 |
+
repeat = []
|
281 |
+
i = 0
|
282 |
+
for elem in s:
|
283 |
+
if elem.strip():
|
284 |
+
if i == 0 or elem != s[i-1]:
|
285 |
+
text.append(elem)
|
286 |
+
repeat.append(1)
|
287 |
+
else:
|
288 |
+
repeat[-1] += 1
|
289 |
+
i += 1
|
290 |
+
text.extend([""] * (10 - len(text)))
|
291 |
+
repeat.extend([1] * (10 - len(repeat)))
|
292 |
+
unique_prompts = len([t for t in text if t])
|
293 |
+
audio_mode = ("sample" if data['audio_mode'] == "none" else data['audio_mode']) if 'audio_mode' in data else "sample"
|
294 |
+
duration = int(data['duration']) if 'duration' in data else 10
|
295 |
+
topk = float(data['topk']) if 'topk' in data else 250
|
296 |
+
topp = float(data['topp']) if 'topp' in data else 0
|
297 |
+
temperature = float(data['temperature']) if 'temperature' in data else 1.0
|
298 |
+
cfg_coef = float(data['cfg_coef']) if 'cfg_coef' in data else 5.0
|
299 |
+
seed = int(data['seed']) if 'seed' in data else -1
|
300 |
+
overlap = int(data['overlap']) if 'overlap' in data else 12
|
301 |
+
channel = data['channel'] if 'channel' in data else "stereo"
|
302 |
+
sr_select = data['sr_select'] if 'sr_select' in data else "48000"
|
303 |
+
return decoder, struc_prompt, global_prompt, bpm, key, scale, model, custom_model, base_model, unique_prompts, text[0], text[1], text[2], text[3], text[4], text[5], text[6], text[7], text[8], text[9], repeat[0], repeat[1], repeat[2], repeat[3], repeat[4], repeat[5], repeat[6], repeat[7], repeat[8], repeat[9], audio_mode, duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select
|
304 |
+
else:
|
305 |
+
with open(audio_path.name) as json_file:
|
306 |
+
data = json.load(json_file)
|
307 |
+
struc_prompt = (False if data['bpm'] == "none" else True) if 'bpm' in data else False
|
308 |
+
global_prompt = data['global_prompt'] if 'global_prompt' in data else ""
|
309 |
+
bpm = (120 if data['bpm'] == "none" else int(data['bpm'])) if 'bpm' in data else 120
|
310 |
+
key = ("C" if data['key'] == "none" else data['key']) if 'key' in data else "C"
|
311 |
+
scale = ("Major" if data['scale'] == "none" else data['scale']) if 'scale' in data else "Major"
|
312 |
+
model = data['model'] if 'model' in data else "large"
|
313 |
+
custom_model = (data['custom_model'] if data['custom_model'] in get_available_models() else None) if 'custom_model' in data else None
|
314 |
+
base_model = data['base_model'] if 'base_model' in data else "medium"
|
315 |
+
decoder = data['decoder'] if 'decoder' in data else "Default"
|
316 |
+
if 'texts' not in data:
|
317 |
+
unique_prompts = 1
|
318 |
+
text = ["", "", "", "", "", "", "", "", "", ""]
|
319 |
+
repeat = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
|
320 |
+
else:
|
321 |
+
s = data['texts']
|
322 |
+
s = re.findall(r"'(.*?)'", s)
|
323 |
+
text = []
|
324 |
+
repeat = []
|
325 |
+
i = 0
|
326 |
+
for elem in s:
|
327 |
+
if elem.strip():
|
328 |
+
if i == 0 or elem != s[i-1]:
|
329 |
+
text.append(elem)
|
330 |
+
repeat.append(1)
|
331 |
+
else:
|
332 |
+
repeat[-1] += 1
|
333 |
+
i += 1
|
334 |
+
text.extend([""] * (10 - len(text)))
|
335 |
+
repeat.extend([1] * (10 - len(repeat)))
|
336 |
+
unique_prompts = len([t for t in text if t])
|
337 |
+
audio_mode = ("sample" if data['audio_mode'] == "none" else data['audio_mode']) if 'audio_mode' in data else "sample"
|
338 |
+
duration = int(data['duration']) if 'duration' in data else 10
|
339 |
+
topk = float(data['topk']) if 'topk' in data else 250
|
340 |
+
topp = float(data['topp']) if 'topp' in data else 0
|
341 |
+
temperature = float(data['temperature']) if 'temperature' in data else 1.0
|
342 |
+
cfg_coef = float(data['cfg_coef']) if 'cfg_coef' in data else 5.0
|
343 |
+
seed = int(data['seed']) if 'seed' in data else -1
|
344 |
+
overlap = int(data['overlap']) if 'overlap' in data else 12
|
345 |
+
channel = data['channel'] if 'channel' in data else "stereo"
|
346 |
+
sr_select = data['sr_select'] if 'sr_select' in data else "48000"
|
347 |
+
return decoder, struc_prompt, global_prompt, bpm, key, scale, model, custom_model, base_model, unique_prompts, text[0], text[1], text[2], text[3], text[4], text[5], text[6], text[7], text[8], text[9], repeat[0], repeat[1], repeat[2], repeat[3], repeat[4], repeat[5], repeat[6], repeat[7], repeat[8], repeat[9], audio_mode, duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select
|
348 |
+
else:
|
349 |
+
return "Default", False, "", 120, "C", "Major", "large", None, "medium", 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, "sample", 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"
|
350 |
+
else:
|
351 |
+
return "Default", False, "", 120, "C", "Major", "large", None, "medium", 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, "sample", 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"
|
352 |
+
|
353 |
+
|
354 |
+
def info_to_params_a(audio_path):
|
355 |
+
if audio_path is not None:
|
356 |
+
if audio_path.name.endswith(".wav") or audio_path.name.endswith(".mp4") or audio_path.name.endswith(".json"):
|
357 |
+
if not audio_path.name.endswith(".json"):
|
358 |
+
with taglib.File(audio_path.name, save_on_exit=False) as song:
|
359 |
+
if 'COMMENT' not in song.tags:
|
360 |
+
return "Default", False, "", 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"
|
361 |
+
json_string = song.tags['COMMENT'][0]
|
362 |
+
data = json.loads(json_string)
|
363 |
+
struc_prompt = (False if data['global_prompt'] == "" else True) if 'global_prompt' in data else False
|
364 |
+
global_prompt = data['global_prompt'] if 'global_prompt' in data else ""
|
365 |
+
decoder = data['decoder'] if 'decoder' in data else "Default"
|
366 |
+
if 'texts' not in data:
|
367 |
+
unique_prompts = 1
|
368 |
+
text = ["", "", "", "", "", "", "", "", "", ""]
|
369 |
+
repeat = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
|
370 |
+
else:
|
371 |
+
s = data['texts']
|
372 |
+
s = re.findall(r"'(.*?)'", s)
|
373 |
+
text = []
|
374 |
+
repeat = []
|
375 |
+
i = 0
|
376 |
+
for elem in s:
|
377 |
+
if elem.strip():
|
378 |
+
if i == 0 or elem != s[i-1]:
|
379 |
+
text.append(elem)
|
380 |
+
repeat.append(1)
|
381 |
+
else:
|
382 |
+
repeat[-1] += 1
|
383 |
+
i += 1
|
384 |
+
text.extend([""] * (10 - len(text)))
|
385 |
+
repeat.extend([1] * (10 - len(repeat)))
|
386 |
+
unique_prompts = len([t for t in text if t])
|
387 |
+
duration = int(data['duration']) if 'duration' in data else 10
|
388 |
+
topk = float(data['topk']) if 'topk' in data else 250
|
389 |
+
topp = float(data['topp']) if 'topp' in data else 0
|
390 |
+
temperature = float(data['temperature']) if 'temperature' in data else 1.0
|
391 |
+
cfg_coef = float(data['cfg_coef']) if 'cfg_coef' in data else 5.0
|
392 |
+
seed = int(data['seed']) if 'seed' in data else -1
|
393 |
+
overlap = int(data['overlap']) if 'overlap' in data else 12
|
394 |
+
channel = data['channel'] if 'channel' in data else "stereo"
|
395 |
+
sr_select = data['sr_select'] if 'sr_select' in data else "48000"
|
396 |
+
return decoder, struc_prompt, global_prompt, unique_prompts, text[0], text[1], text[2], text[3], text[4], text[5], text[6], text[7], text[8], text[9], repeat[0], repeat[1], repeat[2], repeat[3], repeat[4], repeat[5], repeat[6], repeat[7], repeat[8], repeat[9], duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select
|
397 |
+
else:
|
398 |
+
with open(audio_path.name) as json_file:
|
399 |
+
data = json.load(json_file)
|
400 |
+
struc_prompt = (False if data['global_prompt'] == "" else True) if 'global_prompt' in data else False
|
401 |
+
global_prompt = data['global_prompt'] if 'global_prompt' in data else ""
|
402 |
+
decoder = data['decoder'] if 'decoder' in data else "Default"
|
403 |
+
if 'texts' not in data:
|
404 |
+
unique_prompts = 1
|
405 |
+
text = ["", "", "", "", "", "", "", "", "", ""]
|
406 |
+
repeat = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
|
407 |
+
else:
|
408 |
+
s = data['texts']
|
409 |
+
s = re.findall(r"'(.*?)'", s)
|
410 |
+
text = []
|
411 |
+
repeat = []
|
412 |
+
i = 0
|
413 |
+
for elem in s:
|
414 |
+
if elem.strip():
|
415 |
+
if i == 0 or elem != s[i-1]:
|
416 |
+
text.append(elem)
|
417 |
+
repeat.append(1)
|
418 |
+
else:
|
419 |
+
repeat[-1] += 1
|
420 |
+
i += 1
|
421 |
+
text.extend([""] * (10 - len(text)))
|
422 |
+
repeat.extend([1] * (10 - len(repeat)))
|
423 |
+
unique_prompts = len([t for t in text if t])
|
424 |
+
duration = int(data['duration']) if 'duration' in data else 10
|
425 |
+
topk = float(data['topk']) if 'topk' in data else 250
|
426 |
+
topp = float(data['topp']) if 'topp' in data else 0
|
427 |
+
temperature = float(data['temperature']) if 'temperature' in data else 1.0
|
428 |
+
cfg_coef = float(data['cfg_coef']) if 'cfg_coef' in data else 5.0
|
429 |
+
seed = int(data['seed']) if 'seed' in data else -1
|
430 |
+
overlap = int(data['overlap']) if 'overlap' in data else 12
|
431 |
+
channel = data['channel'] if 'channel' in data else "stereo"
|
432 |
+
sr_select = data['sr_select'] if 'sr_select' in data else "48000"
|
433 |
+
return decoder, struc_prompt, global_prompt, unique_prompts, text[0], text[1], text[2], text[3], text[4], text[5], text[6], text[7], text[8], text[9], repeat[0], repeat[1], repeat[2], repeat[3], repeat[4], repeat[5], repeat[6], repeat[7], repeat[8], repeat[9], duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select
|
434 |
+
|
435 |
+
else:
|
436 |
+
return "Default", False, "", 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"
|
437 |
+
else:
|
438 |
+
return "Default", False, "", 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"
|
439 |
+
|
440 |
+
|
441 |
+
def make_pseudo_stereo (filename, sr_select, pan, delay):
|
442 |
+
if pan:
|
443 |
+
temp = AudioSegment.from_wav(filename)
|
444 |
+
if sr_select != "32000":
|
445 |
+
temp = temp.set_frame_rate(int(sr_select))
|
446 |
+
left = temp.pan(-0.5) - 5
|
447 |
+
right = temp.pan(0.6) - 5
|
448 |
+
temp = left.overlay(right, position=5)
|
449 |
+
temp.export(filename, format="wav")
|
450 |
+
if delay:
|
451 |
+
waveform, sample_rate = torchaudio.load(filename) # load mono WAV file
|
452 |
+
delay_seconds = 0.01 # set delay 10ms
|
453 |
+
delay_samples = int(delay_seconds * sample_rate) # Calculating delay value in number of samples
|
454 |
+
stereo_waveform = torch.stack([waveform[0], torch.cat((torch.zeros(delay_samples), waveform[0][:-delay_samples]))]) # Generate a stereo file with original mono audio and delayed version
|
455 |
+
torchaudio.save(filename, stereo_waveform, sample_rate)
|
456 |
+
return
|
457 |
+
|
458 |
+
|
459 |
+
def normalize_audio(audio_data):
|
460 |
+
audio_data = audio_data.astype(np.float32)
|
461 |
+
max_value = np.max(np.abs(audio_data))
|
462 |
+
audio_data /= max_value
|
463 |
+
return audio_data
|
464 |
+
|
465 |
+
|
466 |
+
def load_diffusion():
|
467 |
+
global MBD
|
468 |
+
if MBD is None:
|
469 |
+
print("loading MBD")
|
470 |
+
MBD = MultiBandDiffusion.get_mbd_musicgen()
|
471 |
+
|
472 |
+
|
473 |
+
def unload_diffusion():
|
474 |
+
global MBD
|
475 |
+
if MBD is not None:
|
476 |
+
print("unloading MBD")
|
477 |
+
MBD = None
|
478 |
+
|
479 |
+
|
480 |
+
def _do_predictions(gen_type, texts, melodies, sample, trim_start, trim_end, duration, image, height, width, background, bar1, bar2, channel, sr_select, progress=False, **gen_kwargs):
|
481 |
+
if gen_type == "music":
|
482 |
+
maximum_size = 29.5
|
483 |
+
elif gen_type == "audio":
|
484 |
+
maximum_size = 9.5
|
485 |
+
cut_size = 0
|
486 |
+
input_length = 0
|
487 |
+
sampleP = None
|
488 |
+
if sample is not None:
|
489 |
+
globalSR, sampleM = sample[0], sample[1]
|
490 |
+
sampleM = normalize_audio(sampleM)
|
491 |
+
sampleM = torch.from_numpy(sampleM).t()
|
492 |
+
if sampleM.dim() == 1:
|
493 |
+
sampleM = sampleM.unsqueeze(0)
|
494 |
+
sample_length = sampleM.shape[sampleM.dim() - 1] / globalSR
|
495 |
+
if trim_start >= sample_length:
|
496 |
+
trim_start = sample_length - 0.5
|
497 |
+
if trim_end >= sample_length:
|
498 |
+
trim_end = sample_length - 0.5
|
499 |
+
if trim_start + trim_end >= sample_length:
|
500 |
+
tmp = sample_length - 0.5
|
501 |
+
trim_start = tmp / 2
|
502 |
+
trim_end = tmp / 2
|
503 |
+
sampleM = sampleM[..., int(globalSR * trim_start):int(globalSR * (sample_length - trim_end))]
|
504 |
+
sample_length = sample_length - (trim_start + trim_end)
|
505 |
+
if sample_length > maximum_size:
|
506 |
+
cut_size = sample_length - maximum_size
|
507 |
+
sampleP = sampleM[..., :int(globalSR * cut_size)]
|
508 |
+
sampleM = sampleM[..., int(globalSR * cut_size):]
|
509 |
+
if sample_length >= duration:
|
510 |
+
duration = sample_length + 0.5
|
511 |
+
input_length = sample_length
|
512 |
+
global MODEL
|
513 |
+
MODEL.set_generation_params(duration=(duration - cut_size), **gen_kwargs)
|
514 |
+
print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies], [None if sample is None else (sample[0], sample[1].shape)])
|
515 |
+
be = time.time()
|
516 |
+
processed_melodies = []
|
517 |
+
if gen_type == "music":
|
518 |
+
target_sr = 32000
|
519 |
+
elif gen_type == "audio":
|
520 |
+
target_sr = 16000
|
521 |
+
target_ac = 1
|
522 |
+
|
523 |
+
for melody in melodies:
|
524 |
+
if melody is None:
|
525 |
+
processed_melodies.append(None)
|
526 |
+
else:
|
527 |
+
sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
|
528 |
+
if melody.dim() == 1:
|
529 |
+
melody = melody[None]
|
530 |
+
melody = melody[..., :int(sr * duration)]
|
531 |
+
melody = convert_audio(melody, sr, target_sr, target_ac)
|
532 |
+
processed_melodies.append(melody)
|
533 |
+
|
534 |
+
if sample is not None:
|
535 |
+
if sampleP is None:
|
536 |
+
if gen_type == "music":
|
537 |
+
outputs = MODEL.generate_continuation(
|
538 |
+
prompt=sampleM,
|
539 |
+
prompt_sample_rate=globalSR,
|
540 |
+
descriptions=texts,
|
541 |
+
progress=progress,
|
542 |
+
return_tokens=USE_DIFFUSION
|
543 |
+
)
|
544 |
+
elif gen_type == "audio":
|
545 |
+
outputs = MODEL.generate_continuation(
|
546 |
+
prompt=sampleM,
|
547 |
+
prompt_sample_rate=globalSR,
|
548 |
+
descriptions=texts,
|
549 |
+
progress=progress
|
550 |
+
)
|
551 |
+
else:
|
552 |
+
if sampleP.dim() > 1:
|
553 |
+
sampleP = convert_audio(sampleP, globalSR, target_sr, target_ac)
|
554 |
+
sampleP = sampleP.to(MODEL.device).float().unsqueeze(0)
|
555 |
+
if gen_type == "music":
|
556 |
+
outputs = MODEL.generate_continuation(
|
557 |
+
prompt=sampleM,
|
558 |
+
prompt_sample_rate=globalSR,
|
559 |
+
descriptions=texts,
|
560 |
+
progress=progress,
|
561 |
+
return_tokens=USE_DIFFUSION
|
562 |
+
)
|
563 |
+
elif gen_type == "audio":
|
564 |
+
outputs = MODEL.generate_continuation(
|
565 |
+
prompt=sampleM,
|
566 |
+
prompt_sample_rate=globalSR,
|
567 |
+
descriptions=texts,
|
568 |
+
progress=progress
|
569 |
+
)
|
570 |
+
outputs = torch.cat([sampleP, outputs], 2)
|
571 |
+
|
572 |
+
elif any(m is not None for m in processed_melodies):
|
573 |
+
if gen_type == "music":
|
574 |
+
outputs = MODEL.generate_with_chroma(
|
575 |
+
descriptions=texts,
|
576 |
+
melody_wavs=processed_melodies,
|
577 |
+
melody_sample_rate=target_sr,
|
578 |
+
progress=progress,
|
579 |
+
return_tokens=USE_DIFFUSION
|
580 |
+
)
|
581 |
+
elif gen_type == "audio":
|
582 |
+
outputs = MODEL.generate_with_chroma(
|
583 |
+
descriptions=texts,
|
584 |
+
melody_wavs=processed_melodies,
|
585 |
+
melody_sample_rate=target_sr,
|
586 |
+
progress=progress
|
587 |
+
)
|
588 |
+
else:
|
589 |
+
if gen_type == "music":
|
590 |
+
outputs = MODEL.generate(texts, progress=progress, return_tokens=USE_DIFFUSION)
|
591 |
+
elif gen_type == "audio":
|
592 |
+
outputs = MODEL.generate(texts, progress=progress)
|
593 |
+
|
594 |
+
if USE_DIFFUSION:
|
595 |
+
print("outputs: " + str(outputs))
|
596 |
+
outputs_diffusion = MBD.tokens_to_wav(outputs[1])
|
597 |
+
outputs = torch.cat([outputs[0], outputs_diffusion], dim=0)
|
598 |
+
outputs = outputs.detach().cpu().float()
|
599 |
+
backups = outputs
|
600 |
+
if channel == "stereo":
|
601 |
+
outputs = convert_audio(outputs, target_sr, int(sr_select), 2)
|
602 |
+
elif channel == "mono" and sr_select != "32000":
|
603 |
+
outputs = convert_audio(outputs, target_sr, int(sr_select), 1)
|
604 |
+
out_files = []
|
605 |
+
out_audios = []
|
606 |
+
out_backup = []
|
607 |
+
for output in outputs:
|
608 |
+
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
609 |
+
audio_write(
|
610 |
+
file.name, output, (MODEL.sample_rate if channel == "stereo effect" else int(sr_select)), strategy="loudness",
|
611 |
+
loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
|
612 |
+
|
613 |
+
if channel == "stereo effect":
|
614 |
+
make_pseudo_stereo(file.name, sr_select, pan=True, delay=True);
|
615 |
+
|
616 |
+
out_files.append(pool.submit(make_waveform, file.name, bg_image=image, bg_color=background, bars_color=(bar1, bar2), fg_alpha=1.0, bar_count=75, height=height, width=width))
|
617 |
+
out_audios.append(file.name)
|
618 |
+
file_cleaner.add(file.name)
|
619 |
+
print(f'wav: {file.name}')
|
620 |
+
for backup in backups:
|
621 |
+
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
622 |
+
audio_write(
|
623 |
+
file.name, backup, MODEL.sample_rate, strategy="loudness",
|
624 |
+
loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
|
625 |
+
out_backup.append(file.name)
|
626 |
+
file_cleaner.add(file.name)
|
627 |
+
res = [out_file.result() for out_file in out_files]
|
628 |
+
res_audio = out_audios
|
629 |
+
res_backup = out_backup
|
630 |
+
for file in res:
|
631 |
+
file_cleaner.add(file)
|
632 |
+
print(f'video: {file}')
|
633 |
+
print("batch finished", len(texts), time.time() - be)
|
634 |
+
print("Tempfiles currently stored: ", len(file_cleaner.files))
|
635 |
+
if MOVE_TO_CPU:
|
636 |
+
MODEL.to('cpu')
|
637 |
+
if UNLOAD_MODEL:
|
638 |
+
MODEL = None
|
639 |
+
torch.cuda.empty_cache()
|
640 |
+
torch.cuda.ipc_collect()
|
641 |
+
return res, res_audio, res_backup, input_length
|
642 |
+
|
643 |
+
|
644 |
+
def predict_batched(texts, melodies):
|
645 |
+
max_text_length = 512
|
646 |
+
texts = [text[:max_text_length] for text in texts]
|
647 |
+
load_model('melody')
|
648 |
+
res = _do_predictions(texts, melodies, BATCHED_DURATION)
|
649 |
+
return res
|
650 |
+
|
651 |
+
|
652 |
+
def add_tags(filename, tags):
|
653 |
+
json_string = None
|
654 |
+
|
655 |
+
data = {
|
656 |
+
"global_prompt": tags[0],
|
657 |
+
"bpm": tags[1],
|
658 |
+
"key": tags[2],
|
659 |
+
"scale": tags[3],
|
660 |
+
"texts": tags[4],
|
661 |
+
"duration": tags[5],
|
662 |
+
"overlap": tags[6],
|
663 |
+
"seed": tags[7],
|
664 |
+
"audio_mode": tags[8],
|
665 |
+
"input_length": tags[9],
|
666 |
+
"channel": tags[10],
|
667 |
+
"sr_select": tags[11],
|
668 |
+
"model": tags[12],
|
669 |
+
"custom_model": tags[13],
|
670 |
+
"base_model": tags[14],
|
671 |
+
"decoder": tags[15],
|
672 |
+
"topk": tags[16],
|
673 |
+
"topp": tags[17],
|
674 |
+
"temperature": tags[18],
|
675 |
+
"cfg_coef": tags[19],
|
676 |
+
"generator": tags[20],
|
677 |
+
"version": version
|
678 |
+
}
|
679 |
+
|
680 |
+
json_string = json.dumps(data)
|
681 |
+
|
682 |
+
if os.path.exists(filename):
|
683 |
+
with taglib.File(filename, save_on_exit=True) as song:
|
684 |
+
song.tags = {'COMMENT': json_string }
|
685 |
+
|
686 |
+
json_file = open(tags[7] + '.json', 'w')
|
687 |
+
json_file.write(json_string)
|
688 |
+
json_file.close()
|
689 |
+
|
690 |
+
return json_file.name;
|
691 |
+
|
692 |
+
|
693 |
+
def save_outputs(mp4, wav_tmp, tags, gen_type):
|
694 |
+
# mp4: .mp4 file name in root running folder of app.py
|
695 |
+
# wav_tmp: temporary wav file located in %TEMP% folder
|
696 |
+
# seed - used seed
|
697 |
+
# exanple BgnJtr4Pn1AJ.mp4, C:\Users\Alex\AppData\Local\Temp\tmp4ermrebs.wav, 195123182343465
|
698 |
+
# procedure read generated .mp4 and wav files, rename it by using seed as name,
|
699 |
+
# and will store it to ./output/today_date/wav and ./output/today_date/mp4 folders.
|
700 |
+
# if file with same seed number already exist its make postfix in name like seed(n)
|
701 |
+
# where is n - consiqunce number 1-2-3-4 and so on
|
702 |
+
# then we store generated mp4 and wav into destination folders.
|
703 |
+
|
704 |
+
current_date = datetime.now().strftime("%Y%m%d")
|
705 |
+
wav_directory = os.path.join(os.getcwd(), 'output', current_date, gen_type,'wav')
|
706 |
+
mp4_directory = os.path.join(os.getcwd(), 'output', current_date, gen_type,'mp4')
|
707 |
+
json_directory = os.path.join(os.getcwd(), 'output', current_date, gen_type,'json')
|
708 |
+
os.makedirs(wav_directory, exist_ok=True)
|
709 |
+
os.makedirs(mp4_directory, exist_ok=True)
|
710 |
+
os.makedirs(json_directory, exist_ok=True)
|
711 |
+
|
712 |
+
filename = str(tags[7]) + '.wav'
|
713 |
+
target = os.path.join(wav_directory, filename)
|
714 |
+
counter = 1
|
715 |
+
while os.path.exists(target):
|
716 |
+
filename = str(tags[7]) + f'({counter})' + '.wav'
|
717 |
+
target = os.path.join(wav_directory, filename)
|
718 |
+
counter += 1
|
719 |
+
|
720 |
+
shutil.copyfile(wav_tmp, target); # make copy of original file
|
721 |
+
json_file = add_tags(target, tags);
|
722 |
+
|
723 |
+
wav_target=target;
|
724 |
+
target=target.replace('wav', 'mp4');
|
725 |
+
mp4_target=target;
|
726 |
+
|
727 |
+
mp4=r'./' +mp4;
|
728 |
+
shutil.copyfile(mp4, target); # make copy of original file
|
729 |
+
_ = add_tags(target, tags);
|
730 |
+
|
731 |
+
target=target.replace('mp4', 'json'); # change the extension to json
|
732 |
+
json_target=target; # store the json target
|
733 |
+
|
734 |
+
with open(target, 'w') as f: # open a writable file object
|
735 |
+
shutil.copyfile(json_file, target); # make copy of original file
|
736 |
+
|
737 |
+
os.remove(json_file)
|
738 |
+
|
739 |
+
return wav_target, mp4_target, json_target;
|
740 |
+
|
741 |
+
|
742 |
+
def clear_cash():
|
743 |
+
# delete all temporary files genegated my system
|
744 |
+
current_date = datetime.now().date()
|
745 |
+
current_directory = os.getcwd()
|
746 |
+
files = glob.glob(os.path.join(current_directory, '*.mp4'))
|
747 |
+
for file in files:
|
748 |
+
creation_date = datetime.fromtimestamp(os.path.getctime(file)).date()
|
749 |
+
if creation_date == current_date:
|
750 |
+
os.remove(file)
|
751 |
+
|
752 |
+
temp_directory = os.environ.get('TEMP')
|
753 |
+
files = glob.glob(os.path.join(temp_directory, 'tmp*.mp4'))
|
754 |
+
for file in files:
|
755 |
+
creation_date = datetime.fromtimestamp(os.path.getctime(file)).date()
|
756 |
+
if creation_date == current_date:
|
757 |
+
os.remove(file)
|
758 |
+
|
759 |
+
files = glob.glob(os.path.join(temp_directory, 'tmp*.wav'))
|
760 |
+
for file in files:
|
761 |
+
creation_date = datetime.fromtimestamp(os.path.getctime(file)).date()
|
762 |
+
if creation_date == current_date:
|
763 |
+
os.remove(file)
|
764 |
+
|
765 |
+
files = glob.glob(os.path.join(temp_directory, 'tmp*.png'))
|
766 |
+
for file in files:
|
767 |
+
creation_date = datetime.fromtimestamp(os.path.getctime(file)).date()
|
768 |
+
if creation_date == current_date:
|
769 |
+
os.remove(file)
|
770 |
+
return
|
771 |
+
|
772 |
+
|
773 |
+
def s2t(seconds, seconds2):
|
774 |
+
# convert seconds to time format
|
775 |
+
# seconds - time in seconds
|
776 |
+
# return time in format 00:00
|
777 |
+
m, s = divmod(seconds, 60)
|
778 |
+
m2, s2 = divmod(seconds2, 60)
|
779 |
+
if seconds != 0 and seconds < seconds2:
|
780 |
+
s = s + 1
|
781 |
+
return ("%02d:%02d - %02d:%02d" % (m, s, m2, s2))
|
782 |
+
|
783 |
+
|
784 |
+
def calc_time(gen_type, s, duration, overlap, d0, d1, d2, d3, d4, d5, d6, d7, d8, d9):
|
785 |
+
# calculate the time of generation
|
786 |
+
# overlap - overlap in seconds
|
787 |
+
# d0-d9 - drag
|
788 |
+
# return time in seconds
|
789 |
+
d_amount = [int(d0), int(d1), int(d2), int(d3), int(d4), int(d5), int(d6), int(d7), int(d8), int(d9)]
|
790 |
+
calc = []
|
791 |
+
tracks = []
|
792 |
+
time = 0
|
793 |
+
s = s - 1
|
794 |
+
max_time = duration
|
795 |
+
max_limit = 0
|
796 |
+
if gen_type == "music":
|
797 |
+
max_limit = 30
|
798 |
+
elif gen_type == "audio":
|
799 |
+
max_limit = 10
|
800 |
+
track_add = max_limit - overlap
|
801 |
+
tracks.append(max_limit + ((d_amount[0] - 1) * track_add))
|
802 |
+
for i in range(1, 10):
|
803 |
+
tracks.append(d_amount[i] * track_add)
|
804 |
+
|
805 |
+
if tracks[0] >= max_time or s == 0:
|
806 |
+
calc.append(s2t(time, max_time))
|
807 |
+
time = max_time
|
808 |
+
else:
|
809 |
+
calc.append(s2t(time, tracks[0]))
|
810 |
+
time = tracks[0]
|
811 |
+
|
812 |
+
for i in range(1, 10):
|
813 |
+
if time + tracks[i] >= max_time or i == s:
|
814 |
+
calc.append(s2t(time, max_time))
|
815 |
+
time = max_time
|
816 |
+
else:
|
817 |
+
calc.append(s2t(time, time + tracks[i]))
|
818 |
+
time = time + tracks[i]
|
819 |
+
|
820 |
+
return calc[0], calc[1], calc[2], calc[3], calc[4], calc[5], calc[6], calc[7], calc[8], calc[9]
|
821 |
+
|
822 |
+
|
823 |
+
def predict_full(gen_type, model, decoder, custom_model, base_model, prompt_amount, struc_prompt, bpm, key, scale, global_prompt, p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, audio, mode, trim_start, trim_end, duration, topk, topp, temperature, cfg_coef, seed, overlap, image, height, width, background, bar1, bar2, channel, sr_select, progress=gr.Progress()):
|
824 |
+
global INTERRUPTING
|
825 |
+
global USE_DIFFUSION
|
826 |
+
INTERRUPTING = False
|
827 |
+
|
828 |
+
if gen_type == "audio":
|
829 |
+
custom_model = None
|
830 |
+
base_model = "medium"
|
831 |
+
|
832 |
+
if temperature < 0:
|
833 |
+
raise gr.Error("Temperature must be >= 0.")
|
834 |
+
if topk < 0:
|
835 |
+
raise gr.Error("Topk must be non-negative.")
|
836 |
+
if topp < 0:
|
837 |
+
raise gr.Error("Topp must be non-negative.")
|
838 |
+
|
839 |
+
if trim_start < 0:
|
840 |
+
trim_start = 0
|
841 |
+
if trim_end < 0:
|
842 |
+
trim_end = 0
|
843 |
+
|
844 |
+
topk = int(topk)
|
845 |
+
|
846 |
+
if decoder == "MultiBand_Diffusion":
|
847 |
+
USE_DIFFUSION = True
|
848 |
+
load_diffusion()
|
849 |
+
else:
|
850 |
+
USE_DIFFUSION = False
|
851 |
+
unload_diffusion()
|
852 |
+
|
853 |
+
if gen_type == "music":
|
854 |
+
model_shrt = model
|
855 |
+
model = "GrandaddyShmax/musicgen-" + model
|
856 |
+
elif gen_type == "audio":
|
857 |
+
model_shrt = model
|
858 |
+
model = "GrandaddyShmax/audiogen-" + model
|
859 |
+
base_model_shrt = base_model
|
860 |
+
base_model = "GrandaddyShmax/musicgen-" + base_model
|
861 |
+
|
862 |
+
if MODEL is None or MODEL.name != (model):
|
863 |
+
load_model(model, custom_model, base_model, gen_type)
|
864 |
+
else:
|
865 |
+
if MOVE_TO_CPU:
|
866 |
+
MODEL.to('cuda')
|
867 |
+
|
868 |
+
if seed < 0:
|
869 |
+
seed = random.randint(0, 0xffff_ffff_ffff)
|
870 |
+
torch.manual_seed(seed)
|
871 |
+
|
872 |
+
def _progress(generated, to_generate):
|
873 |
+
progress((min(generated, to_generate), to_generate))
|
874 |
+
if INTERRUPTING:
|
875 |
+
raise gr.Error("Interrupted.")
|
876 |
+
MODEL.set_custom_progress_callback(_progress)
|
877 |
+
|
878 |
+
audio_mode = "none"
|
879 |
+
melody = None
|
880 |
+
sample = None
|
881 |
+
if audio:
|
882 |
+
audio_mode = mode
|
883 |
+
if mode == "sample":
|
884 |
+
sample = audio
|
885 |
+
elif mode == "melody":
|
886 |
+
melody = audio
|
887 |
+
|
888 |
+
base_model = "none" if model != "custom" else base_model
|
889 |
+
custom_model = "none" if model != "custom" else custom_model
|
890 |
+
|
891 |
+
text_cat = [p0, p1, p2, p3, p4, p5, p6, p7, p8, p9]
|
892 |
+
drag_cat = [d0, d1, d2, d3, d4, d5, d6, d7, d8, d9]
|
893 |
+
texts = []
|
894 |
+
raw_texts = []
|
895 |
+
ind = 0
|
896 |
+
ind2 = 0
|
897 |
+
while ind < prompt_amount:
|
898 |
+
for ind2 in range(int(drag_cat[ind])):
|
899 |
+
if not struc_prompt:
|
900 |
+
texts.append(text_cat[ind])
|
901 |
+
global_prompt = "none"
|
902 |
+
bpm = "none"
|
903 |
+
key = "none"
|
904 |
+
scale = "none"
|
905 |
+
raw_texts.append(text_cat[ind])
|
906 |
+
else:
|
907 |
+
if gen_type == "music":
|
908 |
+
bpm_str = str(bpm) + " bpm"
|
909 |
+
key_str = ", " + str(key) + " " + str(scale)
|
910 |
+
global_str = (", " + str(global_prompt)) if str(global_prompt) != "" else ""
|
911 |
+
elif gen_type == "audio":
|
912 |
+
bpm_str = ""
|
913 |
+
key_str = ""
|
914 |
+
global_str = (str(global_prompt)) if str(global_prompt) != "" else ""
|
915 |
+
texts_str = (", " + str(text_cat[ind])) if str(text_cat[ind]) != "" else ""
|
916 |
+
texts.append(bpm_str + key_str + global_str + texts_str)
|
917 |
+
raw_texts.append(text_cat[ind])
|
918 |
+
ind2 = 0
|
919 |
+
ind = ind + 1
|
920 |
+
|
921 |
+
outs, outs_audio, outs_backup, input_length = _do_predictions(
|
922 |
+
gen_type, [texts], [melody], sample, trim_start, trim_end, duration, image, height, width, background, bar1, bar2, channel, sr_select, progress=True,
|
923 |
+
top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef, extend_stride=MODEL.max_duration-overlap)
|
924 |
+
tags = [str(global_prompt), str(bpm), str(key), str(scale), str(raw_texts), str(duration), str(overlap), str(seed), str(audio_mode), str(input_length), str(channel), str(sr_select), str(model_shrt), str(custom_model), str(base_model_shrt), str(decoder), str(topk), str(topp), str(temperature), str(cfg_coef), str(gen_type)]
|
925 |
+
wav_target, mp4_target, json_target = save_outputs(outs[0], outs_audio[0], tags, gen_type);
|
926 |
+
# Removes the temporary files.
|
927 |
+
for out in outs:
|
928 |
+
os.remove(out)
|
929 |
+
for out in outs_audio:
|
930 |
+
os.remove(out)
|
931 |
+
|
932 |
+
return mp4_target, wav_target, outs_backup[0], [mp4_target, wav_target, json_target], seed
|
933 |
+
|
934 |
+
|
935 |
+
max_textboxes = 10
|
936 |
+
|
937 |
+
|
938 |
+
def get_available_models():
|
939 |
+
return sorted([re.sub('.pt$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('.pt')])
|
940 |
+
|
941 |
+
|
942 |
+
def toggle_audio_src(choice):
|
943 |
+
if choice == "mic":
|
944 |
+
return gr.update(source="microphone", value=None, label="Microphone")
|
945 |
+
else:
|
946 |
+
return gr.update(source="upload", value=None, label="File")
|
947 |
+
|
948 |
+
|
949 |
+
def ui_full(launch_kwargs):
|
950 |
+
with gr.Blocks(title='AudioCraft Plus', theme=theme) as interface:
|
951 |
+
gr.Markdown(
|
952 |
+
"""
|
953 |
+
# AudioCraft Plus - v2.0.0a
|
954 |
+
|
955 |
+
### An All-in-One AudioCraft WebUI
|
956 |
+
|
957 |
+
#### **Disclaimer:** This will not run on CPU only. Its best to clone this App and run on GPU instance!
|
958 |
+
**Alternatively**, you can run this for free on a google colab:
|
959 |
+
https://colab.research.google.com/github/camenduru/MusicGen-colab/blob/main/MusicGen_ClownOfMadness_plus_colab.ipynb
|
960 |
+
|
961 |
+
**Or**, run this locally on your PC:
|
962 |
+
https://github.com/GrandaddyShmax/audiocraft_plus/tree/main
|
963 |
+
|
964 |
+
Thanks to: facebookresearch, Camenduru, rkfg, oobabooga, AlexHK and GrandaddyShmax
|
965 |
+
"""
|
966 |
+
)
|
967 |
+
with gr.Tab("MusicGen"):
|
968 |
+
gr.Markdown(
|
969 |
+
"""
|
970 |
+
### MusicGen
|
971 |
+
"""
|
972 |
+
)
|
973 |
+
with gr.Row():
|
974 |
+
with gr.Column():
|
975 |
+
with gr.Tab("Generation"):
|
976 |
+
with gr.Accordion("Structure Prompts", open=False):
|
977 |
+
with gr.Column():
|
978 |
+
with gr.Row():
|
979 |
+
struc_prompts = gr.Checkbox(label="Enable", value=False, interactive=True, container=False)
|
980 |
+
bpm = gr.Number(label="BPM", value=120, interactive=True, scale=1, precision=0)
|
981 |
+
key = gr.Dropdown(["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "Bb", "B"], label="Key", value="C", interactive=True)
|
982 |
+
scale = gr.Dropdown(["Major", "Minor"], label="Scale", value="Major", interactive=True)
|
983 |
+
with gr.Row():
|
984 |
+
global_prompt = gr.Text(label="Global Prompt", interactive=True, scale=3)
|
985 |
+
with gr.Row():
|
986 |
+
s = gr.Slider(1, max_textboxes, value=1, step=1, label="Prompts:", interactive=True, scale=2)
|
987 |
+
#s_mode = gr.Radio(["segmentation", "batch"], value="segmentation", interactive=True, scale=1, label="Generation Mode")
|
988 |
+
with gr.Column():
|
989 |
+
textboxes = []
|
990 |
+
prompts = []
|
991 |
+
repeats = []
|
992 |
+
calcs = []
|
993 |
+
with gr.Row():
|
994 |
+
text0 = gr.Text(label="Input Text", interactive=True, scale=4)
|
995 |
+
prompts.append(text0)
|
996 |
+
drag0 = gr.Number(label="Repeat", value=1, interactive=True, scale=1)
|
997 |
+
repeats.append(drag0)
|
998 |
+
calc0 = gr.Text(interactive=False, value="00:00 - 00:00", scale=1, label="Time")
|
999 |
+
calcs.append(calc0)
|
1000 |
+
for i in range(max_textboxes):
|
1001 |
+
with gr.Row(visible=False) as t:
|
1002 |
+
text = gr.Text(label="Input Text", interactive=True, scale=3)
|
1003 |
+
repeat = gr.Number(label="Repeat", minimum=1, value=1, interactive=True, scale=1)
|
1004 |
+
calc = gr.Text(interactive=False, value="00:00 - 00:00", scale=1, label="Time")
|
1005 |
+
textboxes.append(t)
|
1006 |
+
prompts.append(text)
|
1007 |
+
repeats.append(repeat)
|
1008 |
+
calcs.append(calc)
|
1009 |
+
to_calc = gr.Button("Calculate Timings", variant="secondary")
|
1010 |
+
with gr.Row():
|
1011 |
+
duration = gr.Slider(minimum=1, maximum=300, value=10, step=1, label="Duration", interactive=True)
|
1012 |
+
with gr.Row():
|
1013 |
+
overlap = gr.Slider(minimum=1, maximum=29, value=12, step=1, label="Overlap", interactive=True)
|
1014 |
+
with gr.Row():
|
1015 |
+
seed = gr.Number(label="Seed", value=-1, scale=4, precision=0, interactive=True)
|
1016 |
+
gr.Button('\U0001f3b2\ufe0f', scale=1).click(fn=lambda: -1, outputs=[seed], queue=False)
|
1017 |
+
reuse_seed = gr.Button('\u267b\ufe0f', scale=1)
|
1018 |
+
|
1019 |
+
with gr.Tab("Audio"):
|
1020 |
+
with gr.Row():
|
1021 |
+
with gr.Column():
|
1022 |
+
input_type = gr.Radio(["file", "mic"], value="file", label="Input Type (optional)", interactive=True)
|
1023 |
+
mode = gr.Radio(["melody", "sample"], label="Input Audio Mode (optional)", value="sample", interactive=True)
|
1024 |
+
with gr.Row():
|
1025 |
+
trim_start = gr.Number(label="Trim Start", value=0, interactive=True)
|
1026 |
+
trim_end = gr.Number(label="Trim End", value=0, interactive=True)
|
1027 |
+
audio = gr.Audio(source="upload", type="numpy", label="Input Audio (optional)", interactive=True)
|
1028 |
+
|
1029 |
+
with gr.Tab("Customization"):
|
1030 |
+
with gr.Row():
|
1031 |
+
with gr.Column():
|
1032 |
+
background = gr.ColorPicker(value="#0f0f0f", label="background color", interactive=True, scale=0)
|
1033 |
+
bar1 = gr.ColorPicker(value="#84cc16", label="bar color start", interactive=True, scale=0)
|
1034 |
+
bar2 = gr.ColorPicker(value="#10b981", label="bar color end", interactive=True, scale=0)
|
1035 |
+
with gr.Column():
|
1036 |
+
image = gr.Image(label="Background Image", type="filepath", interactive=True, scale=4)
|
1037 |
+
with gr.Row():
|
1038 |
+
height = gr.Number(label="Height", value=512, interactive=True)
|
1039 |
+
width = gr.Number(label="Width", value=768, interactive=True)
|
1040 |
+
|
1041 |
+
with gr.Tab("Settings"):
|
1042 |
+
with gr.Row():
|
1043 |
+
channel = gr.Radio(["mono", "stereo", "stereo effect"], label="Output Audio Channels", value="stereo", interactive=True, scale=1)
|
1044 |
+
sr_select = gr.Dropdown(["11025", "16000", "22050", "24000", "32000", "44100", "48000"], label="Output Audio Sample Rate", value="48000", interactive=True)
|
1045 |
+
with gr.Row():
|
1046 |
+
model = gr.Radio(["melody", "small", "medium", "large", "custom"], label="Model", value="large", interactive=True, scale=1)
|
1047 |
+
with gr.Column():
|
1048 |
+
dropdown = gr.Dropdown(choices=get_available_models(), value=("No models found" if len(get_available_models()) < 1 else get_available_models()[0]), label='Custom Model (models folder)', elem_classes='slim-dropdown', interactive=True)
|
1049 |
+
ui.create_refresh_button(dropdown, lambda: None, lambda: {'choices': get_available_models()}, 'refresh-button')
|
1050 |
+
basemodel = gr.Radio(["small", "medium", "melody", "large"], label="Base Model", value="medium", interactive=True, scale=1)
|
1051 |
+
with gr.Row():
|
1052 |
+
decoder = gr.Radio(["Default", "MultiBand_Diffusion"], label="Decoder", value="Default", interactive=True)
|
1053 |
+
with gr.Row():
|
1054 |
+
topk = gr.Number(label="Top-k", value=250, interactive=True)
|
1055 |
+
topp = gr.Number(label="Top-p", value=0, interactive=True)
|
1056 |
+
temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
|
1057 |
+
cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
|
1058 |
+
with gr.Row():
|
1059 |
+
submit = gr.Button("Generate", variant="primary")
|
1060 |
+
# Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
|
1061 |
+
_ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
|
1062 |
+
with gr.Column() as c:
|
1063 |
+
with gr.Tab("Output"):
|
1064 |
+
output = gr.Video(label="Generated Music", scale=0)
|
1065 |
+
with gr.Row():
|
1066 |
+
audio_only = gr.Audio(type="numpy", label="Audio Only", interactive=False)
|
1067 |
+
backup_only = gr.Audio(type="numpy", label="Backup Audio", interactive=False, visible=False)
|
1068 |
+
send_audio = gr.Button("Send to Input Audio")
|
1069 |
+
seed_used = gr.Number(label='Seed used', value=-1, interactive=False)
|
1070 |
+
download = gr.File(label="Generated Files", interactive=False)
|
1071 |
+
with gr.Tab("Wiki"):
|
1072 |
+
gr.Markdown(
|
1073 |
+
"""
|
1074 |
+
- **[Generate (button)]:**
|
1075 |
+
Generates the music with the given settings and prompts.
|
1076 |
+
|
1077 |
+
- **[Interrupt (button)]:**
|
1078 |
+
Stops the music generation as soon as it can, providing an incomplete output.
|
1079 |
+
|
1080 |
+
---
|
1081 |
+
|
1082 |
+
### Generation Tab:
|
1083 |
+
|
1084 |
+
#### Structure Prompts:
|
1085 |
+
|
1086 |
+
This feature helps reduce repetetive prompts by allowing you to set global prompts
|
1087 |
+
that will be used for all prompt segments.
|
1088 |
+
|
1089 |
+
- **[Structure Prompts (checkbox)]:**
|
1090 |
+
Enable/Disable the structure prompts feature.
|
1091 |
+
|
1092 |
+
- **[BPM (number)]:**
|
1093 |
+
Beats per minute of the generated music.
|
1094 |
+
|
1095 |
+
- **[Key (dropdown)]:**
|
1096 |
+
The key of the generated music.
|
1097 |
+
|
1098 |
+
- **[Scale (dropdown)]:**
|
1099 |
+
The scale of the generated music.
|
1100 |
+
|
1101 |
+
- **[Global Prompt (text)]:**
|
1102 |
+
Here write the prompt that you wish to be used for all prompt segments.
|
1103 |
+
|
1104 |
+
#### Multi-Prompt:
|
1105 |
+
|
1106 |
+
This feature allows you to control the music, adding variation to different time segments.
|
1107 |
+
You have up to 10 prompt segments. the first prompt will always be 30s long
|
1108 |
+
the other prompts will be [30s - overlap].
|
1109 |
+
for example if the overlap is 10s, each prompt segment will be 20s.
|
1110 |
+
|
1111 |
+
- **[Prompt Segments (number)]:**
|
1112 |
+
Amount of unique prompt to generate throughout the music generation.
|
1113 |
+
|
1114 |
+
- **[Prompt/Input Text (prompt)]:**
|
1115 |
+
Here describe the music you wish the model to generate.
|
1116 |
+
|
1117 |
+
- **[Repeat (number)]:**
|
1118 |
+
Write how many times this prompt will repeat (instead of wasting another prompt segment on the same prompt).
|
1119 |
+
|
1120 |
+
- **[Time (text)]:**
|
1121 |
+
The time of the prompt segment.
|
1122 |
+
|
1123 |
+
- **[Calculate Timings (button)]:**
|
1124 |
+
Calculates the timings of the prompt segments.
|
1125 |
+
|
1126 |
+
- **[Duration (number)]:**
|
1127 |
+
How long you want the generated music to be (in seconds).
|
1128 |
+
|
1129 |
+
- **[Overlap (number)]:**
|
1130 |
+
How much each new segment will reference the previous segment (in seconds).
|
1131 |
+
For example, if you choose 20s: Each new segment after the first one will reference the previous segment 20s
|
1132 |
+
and will generate only 10s of new music. The model can only process 30s of music.
|
1133 |
+
|
1134 |
+
- **[Seed (number)]:**
|
1135 |
+
Your generated music id. If you wish to generate the exact same music,
|
1136 |
+
place the exact seed with the exact prompts
|
1137 |
+
(This way you can also extend specific song that was generated short).
|
1138 |
+
|
1139 |
+
- **[Random Seed (button)]:**
|
1140 |
+
Gives "-1" as a seed, which counts as a random seed.
|
1141 |
+
|
1142 |
+
- **[Copy Previous Seed (button)]:**
|
1143 |
+
Copies the seed from the output seed (if you don't feel like doing it manualy).
|
1144 |
+
|
1145 |
+
---
|
1146 |
+
|
1147 |
+
### Audio Tab:
|
1148 |
+
|
1149 |
+
- **[Input Type (selection)]:**
|
1150 |
+
`File` mode allows you to upload an audio file to use as input
|
1151 |
+
`Mic` mode allows you to use your microphone as input
|
1152 |
+
|
1153 |
+
- **[Input Audio Mode (selection)]:**
|
1154 |
+
`Melody` mode only works with the melody model: it conditions the music generation to reference the melody
|
1155 |
+
`Sample` mode works with any model: it gives a music sample to the model to generate its continuation.
|
1156 |
+
|
1157 |
+
- **[Trim Start and Trim End (numbers)]:**
|
1158 |
+
`Trim Start` set how much you'd like to trim the input audio from the start
|
1159 |
+
`Trim End` same as the above but from the end
|
1160 |
+
|
1161 |
+
- **[Input Audio (audio file)]:**
|
1162 |
+
Input here the audio you wish to use with "melody" or "sample" mode.
|
1163 |
+
|
1164 |
+
---
|
1165 |
+
|
1166 |
+
### Customization Tab:
|
1167 |
+
|
1168 |
+
- **[Background Color (color)]:**
|
1169 |
+
Works only if you don't upload image. Color of the background of the waveform.
|
1170 |
+
|
1171 |
+
- **[Bar Color Start (color)]:**
|
1172 |
+
First color of the waveform bars.
|
1173 |
+
|
1174 |
+
- **[Bar Color End (color)]:**
|
1175 |
+
Second color of the waveform bars.
|
1176 |
+
|
1177 |
+
- **[Background Image (image)]:**
|
1178 |
+
Background image that you wish to be attached to the generated video along with the waveform.
|
1179 |
+
|
1180 |
+
- **[Height and Width (numbers)]:**
|
1181 |
+
Output video resolution, only works with image.
|
1182 |
+
(minimum height and width is 256).
|
1183 |
+
|
1184 |
+
---
|
1185 |
+
|
1186 |
+
### Settings Tab:
|
1187 |
+
|
1188 |
+
- **[Output Audio Channels (selection)]:**
|
1189 |
+
With this you can select the amount of channels that you wish for your output audio.
|
1190 |
+
`mono` is a straightforward single channel audio
|
1191 |
+
`stereo` is a dual channel audio but it will sound more or less like mono
|
1192 |
+
`stereo effect` this one is also dual channel but uses tricks to simulate a stereo audio.
|
1193 |
+
|
1194 |
+
- **[Output Audio Sample Rate (dropdown)]:**
|
1195 |
+
The output audio sample rate, the model default is 32000.
|
1196 |
+
|
1197 |
+
- **[Model (selection)]:**
|
1198 |
+
Here you can choose which model you wish to use:
|
1199 |
+
`melody` model is based on the medium model with a unique feature that lets you use melody conditioning
|
1200 |
+
`small` model is trained on 300M parameters
|
1201 |
+
`medium` model is trained on 1.5B parameters
|
1202 |
+
`large` model is trained on 3.3B parameters
|
1203 |
+
`custom` model runs the custom model that you provided.
|
1204 |
+
|
1205 |
+
- **[Custom Model (selection)]:**
|
1206 |
+
This dropdown will show you models that are placed in the `models` folder
|
1207 |
+
you must select `custom` in the model options in order to use it.
|
1208 |
+
|
1209 |
+
- **[Refresh (button)]:**
|
1210 |
+
Refreshes the dropdown list for custom model.
|
1211 |
+
|
1212 |
+
- **[Base Model (selection)]:**
|
1213 |
+
Choose here the model that your custom model is based on.
|
1214 |
+
|
1215 |
+
- **[Decoder (selection)]:**
|
1216 |
+
Choose here the decoder that you wish to use:
|
1217 |
+
`Default` is the default decoder
|
1218 |
+
`MultiBand_Diffusion` is a decoder that uses diffusion to generate the audio.
|
1219 |
+
|
1220 |
+
- **[Top-k (number)]:**
|
1221 |
+
is a parameter used in text generation models, including music generation models. It determines the number of most likely next tokens to consider at each step of the generation process. The model ranks all possible tokens based on their predicted probabilities, and then selects the top-k tokens from the ranked list. The model then samples from this reduced set of tokens to determine the next token in the generated sequence. A smaller value of k results in a more focused and deterministic output, while a larger value of k allows for more diversity in the generated music.
|
1222 |
+
|
1223 |
+
- **[Top-p (number)]:**
|
1224 |
+
also known as nucleus sampling or probabilistic sampling, is another method used for token selection during text generation. Instead of specifying a fixed number like top-k, top-p considers the cumulative probability distribution of the ranked tokens. It selects the smallest possible set of tokens whose cumulative probability exceeds a certain threshold (usually denoted as p). The model then samples from this set to choose the next token. This approach ensures that the generated output maintains a balance between diversity and coherence, as it allows for a varying number of tokens to be considered based on their probabilities.
|
1225 |
+
|
1226 |
+
- **[Temperature (number)]:**
|
1227 |
+
is a parameter that controls the randomness of the generated output. It is applied during the sampling process, where a higher temperature value results in more random and diverse outputs, while a lower temperature value leads to more deterministic and focused outputs. In the context of music generation, a higher temperature can introduce more variability and creativity into the generated music, but it may also lead to less coherent or structured compositions. On the other hand, a lower temperature can produce more repetitive and predictable music.
|
1228 |
+
|
1229 |
+
- **[Classifier Free Guidance (number)]:**
|
1230 |
+
refers to a technique used in some music generation models where a separate classifier network is trained to provide guidance or control over the generated music. This classifier is trained on labeled data to recognize specific musical characteristics or styles. During the generation process, the output of the generator model is evaluated by the classifier, and the generator is encouraged to produce music that aligns with the desired characteristics or style. This approach allows for more fine-grained control over the generated music, enabling users to specify certain attributes they want the model to capture.
|
1231 |
+
"""
|
1232 |
+
)
|
1233 |
+
with gr.Tab("AudioGen"):
|
1234 |
+
gr.Markdown(
|
1235 |
+
"""
|
1236 |
+
### AudioGen
|
1237 |
+
"""
|
1238 |
+
)
|
1239 |
+
with gr.Row():
|
1240 |
+
with gr.Column():
|
1241 |
+
with gr.Tab("Generation"):
|
1242 |
+
with gr.Accordion("Structure Prompts", open=False):
|
1243 |
+
with gr.Row():
|
1244 |
+
struc_prompts_a = gr.Checkbox(label="Enable", value=False, interactive=True, container=False)
|
1245 |
+
global_prompt_a = gr.Text(label="Global Prompt", interactive=True, scale=3)
|
1246 |
+
with gr.Row():
|
1247 |
+
s_a = gr.Slider(1, max_textboxes, value=1, step=1, label="Prompts:", interactive=True, scale=2)
|
1248 |
+
with gr.Column():
|
1249 |
+
textboxes_a = []
|
1250 |
+
prompts_a = []
|
1251 |
+
repeats_a = []
|
1252 |
+
calcs_a = []
|
1253 |
+
with gr.Row():
|
1254 |
+
text0_a = gr.Text(label="Input Text", interactive=True, scale=4)
|
1255 |
+
prompts_a.append(text0_a)
|
1256 |
+
drag0_a = gr.Number(label="Repeat", value=1, interactive=True, scale=1)
|
1257 |
+
repeats_a.append(drag0_a)
|
1258 |
+
calc0_a = gr.Text(interactive=False, value="00:00 - 00:00", scale=1, label="Time")
|
1259 |
+
calcs_a.append(calc0_a)
|
1260 |
+
for i in range(max_textboxes):
|
1261 |
+
with gr.Row(visible=False) as t_a:
|
1262 |
+
text_a = gr.Text(label="Input Text", interactive=True, scale=3)
|
1263 |
+
repeat_a = gr.Number(label="Repeat", minimum=1, value=1, interactive=True, scale=1)
|
1264 |
+
calc_a = gr.Text(interactive=False, value="00:00 - 00:00", scale=1, label="Time")
|
1265 |
+
textboxes_a.append(t_a)
|
1266 |
+
prompts_a.append(text_a)
|
1267 |
+
repeats_a.append(repeat_a)
|
1268 |
+
calcs_a.append(calc_a)
|
1269 |
+
to_calc_a = gr.Button("Calculate Timings", variant="secondary")
|
1270 |
+
with gr.Row():
|
1271 |
+
duration_a = gr.Slider(minimum=1, maximum=300, value=10, step=1, label="Duration", interactive=True)
|
1272 |
+
with gr.Row():
|
1273 |
+
overlap_a = gr.Slider(minimum=1, maximum=9, value=2, step=1, label="Overlap", interactive=True)
|
1274 |
+
with gr.Row():
|
1275 |
+
seed_a = gr.Number(label="Seed", value=-1, scale=4, precision=0, interactive=True)
|
1276 |
+
gr.Button('\U0001f3b2\ufe0f', scale=1).click(fn=lambda: -1, outputs=[seed_a], queue=False)
|
1277 |
+
reuse_seed_a = gr.Button('\u267b\ufe0f', scale=1)
|
1278 |
+
|
1279 |
+
with gr.Tab("Audio"):
|
1280 |
+
with gr.Row():
|
1281 |
+
with gr.Column():
|
1282 |
+
input_type_a = gr.Radio(["file", "mic"], value="file", label="Input Type (optional)", interactive=True)
|
1283 |
+
mode_a = gr.Radio(["sample"], label="Input Audio Mode (optional)", value="sample", interactive=False, visible=False)
|
1284 |
+
with gr.Row():
|
1285 |
+
trim_start_a = gr.Number(label="Trim Start", value=0, interactive=True)
|
1286 |
+
trim_end_a = gr.Number(label="Trim End", value=0, interactive=True)
|
1287 |
+
audio_a = gr.Audio(source="upload", type="numpy", label="Input Audio (optional)", interactive=True)
|
1288 |
+
|
1289 |
+
with gr.Tab("Customization"):
|
1290 |
+
with gr.Row():
|
1291 |
+
with gr.Column():
|
1292 |
+
background_a = gr.ColorPicker(value="#0f0f0f", label="background color", interactive=True, scale=0)
|
1293 |
+
bar1_a = gr.ColorPicker(value="#84cc16", label="bar color start", interactive=True, scale=0)
|
1294 |
+
bar2_a = gr.ColorPicker(value="#10b981", label="bar color end", interactive=True, scale=0)
|
1295 |
+
with gr.Column():
|
1296 |
+
image_a = gr.Image(label="Background Image", type="filepath", interactive=True, scale=4)
|
1297 |
+
with gr.Row():
|
1298 |
+
height_a = gr.Number(label="Height", value=512, interactive=True)
|
1299 |
+
width_a = gr.Number(label="Width", value=768, interactive=True)
|
1300 |
+
|
1301 |
+
with gr.Tab("Settings"):
|
1302 |
+
with gr.Row():
|
1303 |
+
channel_a = gr.Radio(["mono", "stereo", "stereo effect"], label="Output Audio Channels", value="stereo", interactive=True, scale=1)
|
1304 |
+
sr_select_a = gr.Dropdown(["11025", "16000", "22050", "24000", "32000", "44100", "48000"], label="Output Audio Sample Rate", value="48000", interactive=True)
|
1305 |
+
with gr.Row():
|
1306 |
+
model_a = gr.Radio(["medium"], label="Model", value="medium", interactive=False, visible=False)
|
1307 |
+
decoder_a = gr.Radio(["Default"], label="Decoder", value="Default", interactive=False, visible=False)
|
1308 |
+
with gr.Row():
|
1309 |
+
topk_a = gr.Number(label="Top-k", value=250, interactive=True)
|
1310 |
+
topp_a = gr.Number(label="Top-p", value=0, interactive=True)
|
1311 |
+
temperature_a = gr.Number(label="Temperature", value=1.0, interactive=True)
|
1312 |
+
cfg_coef_a = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
|
1313 |
+
with gr.Row():
|
1314 |
+
submit_a = gr.Button("Generate", variant="primary")
|
1315 |
+
_ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
|
1316 |
+
with gr.Column():
|
1317 |
+
with gr.Tab("Output"):
|
1318 |
+
output_a = gr.Video(label="Generated Audio", scale=0)
|
1319 |
+
with gr.Row():
|
1320 |
+
audio_only_a = gr.Audio(type="numpy", label="Audio Only", interactive=False)
|
1321 |
+
backup_only_a = gr.Audio(type="numpy", label="Backup Audio", interactive=False, visible=False)
|
1322 |
+
send_audio_a = gr.Button("Send to Input Audio")
|
1323 |
+
seed_used_a = gr.Number(label='Seed used', value=-1, interactive=False)
|
1324 |
+
download_a = gr.File(label="Generated Files", interactive=False)
|
1325 |
+
with gr.Tab("Wiki"):
|
1326 |
+
gr.Markdown(
|
1327 |
+
"""
|
1328 |
+
- **[Generate (button)]:**
|
1329 |
+
Generates the audio with the given settings and prompts.
|
1330 |
+
|
1331 |
+
- **[Interrupt (button)]:**
|
1332 |
+
Stops the audio generation as soon as it can, providing an incomplete output.
|
1333 |
+
|
1334 |
+
---
|
1335 |
+
|
1336 |
+
### Generation Tab:
|
1337 |
+
|
1338 |
+
#### Structure Prompts:
|
1339 |
+
|
1340 |
+
This feature helps reduce repetetive prompts by allowing you to set global prompts
|
1341 |
+
that will be used for all prompt segments.
|
1342 |
+
|
1343 |
+
- **[Structure Prompts (checkbox)]:**
|
1344 |
+
Enable/Disable the structure prompts feature.
|
1345 |
+
|
1346 |
+
- **[Global Prompt (text)]:**
|
1347 |
+
Here write the prompt that you wish to be used for all prompt segments.
|
1348 |
+
|
1349 |
+
#### Multi-Prompt:
|
1350 |
+
|
1351 |
+
This feature allows you to control the audio, adding variation to different time segments.
|
1352 |
+
You have up to 10 prompt segments. the first prompt will always be 10s long
|
1353 |
+
the other prompts will be [10s - overlap].
|
1354 |
+
for example if the overlap is 2s, each prompt segment will be 8s.
|
1355 |
+
|
1356 |
+
- **[Prompt Segments (number)]:**
|
1357 |
+
Amount of unique prompt to generate throughout the audio generation.
|
1358 |
+
|
1359 |
+
- **[Prompt/Input Text (prompt)]:**
|
1360 |
+
Here describe the audio you wish the model to generate.
|
1361 |
+
|
1362 |
+
- **[Repeat (number)]:**
|
1363 |
+
Write how many times this prompt will repeat (instead of wasting another prompt segment on the same prompt).
|
1364 |
+
|
1365 |
+
- **[Time (text)]:**
|
1366 |
+
The time of the prompt segment.
|
1367 |
+
|
1368 |
+
- **[Calculate Timings (button)]:**
|
1369 |
+
Calculates the timings of the prompt segments.
|
1370 |
+
|
1371 |
+
- **[Duration (number)]:**
|
1372 |
+
How long you want the generated audio to be (in seconds).
|
1373 |
+
|
1374 |
+
- **[Overlap (number)]:**
|
1375 |
+
How much each new segment will reference the previous segment (in seconds).
|
1376 |
+
For example, if you choose 2s: Each new segment after the first one will reference the previous segment 2s
|
1377 |
+
and will generate only 8s of new audio. The model can only process 10s of music.
|
1378 |
+
|
1379 |
+
- **[Seed (number)]:**
|
1380 |
+
Your generated audio id. If you wish to generate the exact same audio,
|
1381 |
+
place the exact seed with the exact prompts
|
1382 |
+
(This way you can also extend specific song that was generated short).
|
1383 |
+
|
1384 |
+
- **[Random Seed (button)]:**
|
1385 |
+
Gives "-1" as a seed, which counts as a random seed.
|
1386 |
+
|
1387 |
+
- **[Copy Previous Seed (button)]:**
|
1388 |
+
Copies the seed from the output seed (if you don't feel like doing it manualy).
|
1389 |
+
|
1390 |
+
---
|
1391 |
+
|
1392 |
+
### Audio Tab:
|
1393 |
+
|
1394 |
+
- **[Input Type (selection)]:**
|
1395 |
+
`File` mode allows you to upload an audio file to use as input
|
1396 |
+
`Mic` mode allows you to use your microphone as input
|
1397 |
+
|
1398 |
+
- **[Trim Start and Trim End (numbers)]:**
|
1399 |
+
`Trim Start` set how much you'd like to trim the input audio from the start
|
1400 |
+
`Trim End` same as the above but from the end
|
1401 |
+
|
1402 |
+
- **[Input Audio (audio file)]:**
|
1403 |
+
Input here the audio you wish to use.
|
1404 |
+
|
1405 |
+
---
|
1406 |
+
|
1407 |
+
### Customization Tab:
|
1408 |
+
|
1409 |
+
- **[Background Color (color)]:**
|
1410 |
+
Works only if you don't upload image. Color of the background of the waveform.
|
1411 |
+
|
1412 |
+
- **[Bar Color Start (color)]:**
|
1413 |
+
First color of the waveform bars.
|
1414 |
+
|
1415 |
+
- **[Bar Color End (color)]:**
|
1416 |
+
Second color of the waveform bars.
|
1417 |
+
|
1418 |
+
- **[Background Image (image)]:**
|
1419 |
+
Background image that you wish to be attached to the generated video along with the waveform.
|
1420 |
+
|
1421 |
+
- **[Height and Width (numbers)]:**
|
1422 |
+
Output video resolution, only works with image.
|
1423 |
+
(minimum height and width is 256).
|
1424 |
+
|
1425 |
+
---
|
1426 |
+
|
1427 |
+
### Settings Tab:
|
1428 |
+
|
1429 |
+
- **[Output Audio Channels (selection)]:**
|
1430 |
+
With this you can select the amount of channels that you wish for your output audio.
|
1431 |
+
`mono` is a straightforward single channel audio
|
1432 |
+
`stereo` is a dual channel audio but it will sound more or less like mono
|
1433 |
+
`stereo effect` this one is also dual channel but uses tricks to simulate a stereo audio.
|
1434 |
+
|
1435 |
+
- **[Output Audio Sample Rate (dropdown)]:**
|
1436 |
+
The output audio sample rate, the model default is 32000.
|
1437 |
+
|
1438 |
+
- **[Top-k (number)]:**
|
1439 |
+
is a parameter used in text generation models, including music generation models. It determines the number of most likely next tokens to consider at each step of the generation process. The model ranks all possible tokens based on their predicted probabilities, and then selects the top-k tokens from the ranked list. The model then samples from this reduced set of tokens to determine the next token in the generated sequence. A smaller value of k results in a more focused and deterministic output, while a larger value of k allows for more diversity in the generated music.
|
1440 |
+
|
1441 |
+
- **[Top-p (number)]:**
|
1442 |
+
also known as nucleus sampling or probabilistic sampling, is another method used for token selection during text generation. Instead of specifying a fixed number like top-k, top-p considers the cumulative probability distribution of the ranked tokens. It selects the smallest possible set of tokens whose cumulative probability exceeds a certain threshold (usually denoted as p). The model then samples from this set to choose the next token. This approach ensures that the generated output maintains a balance between diversity and coherence, as it allows for a varying number of tokens to be considered based on their probabilities.
|
1443 |
+
|
1444 |
+
- **[Temperature (number)]:**
|
1445 |
+
is a parameter that controls the randomness of the generated output. It is applied during the sampling process, where a higher temperature value results in more random and diverse outputs, while a lower temperature value leads to more deterministic and focused outputs. In the context of music generation, a higher temperature can introduce more variability and creativity into the generated music, but it may also lead to less coherent or structured compositions. On the other hand, a lower temperature can produce more repetitive and predictable music.
|
1446 |
+
|
1447 |
+
- **[Classifier Free Guidance (number)]:**
|
1448 |
+
refers to a technique used in some music generation models where a separate classifier network is trained to provide guidance or control over the generated music. This classifier is trained on labeled data to recognize specific musical characteristics or styles. During the generation process, the output of the generator model is evaluated by the classifier, and the generator is encouraged to produce music that aligns with the desired characteristics or style. This approach allows for more fine-grained control over the generated music, enabling users to specify certain attributes they want the model to capture.
|
1449 |
+
"""
|
1450 |
+
)
|
1451 |
+
with gr.Tab("Audio Info"):
|
1452 |
+
gr.Markdown(
|
1453 |
+
"""
|
1454 |
+
### Audio Info
|
1455 |
+
"""
|
1456 |
+
)
|
1457 |
+
with gr.Row():
|
1458 |
+
with gr.Column():
|
1459 |
+
in_audio = gr.File(type="file", label="Input Any Audio", interactive=True)
|
1460 |
+
with gr.Row():
|
1461 |
+
send_gen = gr.Button("Send to MusicGen", variant="primary")
|
1462 |
+
send_gen_a = gr.Button("Send to AudioGen", variant="primary")
|
1463 |
+
with gr.Column():
|
1464 |
+
info = gr.Textbox(label="Audio Info", lines=10, interactive=False)
|
1465 |
+
with gr.Tab("Changelog"):
|
1466 |
+
gr.Markdown(
|
1467 |
+
"""
|
1468 |
+
## Changelog:
|
1469 |
+
|
1470 |
+
### v2.0.0a
|
1471 |
+
|
1472 |
+
- Forgot to move all the update to app.py from temp2.py... oops
|
1473 |
+
|
1474 |
+
|
1475 |
+
|
1476 |
+
### v2.0.0
|
1477 |
+
|
1478 |
+
- Changed name from MusicGen+ to AudioCraft Plus
|
1479 |
+
|
1480 |
+
- Complete overhaul of the repo "backend" with the latest changes from the main facebookresearch repo
|
1481 |
+
|
1482 |
+
- Added a new decoder: MultiBand_Diffusion
|
1483 |
+
|
1484 |
+
- Added AudioGen: a new tab for generating audio
|
1485 |
+
|
1486 |
+
|
1487 |
+
|
1488 |
+
### v1.2.8c
|
1489 |
+
|
1490 |
+
- Implemented Reverse compatibility for audio info tab with previous versions
|
1491 |
+
|
1492 |
+
|
1493 |
+
|
1494 |
+
### v1.2.8b
|
1495 |
+
|
1496 |
+
- Fixed the error when loading default models
|
1497 |
+
|
1498 |
+
|
1499 |
+
|
1500 |
+
### v1.2.8a
|
1501 |
+
|
1502 |
+
- Adapted Audio info tab to work with the new structure prompts feature
|
1503 |
+
|
1504 |
+
- Now custom models actually work, make sure you select the correct base model
|
1505 |
+
|
1506 |
+
|
1507 |
+
|
1508 |
+
### v1.2.8
|
1509 |
+
|
1510 |
+
- Now you will also recieve json file with metadata of generated audio
|
1511 |
+
|
1512 |
+
- Added error messages in Audio Info tab
|
1513 |
+
|
1514 |
+
- Added structure prompts: you can select bpm, key and global prompt for all prompts
|
1515 |
+
|
1516 |
+
- Added time display next to each prompt, can be calculated with "Calculate Timings" button
|
1517 |
+
|
1518 |
+
|
1519 |
+
|
1520 |
+
### v1.2.7
|
1521 |
+
|
1522 |
+
- When sending generated audio to Input Audio, it will send a backup audio with default settings
|
1523 |
+
(best for continuos generation)
|
1524 |
+
|
1525 |
+
- Added Metadata to generated audio (Thanks to AlexHK ♥)
|
1526 |
+
|
1527 |
+
- Added Audio Info tab that will display the metadata of the input audio
|
1528 |
+
|
1529 |
+
- Added "send to Text2Audio" button in Audio Info tab
|
1530 |
+
|
1531 |
+
- Generated audio is now stored in the "output" folder (Thanks to AlexHK ♥)
|
1532 |
+
|
1533 |
+
- Added an output area with generated files and download buttons
|
1534 |
+
|
1535 |
+
- Enhanced Stereo effect (Thanks to AlexHK ♥)
|
1536 |
+
|
1537 |
+
|
1538 |
+
|
1539 |
+
### v1.2.6
|
1540 |
+
|
1541 |
+
- Added option to generate in stereo (instead of only mono)
|
1542 |
+
|
1543 |
+
- Added dropdown for selecting output sample rate (model default is 32000)
|
1544 |
+
|
1545 |
+
|
1546 |
+
|
1547 |
+
### v1.2.5a
|
1548 |
+
|
1549 |
+
- Added file cleaner (This comes from the main facebookresearch repo)
|
1550 |
+
|
1551 |
+
- Reorganized a little, moved audio to a seperate tab
|
1552 |
+
|
1553 |
+
|
1554 |
+
|
1555 |
+
### v1.2.5
|
1556 |
+
|
1557 |
+
- Gave a unique lime theme to the webui
|
1558 |
+
|
1559 |
+
- Added additional output for audio only
|
1560 |
+
|
1561 |
+
- Added button to send generated audio to Input Audio
|
1562 |
+
|
1563 |
+
- Added option to trim Input Audio
|
1564 |
+
|
1565 |
+
|
1566 |
+
|
1567 |
+
### v1.2.4
|
1568 |
+
|
1569 |
+
- Added mic input (This comes from the main facebookresearch repo)
|
1570 |
+
|
1571 |
+
|
1572 |
+
|
1573 |
+
### v1.2.3
|
1574 |
+
|
1575 |
+
- Added option to change video size to fit the image you upload
|
1576 |
+
|
1577 |
+
|
1578 |
+
|
1579 |
+
### v1.2.2
|
1580 |
+
|
1581 |
+
- Added Wiki, Changelog and About tabs
|
1582 |
+
|
1583 |
+
|
1584 |
+
|
1585 |
+
### v1.2.1
|
1586 |
+
|
1587 |
+
- Added tabs and organized the entire interface
|
1588 |
+
|
1589 |
+
- Added option to attach image to the output video
|
1590 |
+
|
1591 |
+
- Added option to load fine-tuned models (Yet to be tested)
|
1592 |
+
|
1593 |
+
|
1594 |
+
|
1595 |
+
### v1.2.0
|
1596 |
+
|
1597 |
+
- Added Multi-Prompt
|
1598 |
+
|
1599 |
+
|
1600 |
+
|
1601 |
+
### v1.1.3
|
1602 |
+
|
1603 |
+
- Added customization options for generated waveform
|
1604 |
+
|
1605 |
+
|
1606 |
+
|
1607 |
+
### v1.1.2
|
1608 |
+
|
1609 |
+
- Removed sample length limit: now you can input audio of any length as music sample
|
1610 |
+
|
1611 |
+
|
1612 |
+
|
1613 |
+
### v1.1.1
|
1614 |
+
|
1615 |
+
- Improved music sample audio quality when using music continuation
|
1616 |
+
|
1617 |
+
|
1618 |
+
|
1619 |
+
### v1.1.0
|
1620 |
+
|
1621 |
+
- Rebuilt the repo on top of the latest structure of the main MusicGen repo
|
1622 |
+
|
1623 |
+
- Improved Music continuation feature
|
1624 |
+
|
1625 |
+
|
1626 |
+
|
1627 |
+
### v1.0.0 - Stable Version
|
1628 |
+
|
1629 |
+
- Added Music continuation
|
1630 |
+
"""
|
1631 |
+
)
|
1632 |
+
with gr.Tab("About"):
|
1633 |
+
gen_type = gr.Text(value="music", interactive=False, visible=False)
|
1634 |
+
gen_type_a = gr.Text(value="audio", interactive=False, visible=False)
|
1635 |
+
gr.Markdown(
|
1636 |
+
"""
|
1637 |
+
This is your private demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
|
1638 |
+
presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
|
1639 |
+
|
1640 |
+
## MusicGen+ is an extended version of the original MusicGen by facebookresearch.
|
1641 |
+
|
1642 |
+
### Repo: https://github.com/GrandaddyShmax/audiocraft_plus/tree/plus
|
1643 |
+
|
1644 |
+
---
|
1645 |
+
|
1646 |
+
### This project was possible thanks to:
|
1647 |
+
|
1648 |
+
#### GrandaddyShmax - https://github.com/GrandaddyShmax
|
1649 |
+
|
1650 |
+
#### Camenduru - https://github.com/camenduru
|
1651 |
+
|
1652 |
+
#### rkfg - https://github.com/rkfg
|
1653 |
+
|
1654 |
+
#### oobabooga - https://github.com/oobabooga
|
1655 |
+
|
1656 |
+
#### AlexHK - https://github.com/alanhk147
|
1657 |
+
"""
|
1658 |
+
)
|
1659 |
+
|
1660 |
+
send_gen.click(info_to_params, inputs=[in_audio], outputs=[decoder, struc_prompts, global_prompt, bpm, key, scale, model, dropdown, basemodel, s, prompts[0], prompts[1], prompts[2], prompts[3], prompts[4], prompts[5], prompts[6], prompts[7], prompts[8], prompts[9], repeats[0], repeats[1], repeats[2], repeats[3], repeats[4], repeats[5], repeats[6], repeats[7], repeats[8], repeats[9], mode, duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select], queue=False)
|
1661 |
+
reuse_seed.click(fn=lambda x: x, inputs=[seed_used], outputs=[seed], queue=False)
|
1662 |
+
send_audio.click(fn=lambda x: x, inputs=[backup_only], outputs=[audio], queue=False)
|
1663 |
+
submit.click(predict_full, inputs=[gen_type, model, decoder, dropdown, basemodel, s, struc_prompts, bpm, key, scale, global_prompt, prompts[0], prompts[1], prompts[2], prompts[3], prompts[4], prompts[5], prompts[6], prompts[7], prompts[8], prompts[9], repeats[0], repeats[1], repeats[2], repeats[3], repeats[4], repeats[5], repeats[6], repeats[7], repeats[8], repeats[9], audio, mode, trim_start, trim_end, duration, topk, topp, temperature, cfg_coef, seed, overlap, image, height, width, background, bar1, bar2, channel, sr_select], outputs=[output, audio_only, backup_only, download, seed_used])
|
1664 |
+
input_type.change(toggle_audio_src, input_type, [audio], queue=False, show_progress=False)
|
1665 |
+
to_calc.click(calc_time, inputs=[gen_type, s, duration, overlap, repeats[0], repeats[1], repeats[2], repeats[3], repeats[4], repeats[5], repeats[6], repeats[7], repeats[8], repeats[9]], outputs=[calcs[0], calcs[1], calcs[2], calcs[3], calcs[4], calcs[5], calcs[6], calcs[7], calcs[8], calcs[9]], queue=False)
|
1666 |
+
|
1667 |
+
send_gen_a.click(info_to_params_a, inputs=[in_audio], outputs=[decoder_a, struc_prompts_a, global_prompt_a, s_a, prompts_a[0], prompts_a[1], prompts_a[2], prompts_a[3], prompts_a[4], prompts_a[5], prompts_a[6], prompts_a[7], prompts_a[8], prompts_a[9], repeats_a[0], repeats_a[1], repeats_a[2], repeats_a[3], repeats_a[4], repeats_a[5], repeats_a[6], repeats_a[7], repeats_a[8], repeats_a[9], duration_a, topk_a, topp_a, temperature_a, cfg_coef_a, seed_a, overlap_a, channel_a, sr_select_a], queue=False)
|
1668 |
+
reuse_seed_a.click(fn=lambda x: x, inputs=[seed_used_a], outputs=[seed_a], queue=False)
|
1669 |
+
send_audio_a.click(fn=lambda x: x, inputs=[backup_only_a], outputs=[audio_a], queue=False)
|
1670 |
+
submit_a.click(predict_full, inputs=[gen_type_a, model_a, decoder_a, dropdown, basemodel, s_a, struc_prompts_a, bpm, key, scale, global_prompt_a, prompts_a[0], prompts_a[1], prompts_a[2], prompts_a[3], prompts_a[4], prompts_a[5], prompts_a[6], prompts_a[7], prompts_a[8], prompts_a[9], repeats_a[0], repeats_a[1], repeats_a[2], repeats_a[3], repeats_a[4], repeats_a[5], repeats_a[6], repeats_a[7], repeats_a[8], repeats_a[9], audio_a, mode_a, trim_start_a, trim_end_a, duration_a, topk_a, topp_a, temperature_a, cfg_coef_a, seed_a, overlap_a, image_a, height_a, width_a, background_a, bar1_a, bar2_a, channel_a, sr_select_a], outputs=[output_a, audio_only_a, backup_only_a, download_a, seed_used_a])
|
1671 |
+
input_type_a.change(toggle_audio_src, input_type_a, [audio_a], queue=False, show_progress=False)
|
1672 |
+
to_calc_a.click(calc_time, inputs=[gen_type_a, s_a, duration_a, overlap_a, repeats_a[0], repeats_a[1], repeats_a[2], repeats_a[3], repeats_a[4], repeats_a[5], repeats_a[6], repeats_a[7], repeats_a[8], repeats_a[9]], outputs=[calcs_a[0], calcs_a[1], calcs_a[2], calcs_a[3], calcs_a[4], calcs_a[5], calcs_a[6], calcs_a[7], calcs_a[8], calcs_a[9]], queue=False)
|
1673 |
+
|
1674 |
+
in_audio.change(get_audio_info, in_audio, outputs=[info])
|
1675 |
+
|
1676 |
+
def variable_outputs(k):
|
1677 |
+
k = int(k) - 1
|
1678 |
+
return [gr.Textbox.update(visible=True)]*k + [gr.Textbox.update(visible=False)]*(max_textboxes-k)
|
1679 |
+
def get_size(image):
|
1680 |
+
if image is not None:
|
1681 |
+
img = Image.open(image)
|
1682 |
+
img_height = img.height
|
1683 |
+
img_width = img.width
|
1684 |
+
if (img_height%2) != 0:
|
1685 |
+
img_height = img_height + 1
|
1686 |
+
if (img_width%2) != 0:
|
1687 |
+
img_width = img_width + 1
|
1688 |
+
return img_height, img_width
|
1689 |
+
else:
|
1690 |
+
return 512, 768
|
1691 |
+
|
1692 |
+
image.change(get_size, image, outputs=[height, width])
|
1693 |
+
image_a.change(get_size, image_a, outputs=[height_a, width_a])
|
1694 |
+
s.change(variable_outputs, s, textboxes)
|
1695 |
+
s_a.change(variable_outputs, s_a, textboxes_a)
|
1696 |
+
interface.queue().launch(**launch_kwargs)
|
1697 |
+
|
1698 |
+
|
1699 |
+
def ui_batched(launch_kwargs):
|
1700 |
+
with gr.Blocks() as demo:
|
1701 |
+
gr.Markdown(
|
1702 |
+
"""
|
1703 |
+
# MusicGen
|
1704 |
+
|
1705 |
+
This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft),
|
1706 |
+
a simple and controllable model for music generation
|
1707 |
+
presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
|
1708 |
+
<br/>
|
1709 |
+
<a href="https://huggingface.co/spaces/facebook/MusicGen?duplicate=true"
|
1710 |
+
style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
|
1711 |
+
<img style="margin-bottom: 0em;display: inline;margin-top: -.25em;"
|
1712 |
+
src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
|
1713 |
+
for longer sequences, more control and no queue.</p>
|
1714 |
+
"""
|
1715 |
+
)
|
1716 |
+
with gr.Row():
|
1717 |
+
with gr.Column():
|
1718 |
+
with gr.Row():
|
1719 |
+
text = gr.Text(label="Describe your music", lines=2, interactive=True)
|
1720 |
+
with gr.Column():
|
1721 |
+
radio = gr.Radio(["file", "mic"], value="file",
|
1722 |
+
label="Condition on a melody (optional) File or Mic")
|
1723 |
+
melody = gr.Audio(source="upload", type="numpy", label="File",
|
1724 |
+
interactive=True, elem_id="melody-input")
|
1725 |
+
with gr.Row():
|
1726 |
+
submit = gr.Button("Generate")
|
1727 |
+
with gr.Column():
|
1728 |
+
output = gr.Video(label="Generated Music")
|
1729 |
+
audio_output = gr.Audio(label="Generated Music (wav)", type='filepath')
|
1730 |
+
submit.click(predict_batched, inputs=[text, melody],
|
1731 |
+
outputs=[output, audio_output], batch=True, max_batch_size=MAX_BATCH_SIZE)
|
1732 |
+
radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
|
1733 |
+
gr.Examples(
|
1734 |
+
fn=predict_batched,
|
1735 |
+
examples=[
|
1736 |
+
[
|
1737 |
+
"An 80s driving pop song with heavy drums and synth pads in the background",
|
1738 |
+
"./assets/bach.mp3",
|
1739 |
+
],
|
1740 |
+
[
|
1741 |
+
"A cheerful country song with acoustic guitars",
|
1742 |
+
"./assets/bolero_ravel.mp3",
|
1743 |
+
],
|
1744 |
+
[
|
1745 |
+
"90s rock song with electric guitar and heavy drums",
|
1746 |
+
None,
|
1747 |
+
],
|
1748 |
+
[
|
1749 |
+
"a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
|
1750 |
+
"./assets/bach.mp3",
|
1751 |
+
],
|
1752 |
+
[
|
1753 |
+
"lofi slow bpm electro chill with organic samples",
|
1754 |
+
None,
|
1755 |
+
],
|
1756 |
+
],
|
1757 |
+
inputs=[text, melody],
|
1758 |
+
outputs=[output]
|
1759 |
+
)
|
1760 |
+
gr.Markdown("""
|
1761 |
+
### More details
|
1762 |
+
|
1763 |
+
The model will generate 12 seconds of audio based on the description you provided.
|
1764 |
+
You can optionally provide a reference audio from which a broad melody will be extracted.
|
1765 |
+
The model will then try to follow both the description and melody provided.
|
1766 |
+
All samples are generated with the `melody` model.
|
1767 |
+
|
1768 |
+
You can also use your own GPU or a Google Colab by following the instructions on our repo.
|
1769 |
+
|
1770 |
+
See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
|
1771 |
+
for more details.
|
1772 |
+
""")
|
1773 |
+
|
1774 |
+
demo.queue(max_size=8 * 4).launch(**launch_kwargs)
|
1775 |
+
|
1776 |
+
|
1777 |
+
if __name__ == "__main__":
|
1778 |
+
parser = argparse.ArgumentParser()
|
1779 |
+
parser.add_argument(
|
1780 |
+
'--listen',
|
1781 |
+
type=str,
|
1782 |
+
default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1',
|
1783 |
+
help='IP to listen on for connections to Gradio',
|
1784 |
+
)
|
1785 |
+
parser.add_argument(
|
1786 |
+
'--username', type=str, default='', help='Username for authentication'
|
1787 |
+
)
|
1788 |
+
parser.add_argument(
|
1789 |
+
'--password', type=str, default='', help='Password for authentication'
|
1790 |
+
)
|
1791 |
+
parser.add_argument(
|
1792 |
+
'--server_port',
|
1793 |
+
type=int,
|
1794 |
+
default=0,
|
1795 |
+
help='Port to run the server listener on',
|
1796 |
+
)
|
1797 |
+
parser.add_argument(
|
1798 |
+
'--inbrowser', action='store_true', help='Open in browser'
|
1799 |
+
)
|
1800 |
+
parser.add_argument(
|
1801 |
+
'--share', action='store_true', help='Share the gradio UI'
|
1802 |
+
)
|
1803 |
+
parser.add_argument(
|
1804 |
+
'--unload_model', action='store_true', help='Unload the model after every generation to save GPU memory'
|
1805 |
+
)
|
1806 |
+
|
1807 |
+
parser.add_argument(
|
1808 |
+
'--unload_to_cpu', action='store_true', help='Move the model to main RAM after every generation to save GPU memory but reload faster than after full unload (see above)'
|
1809 |
+
)
|
1810 |
+
|
1811 |
+
parser.add_argument(
|
1812 |
+
'--cache', action='store_true', help='Cache models in RAM to quickly switch between them'
|
1813 |
+
)
|
1814 |
+
|
1815 |
+
args = parser.parse_args()
|
1816 |
+
UNLOAD_MODEL = args.unload_model
|
1817 |
+
MOVE_TO_CPU = args.unload_to_cpu
|
1818 |
+
if args.cache:
|
1819 |
+
MODELS = {}
|
1820 |
+
|
1821 |
+
launch_kwargs = {}
|
1822 |
+
launch_kwargs['server_name'] = args.listen
|
1823 |
+
|
1824 |
+
if args.username and args.password:
|
1825 |
+
launch_kwargs['auth'] = (args.username, args.password)
|
1826 |
+
if args.server_port:
|
1827 |
+
launch_kwargs['server_port'] = args.server_port
|
1828 |
+
if args.inbrowser:
|
1829 |
+
launch_kwargs['inbrowser'] = args.inbrowser
|
1830 |
+
if args.share:
|
1831 |
+
launch_kwargs['share'] = args.share
|
1832 |
+
|
1833 |
+
# Show the interface
|
1834 |
+
if IS_BATCHED:
|
1835 |
+
global USE_DIFFUSION
|
1836 |
+
USE_DIFFUSION = False
|
1837 |
+
ui_batched(launch_kwargs)
|
1838 |
+
else:
|
1839 |
+
ui_full(launch_kwargs)
|
assets/a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3
ADDED
Binary file (15.2 kB). View file
|
|
assets/bach.mp3
ADDED
Binary file (160 kB). View file
|
|
assets/bolero_ravel.mp3
ADDED
Binary file (161 kB). View file
|
|
assets/sirens_and_a_humming_engine_approach_and_pass.mp3
ADDED
Binary file (15.2 kB). View file
|
|
audiocraft/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
AudioCraft is a general framework for training audio generative models.
|
8 |
+
At the moment we provide the training code for:
|
9 |
+
|
10 |
+
- [MusicGen](https://arxiv.org/abs/2306.05284), a state-of-the-art
|
11 |
+
text-to-music and melody+text autoregressive generative model.
|
12 |
+
For the solver, see `audiocraft.solvers.musicgen.MusicGenSolver`, and for the model,
|
13 |
+
`audiocraft.models.musicgen.MusicGen`.
|
14 |
+
- [AudioGen](https://arxiv.org/abs/2209.15352), a state-of-the-art
|
15 |
+
text-to-general-audio generative model.
|
16 |
+
- [EnCodec](https://arxiv.org/abs/2210.13438), efficient and high fidelity
|
17 |
+
neural audio codec which provides an excellent tokenizer for autoregressive language models.
|
18 |
+
See `audiocraft.solvers.compression.CompressionSolver`, and `audiocraft.models.encodec.EncodecModel`.
|
19 |
+
- [MultiBandDiffusion](TODO), alternative diffusion-based decoder compatible with EnCodec that
|
20 |
+
improves the perceived quality and reduces the artifacts coming from adversarial decoders.
|
21 |
+
"""
|
22 |
+
|
23 |
+
# flake8: noqa
|
24 |
+
from . import data, modules, models
|
25 |
+
|
26 |
+
__version__ = '1.0.0'
|
audiocraft/adversarial/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Adversarial losses and discriminator architectures."""
|
7 |
+
|
8 |
+
# flake8: noqa
|
9 |
+
from .discriminators import (
|
10 |
+
MultiPeriodDiscriminator,
|
11 |
+
MultiScaleDiscriminator,
|
12 |
+
MultiScaleSTFTDiscriminator
|
13 |
+
)
|
14 |
+
from .losses import (
|
15 |
+
AdversarialLoss,
|
16 |
+
AdvLossType,
|
17 |
+
get_adv_criterion,
|
18 |
+
get_fake_criterion,
|
19 |
+
get_real_criterion,
|
20 |
+
FeatLossType,
|
21 |
+
FeatureMatchingLoss
|
22 |
+
)
|
audiocraft/adversarial/discriminators/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# flake8: noqa
|
8 |
+
from .mpd import MultiPeriodDiscriminator
|
9 |
+
from .msd import MultiScaleDiscriminator
|
10 |
+
from .msstftd import MultiScaleSTFTDiscriminator
|
audiocraft/adversarial/discriminators/base.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from abc import ABC, abstractmethod
|
8 |
+
import typing as tp
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
|
14 |
+
FeatureMapType = tp.List[torch.Tensor]
|
15 |
+
LogitsType = torch.Tensor
|
16 |
+
MultiDiscriminatorOutputType = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]
|
17 |
+
|
18 |
+
|
19 |
+
class MultiDiscriminator(ABC, nn.Module):
|
20 |
+
"""Base implementation for discriminators composed of sub-discriminators acting at different scales.
|
21 |
+
"""
|
22 |
+
def __init__(self):
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
@abstractmethod
|
26 |
+
def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
|
27 |
+
...
|
28 |
+
|
29 |
+
@property
|
30 |
+
@abstractmethod
|
31 |
+
def num_discriminators(self) -> int:
|
32 |
+
"""Number of discriminators.
|
33 |
+
"""
|
34 |
+
...
|
audiocraft/adversarial/discriminators/mpd.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
from ...modules import NormConv2d
|
14 |
+
from .base import MultiDiscriminator, MultiDiscriminatorOutputType
|
15 |
+
|
16 |
+
|
17 |
+
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
18 |
+
return int((kernel_size * dilation - dilation) / 2)
|
19 |
+
|
20 |
+
|
21 |
+
class PeriodDiscriminator(nn.Module):
|
22 |
+
"""Period sub-discriminator.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
period (int): Period between samples of audio.
|
26 |
+
in_channels (int): Number of input channels.
|
27 |
+
out_channels (int): Number of output channels.
|
28 |
+
n_layers (int): Number of convolutional layers.
|
29 |
+
kernel_sizes (list of int): Kernel sizes for convolutions.
|
30 |
+
stride (int): Stride for convolutions.
|
31 |
+
filters (int): Initial number of filters in convolutions.
|
32 |
+
filters_scale (int): Multiplier of number of filters as we increase depth.
|
33 |
+
max_filters (int): Maximum number of filters.
|
34 |
+
norm (str): Normalization method.
|
35 |
+
activation (str): Activation function.
|
36 |
+
activation_params (dict): Parameters to provide to the activation function.
|
37 |
+
"""
|
38 |
+
def __init__(self, period: int, in_channels: int = 1, out_channels: int = 1,
|
39 |
+
n_layers: int = 5, kernel_sizes: tp.List[int] = [5, 3], stride: int = 3,
|
40 |
+
filters: int = 8, filters_scale: int = 4, max_filters: int = 1024,
|
41 |
+
norm: str = 'weight_norm', activation: str = 'LeakyReLU',
|
42 |
+
activation_params: dict = {'negative_slope': 0.2}):
|
43 |
+
super().__init__()
|
44 |
+
self.period = period
|
45 |
+
self.n_layers = n_layers
|
46 |
+
self.activation = getattr(torch.nn, activation)(**activation_params)
|
47 |
+
self.convs = nn.ModuleList()
|
48 |
+
in_chs = in_channels
|
49 |
+
for i in range(self.n_layers):
|
50 |
+
out_chs = min(filters * (filters_scale ** (i + 1)), max_filters)
|
51 |
+
eff_stride = 1 if i == self.n_layers - 1 else stride
|
52 |
+
self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_sizes[0], 1), stride=(eff_stride, 1),
|
53 |
+
padding=((kernel_sizes[0] - 1) // 2, 0), norm=norm))
|
54 |
+
in_chs = out_chs
|
55 |
+
self.conv_post = NormConv2d(in_chs, out_channels, kernel_size=(kernel_sizes[1], 1), stride=1,
|
56 |
+
padding=((kernel_sizes[1] - 1) // 2, 0), norm=norm)
|
57 |
+
|
58 |
+
def forward(self, x: torch.Tensor):
|
59 |
+
fmap = []
|
60 |
+
# 1d to 2d
|
61 |
+
b, c, t = x.shape
|
62 |
+
if t % self.period != 0: # pad first
|
63 |
+
n_pad = self.period - (t % self.period)
|
64 |
+
x = F.pad(x, (0, n_pad), 'reflect')
|
65 |
+
t = t + n_pad
|
66 |
+
x = x.view(b, c, t // self.period, self.period)
|
67 |
+
|
68 |
+
for conv in self.convs:
|
69 |
+
x = conv(x)
|
70 |
+
x = self.activation(x)
|
71 |
+
fmap.append(x)
|
72 |
+
x = self.conv_post(x)
|
73 |
+
fmap.append(x)
|
74 |
+
# x = torch.flatten(x, 1, -1)
|
75 |
+
|
76 |
+
return x, fmap
|
77 |
+
|
78 |
+
|
79 |
+
class MultiPeriodDiscriminator(MultiDiscriminator):
|
80 |
+
"""Multi-Period (MPD) Discriminator.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
in_channels (int): Number of input channels.
|
84 |
+
out_channels (int): Number of output channels.
|
85 |
+
periods (Sequence[int]): Periods between samples of audio for the sub-discriminators.
|
86 |
+
**kwargs: Additional args for `PeriodDiscriminator`
|
87 |
+
"""
|
88 |
+
def __init__(self, in_channels: int = 1, out_channels: int = 1,
|
89 |
+
periods: tp.Sequence[int] = [2, 3, 5, 7, 11], **kwargs):
|
90 |
+
super().__init__()
|
91 |
+
self.discriminators = nn.ModuleList([
|
92 |
+
PeriodDiscriminator(p, in_channels, out_channels, **kwargs) for p in periods
|
93 |
+
])
|
94 |
+
|
95 |
+
@property
|
96 |
+
def num_discriminators(self):
|
97 |
+
return len(self.discriminators)
|
98 |
+
|
99 |
+
def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
|
100 |
+
logits = []
|
101 |
+
fmaps = []
|
102 |
+
for disc in self.discriminators:
|
103 |
+
logit, fmap = disc(x)
|
104 |
+
logits.append(logit)
|
105 |
+
fmaps.append(fmap)
|
106 |
+
return logits, fmaps
|
audiocraft/adversarial/discriminators/msd.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
from ...modules import NormConv1d
|
14 |
+
from .base import MultiDiscriminator, MultiDiscriminatorOutputType
|
15 |
+
|
16 |
+
|
17 |
+
class ScaleDiscriminator(nn.Module):
|
18 |
+
"""Waveform sub-discriminator.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
in_channels (int): Number of input channels.
|
22 |
+
out_channels (int): Number of output channels.
|
23 |
+
kernel_sizes (Sequence[int]): Kernel sizes for first and last convolutions.
|
24 |
+
filters (int): Number of initial filters for convolutions.
|
25 |
+
max_filters (int): Maximum number of filters.
|
26 |
+
downsample_scales (Sequence[int]): Scale for downsampling implemented as strided convolutions.
|
27 |
+
inner_kernel_sizes (Sequence[int] or None): Kernel sizes for inner convolutions.
|
28 |
+
groups (Sequence[int] or None): Groups for inner convolutions.
|
29 |
+
strides (Sequence[int] or None): Strides for inner convolutions.
|
30 |
+
paddings (Sequence[int] or None): Paddings for inner convolutions.
|
31 |
+
norm (str): Normalization method.
|
32 |
+
activation (str): Activation function.
|
33 |
+
activation_params (dict): Parameters to provide to the activation function.
|
34 |
+
pad (str): Padding for initial convolution.
|
35 |
+
pad_params (dict): Parameters to provide to the padding module.
|
36 |
+
"""
|
37 |
+
def __init__(self, in_channels=1, out_channels=1, kernel_sizes: tp.Sequence[int] = [5, 3],
|
38 |
+
filters: int = 16, max_filters: int = 1024, downsample_scales: tp.Sequence[int] = [4, 4, 4, 4],
|
39 |
+
inner_kernel_sizes: tp.Optional[tp.Sequence[int]] = None, groups: tp.Optional[tp.Sequence[int]] = None,
|
40 |
+
strides: tp.Optional[tp.Sequence[int]] = None, paddings: tp.Optional[tp.Sequence[int]] = None,
|
41 |
+
norm: str = 'weight_norm', activation: str = 'LeakyReLU',
|
42 |
+
activation_params: dict = {'negative_slope': 0.2}, pad: str = 'ReflectionPad1d',
|
43 |
+
pad_params: dict = {}):
|
44 |
+
super().__init__()
|
45 |
+
assert len(kernel_sizes) == 2
|
46 |
+
assert kernel_sizes[0] % 2 == 1
|
47 |
+
assert kernel_sizes[1] % 2 == 1
|
48 |
+
assert (inner_kernel_sizes is None or len(inner_kernel_sizes) == len(downsample_scales))
|
49 |
+
assert (groups is None or len(groups) == len(downsample_scales))
|
50 |
+
assert (strides is None or len(strides) == len(downsample_scales))
|
51 |
+
assert (paddings is None or len(paddings) == len(downsample_scales))
|
52 |
+
self.activation = getattr(torch.nn, activation)(**activation_params)
|
53 |
+
self.convs = nn.ModuleList()
|
54 |
+
self.convs.append(
|
55 |
+
nn.Sequential(
|
56 |
+
getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params),
|
57 |
+
NormConv1d(in_channels, filters, kernel_size=np.prod(kernel_sizes), stride=1, norm=norm)
|
58 |
+
)
|
59 |
+
)
|
60 |
+
|
61 |
+
in_chs = filters
|
62 |
+
for i, downsample_scale in enumerate(downsample_scales):
|
63 |
+
out_chs = min(in_chs * downsample_scale, max_filters)
|
64 |
+
default_kernel_size = downsample_scale * 10 + 1
|
65 |
+
default_stride = downsample_scale
|
66 |
+
default_padding = (default_kernel_size - 1) // 2
|
67 |
+
default_groups = in_chs // 4
|
68 |
+
self.convs.append(
|
69 |
+
NormConv1d(in_chs, out_chs,
|
70 |
+
kernel_size=inner_kernel_sizes[i] if inner_kernel_sizes else default_kernel_size,
|
71 |
+
stride=strides[i] if strides else default_stride,
|
72 |
+
groups=groups[i] if groups else default_groups,
|
73 |
+
padding=paddings[i] if paddings else default_padding,
|
74 |
+
norm=norm))
|
75 |
+
in_chs = out_chs
|
76 |
+
|
77 |
+
out_chs = min(in_chs * 2, max_filters)
|
78 |
+
self.convs.append(NormConv1d(in_chs, out_chs, kernel_size=kernel_sizes[0], stride=1,
|
79 |
+
padding=(kernel_sizes[0] - 1) // 2, norm=norm))
|
80 |
+
self.conv_post = NormConv1d(out_chs, out_channels, kernel_size=kernel_sizes[1], stride=1,
|
81 |
+
padding=(kernel_sizes[1] - 1) // 2, norm=norm)
|
82 |
+
|
83 |
+
def forward(self, x: torch.Tensor):
|
84 |
+
fmap = []
|
85 |
+
for layer in self.convs:
|
86 |
+
x = layer(x)
|
87 |
+
x = self.activation(x)
|
88 |
+
fmap.append(x)
|
89 |
+
x = self.conv_post(x)
|
90 |
+
fmap.append(x)
|
91 |
+
# x = torch.flatten(x, 1, -1)
|
92 |
+
return x, fmap
|
93 |
+
|
94 |
+
|
95 |
+
class MultiScaleDiscriminator(MultiDiscriminator):
|
96 |
+
"""Multi-Scale (MSD) Discriminator,
|
97 |
+
|
98 |
+
Args:
|
99 |
+
in_channels (int): Number of input channels.
|
100 |
+
out_channels (int): Number of output channels.
|
101 |
+
downsample_factor (int): Downsampling factor between the different scales.
|
102 |
+
scale_norms (Sequence[str]): Normalization for each sub-discriminator.
|
103 |
+
**kwargs: Additional args for ScaleDiscriminator.
|
104 |
+
"""
|
105 |
+
def __init__(self, in_channels: int = 1, out_channels: int = 1, downsample_factor: int = 2,
|
106 |
+
scale_norms: tp.Sequence[str] = ['weight_norm', 'weight_norm', 'weight_norm'], **kwargs):
|
107 |
+
super().__init__()
|
108 |
+
self.discriminators = nn.ModuleList([
|
109 |
+
ScaleDiscriminator(in_channels, out_channels, norm=norm, **kwargs) for norm in scale_norms
|
110 |
+
])
|
111 |
+
self.downsample = nn.AvgPool1d(downsample_factor * 2, downsample_factor, padding=downsample_factor)
|
112 |
+
|
113 |
+
@property
|
114 |
+
def num_discriminators(self):
|
115 |
+
return len(self.discriminators)
|
116 |
+
|
117 |
+
def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
|
118 |
+
logits = []
|
119 |
+
fmaps = []
|
120 |
+
for i, disc in enumerate(self.discriminators):
|
121 |
+
if i != 0:
|
122 |
+
self.downsample(x)
|
123 |
+
logit, fmap = disc(x)
|
124 |
+
logits.append(logit)
|
125 |
+
fmaps.append(fmap)
|
126 |
+
return logits, fmaps
|
audiocraft/adversarial/discriminators/msstftd.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
import torchaudio
|
10 |
+
import torch
|
11 |
+
from torch import nn
|
12 |
+
from einops import rearrange
|
13 |
+
|
14 |
+
from ...modules import NormConv2d
|
15 |
+
from .base import MultiDiscriminator, MultiDiscriminatorOutputType
|
16 |
+
|
17 |
+
|
18 |
+
def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)):
|
19 |
+
return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2)
|
20 |
+
|
21 |
+
|
22 |
+
class DiscriminatorSTFT(nn.Module):
|
23 |
+
"""STFT sub-discriminator.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
filters (int): Number of filters in convolutions.
|
27 |
+
in_channels (int): Number of input channels.
|
28 |
+
out_channels (int): Number of output channels.
|
29 |
+
n_fft (int): Size of FFT for each scale.
|
30 |
+
hop_length (int): Length of hop between STFT windows for each scale.
|
31 |
+
kernel_size (tuple of int): Inner Conv2d kernel sizes.
|
32 |
+
stride (tuple of int): Inner Conv2d strides.
|
33 |
+
dilations (list of int): Inner Conv2d dilation on the time dimension.
|
34 |
+
win_length (int): Window size for each scale.
|
35 |
+
normalized (bool): Whether to normalize by magnitude after stft.
|
36 |
+
norm (str): Normalization method.
|
37 |
+
activation (str): Activation function.
|
38 |
+
activation_params (dict): Parameters to provide to the activation function.
|
39 |
+
growth (int): Growth factor for the filters.
|
40 |
+
"""
|
41 |
+
def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1,
|
42 |
+
n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024,
|
43 |
+
filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4],
|
44 |
+
stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm',
|
45 |
+
activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}):
|
46 |
+
super().__init__()
|
47 |
+
assert len(kernel_size) == 2
|
48 |
+
assert len(stride) == 2
|
49 |
+
self.filters = filters
|
50 |
+
self.in_channels = in_channels
|
51 |
+
self.out_channels = out_channels
|
52 |
+
self.n_fft = n_fft
|
53 |
+
self.hop_length = hop_length
|
54 |
+
self.win_length = win_length
|
55 |
+
self.normalized = normalized
|
56 |
+
self.activation = getattr(torch.nn, activation)(**activation_params)
|
57 |
+
self.spec_transform = torchaudio.transforms.Spectrogram(
|
58 |
+
n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window,
|
59 |
+
normalized=self.normalized, center=False, pad_mode=None, power=None)
|
60 |
+
spec_channels = 2 * self.in_channels
|
61 |
+
self.convs = nn.ModuleList()
|
62 |
+
self.convs.append(
|
63 |
+
NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size))
|
64 |
+
)
|
65 |
+
in_chs = min(filters_scale * self.filters, max_filters)
|
66 |
+
for i, dilation in enumerate(dilations):
|
67 |
+
out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
|
68 |
+
self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride,
|
69 |
+
dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)),
|
70 |
+
norm=norm))
|
71 |
+
in_chs = out_chs
|
72 |
+
out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters)
|
73 |
+
self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]),
|
74 |
+
padding=get_2d_padding((kernel_size[0], kernel_size[0])),
|
75 |
+
norm=norm))
|
76 |
+
self.conv_post = NormConv2d(out_chs, self.out_channels,
|
77 |
+
kernel_size=(kernel_size[0], kernel_size[0]),
|
78 |
+
padding=get_2d_padding((kernel_size[0], kernel_size[0])),
|
79 |
+
norm=norm)
|
80 |
+
|
81 |
+
def forward(self, x: torch.Tensor):
|
82 |
+
fmap = []
|
83 |
+
z = self.spec_transform(x) # [B, 2, Freq, Frames, 2]
|
84 |
+
z = torch.cat([z.real, z.imag], dim=1)
|
85 |
+
z = rearrange(z, 'b c w t -> b c t w')
|
86 |
+
for i, layer in enumerate(self.convs):
|
87 |
+
z = layer(z)
|
88 |
+
z = self.activation(z)
|
89 |
+
fmap.append(z)
|
90 |
+
z = self.conv_post(z)
|
91 |
+
return z, fmap
|
92 |
+
|
93 |
+
|
94 |
+
class MultiScaleSTFTDiscriminator(MultiDiscriminator):
|
95 |
+
"""Multi-Scale STFT (MS-STFT) discriminator.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
filters (int): Number of filters in convolutions.
|
99 |
+
in_channels (int): Number of input channels.
|
100 |
+
out_channels (int): Number of output channels.
|
101 |
+
sep_channels (bool): Separate channels to distinct samples for stereo support.
|
102 |
+
n_ffts (Sequence[int]): Size of FFT for each scale.
|
103 |
+
hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale.
|
104 |
+
win_lengths (Sequence[int]): Window size for each scale.
|
105 |
+
**kwargs: Additional args for STFTDiscriminator.
|
106 |
+
"""
|
107 |
+
def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, sep_channels: bool = False,
|
108 |
+
n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128],
|
109 |
+
win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs):
|
110 |
+
super().__init__()
|
111 |
+
assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
|
112 |
+
self.sep_channels = sep_channels
|
113 |
+
self.discriminators = nn.ModuleList([
|
114 |
+
DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels,
|
115 |
+
n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs)
|
116 |
+
for i in range(len(n_ffts))
|
117 |
+
])
|
118 |
+
|
119 |
+
@property
|
120 |
+
def num_discriminators(self):
|
121 |
+
return len(self.discriminators)
|
122 |
+
|
123 |
+
def _separate_channels(self, x: torch.Tensor) -> torch.Tensor:
|
124 |
+
B, C, T = x.shape
|
125 |
+
return x.view(-1, 1, T)
|
126 |
+
|
127 |
+
def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
|
128 |
+
logits = []
|
129 |
+
fmaps = []
|
130 |
+
for disc in self.discriminators:
|
131 |
+
logit, fmap = disc(x)
|
132 |
+
logits.append(logit)
|
133 |
+
fmaps.append(fmap)
|
134 |
+
return logits, fmaps
|
audiocraft/adversarial/losses.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Utility module to handle adversarial losses without requiring to mess up the main training loop.
|
9 |
+
"""
|
10 |
+
|
11 |
+
import typing as tp
|
12 |
+
|
13 |
+
import flashy
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
|
19 |
+
ADVERSARIAL_LOSSES = ['mse', 'hinge', 'hinge2']
|
20 |
+
|
21 |
+
|
22 |
+
AdvLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor], torch.Tensor]]
|
23 |
+
FeatLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]
|
24 |
+
|
25 |
+
|
26 |
+
class AdversarialLoss(nn.Module):
|
27 |
+
"""Adversary training wrapper.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
adversary (nn.Module): The adversary module will be used to estimate the logits given the fake and real samples.
|
31 |
+
We assume here the adversary output is ``Tuple[List[torch.Tensor], List[List[torch.Tensor]]]``
|
32 |
+
where the first item is a list of logits and the second item is a list of feature maps.
|
33 |
+
optimizer (torch.optim.Optimizer): Optimizer used for training the given module.
|
34 |
+
loss (AdvLossType): Loss function for generator training.
|
35 |
+
loss_real (AdvLossType): Loss function for adversarial training on logits from real samples.
|
36 |
+
loss_fake (AdvLossType): Loss function for adversarial training on logits from fake samples.
|
37 |
+
loss_feat (FeatLossType): Feature matching loss function for generator training.
|
38 |
+
normalize (bool): Whether to normalize by number of sub-discriminators.
|
39 |
+
|
40 |
+
Example of usage:
|
41 |
+
adv_loss = AdversarialLoss(adversaries, optimizer, loss, loss_real, loss_fake)
|
42 |
+
for real in loader:
|
43 |
+
noise = torch.randn(...)
|
44 |
+
fake = model(noise)
|
45 |
+
adv_loss.train_adv(fake, real)
|
46 |
+
loss, _ = adv_loss(fake, real)
|
47 |
+
loss.backward()
|
48 |
+
"""
|
49 |
+
def __init__(self,
|
50 |
+
adversary: nn.Module,
|
51 |
+
optimizer: torch.optim.Optimizer,
|
52 |
+
loss: AdvLossType,
|
53 |
+
loss_real: AdvLossType,
|
54 |
+
loss_fake: AdvLossType,
|
55 |
+
loss_feat: tp.Optional[FeatLossType] = None,
|
56 |
+
normalize: bool = True):
|
57 |
+
super().__init__()
|
58 |
+
self.adversary: nn.Module = adversary
|
59 |
+
flashy.distrib.broadcast_model(self.adversary)
|
60 |
+
self.optimizer = optimizer
|
61 |
+
self.loss = loss
|
62 |
+
self.loss_real = loss_real
|
63 |
+
self.loss_fake = loss_fake
|
64 |
+
self.loss_feat = loss_feat
|
65 |
+
self.normalize = normalize
|
66 |
+
|
67 |
+
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
68 |
+
# Add the optimizer state dict inside our own.
|
69 |
+
super()._save_to_state_dict(destination, prefix, keep_vars)
|
70 |
+
destination[prefix + 'optimizer'] = self.optimizer.state_dict()
|
71 |
+
return destination
|
72 |
+
|
73 |
+
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
74 |
+
# Load optimizer state.
|
75 |
+
self.optimizer.load_state_dict(state_dict.pop(prefix + 'optimizer'))
|
76 |
+
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
77 |
+
|
78 |
+
def get_adversary_pred(self, x):
|
79 |
+
"""Run adversary model, validating expected output format."""
|
80 |
+
logits, fmaps = self.adversary(x)
|
81 |
+
assert isinstance(logits, list) and all([isinstance(t, torch.Tensor) for t in logits]), \
|
82 |
+
f'Expecting a list of tensors as logits but {type(logits)} found.'
|
83 |
+
assert isinstance(fmaps, list), f'Expecting a list of features maps but {type(fmaps)} found.'
|
84 |
+
for fmap in fmaps:
|
85 |
+
assert isinstance(fmap, list) and all([isinstance(f, torch.Tensor) for f in fmap]), \
|
86 |
+
f'Expecting a list of tensors as feature maps but {type(fmap)} found.'
|
87 |
+
return logits, fmaps
|
88 |
+
|
89 |
+
def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
|
90 |
+
"""Train the adversary with the given fake and real example.
|
91 |
+
|
92 |
+
We assume the adversary output is the following format: Tuple[List[torch.Tensor], List[List[torch.Tensor]]].
|
93 |
+
The first item being the logits and second item being a list of feature maps for each sub-discriminator.
|
94 |
+
|
95 |
+
This will automatically synchronize gradients (with `flashy.distrib.eager_sync_model`)
|
96 |
+
and call the optimizer.
|
97 |
+
"""
|
98 |
+
loss = torch.tensor(0., device=fake.device)
|
99 |
+
all_logits_fake_is_fake, _ = self.get_adversary_pred(fake.detach())
|
100 |
+
all_logits_real_is_fake, _ = self.get_adversary_pred(real.detach())
|
101 |
+
n_sub_adversaries = len(all_logits_fake_is_fake)
|
102 |
+
for logit_fake_is_fake, logit_real_is_fake in zip(all_logits_fake_is_fake, all_logits_real_is_fake):
|
103 |
+
loss += self.loss_fake(logit_fake_is_fake) + self.loss_real(logit_real_is_fake)
|
104 |
+
|
105 |
+
if self.normalize:
|
106 |
+
loss /= n_sub_adversaries
|
107 |
+
|
108 |
+
self.optimizer.zero_grad()
|
109 |
+
with flashy.distrib.eager_sync_model(self.adversary):
|
110 |
+
loss.backward()
|
111 |
+
self.optimizer.step()
|
112 |
+
|
113 |
+
return loss
|
114 |
+
|
115 |
+
def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
116 |
+
"""Return the loss for the generator, i.e. trying to fool the adversary,
|
117 |
+
and feature matching loss if provided.
|
118 |
+
"""
|
119 |
+
adv = torch.tensor(0., device=fake.device)
|
120 |
+
feat = torch.tensor(0., device=fake.device)
|
121 |
+
with flashy.utils.readonly(self.adversary):
|
122 |
+
all_logits_fake_is_fake, all_fmap_fake = self.get_adversary_pred(fake)
|
123 |
+
all_logits_real_is_fake, all_fmap_real = self.get_adversary_pred(real)
|
124 |
+
n_sub_adversaries = len(all_logits_fake_is_fake)
|
125 |
+
for logit_fake_is_fake in all_logits_fake_is_fake:
|
126 |
+
adv += self.loss(logit_fake_is_fake)
|
127 |
+
if self.loss_feat:
|
128 |
+
for fmap_fake, fmap_real in zip(all_fmap_fake, all_fmap_real):
|
129 |
+
feat += self.loss_feat(fmap_fake, fmap_real)
|
130 |
+
|
131 |
+
if self.normalize:
|
132 |
+
adv /= n_sub_adversaries
|
133 |
+
feat /= n_sub_adversaries
|
134 |
+
|
135 |
+
return adv, feat
|
136 |
+
|
137 |
+
|
138 |
+
def get_adv_criterion(loss_type: str) -> tp.Callable:
|
139 |
+
assert loss_type in ADVERSARIAL_LOSSES
|
140 |
+
if loss_type == 'mse':
|
141 |
+
return mse_loss
|
142 |
+
elif loss_type == 'hinge':
|
143 |
+
return hinge_loss
|
144 |
+
elif loss_type == 'hinge2':
|
145 |
+
return hinge2_loss
|
146 |
+
raise ValueError('Unsupported loss')
|
147 |
+
|
148 |
+
|
149 |
+
def get_fake_criterion(loss_type: str) -> tp.Callable:
|
150 |
+
assert loss_type in ADVERSARIAL_LOSSES
|
151 |
+
if loss_type == 'mse':
|
152 |
+
return mse_fake_loss
|
153 |
+
elif loss_type in ['hinge', 'hinge2']:
|
154 |
+
return hinge_fake_loss
|
155 |
+
raise ValueError('Unsupported loss')
|
156 |
+
|
157 |
+
|
158 |
+
def get_real_criterion(loss_type: str) -> tp.Callable:
|
159 |
+
assert loss_type in ADVERSARIAL_LOSSES
|
160 |
+
if loss_type == 'mse':
|
161 |
+
return mse_real_loss
|
162 |
+
elif loss_type in ['hinge', 'hinge2']:
|
163 |
+
return hinge_real_loss
|
164 |
+
raise ValueError('Unsupported loss')
|
165 |
+
|
166 |
+
|
167 |
+
def mse_real_loss(x: torch.Tensor) -> torch.Tensor:
|
168 |
+
return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
|
169 |
+
|
170 |
+
|
171 |
+
def mse_fake_loss(x: torch.Tensor) -> torch.Tensor:
|
172 |
+
return F.mse_loss(x, torch.tensor(0., device=x.device).expand_as(x))
|
173 |
+
|
174 |
+
|
175 |
+
def hinge_real_loss(x: torch.Tensor) -> torch.Tensor:
|
176 |
+
return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
|
177 |
+
|
178 |
+
|
179 |
+
def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor:
|
180 |
+
return -torch.mean(torch.min(-x - 1, torch.tensor(0., device=x.device).expand_as(x)))
|
181 |
+
|
182 |
+
|
183 |
+
def mse_loss(x: torch.Tensor) -> torch.Tensor:
|
184 |
+
if x.numel() == 0:
|
185 |
+
return torch.tensor([0.0], device=x.device)
|
186 |
+
return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
|
187 |
+
|
188 |
+
|
189 |
+
def hinge_loss(x: torch.Tensor) -> torch.Tensor:
|
190 |
+
if x.numel() == 0:
|
191 |
+
return torch.tensor([0.0], device=x.device)
|
192 |
+
return -x.mean()
|
193 |
+
|
194 |
+
|
195 |
+
def hinge2_loss(x: torch.Tensor) -> torch.Tensor:
|
196 |
+
if x.numel() == 0:
|
197 |
+
return torch.tensor([0.0])
|
198 |
+
return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
|
199 |
+
|
200 |
+
|
201 |
+
class FeatureMatchingLoss(nn.Module):
|
202 |
+
"""Feature matching loss for adversarial training.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
loss (nn.Module): Loss to use for feature matching (default=torch.nn.L1).
|
206 |
+
normalize (bool): Whether to normalize the loss.
|
207 |
+
by number of feature maps.
|
208 |
+
"""
|
209 |
+
def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: bool = True):
|
210 |
+
super().__init__()
|
211 |
+
self.loss = loss
|
212 |
+
self.normalize = normalize
|
213 |
+
|
214 |
+
def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List[torch.Tensor]) -> torch.Tensor:
|
215 |
+
assert len(fmap_fake) == len(fmap_real) and len(fmap_fake) > 0
|
216 |
+
feat_loss = torch.tensor(0., device=fmap_fake[0].device)
|
217 |
+
feat_scale = torch.tensor(0., device=fmap_fake[0].device)
|
218 |
+
n_fmaps = 0
|
219 |
+
for (feat_fake, feat_real) in zip(fmap_fake, fmap_real):
|
220 |
+
assert feat_fake.shape == feat_real.shape
|
221 |
+
n_fmaps += 1
|
222 |
+
feat_loss += self.loss(feat_fake, feat_real)
|
223 |
+
feat_scale += torch.mean(torch.abs(feat_real))
|
224 |
+
|
225 |
+
if self.normalize:
|
226 |
+
feat_loss /= n_fmaps
|
227 |
+
|
228 |
+
return feat_loss
|
audiocraft/data/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Audio loading and writing support. Datasets for raw audio
|
7 |
+
or also including some metadata."""
|
8 |
+
|
9 |
+
# flake8: noqa
|
10 |
+
from . import audio, audio_dataset, info_audio_dataset, music_dataset, sound_dataset
|
audiocraft/data/audio.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Audio IO methods are defined in this module (info, read, write),
|
9 |
+
We rely on av library for faster read when possible, otherwise on torchaudio.
|
10 |
+
"""
|
11 |
+
|
12 |
+
from dataclasses import dataclass
|
13 |
+
from pathlib import Path
|
14 |
+
import logging
|
15 |
+
import typing as tp
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import soundfile
|
19 |
+
import torch
|
20 |
+
from torch.nn import functional as F
|
21 |
+
import torchaudio as ta
|
22 |
+
|
23 |
+
import av
|
24 |
+
|
25 |
+
from .audio_utils import f32_pcm, i16_pcm, normalize_audio
|
26 |
+
|
27 |
+
|
28 |
+
_av_initialized = False
|
29 |
+
|
30 |
+
|
31 |
+
def _init_av():
|
32 |
+
global _av_initialized
|
33 |
+
if _av_initialized:
|
34 |
+
return
|
35 |
+
logger = logging.getLogger('libav.mp3')
|
36 |
+
logger.setLevel(logging.ERROR)
|
37 |
+
_av_initialized = True
|
38 |
+
|
39 |
+
|
40 |
+
@dataclass(frozen=True)
|
41 |
+
class AudioFileInfo:
|
42 |
+
sample_rate: int
|
43 |
+
duration: float
|
44 |
+
channels: int
|
45 |
+
|
46 |
+
|
47 |
+
def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
|
48 |
+
_init_av()
|
49 |
+
with av.open(str(filepath)) as af:
|
50 |
+
stream = af.streams.audio[0]
|
51 |
+
sample_rate = stream.codec_context.sample_rate
|
52 |
+
duration = float(stream.duration * stream.time_base)
|
53 |
+
channels = stream.channels
|
54 |
+
return AudioFileInfo(sample_rate, duration, channels)
|
55 |
+
|
56 |
+
|
57 |
+
def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
|
58 |
+
info = soundfile.info(filepath)
|
59 |
+
return AudioFileInfo(info.samplerate, info.duration, info.channels)
|
60 |
+
|
61 |
+
|
62 |
+
def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
|
63 |
+
# torchaudio no longer returns useful duration informations for some formats like mp3s.
|
64 |
+
filepath = Path(filepath)
|
65 |
+
if filepath.suffix in ['.flac', '.ogg']: # TODO: Validate .ogg can be safely read with av_info
|
66 |
+
# ffmpeg has some weird issue with flac.
|
67 |
+
return _soundfile_info(filepath)
|
68 |
+
else:
|
69 |
+
return _av_info(filepath)
|
70 |
+
|
71 |
+
|
72 |
+
def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]:
|
73 |
+
"""FFMPEG-based audio file reading using PyAV bindings.
|
74 |
+
Soundfile cannot read mp3 and av_read is more efficient than torchaudio.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
filepath (str or Path): Path to audio file to read.
|
78 |
+
seek_time (float): Time at which to start reading in the file.
|
79 |
+
duration (float): Duration to read from the file. If set to -1, the whole file is read.
|
80 |
+
Returns:
|
81 |
+
tuple of torch.Tensor, int: Tuple containing audio data and sample rate
|
82 |
+
"""
|
83 |
+
_init_av()
|
84 |
+
with av.open(str(filepath)) as af:
|
85 |
+
stream = af.streams.audio[0]
|
86 |
+
sr = stream.codec_context.sample_rate
|
87 |
+
num_frames = int(sr * duration) if duration >= 0 else -1
|
88 |
+
frame_offset = int(sr * seek_time)
|
89 |
+
# we need a small negative offset otherwise we get some edge artifact
|
90 |
+
# from the mp3 decoder.
|
91 |
+
af.seek(int(max(0, (seek_time - 0.1)) / stream.time_base), stream=stream)
|
92 |
+
frames = []
|
93 |
+
length = 0
|
94 |
+
for frame in af.decode(streams=stream.index):
|
95 |
+
current_offset = int(frame.rate * frame.pts * frame.time_base)
|
96 |
+
strip = max(0, frame_offset - current_offset)
|
97 |
+
buf = torch.from_numpy(frame.to_ndarray())
|
98 |
+
if buf.shape[0] != stream.channels:
|
99 |
+
buf = buf.view(-1, stream.channels).t()
|
100 |
+
buf = buf[:, strip:]
|
101 |
+
frames.append(buf)
|
102 |
+
length += buf.shape[1]
|
103 |
+
if num_frames > 0 and length >= num_frames:
|
104 |
+
break
|
105 |
+
assert frames
|
106 |
+
# If the above assert fails, it is likely because we seeked past the end of file point,
|
107 |
+
# in which case ffmpeg returns a single frame with only zeros, and a weird timestamp.
|
108 |
+
# This will need proper debugging, in due time.
|
109 |
+
wav = torch.cat(frames, dim=1)
|
110 |
+
assert wav.shape[0] == stream.channels
|
111 |
+
if num_frames > 0:
|
112 |
+
wav = wav[:, :num_frames]
|
113 |
+
return f32_pcm(wav), sr
|
114 |
+
|
115 |
+
|
116 |
+
def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
|
117 |
+
duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
|
118 |
+
"""Read audio by picking the most appropriate backend tool based on the audio format.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
filepath (str or Path): Path to audio file to read.
|
122 |
+
seek_time (float): Time at which to start reading in the file.
|
123 |
+
duration (float): Duration to read from the file. If set to -1, the whole file is read.
|
124 |
+
pad (bool): Pad output audio if not reaching expected duration.
|
125 |
+
Returns:
|
126 |
+
tuple of torch.Tensor, int: Tuple containing audio data and sample rate.
|
127 |
+
"""
|
128 |
+
fp = Path(filepath)
|
129 |
+
if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
|
130 |
+
# There is some bug with ffmpeg and reading flac
|
131 |
+
info = _soundfile_info(filepath)
|
132 |
+
frames = -1 if duration <= 0 else int(duration * info.sample_rate)
|
133 |
+
frame_offset = int(seek_time * info.sample_rate)
|
134 |
+
wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32)
|
135 |
+
assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}"
|
136 |
+
wav = torch.from_numpy(wav).t().contiguous()
|
137 |
+
if len(wav.shape) == 1:
|
138 |
+
wav = torch.unsqueeze(wav, 0)
|
139 |
+
elif (
|
140 |
+
fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats()
|
141 |
+
and duration <= 0 and seek_time == 0
|
142 |
+
):
|
143 |
+
# Torchaudio is faster if we load an entire file at once.
|
144 |
+
wav, sr = ta.load(fp)
|
145 |
+
else:
|
146 |
+
wav, sr = _av_read(filepath, seek_time, duration)
|
147 |
+
if pad and duration > 0:
|
148 |
+
expected_frames = int(duration * sr)
|
149 |
+
wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
|
150 |
+
return wav, sr
|
151 |
+
|
152 |
+
|
153 |
+
def audio_write(stem_name: tp.Union[str, Path],
|
154 |
+
wav: torch.Tensor, sample_rate: int,
|
155 |
+
format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
|
156 |
+
strategy: str = 'peak', peak_clip_headroom_db: float = 1,
|
157 |
+
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
|
158 |
+
loudness_compressor: bool = False,
|
159 |
+
log_clipping: bool = True, make_parent_dir: bool = True,
|
160 |
+
add_suffix: bool = True) -> Path:
|
161 |
+
"""Convenience function for saving audio to disk. Returns the filename the audio was written to.
|
162 |
+
|
163 |
+
Args:
|
164 |
+
stem_name (str or Path): Filename without extension which will be added automatically.
|
165 |
+
format (str): Either "wav" or "mp3".
|
166 |
+
mp3_rate (int): kbps when using mp3s.
|
167 |
+
normalize (bool): if `True` (default), normalizes according to the prescribed
|
168 |
+
strategy (see after). If `False`, the strategy is only used in case clipping
|
169 |
+
would happen.
|
170 |
+
strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
|
171 |
+
i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
|
172 |
+
with extra headroom to avoid clipping. 'clip' just clips.
|
173 |
+
peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
|
174 |
+
rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
|
175 |
+
than the `peak_clip` one to avoid further clipping.
|
176 |
+
loudness_headroom_db (float): Target loudness for loudness normalization.
|
177 |
+
loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
|
178 |
+
when strategy is 'loudness' log_clipping (bool): If True, basic logging on stderr when clipping still
|
179 |
+
occurs despite strategy (only for 'rms').
|
180 |
+
make_parent_dir (bool): Make parent directory if it doesn't exist.
|
181 |
+
Returns:
|
182 |
+
Path: Path of the saved audio.
|
183 |
+
"""
|
184 |
+
assert wav.dtype.is_floating_point, "wav is not floating point"
|
185 |
+
if wav.dim() == 1:
|
186 |
+
wav = wav[None]
|
187 |
+
elif wav.dim() > 2:
|
188 |
+
raise ValueError("Input wav should be at most 2 dimension.")
|
189 |
+
assert wav.isfinite().all()
|
190 |
+
wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
|
191 |
+
rms_headroom_db, loudness_headroom_db, loudness_compressor,
|
192 |
+
log_clipping=log_clipping, sample_rate=sample_rate,
|
193 |
+
stem_name=str(stem_name))
|
194 |
+
kwargs: dict = {}
|
195 |
+
if format == 'mp3':
|
196 |
+
suffix = '.mp3'
|
197 |
+
kwargs.update({"compression": mp3_rate})
|
198 |
+
elif format == 'wav':
|
199 |
+
wav = i16_pcm(wav)
|
200 |
+
suffix = '.wav'
|
201 |
+
kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16})
|
202 |
+
else:
|
203 |
+
raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
|
204 |
+
if not add_suffix:
|
205 |
+
suffix = ''
|
206 |
+
path = Path(str(stem_name) + suffix)
|
207 |
+
if make_parent_dir:
|
208 |
+
path.parent.mkdir(exist_ok=True, parents=True)
|
209 |
+
try:
|
210 |
+
ta.save(path, wav, sample_rate, **kwargs)
|
211 |
+
except Exception:
|
212 |
+
if path.exists():
|
213 |
+
# we do not want to leave half written files around.
|
214 |
+
path.unlink()
|
215 |
+
raise
|
216 |
+
return path
|
audiocraft/data/audio_dataset.py
ADDED
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""AudioDataset support. In order to handle a larger number of files
|
7 |
+
without having to scan again the folders, we precompute some metadata
|
8 |
+
(filename, sample rate, duration), and use that to efficiently sample audio segments.
|
9 |
+
"""
|
10 |
+
import argparse
|
11 |
+
import copy
|
12 |
+
from concurrent.futures import ThreadPoolExecutor, Future
|
13 |
+
from dataclasses import dataclass, fields
|
14 |
+
from contextlib import ExitStack
|
15 |
+
from functools import lru_cache
|
16 |
+
import gzip
|
17 |
+
import json
|
18 |
+
import logging
|
19 |
+
import os
|
20 |
+
from pathlib import Path
|
21 |
+
import random
|
22 |
+
import sys
|
23 |
+
import typing as tp
|
24 |
+
|
25 |
+
import torch
|
26 |
+
import torch.nn.functional as F
|
27 |
+
|
28 |
+
from .audio import audio_read, audio_info
|
29 |
+
from .audio_utils import convert_audio
|
30 |
+
from .zip import PathInZip
|
31 |
+
|
32 |
+
try:
|
33 |
+
import dora
|
34 |
+
except ImportError:
|
35 |
+
dora = None # type: ignore
|
36 |
+
|
37 |
+
|
38 |
+
@dataclass(order=True)
|
39 |
+
class BaseInfo:
|
40 |
+
|
41 |
+
@classmethod
|
42 |
+
def _dict2fields(cls, dictionary: dict):
|
43 |
+
return {
|
44 |
+
field.name: dictionary[field.name]
|
45 |
+
for field in fields(cls) if field.name in dictionary
|
46 |
+
}
|
47 |
+
|
48 |
+
@classmethod
|
49 |
+
def from_dict(cls, dictionary: dict):
|
50 |
+
_dictionary = cls._dict2fields(dictionary)
|
51 |
+
return cls(**_dictionary)
|
52 |
+
|
53 |
+
def to_dict(self):
|
54 |
+
return {
|
55 |
+
field.name: self.__getattribute__(field.name)
|
56 |
+
for field in fields(self)
|
57 |
+
}
|
58 |
+
|
59 |
+
|
60 |
+
@dataclass(order=True)
|
61 |
+
class AudioMeta(BaseInfo):
|
62 |
+
path: str
|
63 |
+
duration: float
|
64 |
+
sample_rate: int
|
65 |
+
amplitude: tp.Optional[float] = None
|
66 |
+
weight: tp.Optional[float] = None
|
67 |
+
# info_path is used to load additional information about the audio file that is stored in zip files.
|
68 |
+
info_path: tp.Optional[PathInZip] = None
|
69 |
+
|
70 |
+
@classmethod
|
71 |
+
def from_dict(cls, dictionary: dict):
|
72 |
+
base = cls._dict2fields(dictionary)
|
73 |
+
if 'info_path' in base and base['info_path'] is not None:
|
74 |
+
base['info_path'] = PathInZip(base['info_path'])
|
75 |
+
return cls(**base)
|
76 |
+
|
77 |
+
def to_dict(self):
|
78 |
+
d = super().to_dict()
|
79 |
+
if d['info_path'] is not None:
|
80 |
+
d['info_path'] = str(d['info_path'])
|
81 |
+
return d
|
82 |
+
|
83 |
+
|
84 |
+
@dataclass(order=True)
|
85 |
+
class SegmentInfo(BaseInfo):
|
86 |
+
meta: AudioMeta
|
87 |
+
seek_time: float
|
88 |
+
# The following values are given once the audio is processed, e.g.
|
89 |
+
# at the target sample rate and target number of channels.
|
90 |
+
n_frames: int # actual number of frames without padding
|
91 |
+
total_frames: int # total number of frames, padding included
|
92 |
+
sample_rate: int # actual sample rate
|
93 |
+
channels: int # number of audio channels.
|
94 |
+
|
95 |
+
|
96 |
+
DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
|
97 |
+
|
98 |
+
logger = logging.getLogger(__name__)
|
99 |
+
|
100 |
+
|
101 |
+
def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta:
|
102 |
+
"""AudioMeta from a path to an audio file.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
file_path (str): Resolved path of valid audio file.
|
106 |
+
minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
|
107 |
+
Returns:
|
108 |
+
AudioMeta: Audio file path and its metadata.
|
109 |
+
"""
|
110 |
+
info = audio_info(file_path)
|
111 |
+
amplitude: tp.Optional[float] = None
|
112 |
+
if not minimal:
|
113 |
+
wav, sr = audio_read(file_path)
|
114 |
+
amplitude = wav.abs().max().item()
|
115 |
+
return AudioMeta(file_path, info.duration, info.sample_rate, amplitude)
|
116 |
+
|
117 |
+
|
118 |
+
def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta:
|
119 |
+
"""If Dora is available as a dependency, try to resolve potential relative paths
|
120 |
+
in list of AudioMeta. This method is expected to be used when loading meta from file.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
m (AudioMeta): Audio meta to resolve.
|
124 |
+
fast (bool): If True, uses a really fast check for determining if a file
|
125 |
+
is already absolute or not. Only valid on Linux/Mac.
|
126 |
+
Returns:
|
127 |
+
AudioMeta: Audio meta with resolved path.
|
128 |
+
"""
|
129 |
+
def is_abs(m):
|
130 |
+
if fast:
|
131 |
+
return str(m)[0] == '/'
|
132 |
+
else:
|
133 |
+
os.path.isabs(str(m))
|
134 |
+
|
135 |
+
if not dora:
|
136 |
+
return m
|
137 |
+
|
138 |
+
if not is_abs(m.path):
|
139 |
+
m.path = dora.git_save.to_absolute_path(m.path)
|
140 |
+
if m.info_path is not None and not is_abs(m.info_path.zip_path):
|
141 |
+
m.info_path.zip_path = dora.git_save.to_absolute_path(m.path)
|
142 |
+
return m
|
143 |
+
|
144 |
+
|
145 |
+
def find_audio_files(path: tp.Union[Path, str],
|
146 |
+
exts: tp.List[str] = DEFAULT_EXTS,
|
147 |
+
resolve: bool = True,
|
148 |
+
minimal: bool = True,
|
149 |
+
progress: bool = False,
|
150 |
+
workers: int = 0) -> tp.List[AudioMeta]:
|
151 |
+
"""Build a list of AudioMeta from a given path,
|
152 |
+
collecting relevant audio files and fetching meta info.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
path (str or Path): Path to folder containing audio files.
|
156 |
+
exts (list of str): List of file extensions to consider for audio files.
|
157 |
+
minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
|
158 |
+
progress (bool): Whether to log progress on audio files collection.
|
159 |
+
workers (int): number of parallel workers, if 0, use only the current thread.
|
160 |
+
Returns:
|
161 |
+
list of AudioMeta: List of audio file path and its metadata.
|
162 |
+
"""
|
163 |
+
audio_files = []
|
164 |
+
futures: tp.List[Future] = []
|
165 |
+
pool: tp.Optional[ThreadPoolExecutor] = None
|
166 |
+
with ExitStack() as stack:
|
167 |
+
if workers > 0:
|
168 |
+
pool = ThreadPoolExecutor(workers)
|
169 |
+
stack.enter_context(pool)
|
170 |
+
|
171 |
+
if progress:
|
172 |
+
print("Finding audio files...")
|
173 |
+
for root, folders, files in os.walk(path, followlinks=True):
|
174 |
+
for file in files:
|
175 |
+
full_path = Path(root) / file
|
176 |
+
if full_path.suffix.lower() in exts:
|
177 |
+
audio_files.append(full_path)
|
178 |
+
if pool is not None:
|
179 |
+
futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal))
|
180 |
+
if progress:
|
181 |
+
print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr)
|
182 |
+
|
183 |
+
if progress:
|
184 |
+
print("Getting audio metadata...")
|
185 |
+
meta: tp.List[AudioMeta] = []
|
186 |
+
for idx, file_path in enumerate(audio_files):
|
187 |
+
try:
|
188 |
+
if pool is None:
|
189 |
+
m = _get_audio_meta(str(file_path), minimal)
|
190 |
+
else:
|
191 |
+
m = futures[idx].result()
|
192 |
+
if resolve:
|
193 |
+
m = _resolve_audio_meta(m)
|
194 |
+
except Exception as err:
|
195 |
+
print("Error with", str(file_path), err, file=sys.stderr)
|
196 |
+
continue
|
197 |
+
meta.append(m)
|
198 |
+
if progress:
|
199 |
+
print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr)
|
200 |
+
meta.sort()
|
201 |
+
return meta
|
202 |
+
|
203 |
+
|
204 |
+
def load_audio_meta(path: tp.Union[str, Path],
|
205 |
+
resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]:
|
206 |
+
"""Load list of AudioMeta from an optionally compressed json file.
|
207 |
+
|
208 |
+
Args:
|
209 |
+
path (str or Path): Path to JSON file.
|
210 |
+
resolve (bool): Whether to resolve the path from AudioMeta (default=True).
|
211 |
+
fast (bool): activates some tricks to make things faster.
|
212 |
+
Returns:
|
213 |
+
list of AudioMeta: List of audio file path and its total duration.
|
214 |
+
"""
|
215 |
+
open_fn = gzip.open if str(path).lower().endswith('.gz') else open
|
216 |
+
with open_fn(path, 'rb') as fp: # type: ignore
|
217 |
+
lines = fp.readlines()
|
218 |
+
meta = []
|
219 |
+
for line in lines:
|
220 |
+
d = json.loads(line)
|
221 |
+
m = AudioMeta.from_dict(d)
|
222 |
+
if resolve:
|
223 |
+
m = _resolve_audio_meta(m, fast=fast)
|
224 |
+
meta.append(m)
|
225 |
+
return meta
|
226 |
+
|
227 |
+
|
228 |
+
def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]):
|
229 |
+
"""Save the audio metadata to the file pointer as json.
|
230 |
+
|
231 |
+
Args:
|
232 |
+
path (str or Path): Path to JSON file.
|
233 |
+
metadata (list of BaseAudioMeta): List of audio meta to save.
|
234 |
+
"""
|
235 |
+
Path(path).parent.mkdir(exist_ok=True, parents=True)
|
236 |
+
open_fn = gzip.open if str(path).lower().endswith('.gz') else open
|
237 |
+
with open_fn(path, 'wb') as fp: # type: ignore
|
238 |
+
for m in meta:
|
239 |
+
json_str = json.dumps(m.to_dict()) + '\n'
|
240 |
+
json_bytes = json_str.encode('utf-8')
|
241 |
+
fp.write(json_bytes)
|
242 |
+
|
243 |
+
|
244 |
+
class AudioDataset:
|
245 |
+
"""Base audio dataset.
|
246 |
+
|
247 |
+
The dataset takes a list of AudioMeta and create a dataset composed of segments of audio
|
248 |
+
and potentially additional information, by creating random segments from the list of audio
|
249 |
+
files referenced in the metadata and applying minimal data pre-processing such as resampling,
|
250 |
+
mixing of channels, padding, etc.
|
251 |
+
|
252 |
+
If no segment_duration value is provided, the AudioDataset will return the full wav for each
|
253 |
+
audio file. Otherwise, it will randomly sample audio files and create a segment of the specified
|
254 |
+
duration, applying padding if required.
|
255 |
+
|
256 |
+
By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True
|
257 |
+
allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
|
258 |
+
original audio meta.
|
259 |
+
|
260 |
+
Note that you can call `start_epoch(epoch)` in order to get
|
261 |
+
a deterministic "randomization" for `shuffle=True`.
|
262 |
+
For a given epoch and dataset index, this will always return the same extract.
|
263 |
+
You can get back some diversity by setting the `shuffle_seed` param.
|
264 |
+
|
265 |
+
Args:
|
266 |
+
meta (list of AudioMeta): List of audio files metadata.
|
267 |
+
segment_duration (float, optional): Optional segment duration of audio to load.
|
268 |
+
If not specified, the dataset will load the full audio segment from the file.
|
269 |
+
shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
|
270 |
+
sample_rate (int): Target sample rate of the loaded audio samples.
|
271 |
+
channels (int): Target number of channels of the loaded audio samples.
|
272 |
+
sample_on_duration (bool): Set to `True` to sample segments with probability
|
273 |
+
dependent on audio file duration. This is only used if `segment_duration` is provided.
|
274 |
+
sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of
|
275 |
+
`AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product
|
276 |
+
of the file duration and file weight. This is only used if `segment_duration` is provided.
|
277 |
+
min_segment_ratio (float): Minimum segment ratio to use when the audio file
|
278 |
+
is shorter than the desired segment.
|
279 |
+
max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
|
280 |
+
return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
|
281 |
+
min_audio_duration (float, optional): Minimum audio file duration, in seconds, if provided
|
282 |
+
audio shorter than this will be filtered out.
|
283 |
+
max_audio_duration (float, optional): Maximal audio file duration in seconds, if provided
|
284 |
+
audio longer than this will be filtered out.
|
285 |
+
shuffle_seed (int): can be used to further randomize
|
286 |
+
load_wav (bool): if False, skip loading the wav but returns a tensor of 0
|
287 |
+
with the expected segment_duration (which must be provided if load_wav is False).
|
288 |
+
permutation_on_files (bool): only if `sample_on_weight` and `sample_on_duration`
|
289 |
+
are False. Will ensure a permutation on files when going through the dataset.
|
290 |
+
In that case the epoch number must be provided in order for the model
|
291 |
+
to continue the permutation across epochs. In that case, it is assumed
|
292 |
+
that `num_samples = total_batch_size * num_updates_per_epoch`, with
|
293 |
+
`total_batch_size` the overall batch size accounting for all gpus.
|
294 |
+
"""
|
295 |
+
def __init__(self,
|
296 |
+
meta: tp.List[AudioMeta],
|
297 |
+
segment_duration: tp.Optional[float] = None,
|
298 |
+
shuffle: bool = True,
|
299 |
+
num_samples: int = 10_000,
|
300 |
+
sample_rate: int = 48_000,
|
301 |
+
channels: int = 2,
|
302 |
+
pad: bool = True,
|
303 |
+
sample_on_duration: bool = True,
|
304 |
+
sample_on_weight: bool = True,
|
305 |
+
min_segment_ratio: float = 0.5,
|
306 |
+
max_read_retry: int = 10,
|
307 |
+
return_info: bool = False,
|
308 |
+
min_audio_duration: tp.Optional[float] = None,
|
309 |
+
max_audio_duration: tp.Optional[float] = None,
|
310 |
+
shuffle_seed: int = 0,
|
311 |
+
load_wav: bool = True,
|
312 |
+
permutation_on_files: bool = False,
|
313 |
+
):
|
314 |
+
assert len(meta) > 0, "No audio meta provided to AudioDataset. Please check loading of audio meta."
|
315 |
+
assert segment_duration is None or segment_duration > 0
|
316 |
+
assert segment_duration is None or min_segment_ratio >= 0
|
317 |
+
self.segment_duration = segment_duration
|
318 |
+
self.min_segment_ratio = min_segment_ratio
|
319 |
+
self.max_audio_duration = max_audio_duration
|
320 |
+
self.min_audio_duration = min_audio_duration
|
321 |
+
if self.min_audio_duration is not None and self.max_audio_duration is not None:
|
322 |
+
assert self.min_audio_duration <= self.max_audio_duration
|
323 |
+
self.meta: tp.List[AudioMeta] = self._filter_duration(meta)
|
324 |
+
assert len(self.meta) # Fail fast if all data has been filtered.
|
325 |
+
self.total_duration = sum(d.duration for d in self.meta)
|
326 |
+
|
327 |
+
if segment_duration is None:
|
328 |
+
num_samples = len(self.meta)
|
329 |
+
self.num_samples = num_samples
|
330 |
+
self.shuffle = shuffle
|
331 |
+
self.sample_rate = sample_rate
|
332 |
+
self.channels = channels
|
333 |
+
self.pad = pad
|
334 |
+
self.sample_on_weight = sample_on_weight
|
335 |
+
self.sample_on_duration = sample_on_duration
|
336 |
+
self.sampling_probabilities = self._get_sampling_probabilities()
|
337 |
+
self.max_read_retry = max_read_retry
|
338 |
+
self.return_info = return_info
|
339 |
+
self.shuffle_seed = shuffle_seed
|
340 |
+
self.current_epoch: tp.Optional[int] = None
|
341 |
+
self.load_wav = load_wav
|
342 |
+
if not load_wav:
|
343 |
+
assert segment_duration is not None
|
344 |
+
self.permutation_on_files = permutation_on_files
|
345 |
+
if permutation_on_files:
|
346 |
+
assert not self.sample_on_duration
|
347 |
+
assert not self.sample_on_weight
|
348 |
+
assert self.shuffle
|
349 |
+
|
350 |
+
def start_epoch(self, epoch: int):
|
351 |
+
self.current_epoch = epoch
|
352 |
+
|
353 |
+
def __len__(self):
|
354 |
+
return self.num_samples
|
355 |
+
|
356 |
+
def _get_sampling_probabilities(self, normalized: bool = True):
|
357 |
+
"""Return the sampling probabilities for each file inside `self.meta`."""
|
358 |
+
scores: tp.List[float] = []
|
359 |
+
for file_meta in self.meta:
|
360 |
+
score = 1.
|
361 |
+
if self.sample_on_weight and file_meta.weight is not None:
|
362 |
+
score *= file_meta.weight
|
363 |
+
if self.sample_on_duration:
|
364 |
+
score *= file_meta.duration
|
365 |
+
scores.append(score)
|
366 |
+
probabilities = torch.tensor(scores)
|
367 |
+
if normalized:
|
368 |
+
probabilities /= probabilities.sum()
|
369 |
+
return probabilities
|
370 |
+
|
371 |
+
@staticmethod
|
372 |
+
@lru_cache(16)
|
373 |
+
def _get_file_permutation(num_files: int, permutation_index: int, base_seed: int):
|
374 |
+
# Used to keep the most recent files permutation in memory implicitely.
|
375 |
+
# will work unless someone is using a lot of Datasets in parallel.
|
376 |
+
rng = torch.Generator()
|
377 |
+
rng.manual_seed(base_seed + permutation_index)
|
378 |
+
return torch.randperm(num_files, generator=rng)
|
379 |
+
|
380 |
+
def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta:
|
381 |
+
"""Sample a given file from `self.meta`. Can be overridden in subclasses.
|
382 |
+
This is only called if `segment_duration` is not None.
|
383 |
+
|
384 |
+
You must use the provided random number generator `rng` for reproducibility.
|
385 |
+
You can further make use of the index accessed.
|
386 |
+
"""
|
387 |
+
if self.permutation_on_files:
|
388 |
+
assert self.current_epoch is not None
|
389 |
+
total_index = self.current_epoch * len(self) + index
|
390 |
+
permutation_index = total_index // len(self.meta)
|
391 |
+
relative_index = total_index % len(self.meta)
|
392 |
+
permutation = AudioDataset._get_file_permutation(
|
393 |
+
len(self.meta), permutation_index, self.shuffle_seed)
|
394 |
+
file_index = permutation[relative_index]
|
395 |
+
return self.meta[file_index]
|
396 |
+
|
397 |
+
if not self.sample_on_weight and not self.sample_on_duration:
|
398 |
+
file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
|
399 |
+
else:
|
400 |
+
file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item())
|
401 |
+
|
402 |
+
return self.meta[file_index]
|
403 |
+
|
404 |
+
def _audio_read(self, path: str, seek_time: float = 0, duration: float = -1):
|
405 |
+
# Override this method in subclass if needed.
|
406 |
+
if self.load_wav:
|
407 |
+
return audio_read(path, seek_time, duration, pad=False)
|
408 |
+
else:
|
409 |
+
assert self.segment_duration is not None
|
410 |
+
n_frames = int(self.sample_rate * self.segment_duration)
|
411 |
+
return torch.zeros(self.channels, n_frames), self.sample_rate
|
412 |
+
|
413 |
+
def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
|
414 |
+
if self.segment_duration is None:
|
415 |
+
file_meta = self.meta[index]
|
416 |
+
out, sr = audio_read(file_meta.path)
|
417 |
+
out = convert_audio(out, sr, self.sample_rate, self.channels)
|
418 |
+
n_frames = out.shape[-1]
|
419 |
+
segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
|
420 |
+
sample_rate=self.sample_rate, channels=out.shape[0])
|
421 |
+
else:
|
422 |
+
rng = torch.Generator()
|
423 |
+
if self.shuffle:
|
424 |
+
# We use index, plus extra randomness, either totally random if we don't know the epoch.
|
425 |
+
# otherwise we make use of the epoch number and optional shuffle_seed.
|
426 |
+
if self.current_epoch is None:
|
427 |
+
rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
|
428 |
+
else:
|
429 |
+
rng.manual_seed(index + self.num_samples * (self.current_epoch + self.shuffle_seed))
|
430 |
+
else:
|
431 |
+
# We only use index
|
432 |
+
rng.manual_seed(index)
|
433 |
+
|
434 |
+
for retry in range(self.max_read_retry):
|
435 |
+
file_meta = self.sample_file(index, rng)
|
436 |
+
# We add some variance in the file position even if audio file is smaller than segment
|
437 |
+
# without ending up with empty segments
|
438 |
+
max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
|
439 |
+
seek_time = torch.rand(1, generator=rng).item() * max_seek
|
440 |
+
try:
|
441 |
+
out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False)
|
442 |
+
out = convert_audio(out, sr, self.sample_rate, self.channels)
|
443 |
+
n_frames = out.shape[-1]
|
444 |
+
target_frames = int(self.segment_duration * self.sample_rate)
|
445 |
+
if self.pad:
|
446 |
+
out = F.pad(out, (0, target_frames - n_frames))
|
447 |
+
segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
|
448 |
+
sample_rate=self.sample_rate, channels=out.shape[0])
|
449 |
+
except Exception as exc:
|
450 |
+
logger.warning("Error opening file %s: %r", file_meta.path, exc)
|
451 |
+
if retry == self.max_read_retry - 1:
|
452 |
+
raise
|
453 |
+
else:
|
454 |
+
break
|
455 |
+
|
456 |
+
if self.return_info:
|
457 |
+
# Returns the wav and additional information on the wave segment
|
458 |
+
return out, segment_info
|
459 |
+
else:
|
460 |
+
return out
|
461 |
+
|
462 |
+
def collater(self, samples):
|
463 |
+
"""The collater function has to be provided to the dataloader
|
464 |
+
if AudioDataset has return_info=True in order to properly collate
|
465 |
+
the samples of a batch.
|
466 |
+
"""
|
467 |
+
if self.segment_duration is None and len(samples) > 1:
|
468 |
+
assert self.pad, "Must allow padding when batching examples of different durations."
|
469 |
+
|
470 |
+
# In this case the audio reaching the collater is of variable length as segment_duration=None.
|
471 |
+
to_pad = self.segment_duration is None and self.pad
|
472 |
+
if to_pad:
|
473 |
+
max_len = max([wav.shape[-1] for wav, _ in samples])
|
474 |
+
|
475 |
+
def _pad_wav(wav):
|
476 |
+
return F.pad(wav, (0, max_len - wav.shape[-1]))
|
477 |
+
|
478 |
+
if self.return_info:
|
479 |
+
if len(samples) > 0:
|
480 |
+
assert len(samples[0]) == 2
|
481 |
+
assert isinstance(samples[0][0], torch.Tensor)
|
482 |
+
assert isinstance(samples[0][1], SegmentInfo)
|
483 |
+
|
484 |
+
wavs = [wav for wav, _ in samples]
|
485 |
+
segment_infos = [copy.deepcopy(info) for _, info in samples]
|
486 |
+
|
487 |
+
if to_pad:
|
488 |
+
# Each wav could be of a different duration as they are not segmented.
|
489 |
+
for i in range(len(samples)):
|
490 |
+
# Determines the total length of the signal with padding, so we update here as we pad.
|
491 |
+
segment_infos[i].total_frames = max_len
|
492 |
+
wavs[i] = _pad_wav(wavs[i])
|
493 |
+
|
494 |
+
wav = torch.stack(wavs)
|
495 |
+
return wav, segment_infos
|
496 |
+
else:
|
497 |
+
assert isinstance(samples[0], torch.Tensor)
|
498 |
+
if to_pad:
|
499 |
+
samples = [_pad_wav(s) for s in samples]
|
500 |
+
return torch.stack(samples)
|
501 |
+
|
502 |
+
def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
|
503 |
+
"""Filters out audio files with audio durations that will not allow to sample examples from them."""
|
504 |
+
orig_len = len(meta)
|
505 |
+
|
506 |
+
# Filter data that is too short.
|
507 |
+
if self.min_audio_duration is not None:
|
508 |
+
meta = [m for m in meta if m.duration >= self.min_audio_duration]
|
509 |
+
|
510 |
+
# Filter data that is too long.
|
511 |
+
if self.max_audio_duration is not None:
|
512 |
+
meta = [m for m in meta if m.duration <= self.max_audio_duration]
|
513 |
+
|
514 |
+
filtered_len = len(meta)
|
515 |
+
removed_percentage = 100*(1-float(filtered_len)/orig_len)
|
516 |
+
msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage
|
517 |
+
if removed_percentage < 10:
|
518 |
+
logging.debug(msg)
|
519 |
+
else:
|
520 |
+
logging.warning(msg)
|
521 |
+
return meta
|
522 |
+
|
523 |
+
@classmethod
|
524 |
+
def from_meta(cls, root: tp.Union[str, Path], **kwargs):
|
525 |
+
"""Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.
|
526 |
+
|
527 |
+
Args:
|
528 |
+
root (str or Path): Path to root folder containing audio files.
|
529 |
+
kwargs: Additional keyword arguments for the AudioDataset.
|
530 |
+
"""
|
531 |
+
root = Path(root)
|
532 |
+
if root.is_dir():
|
533 |
+
if (root / 'data.jsonl').exists():
|
534 |
+
root = root / 'data.jsonl'
|
535 |
+
elif (root / 'data.jsonl.gz').exists():
|
536 |
+
root = root / 'data.jsonl.gz'
|
537 |
+
else:
|
538 |
+
raise ValueError("Don't know where to read metadata from in the dir. "
|
539 |
+
"Expecting either a data.jsonl or data.jsonl.gz file but none found.")
|
540 |
+
meta = load_audio_meta(root)
|
541 |
+
return cls(meta, **kwargs)
|
542 |
+
|
543 |
+
@classmethod
|
544 |
+
def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True,
|
545 |
+
exts: tp.List[str] = DEFAULT_EXTS, **kwargs):
|
546 |
+
"""Instantiate AudioDataset from a path containing (possibly nested) audio files.
|
547 |
+
|
548 |
+
Args:
|
549 |
+
root (str or Path): Path to root folder containing audio files.
|
550 |
+
minimal_meta (bool): Whether to only load minimal metadata or not.
|
551 |
+
exts (list of str): Extensions for audio files.
|
552 |
+
kwargs: Additional keyword arguments for the AudioDataset.
|
553 |
+
"""
|
554 |
+
root = Path(root)
|
555 |
+
if root.is_file():
|
556 |
+
meta = load_audio_meta(root, resolve=True)
|
557 |
+
else:
|
558 |
+
meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True)
|
559 |
+
return cls(meta, **kwargs)
|
560 |
+
|
561 |
+
|
562 |
+
def main():
|
563 |
+
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
|
564 |
+
parser = argparse.ArgumentParser(
|
565 |
+
prog='audio_dataset',
|
566 |
+
description='Generate .jsonl files by scanning a folder.')
|
567 |
+
parser.add_argument('root', help='Root folder with all the audio files')
|
568 |
+
parser.add_argument('output_meta_file',
|
569 |
+
help='Output file to store the metadata, ')
|
570 |
+
parser.add_argument('--complete',
|
571 |
+
action='store_false', dest='minimal', default=True,
|
572 |
+
help='Retrieve all metadata, even the one that are expansive '
|
573 |
+
'to compute (e.g. normalization).')
|
574 |
+
parser.add_argument('--resolve',
|
575 |
+
action='store_true', default=False,
|
576 |
+
help='Resolve the paths to be absolute and with no symlinks.')
|
577 |
+
parser.add_argument('--workers',
|
578 |
+
default=10, type=int,
|
579 |
+
help='Number of workers.')
|
580 |
+
args = parser.parse_args()
|
581 |
+
meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True,
|
582 |
+
resolve=args.resolve, minimal=args.minimal, workers=args.workers)
|
583 |
+
save_audio_meta(args.output_meta_file, meta)
|
584 |
+
|
585 |
+
|
586 |
+
if __name__ == '__main__':
|
587 |
+
main()
|
audiocraft/data/audio_utils.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Various utilities for audio convertion (pcm format, sample rate and channels),
|
7 |
+
and volume normalization."""
|
8 |
+
import sys
|
9 |
+
import typing as tp
|
10 |
+
|
11 |
+
import julius
|
12 |
+
import torch
|
13 |
+
import torchaudio
|
14 |
+
|
15 |
+
|
16 |
+
def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor:
|
17 |
+
"""Convert audio to the given number of channels.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
wav (torch.Tensor): Audio wave of shape [B, C, T].
|
21 |
+
channels (int): Expected number of channels as output.
|
22 |
+
Returns:
|
23 |
+
torch.Tensor: Downmixed or unchanged audio wave [B, C, T].
|
24 |
+
"""
|
25 |
+
*shape, src_channels, length = wav.shape
|
26 |
+
if src_channels == channels:
|
27 |
+
pass
|
28 |
+
elif channels == 1:
|
29 |
+
# Case 1:
|
30 |
+
# The caller asked 1-channel audio, and the stream has multiple
|
31 |
+
# channels, downmix all channels.
|
32 |
+
wav = wav.mean(dim=-2, keepdim=True)
|
33 |
+
elif src_channels == 1:
|
34 |
+
# Case 2:
|
35 |
+
# The caller asked for multiple channels, but the input file has
|
36 |
+
# a single channel, replicate the audio over all channels.
|
37 |
+
wav = wav.expand(*shape, channels, length)
|
38 |
+
elif src_channels >= channels:
|
39 |
+
# Case 3:
|
40 |
+
# The caller asked for multiple channels, and the input file has
|
41 |
+
# more channels than requested. In that case return the first channels.
|
42 |
+
wav = wav[..., :channels, :]
|
43 |
+
else:
|
44 |
+
# Case 4: What is a reasonable choice here?
|
45 |
+
raise ValueError('The audio file has less channels than requested but is not mono.')
|
46 |
+
return wav
|
47 |
+
|
48 |
+
|
49 |
+
def convert_audio(wav: torch.Tensor, from_rate: float,
|
50 |
+
to_rate: float, to_channels: int) -> torch.Tensor:
|
51 |
+
"""Convert audio to new sample rate and number of audio channels."""
|
52 |
+
wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
|
53 |
+
wav = convert_audio_channels(wav, to_channels)
|
54 |
+
return wav
|
55 |
+
|
56 |
+
|
57 |
+
def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
|
58 |
+
loudness_compressor: bool = False, energy_floor: float = 2e-3):
|
59 |
+
"""Normalize an input signal to a user loudness in dB LKFS.
|
60 |
+
Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
wav (torch.Tensor): Input multichannel audio data.
|
64 |
+
sample_rate (int): Sample rate.
|
65 |
+
loudness_headroom_db (float): Target loudness of the output in dB LUFS.
|
66 |
+
loudness_compressor (bool): Uses tanh for soft clipping.
|
67 |
+
energy_floor (float): anything below that RMS level will not be rescaled.
|
68 |
+
Returns:
|
69 |
+
torch.Tensor: Loudness normalized output data.
|
70 |
+
"""
|
71 |
+
energy = wav.pow(2).mean().sqrt().item()
|
72 |
+
if energy < energy_floor:
|
73 |
+
return wav
|
74 |
+
transform = torchaudio.transforms.Loudness(sample_rate)
|
75 |
+
input_loudness_db = transform(wav).item()
|
76 |
+
# calculate the gain needed to scale to the desired loudness level
|
77 |
+
delta_loudness = -loudness_headroom_db - input_loudness_db
|
78 |
+
gain = 10.0 ** (delta_loudness / 20.0)
|
79 |
+
output = gain * wav
|
80 |
+
if loudness_compressor:
|
81 |
+
output = torch.tanh(output)
|
82 |
+
assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
|
83 |
+
return output
|
84 |
+
|
85 |
+
|
86 |
+
def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None:
|
87 |
+
"""Utility function to clip the audio with logging if specified."""
|
88 |
+
max_scale = wav.abs().max()
|
89 |
+
if log_clipping and max_scale > 1:
|
90 |
+
clamp_prob = (wav.abs() > 1).float().mean().item()
|
91 |
+
print(f"CLIPPING {stem_name or ''} happening with proba (a bit of clipping is okay):",
|
92 |
+
clamp_prob, "maximum scale: ", max_scale.item(), file=sys.stderr)
|
93 |
+
#wav.clamp_(-1, 1)
|
94 |
+
wav = wav.clone().clamp_(-1, 1)
|
95 |
+
|
96 |
+
|
97 |
+
def normalize_audio(wav: torch.Tensor, normalize: bool = True,
|
98 |
+
strategy: str = 'peak', peak_clip_headroom_db: float = 1,
|
99 |
+
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
|
100 |
+
loudness_compressor: bool = False, log_clipping: bool = False,
|
101 |
+
sample_rate: tp.Optional[int] = None,
|
102 |
+
stem_name: tp.Optional[str] = None) -> torch.Tensor:
|
103 |
+
"""Normalize the audio according to the prescribed strategy (see after).
|
104 |
+
|
105 |
+
Args:
|
106 |
+
wav (torch.Tensor): Audio data.
|
107 |
+
normalize (bool): if `True` (default), normalizes according to the prescribed
|
108 |
+
strategy (see after). If `False`, the strategy is only used in case clipping
|
109 |
+
would happen.
|
110 |
+
strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
|
111 |
+
i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
|
112 |
+
with extra headroom to avoid clipping. 'clip' just clips.
|
113 |
+
peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
|
114 |
+
rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
|
115 |
+
than the `peak_clip` one to avoid further clipping.
|
116 |
+
loudness_headroom_db (float): Target loudness for loudness normalization.
|
117 |
+
loudness_compressor (bool): If True, uses tanh based soft clipping.
|
118 |
+
log_clipping (bool): If True, basic logging on stderr when clipping still
|
119 |
+
occurs despite strategy (only for 'rms').
|
120 |
+
sample_rate (int): Sample rate for the audio data (required for loudness).
|
121 |
+
stem_name (str, optional): Stem name for clipping logging.
|
122 |
+
Returns:
|
123 |
+
torch.Tensor: Normalized audio.
|
124 |
+
"""
|
125 |
+
scale_peak = 10 ** (-peak_clip_headroom_db / 20)
|
126 |
+
scale_rms = 10 ** (-rms_headroom_db / 20)
|
127 |
+
if strategy == 'peak':
|
128 |
+
rescaling = (scale_peak / wav.abs().max())
|
129 |
+
if normalize or rescaling < 1:
|
130 |
+
wav = wav * rescaling
|
131 |
+
elif strategy == 'clip':
|
132 |
+
wav = wav.clamp(-scale_peak, scale_peak)
|
133 |
+
elif strategy == 'rms':
|
134 |
+
mono = wav.mean(dim=0)
|
135 |
+
rescaling = scale_rms / mono.pow(2).mean().sqrt()
|
136 |
+
if normalize or rescaling < 1:
|
137 |
+
wav = wav * rescaling
|
138 |
+
_clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
|
139 |
+
elif strategy == 'loudness':
|
140 |
+
assert sample_rate is not None, "Loudness normalization requires sample rate."
|
141 |
+
wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor)
|
142 |
+
_clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
|
143 |
+
else:
|
144 |
+
assert wav.abs().max() < 1
|
145 |
+
assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'"
|
146 |
+
return wav
|
147 |
+
|
148 |
+
|
149 |
+
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
150 |
+
"""Convert audio to float 32 bits PCM format.
|
151 |
+
"""
|
152 |
+
if wav.dtype.is_floating_point:
|
153 |
+
return wav
|
154 |
+
elif wav.dtype == torch.int16:
|
155 |
+
return wav.float() / 2**15
|
156 |
+
elif wav.dtype == torch.int32:
|
157 |
+
return wav.float() / 2**31
|
158 |
+
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
|
159 |
+
|
160 |
+
|
161 |
+
def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
|
162 |
+
"""Convert audio to int 16 bits PCM format.
|
163 |
+
|
164 |
+
..Warning:: There exist many formula for doing this conversion. None are perfect
|
165 |
+
due to the asymmetry of the int16 range. One either have possible clipping, DC offset,
|
166 |
+
or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom,
|
167 |
+
it is possible that `i16_pcm(f32_pcm)) != Identity`.
|
168 |
+
"""
|
169 |
+
if wav.dtype.is_floating_point:
|
170 |
+
assert wav.abs().max() <= 1
|
171 |
+
candidate = (wav * 2 ** 15).round()
|
172 |
+
if candidate.max() >= 2 ** 15: # clipping would occur
|
173 |
+
candidate = (wav * (2 ** 15 - 1)).round()
|
174 |
+
return candidate.short()
|
175 |
+
else:
|
176 |
+
assert wav.dtype == torch.int16
|
177 |
+
return wav
|
audiocraft/data/info_audio_dataset.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Base classes for the datasets that also provide non-audio metadata,
|
7 |
+
e.g. description, text transcription etc.
|
8 |
+
"""
|
9 |
+
from dataclasses import dataclass
|
10 |
+
import logging
|
11 |
+
import math
|
12 |
+
import re
|
13 |
+
import typing as tp
|
14 |
+
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from .audio_dataset import AudioDataset, AudioMeta
|
18 |
+
from ..environment import AudioCraftEnvironment
|
19 |
+
from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes
|
20 |
+
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
def _clusterify_meta(meta: AudioMeta) -> AudioMeta:
|
26 |
+
"""Monkey-patch meta to match cluster specificities."""
|
27 |
+
meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path)
|
28 |
+
if meta.info_path is not None:
|
29 |
+
meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path)
|
30 |
+
return meta
|
31 |
+
|
32 |
+
|
33 |
+
def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
|
34 |
+
"""Monkey-patch all meta to match cluster specificities."""
|
35 |
+
return [_clusterify_meta(m) for m in meta]
|
36 |
+
|
37 |
+
|
38 |
+
@dataclass
|
39 |
+
class AudioInfo(SegmentWithAttributes):
|
40 |
+
"""Dummy SegmentInfo with empty attributes.
|
41 |
+
|
42 |
+
The InfoAudioDataset is expected to return metadata that inherits
|
43 |
+
from SegmentWithAttributes class and can return conditioning attributes.
|
44 |
+
|
45 |
+
This basically guarantees all datasets will be compatible with current
|
46 |
+
solver that contain conditioners requiring this.
|
47 |
+
"""
|
48 |
+
audio_tokens: tp.Optional[torch.Tensor] = None # populated when using cached batch for training a LM.
|
49 |
+
|
50 |
+
def to_condition_attributes(self) -> ConditioningAttributes:
|
51 |
+
return ConditioningAttributes()
|
52 |
+
|
53 |
+
|
54 |
+
class InfoAudioDataset(AudioDataset):
|
55 |
+
"""AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform.
|
56 |
+
|
57 |
+
See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments.
|
58 |
+
"""
|
59 |
+
def __init__(self, meta: tp.List[AudioMeta], **kwargs):
|
60 |
+
super().__init__(clusterify_all_meta(meta), **kwargs)
|
61 |
+
|
62 |
+
def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]:
|
63 |
+
if not self.return_info:
|
64 |
+
wav = super().__getitem__(index)
|
65 |
+
assert isinstance(wav, torch.Tensor)
|
66 |
+
return wav
|
67 |
+
wav, meta = super().__getitem__(index)
|
68 |
+
return wav, AudioInfo(**meta.to_dict())
|
69 |
+
|
70 |
+
|
71 |
+
def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]:
|
72 |
+
"""Preprocess a single keyword or possible a list of keywords."""
|
73 |
+
if isinstance(value, list):
|
74 |
+
return get_keyword_list(value)
|
75 |
+
else:
|
76 |
+
return get_keyword(value)
|
77 |
+
|
78 |
+
|
79 |
+
def get_string(value: tp.Optional[str]) -> tp.Optional[str]:
|
80 |
+
"""Preprocess a single keyword."""
|
81 |
+
if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
|
82 |
+
return None
|
83 |
+
else:
|
84 |
+
return value.strip()
|
85 |
+
|
86 |
+
|
87 |
+
def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]:
|
88 |
+
"""Preprocess a single keyword."""
|
89 |
+
if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
|
90 |
+
return None
|
91 |
+
else:
|
92 |
+
return value.strip().lower()
|
93 |
+
|
94 |
+
|
95 |
+
def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]:
|
96 |
+
"""Preprocess a list of keywords."""
|
97 |
+
if isinstance(values, str):
|
98 |
+
values = [v.strip() for v in re.split(r'[,\s]', values)]
|
99 |
+
elif isinstance(values, float) and math.isnan(values):
|
100 |
+
values = []
|
101 |
+
if not isinstance(values, list):
|
102 |
+
logger.debug(f"Unexpected keyword list {values}")
|
103 |
+
values = [str(values)]
|
104 |
+
|
105 |
+
kws = [get_keyword(v) for v in values]
|
106 |
+
kw_list = [k for k in kws if k is not None]
|
107 |
+
if len(kw_list) == 0:
|
108 |
+
return None
|
109 |
+
else:
|
110 |
+
return kw_list
|
audiocraft/data/music_dataset.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Dataset of music tracks with rich metadata.
|
7 |
+
"""
|
8 |
+
from dataclasses import dataclass, field, fields, replace
|
9 |
+
import gzip
|
10 |
+
import json
|
11 |
+
import logging
|
12 |
+
from pathlib import Path
|
13 |
+
import random
|
14 |
+
import typing as tp
|
15 |
+
|
16 |
+
import torch
|
17 |
+
|
18 |
+
from .info_audio_dataset import (
|
19 |
+
InfoAudioDataset,
|
20 |
+
AudioInfo,
|
21 |
+
get_keyword_list,
|
22 |
+
get_keyword,
|
23 |
+
get_string
|
24 |
+
)
|
25 |
+
from ..modules.conditioners import (
|
26 |
+
ConditioningAttributes,
|
27 |
+
JointEmbedCondition,
|
28 |
+
WavCondition,
|
29 |
+
)
|
30 |
+
from ..utils.utils import warn_once
|
31 |
+
|
32 |
+
|
33 |
+
logger = logging.getLogger(__name__)
|
34 |
+
|
35 |
+
|
36 |
+
@dataclass
|
37 |
+
class MusicInfo(AudioInfo):
|
38 |
+
"""Segment info augmented with music metadata.
|
39 |
+
"""
|
40 |
+
# music-specific metadata
|
41 |
+
title: tp.Optional[str] = None
|
42 |
+
artist: tp.Optional[str] = None # anonymized artist id, used to ensure no overlap between splits
|
43 |
+
key: tp.Optional[str] = None
|
44 |
+
bpm: tp.Optional[float] = None
|
45 |
+
genre: tp.Optional[str] = None
|
46 |
+
moods: tp.Optional[list] = None
|
47 |
+
keywords: tp.Optional[list] = None
|
48 |
+
description: tp.Optional[str] = None
|
49 |
+
name: tp.Optional[str] = None
|
50 |
+
instrument: tp.Optional[str] = None
|
51 |
+
# original wav accompanying the metadata
|
52 |
+
self_wav: tp.Optional[WavCondition] = None
|
53 |
+
# dict mapping attributes names to tuple of wav, text and metadata
|
54 |
+
joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
|
55 |
+
|
56 |
+
@property
|
57 |
+
def has_music_meta(self) -> bool:
|
58 |
+
return self.name is not None
|
59 |
+
|
60 |
+
def to_condition_attributes(self) -> ConditioningAttributes:
|
61 |
+
out = ConditioningAttributes()
|
62 |
+
for _field in fields(self):
|
63 |
+
key, value = _field.name, getattr(self, _field.name)
|
64 |
+
if key == 'self_wav':
|
65 |
+
out.wav[key] = value
|
66 |
+
elif key == 'joint_embed':
|
67 |
+
for embed_attribute, embed_cond in value.items():
|
68 |
+
out.joint_embed[embed_attribute] = embed_cond
|
69 |
+
else:
|
70 |
+
if isinstance(value, list):
|
71 |
+
value = ' '.join(value)
|
72 |
+
out.text[key] = value
|
73 |
+
return out
|
74 |
+
|
75 |
+
@staticmethod
|
76 |
+
def attribute_getter(attribute):
|
77 |
+
if attribute == 'bpm':
|
78 |
+
preprocess_func = get_bpm
|
79 |
+
elif attribute == 'key':
|
80 |
+
preprocess_func = get_musical_key
|
81 |
+
elif attribute in ['moods', 'keywords']:
|
82 |
+
preprocess_func = get_keyword_list
|
83 |
+
elif attribute in ['genre', 'name', 'instrument']:
|
84 |
+
preprocess_func = get_keyword
|
85 |
+
elif attribute in ['title', 'artist', 'description']:
|
86 |
+
preprocess_func = get_string
|
87 |
+
else:
|
88 |
+
preprocess_func = None
|
89 |
+
return preprocess_func
|
90 |
+
|
91 |
+
@classmethod
|
92 |
+
def from_dict(cls, dictionary: dict, fields_required: bool = False):
|
93 |
+
_dictionary: tp.Dict[str, tp.Any] = {}
|
94 |
+
|
95 |
+
# allow a subset of attributes to not be loaded from the dictionary
|
96 |
+
# these attributes may be populated later
|
97 |
+
post_init_attributes = ['self_wav', 'joint_embed']
|
98 |
+
optional_fields = ['keywords']
|
99 |
+
|
100 |
+
for _field in fields(cls):
|
101 |
+
if _field.name in post_init_attributes:
|
102 |
+
continue
|
103 |
+
elif _field.name not in dictionary:
|
104 |
+
if fields_required and _field.name not in optional_fields:
|
105 |
+
raise KeyError(f"Unexpected missing key: {_field.name}")
|
106 |
+
else:
|
107 |
+
preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
|
108 |
+
value = dictionary[_field.name]
|
109 |
+
if preprocess_func:
|
110 |
+
value = preprocess_func(value)
|
111 |
+
_dictionary[_field.name] = value
|
112 |
+
return cls(**_dictionary)
|
113 |
+
|
114 |
+
|
115 |
+
def augment_music_info_description(music_info: MusicInfo, merge_text_p: float = 0.,
|
116 |
+
drop_desc_p: float = 0., drop_other_p: float = 0.) -> MusicInfo:
|
117 |
+
"""Augment MusicInfo description with additional metadata fields and potential dropout.
|
118 |
+
Additional textual attributes are added given probability 'merge_text_conditions_p' and
|
119 |
+
the original textual description is dropped from the augmented description given probability drop_desc_p.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
music_info (MusicInfo): The music metadata to augment.
|
123 |
+
merge_text_p (float): Probability of merging additional metadata to the description.
|
124 |
+
If provided value is 0, then no merging is performed.
|
125 |
+
drop_desc_p (float): Probability of dropping the original description on text merge.
|
126 |
+
if provided value is 0, then no drop out is performed.
|
127 |
+
drop_other_p (float): Probability of dropping the other fields used for text augmentation.
|
128 |
+
Returns:
|
129 |
+
MusicInfo: The MusicInfo with augmented textual description.
|
130 |
+
"""
|
131 |
+
def is_valid_field(field_name: str, field_value: tp.Any) -> bool:
|
132 |
+
valid_field_name = field_name in ['key', 'bpm', 'genre', 'moods', 'instrument', 'keywords']
|
133 |
+
valid_field_value = field_value is not None and isinstance(field_value, (int, float, str, list))
|
134 |
+
keep_field = random.uniform(0, 1) < drop_other_p
|
135 |
+
return valid_field_name and valid_field_value and keep_field
|
136 |
+
|
137 |
+
def process_value(v: tp.Any) -> str:
|
138 |
+
if isinstance(v, (int, float, str)):
|
139 |
+
return str(v)
|
140 |
+
if isinstance(v, list):
|
141 |
+
return ", ".join(v)
|
142 |
+
else:
|
143 |
+
raise ValueError(f"Unknown type for text value! ({type(v), v})")
|
144 |
+
|
145 |
+
description = music_info.description
|
146 |
+
|
147 |
+
metadata_text = ""
|
148 |
+
if random.uniform(0, 1) < merge_text_p:
|
149 |
+
meta_pairs = [f'{_field.name}: {process_value(getattr(music_info, _field.name))}'
|
150 |
+
for _field in fields(music_info) if is_valid_field(_field.name, getattr(music_info, _field.name))]
|
151 |
+
random.shuffle(meta_pairs)
|
152 |
+
metadata_text = ". ".join(meta_pairs)
|
153 |
+
description = description if not random.uniform(0, 1) < drop_desc_p else None
|
154 |
+
logger.debug(f"Applying text augmentation on MMI info. description: {description}, metadata: {metadata_text}")
|
155 |
+
|
156 |
+
if description is None:
|
157 |
+
description = metadata_text if len(metadata_text) > 1 else None
|
158 |
+
else:
|
159 |
+
description = ". ".join([description.rstrip('.'), metadata_text])
|
160 |
+
description = description.strip() if description else None
|
161 |
+
|
162 |
+
music_info = replace(music_info)
|
163 |
+
music_info.description = description
|
164 |
+
return music_info
|
165 |
+
|
166 |
+
|
167 |
+
class Paraphraser:
|
168 |
+
def __init__(self, paraphrase_source: tp.Union[str, Path], paraphrase_p: float = 0.):
|
169 |
+
self.paraphrase_p = paraphrase_p
|
170 |
+
open_fn = gzip.open if str(paraphrase_source).lower().endswith('.gz') else open
|
171 |
+
with open_fn(paraphrase_source, 'rb') as f: # type: ignore
|
172 |
+
self.paraphrase_source = json.loads(f.read())
|
173 |
+
logger.info(f"loaded paraphrasing source from: {paraphrase_source}")
|
174 |
+
|
175 |
+
def sample_paraphrase(self, audio_path: str, description: str):
|
176 |
+
if random.random() >= self.paraphrase_p:
|
177 |
+
return description
|
178 |
+
info_path = Path(audio_path).with_suffix('.json')
|
179 |
+
if info_path not in self.paraphrase_source:
|
180 |
+
warn_once(logger, f"{info_path} not in paraphrase source!")
|
181 |
+
return description
|
182 |
+
new_desc = random.choice(self.paraphrase_source[info_path])
|
183 |
+
logger.debug(f"{description} -> {new_desc}")
|
184 |
+
return new_desc
|
185 |
+
|
186 |
+
|
187 |
+
class MusicDataset(InfoAudioDataset):
|
188 |
+
"""Music dataset is an AudioDataset with music-related metadata.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
info_fields_required (bool): Whether to enforce having required fields.
|
192 |
+
merge_text_p (float): Probability of merging additional metadata to the description.
|
193 |
+
drop_desc_p (float): Probability of dropping the original description on text merge.
|
194 |
+
drop_other_p (float): Probability of dropping the other fields used for text augmentation.
|
195 |
+
joint_embed_attributes (list[str]): A list of attributes for which joint embedding metadata is returned.
|
196 |
+
paraphrase_source (str, optional): Path to the .json or .json.gz file containing the
|
197 |
+
paraphrases for the description. The json should be a dict with keys are the
|
198 |
+
original info path (e.g. track_path.json) and each value is a list of possible
|
199 |
+
paraphrased.
|
200 |
+
paraphrase_p (float): probability of taking a paraphrase.
|
201 |
+
|
202 |
+
See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
|
203 |
+
"""
|
204 |
+
def __init__(self, *args, info_fields_required: bool = True,
|
205 |
+
merge_text_p: float = 0., drop_desc_p: float = 0., drop_other_p: float = 0.,
|
206 |
+
joint_embed_attributes: tp.List[str] = [],
|
207 |
+
paraphrase_source: tp.Optional[str] = None, paraphrase_p: float = 0,
|
208 |
+
**kwargs):
|
209 |
+
kwargs['return_info'] = True # We require the info for each song of the dataset.
|
210 |
+
super().__init__(*args, **kwargs)
|
211 |
+
self.info_fields_required = info_fields_required
|
212 |
+
self.merge_text_p = merge_text_p
|
213 |
+
self.drop_desc_p = drop_desc_p
|
214 |
+
self.drop_other_p = drop_other_p
|
215 |
+
self.joint_embed_attributes = joint_embed_attributes
|
216 |
+
self.paraphraser = None
|
217 |
+
if paraphrase_source is not None:
|
218 |
+
self.paraphraser = Paraphraser(paraphrase_source, paraphrase_p)
|
219 |
+
|
220 |
+
def __getitem__(self, index):
|
221 |
+
wav, info = super().__getitem__(index)
|
222 |
+
info_data = info.to_dict()
|
223 |
+
music_info_path = Path(info.meta.path).with_suffix('.json')
|
224 |
+
|
225 |
+
if Path(music_info_path).exists():
|
226 |
+
with open(music_info_path, 'r') as json_file:
|
227 |
+
music_data = json.load(json_file)
|
228 |
+
music_data.update(info_data)
|
229 |
+
music_info = MusicInfo.from_dict(music_data, fields_required=self.info_fields_required)
|
230 |
+
if self.paraphraser is not None:
|
231 |
+
music_info.description = self.paraphraser.sample(music_info.meta.path, music_info.description)
|
232 |
+
if self.merge_text_p:
|
233 |
+
music_info = augment_music_info_description(
|
234 |
+
music_info, self.merge_text_p, self.drop_desc_p, self.drop_other_p)
|
235 |
+
else:
|
236 |
+
music_info = MusicInfo.from_dict(info_data, fields_required=False)
|
237 |
+
|
238 |
+
music_info.self_wav = WavCondition(
|
239 |
+
wav=wav[None], length=torch.tensor([info.n_frames]),
|
240 |
+
sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
|
241 |
+
|
242 |
+
for att in self.joint_embed_attributes:
|
243 |
+
att_value = getattr(music_info, att)
|
244 |
+
joint_embed_cond = JointEmbedCondition(
|
245 |
+
wav[None], [att_value], torch.tensor([info.n_frames]),
|
246 |
+
sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
|
247 |
+
music_info.joint_embed[att] = joint_embed_cond
|
248 |
+
|
249 |
+
return wav, music_info
|
250 |
+
|
251 |
+
|
252 |
+
def get_musical_key(value: tp.Optional[str]) -> tp.Optional[str]:
|
253 |
+
"""Preprocess key keywords, discarding them if there are multiple key defined."""
|
254 |
+
if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
|
255 |
+
return None
|
256 |
+
elif ',' in value:
|
257 |
+
# For now, we discard when multiple keys are defined separated with comas
|
258 |
+
return None
|
259 |
+
else:
|
260 |
+
return value.strip().lower()
|
261 |
+
|
262 |
+
|
263 |
+
def get_bpm(value: tp.Optional[str]) -> tp.Optional[float]:
|
264 |
+
"""Preprocess to a float."""
|
265 |
+
if value is None:
|
266 |
+
return None
|
267 |
+
try:
|
268 |
+
return float(value)
|
269 |
+
except ValueError:
|
270 |
+
return None
|
audiocraft/data/sound_dataset.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Dataset of audio with a simple description.
|
7 |
+
"""
|
8 |
+
|
9 |
+
from dataclasses import dataclass, fields, replace
|
10 |
+
import json
|
11 |
+
from pathlib import Path
|
12 |
+
import random
|
13 |
+
import typing as tp
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
|
18 |
+
from .info_audio_dataset import (
|
19 |
+
InfoAudioDataset,
|
20 |
+
get_keyword_or_keyword_list
|
21 |
+
)
|
22 |
+
from ..modules.conditioners import (
|
23 |
+
ConditioningAttributes,
|
24 |
+
SegmentWithAttributes,
|
25 |
+
WavCondition,
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
EPS = torch.finfo(torch.float32).eps
|
30 |
+
TARGET_LEVEL_LOWER = -35
|
31 |
+
TARGET_LEVEL_UPPER = -15
|
32 |
+
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
class SoundInfo(SegmentWithAttributes):
|
36 |
+
"""Segment info augmented with Sound metadata.
|
37 |
+
"""
|
38 |
+
description: tp.Optional[str] = None
|
39 |
+
self_wav: tp.Optional[torch.Tensor] = None
|
40 |
+
|
41 |
+
@property
|
42 |
+
def has_sound_meta(self) -> bool:
|
43 |
+
return self.description is not None
|
44 |
+
|
45 |
+
def to_condition_attributes(self) -> ConditioningAttributes:
|
46 |
+
out = ConditioningAttributes()
|
47 |
+
|
48 |
+
for _field in fields(self):
|
49 |
+
key, value = _field.name, getattr(self, _field.name)
|
50 |
+
if key == 'self_wav':
|
51 |
+
out.wav[key] = value
|
52 |
+
else:
|
53 |
+
out.text[key] = value
|
54 |
+
return out
|
55 |
+
|
56 |
+
@staticmethod
|
57 |
+
def attribute_getter(attribute):
|
58 |
+
if attribute == 'description':
|
59 |
+
preprocess_func = get_keyword_or_keyword_list
|
60 |
+
else:
|
61 |
+
preprocess_func = None
|
62 |
+
return preprocess_func
|
63 |
+
|
64 |
+
@classmethod
|
65 |
+
def from_dict(cls, dictionary: dict, fields_required: bool = False):
|
66 |
+
_dictionary: tp.Dict[str, tp.Any] = {}
|
67 |
+
|
68 |
+
# allow a subset of attributes to not be loaded from the dictionary
|
69 |
+
# these attributes may be populated later
|
70 |
+
post_init_attributes = ['self_wav']
|
71 |
+
|
72 |
+
for _field in fields(cls):
|
73 |
+
if _field.name in post_init_attributes:
|
74 |
+
continue
|
75 |
+
elif _field.name not in dictionary:
|
76 |
+
if fields_required:
|
77 |
+
raise KeyError(f"Unexpected missing key: {_field.name}")
|
78 |
+
else:
|
79 |
+
preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
|
80 |
+
value = dictionary[_field.name]
|
81 |
+
if preprocess_func:
|
82 |
+
value = preprocess_func(value)
|
83 |
+
_dictionary[_field.name] = value
|
84 |
+
return cls(**_dictionary)
|
85 |
+
|
86 |
+
|
87 |
+
class SoundDataset(InfoAudioDataset):
|
88 |
+
"""Sound audio dataset: Audio dataset with environmental sound-specific metadata.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
info_fields_required (bool): Whether all the mandatory metadata fields should be in the loaded metadata.
|
92 |
+
external_metadata_source (tp.Optional[str]): Folder containing JSON metadata for the corresponding dataset.
|
93 |
+
The metadata files contained in this folder are expected to match the stem of the audio file with
|
94 |
+
a json extension.
|
95 |
+
aug_p (float): Probability of performing audio mixing augmentation on the batch.
|
96 |
+
mix_p (float): Proportion of batch items that are mixed together when applying audio mixing augmentation.
|
97 |
+
mix_snr_low (int): Lowerbound for SNR value sampled for mixing augmentation.
|
98 |
+
mix_snr_high (int): Upperbound for SNR value sampled for mixing augmentation.
|
99 |
+
mix_min_overlap (float): Minimum overlap between audio files when performing mixing augmentation.
|
100 |
+
kwargs: Additional arguments for AudioDataset.
|
101 |
+
|
102 |
+
See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
|
103 |
+
"""
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
*args,
|
107 |
+
info_fields_required: bool = True,
|
108 |
+
external_metadata_source: tp.Optional[str] = None,
|
109 |
+
aug_p: float = 0.,
|
110 |
+
mix_p: float = 0.,
|
111 |
+
mix_snr_low: int = -5,
|
112 |
+
mix_snr_high: int = 5,
|
113 |
+
mix_min_overlap: float = 0.5,
|
114 |
+
**kwargs
|
115 |
+
):
|
116 |
+
kwargs['return_info'] = True # We require the info for each song of the dataset.
|
117 |
+
super().__init__(*args, **kwargs)
|
118 |
+
self.info_fields_required = info_fields_required
|
119 |
+
self.external_metadata_source = external_metadata_source
|
120 |
+
self.aug_p = aug_p
|
121 |
+
self.mix_p = mix_p
|
122 |
+
if self.aug_p > 0:
|
123 |
+
assert self.mix_p > 0, "Expecting some mixing proportion mix_p if aug_p > 0"
|
124 |
+
assert self.channels == 1, "SoundDataset with audio mixing considers only monophonic audio"
|
125 |
+
self.mix_snr_low = mix_snr_low
|
126 |
+
self.mix_snr_high = mix_snr_high
|
127 |
+
self.mix_min_overlap = mix_min_overlap
|
128 |
+
|
129 |
+
def _get_info_path(self, path: tp.Union[str, Path]) -> Path:
|
130 |
+
"""Get path of JSON with metadata (description, etc.).
|
131 |
+
If there exists a JSON with the same name as 'path.name', then it will be used.
|
132 |
+
Else, such JSON will be searched for in an external json source folder if it exists.
|
133 |
+
"""
|
134 |
+
info_path = Path(path).with_suffix('.json')
|
135 |
+
if Path(info_path).exists():
|
136 |
+
return info_path
|
137 |
+
elif self.external_metadata_source and (Path(self.external_metadata_source) / info_path.name).exists():
|
138 |
+
return Path(self.external_metadata_source) / info_path.name
|
139 |
+
else:
|
140 |
+
raise Exception(f"Unable to find a metadata JSON for path: {path}")
|
141 |
+
|
142 |
+
def __getitem__(self, index):
|
143 |
+
wav, info = super().__getitem__(index)
|
144 |
+
info_data = info.to_dict()
|
145 |
+
info_path = self._get_info_path(info.meta.path)
|
146 |
+
if Path(info_path).exists():
|
147 |
+
with open(info_path, 'r') as json_file:
|
148 |
+
sound_data = json.load(json_file)
|
149 |
+
sound_data.update(info_data)
|
150 |
+
sound_info = SoundInfo.from_dict(sound_data, fields_required=self.info_fields_required)
|
151 |
+
# if there are multiple descriptions, sample one randomly
|
152 |
+
if isinstance(sound_info.description, list):
|
153 |
+
sound_info.description = random.choice(sound_info.description)
|
154 |
+
else:
|
155 |
+
sound_info = SoundInfo.from_dict(info_data, fields_required=False)
|
156 |
+
|
157 |
+
sound_info.self_wav = WavCondition(
|
158 |
+
wav=wav[None], length=torch.tensor([info.n_frames]),
|
159 |
+
sample_rate=[sound_info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
|
160 |
+
|
161 |
+
return wav, sound_info
|
162 |
+
|
163 |
+
def collater(self, samples):
|
164 |
+
# when training, audio mixing is performed in the collate function
|
165 |
+
wav, sound_info = super().collater(samples) # SoundDataset always returns infos
|
166 |
+
if self.aug_p > 0:
|
167 |
+
wav, sound_info = mix_samples(wav, sound_info, self.aug_p, self.mix_p,
|
168 |
+
snr_low=self.mix_snr_low, snr_high=self.mix_snr_high,
|
169 |
+
min_overlap=self.mix_min_overlap)
|
170 |
+
return wav, sound_info
|
171 |
+
|
172 |
+
|
173 |
+
def rms_f(x: torch.Tensor) -> torch.Tensor:
|
174 |
+
return (x ** 2).mean(1).pow(0.5)
|
175 |
+
|
176 |
+
|
177 |
+
def normalize(audio: torch.Tensor, target_level: int = -25) -> torch.Tensor:
|
178 |
+
"""Normalize the signal to the target level."""
|
179 |
+
rms = rms_f(audio)
|
180 |
+
scalar = 10 ** (target_level / 20) / (rms + EPS)
|
181 |
+
audio = audio * scalar.unsqueeze(1)
|
182 |
+
return audio
|
183 |
+
|
184 |
+
|
185 |
+
def is_clipped(audio: torch.Tensor, clipping_threshold: float = 0.99) -> torch.Tensor:
|
186 |
+
return (abs(audio) > clipping_threshold).any(1)
|
187 |
+
|
188 |
+
|
189 |
+
def mix_pair(src: torch.Tensor, dst: torch.Tensor, min_overlap: float) -> torch.Tensor:
|
190 |
+
start = random.randint(0, int(src.shape[1] * (1 - min_overlap)))
|
191 |
+
remainder = src.shape[1] - start
|
192 |
+
if dst.shape[1] > remainder:
|
193 |
+
src[:, start:] = src[:, start:] + dst[:, :remainder]
|
194 |
+
else:
|
195 |
+
src[:, start:start+dst.shape[1]] = src[:, start:start+dst.shape[1]] + dst
|
196 |
+
return src
|
197 |
+
|
198 |
+
|
199 |
+
def snr_mixer(clean: torch.Tensor, noise: torch.Tensor, snr: int, min_overlap: float,
|
200 |
+
target_level: int = -25, clipping_threshold: float = 0.99) -> torch.Tensor:
|
201 |
+
"""Function to mix clean speech and noise at various SNR levels.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
clean (torch.Tensor): Clean audio source to mix, of shape [B, T].
|
205 |
+
noise (torch.Tensor): Noise audio source to mix, of shape [B, T].
|
206 |
+
snr (int): SNR level when mixing.
|
207 |
+
min_overlap (float): Minimum overlap between the two mixed sources.
|
208 |
+
target_level (int): Gain level in dB.
|
209 |
+
clipping_threshold (float): Threshold for clipping the audio.
|
210 |
+
Returns:
|
211 |
+
torch.Tensor: The mixed audio, of shape [B, T].
|
212 |
+
"""
|
213 |
+
if clean.shape[1] > noise.shape[1]:
|
214 |
+
noise = torch.nn.functional.pad(noise, (0, clean.shape[1] - noise.shape[1]))
|
215 |
+
else:
|
216 |
+
noise = noise[:, :clean.shape[1]]
|
217 |
+
|
218 |
+
# normalizing to -25 dB FS
|
219 |
+
clean = clean / (clean.max(1)[0].abs().unsqueeze(1) + EPS)
|
220 |
+
clean = normalize(clean, target_level)
|
221 |
+
rmsclean = rms_f(clean)
|
222 |
+
|
223 |
+
noise = noise / (noise.max(1)[0].abs().unsqueeze(1) + EPS)
|
224 |
+
noise = normalize(noise, target_level)
|
225 |
+
rmsnoise = rms_f(noise)
|
226 |
+
|
227 |
+
# set the noise level for a given SNR
|
228 |
+
noisescalar = (rmsclean / (10 ** (snr / 20)) / (rmsnoise + EPS)).unsqueeze(1)
|
229 |
+
noisenewlevel = noise * noisescalar
|
230 |
+
|
231 |
+
# mix noise and clean speech
|
232 |
+
noisyspeech = mix_pair(clean, noisenewlevel, min_overlap)
|
233 |
+
|
234 |
+
# randomly select RMS value between -15 dBFS and -35 dBFS and normalize noisyspeech with that value
|
235 |
+
# there is a chance of clipping that might happen with very less probability, which is not a major issue.
|
236 |
+
noisy_rms_level = np.random.randint(TARGET_LEVEL_LOWER, TARGET_LEVEL_UPPER)
|
237 |
+
rmsnoisy = rms_f(noisyspeech)
|
238 |
+
scalarnoisy = (10 ** (noisy_rms_level / 20) / (rmsnoisy + EPS)).unsqueeze(1)
|
239 |
+
noisyspeech = noisyspeech * scalarnoisy
|
240 |
+
clean = clean * scalarnoisy
|
241 |
+
noisenewlevel = noisenewlevel * scalarnoisy
|
242 |
+
|
243 |
+
# final check to see if there are any amplitudes exceeding +/- 1. If so, normalize all the signals accordingly
|
244 |
+
clipped = is_clipped(noisyspeech)
|
245 |
+
if clipped.any():
|
246 |
+
noisyspeech_maxamplevel = noisyspeech[clipped].max(1)[0].abs().unsqueeze(1) / (clipping_threshold - EPS)
|
247 |
+
noisyspeech[clipped] = noisyspeech[clipped] / noisyspeech_maxamplevel
|
248 |
+
|
249 |
+
return noisyspeech
|
250 |
+
|
251 |
+
|
252 |
+
def snr_mix(src: torch.Tensor, dst: torch.Tensor, snr_low: int, snr_high: int, min_overlap: float):
|
253 |
+
if snr_low == snr_high:
|
254 |
+
snr = snr_low
|
255 |
+
else:
|
256 |
+
snr = np.random.randint(snr_low, snr_high)
|
257 |
+
mix = snr_mixer(src, dst, snr, min_overlap)
|
258 |
+
return mix
|
259 |
+
|
260 |
+
|
261 |
+
def mix_text(src_text: str, dst_text: str):
|
262 |
+
"""Mix text from different sources by concatenating them."""
|
263 |
+
if src_text == dst_text:
|
264 |
+
return src_text
|
265 |
+
return src_text + " " + dst_text
|
266 |
+
|
267 |
+
|
268 |
+
def mix_samples(wavs: torch.Tensor, infos: tp.List[SoundInfo], aug_p: float, mix_p: float,
|
269 |
+
snr_low: int, snr_high: int, min_overlap: float):
|
270 |
+
"""Mix samples within a batch, summing the waveforms and concatenating the text infos.
|
271 |
+
|
272 |
+
Args:
|
273 |
+
wavs (torch.Tensor): Audio tensors of shape [B, C, T].
|
274 |
+
infos (list[SoundInfo]): List of SoundInfo items corresponding to the audio.
|
275 |
+
aug_p (float): Augmentation probability.
|
276 |
+
mix_p (float): Proportion of items in the batch to mix (and merge) together.
|
277 |
+
snr_low (int): Lowerbound for sampling SNR.
|
278 |
+
snr_high (int): Upperbound for sampling SNR.
|
279 |
+
min_overlap (float): Minimum overlap between mixed samples.
|
280 |
+
Returns:
|
281 |
+
tuple[torch.Tensor, list[SoundInfo]]: A tuple containing the mixed wavs
|
282 |
+
and mixed SoundInfo for the given batch.
|
283 |
+
"""
|
284 |
+
# no mixing to perform within the batch
|
285 |
+
if mix_p == 0:
|
286 |
+
return wavs, infos
|
287 |
+
|
288 |
+
if random.uniform(0, 1) < aug_p:
|
289 |
+
# perform all augmentations on waveforms as [B, T]
|
290 |
+
# randomly picking pairs of audio to mix
|
291 |
+
assert wavs.size(1) == 1, f"Mix samples requires monophonic audio but C={wavs.size(1)}"
|
292 |
+
wavs = wavs.mean(dim=1, keepdim=False)
|
293 |
+
B, T = wavs.shape
|
294 |
+
k = int(mix_p * B)
|
295 |
+
mixed_sources_idx = torch.randperm(B)[:k]
|
296 |
+
mixed_targets_idx = torch.randperm(B)[:k]
|
297 |
+
aug_wavs = snr_mix(
|
298 |
+
wavs[mixed_sources_idx],
|
299 |
+
wavs[mixed_targets_idx],
|
300 |
+
snr_low,
|
301 |
+
snr_high,
|
302 |
+
min_overlap,
|
303 |
+
)
|
304 |
+
# mixing textual descriptions in metadata
|
305 |
+
descriptions = [info.description for info in infos]
|
306 |
+
aug_infos = []
|
307 |
+
for i, j in zip(mixed_sources_idx, mixed_targets_idx):
|
308 |
+
text = mix_text(descriptions[i], descriptions[j])
|
309 |
+
m = replace(infos[i])
|
310 |
+
m.description = text
|
311 |
+
aug_infos.append(m)
|
312 |
+
|
313 |
+
# back to [B, C, T]
|
314 |
+
aug_wavs = aug_wavs.unsqueeze(1)
|
315 |
+
assert aug_wavs.shape[0] > 0, "Samples mixing returned empty batch."
|
316 |
+
assert aug_wavs.dim() == 3, f"Returned wav should be [B, C, T] but dim = {aug_wavs.dim()}"
|
317 |
+
assert aug_wavs.shape[0] == len(aug_infos), "Mismatch between number of wavs and infos in the batch"
|
318 |
+
|
319 |
+
return aug_wavs, aug_infos # [B, C, T]
|
320 |
+
else:
|
321 |
+
# randomly pick samples in the batch to match
|
322 |
+
# the batch size when performing audio mixing
|
323 |
+
B, C, T = wavs.shape
|
324 |
+
k = int(mix_p * B)
|
325 |
+
wav_idx = torch.randperm(B)[:k]
|
326 |
+
wavs = wavs[wav_idx]
|
327 |
+
infos = [infos[i] for i in wav_idx]
|
328 |
+
assert wavs.shape[0] == len(infos), "Mismatch between number of wavs and infos in the batch"
|
329 |
+
|
330 |
+
return wavs, infos # [B, C, T]
|
audiocraft/data/zip.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Utility for reading some info from inside a zip file.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import typing
|
10 |
+
import zipfile
|
11 |
+
|
12 |
+
from dataclasses import dataclass
|
13 |
+
from functools import lru_cache
|
14 |
+
from typing_extensions import Literal
|
15 |
+
|
16 |
+
|
17 |
+
DEFAULT_SIZE = 32
|
18 |
+
MODE = Literal['r', 'w', 'x', 'a']
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass(order=True)
|
22 |
+
class PathInZip:
|
23 |
+
"""Hold a path of file within a zip file.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
path (str): The convention is <path_to_zip>:<relative_path_inside_zip>.
|
27 |
+
Let's assume there is a zip file /some/location/foo.zip
|
28 |
+
and inside of it is a json file located at /data/file1.json,
|
29 |
+
Then we expect path = "/some/location/foo.zip:/data/file1.json".
|
30 |
+
"""
|
31 |
+
|
32 |
+
INFO_PATH_SEP = ':'
|
33 |
+
zip_path: str
|
34 |
+
file_path: str
|
35 |
+
|
36 |
+
def __init__(self, path: str) -> None:
|
37 |
+
split_path = path.split(self.INFO_PATH_SEP)
|
38 |
+
assert len(split_path) == 2
|
39 |
+
self.zip_path, self.file_path = split_path
|
40 |
+
|
41 |
+
@classmethod
|
42 |
+
def from_paths(cls, zip_path: str, file_path: str):
|
43 |
+
return cls(zip_path + cls.INFO_PATH_SEP + file_path)
|
44 |
+
|
45 |
+
def __str__(self) -> str:
|
46 |
+
return self.zip_path + self.INFO_PATH_SEP + self.file_path
|
47 |
+
|
48 |
+
|
49 |
+
def _open_zip(path: str, mode: MODE = 'r'):
|
50 |
+
return zipfile.ZipFile(path, mode)
|
51 |
+
|
52 |
+
|
53 |
+
_cached_open_zip = lru_cache(DEFAULT_SIZE)(_open_zip)
|
54 |
+
|
55 |
+
|
56 |
+
def set_zip_cache_size(max_size: int):
|
57 |
+
"""Sets the maximal LRU caching for zip file opening.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
max_size (int): the maximal LRU cache.
|
61 |
+
"""
|
62 |
+
global _cached_open_zip
|
63 |
+
_cached_open_zip = lru_cache(max_size)(_open_zip)
|
64 |
+
|
65 |
+
|
66 |
+
def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
|
67 |
+
"""Opens a file stored inside a zip and returns a file-like object.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
path_in_zip (PathInZip): A PathInZip object representing the file to return a file-like object of.
|
71 |
+
mode (str): The mode in which to open the file with.
|
72 |
+
Returns:
|
73 |
+
A file-like object for PathInZip.
|
74 |
+
"""
|
75 |
+
zf = _cached_open_zip(path_in_zip.zip_path)
|
76 |
+
return zf.open(path_in_zip.file_path)
|
audiocraft/environment.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Provides cluster and tools configuration across clusters (slurm, dora, utilities).
|
9 |
+
"""
|
10 |
+
|
11 |
+
import logging
|
12 |
+
import os
|
13 |
+
from pathlib import Path
|
14 |
+
import re
|
15 |
+
import typing as tp
|
16 |
+
|
17 |
+
import omegaconf
|
18 |
+
|
19 |
+
from .utils.cluster import _guess_cluster_type
|
20 |
+
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
class AudioCraftEnvironment:
|
26 |
+
"""Environment configuration for teams and clusters.
|
27 |
+
|
28 |
+
AudioCraftEnvironment picks compute cluster settings (slurm, dora) from the current running environment
|
29 |
+
or declared variable and the loaded team configuration. Additionally, the AudioCraftEnvironment
|
30 |
+
provides pointers to a reference folder resolved automatically across clusters that is shared across team members,
|
31 |
+
allowing to share sigs or other files to run jobs. Finally, it provides dataset mappers to automatically
|
32 |
+
map dataset file paths to new locations across clusters, allowing to use the same manifest of files across cluters.
|
33 |
+
|
34 |
+
The cluster type is identified automatically and base configuration file is read from config/teams.yaml.
|
35 |
+
Use the following environment variables to specify the cluster, team or configuration:
|
36 |
+
|
37 |
+
AUDIOCRAFT_CLUSTER (optional): Cluster type to enforce. Useful if the cluster type
|
38 |
+
cannot be inferred automatically.
|
39 |
+
AUDIOCRAFT_CONFIG (optional): Path to yaml config holding the teams configuration.
|
40 |
+
If not set, configuration is read from config/teams.yaml.
|
41 |
+
AUDIOCRAFT_TEAM (optional): Name of the team. Recommended to set to your own team.
|
42 |
+
Cluster configuration are shared across teams to match compute allocation,
|
43 |
+
specify your cluster configuration in the configuration file under a key mapping
|
44 |
+
your team name.
|
45 |
+
"""
|
46 |
+
_instance = None
|
47 |
+
DEFAULT_TEAM = "default"
|
48 |
+
|
49 |
+
def __init__(self) -> None:
|
50 |
+
"""Loads configuration."""
|
51 |
+
self.team: str = os.getenv("AUDIOCRAFT_TEAM", self.DEFAULT_TEAM)
|
52 |
+
cluster_type = _guess_cluster_type()
|
53 |
+
cluster = os.getenv(
|
54 |
+
"AUDIOCRAFT_CLUSTER", cluster_type.value
|
55 |
+
)
|
56 |
+
logger.info("Detecting cluster type %s", cluster_type)
|
57 |
+
|
58 |
+
self.cluster: str = cluster
|
59 |
+
|
60 |
+
config_path = os.getenv(
|
61 |
+
"AUDIOCRAFT_CONFIG",
|
62 |
+
Path(__file__)
|
63 |
+
.parent.parent.joinpath("config/teams", self.team)
|
64 |
+
.with_suffix(".yaml"),
|
65 |
+
)
|
66 |
+
self.config = omegaconf.OmegaConf.load(config_path)
|
67 |
+
self._dataset_mappers = []
|
68 |
+
cluster_config = self._get_cluster_config()
|
69 |
+
if "dataset_mappers" in cluster_config:
|
70 |
+
for pattern, repl in cluster_config["dataset_mappers"].items():
|
71 |
+
regex = re.compile(pattern)
|
72 |
+
self._dataset_mappers.append((regex, repl))
|
73 |
+
|
74 |
+
def _get_cluster_config(self) -> omegaconf.DictConfig:
|
75 |
+
assert isinstance(self.config, omegaconf.DictConfig)
|
76 |
+
return self.config[self.cluster]
|
77 |
+
|
78 |
+
@classmethod
|
79 |
+
def instance(cls):
|
80 |
+
if cls._instance is None:
|
81 |
+
cls._instance = cls()
|
82 |
+
return cls._instance
|
83 |
+
|
84 |
+
@classmethod
|
85 |
+
def reset(cls):
|
86 |
+
"""Clears the environment and forces a reload on next invocation."""
|
87 |
+
cls._instance = None
|
88 |
+
|
89 |
+
@classmethod
|
90 |
+
def get_team(cls) -> str:
|
91 |
+
"""Gets the selected team as dictated by the AUDIOCRAFT_TEAM env var.
|
92 |
+
If not defined, defaults to "labs".
|
93 |
+
"""
|
94 |
+
return cls.instance().team
|
95 |
+
|
96 |
+
@classmethod
|
97 |
+
def get_cluster(cls) -> str:
|
98 |
+
"""Gets the detected cluster.
|
99 |
+
This value can be overridden by the AUDIOCRAFT_CLUSTER env var.
|
100 |
+
"""
|
101 |
+
return cls.instance().cluster
|
102 |
+
|
103 |
+
@classmethod
|
104 |
+
def get_dora_dir(cls) -> Path:
|
105 |
+
"""Gets the path to the dora directory for the current team and cluster.
|
106 |
+
Value is overridden by the AUDIOCRAFT_DORA_DIR env var.
|
107 |
+
"""
|
108 |
+
cluster_config = cls.instance()._get_cluster_config()
|
109 |
+
dora_dir = os.getenv("AUDIOCRAFT_DORA_DIR", cluster_config["dora_dir"])
|
110 |
+
logger.warning(f"Dora directory: {dora_dir}")
|
111 |
+
return Path(dora_dir)
|
112 |
+
|
113 |
+
@classmethod
|
114 |
+
def get_reference_dir(cls) -> Path:
|
115 |
+
"""Gets the path to the reference directory for the current team and cluster.
|
116 |
+
Value is overridden by the AUDIOCRAFT_REFERENCE_DIR env var.
|
117 |
+
"""
|
118 |
+
cluster_config = cls.instance()._get_cluster_config()
|
119 |
+
return Path(os.getenv("AUDIOCRAFT_REFERENCE_DIR", cluster_config["reference_dir"]))
|
120 |
+
|
121 |
+
@classmethod
|
122 |
+
def get_slurm_exclude(cls) -> tp.Optional[str]:
|
123 |
+
"""Get the list of nodes to exclude for that cluster."""
|
124 |
+
cluster_config = cls.instance()._get_cluster_config()
|
125 |
+
return cluster_config.get("slurm_exclude")
|
126 |
+
|
127 |
+
@classmethod
|
128 |
+
def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str]] = None) -> str:
|
129 |
+
"""Gets the requested partitions for the current team and cluster as a comma-separated string.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
partition_types (list[str], optional): partition types to retrieve. Values must be
|
133 |
+
from ['global', 'team']. If not provided, the global partition is returned.
|
134 |
+
"""
|
135 |
+
if not partition_types:
|
136 |
+
partition_types = ["global"]
|
137 |
+
|
138 |
+
cluster_config = cls.instance()._get_cluster_config()
|
139 |
+
partitions = [
|
140 |
+
cluster_config["partitions"][partition_type]
|
141 |
+
for partition_type in partition_types
|
142 |
+
]
|
143 |
+
return ",".join(partitions)
|
144 |
+
|
145 |
+
@classmethod
|
146 |
+
def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path:
|
147 |
+
"""Converts reference placeholder in path with configured reference dir to resolve paths.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
path (str or Path): Path to resolve.
|
151 |
+
Returns:
|
152 |
+
Path: Resolved path.
|
153 |
+
"""
|
154 |
+
path = str(path)
|
155 |
+
|
156 |
+
if path.startswith("//reference"):
|
157 |
+
reference_dir = cls.get_reference_dir()
|
158 |
+
logger.warn(f"Reference directory: {reference_dir}")
|
159 |
+
assert (
|
160 |
+
reference_dir.exists() and reference_dir.is_dir()
|
161 |
+
), f"Reference directory does not exist: {reference_dir}."
|
162 |
+
path = re.sub("^//reference", str(reference_dir), path)
|
163 |
+
|
164 |
+
return Path(path)
|
165 |
+
|
166 |
+
@classmethod
|
167 |
+
def apply_dataset_mappers(cls, path: str) -> str:
|
168 |
+
"""Applies dataset mapping regex rules as defined in the configuration.
|
169 |
+
If no rules are defined, the path is returned as-is.
|
170 |
+
"""
|
171 |
+
instance = cls.instance()
|
172 |
+
|
173 |
+
for pattern, repl in instance._dataset_mappers:
|
174 |
+
path = pattern.sub(repl, path)
|
175 |
+
|
176 |
+
return path
|
audiocraft/grids/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Dora Grids."""
|
audiocraft/grids/_base_explorers.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from abc import ABC, abstractmethod
|
8 |
+
import time
|
9 |
+
import typing as tp
|
10 |
+
from dora import Explorer
|
11 |
+
import treetable as tt
|
12 |
+
|
13 |
+
|
14 |
+
def get_sheep_ping(sheep) -> tp.Optional[str]:
|
15 |
+
"""Return the amount of time since the Sheep made some update
|
16 |
+
to its log. Returns a str using the relevant time unit."""
|
17 |
+
ping = None
|
18 |
+
if sheep.log is not None and sheep.log.exists():
|
19 |
+
delta = time.time() - sheep.log.stat().st_mtime
|
20 |
+
if delta > 3600 * 24:
|
21 |
+
ping = f'{delta / (3600 * 24):.1f}d'
|
22 |
+
elif delta > 3600:
|
23 |
+
ping = f'{delta / (3600):.1f}h'
|
24 |
+
elif delta > 60:
|
25 |
+
ping = f'{delta / 60:.1f}m'
|
26 |
+
else:
|
27 |
+
ping = f'{delta:.1f}s'
|
28 |
+
return ping
|
29 |
+
|
30 |
+
|
31 |
+
class BaseExplorer(ABC, Explorer):
|
32 |
+
"""Base explorer for AudioCraft grids.
|
33 |
+
|
34 |
+
All task specific solvers are expected to implement the `get_grid_metrics`
|
35 |
+
method to specify logic about metrics to display for a given task.
|
36 |
+
|
37 |
+
If additional stages are used, the child explorer must define how to handle
|
38 |
+
these new stages in the `process_history` and `process_sheep` methods.
|
39 |
+
"""
|
40 |
+
def stages(self):
|
41 |
+
return ["train", "valid", "evaluate"]
|
42 |
+
|
43 |
+
def get_grid_meta(self):
|
44 |
+
"""Returns the list of Meta information to display for each XP/job.
|
45 |
+
"""
|
46 |
+
return [
|
47 |
+
tt.leaf("index", align=">"),
|
48 |
+
tt.leaf("name", wrap=140),
|
49 |
+
tt.leaf("state"),
|
50 |
+
tt.leaf("sig", align=">"),
|
51 |
+
tt.leaf("sid", align="<"),
|
52 |
+
]
|
53 |
+
|
54 |
+
@abstractmethod
|
55 |
+
def get_grid_metrics(self):
|
56 |
+
"""Return the metrics that should be displayed in the tracking table.
|
57 |
+
"""
|
58 |
+
...
|
59 |
+
|
60 |
+
def process_sheep(self, sheep, history):
|
61 |
+
train = {
|
62 |
+
"epoch": len(history),
|
63 |
+
}
|
64 |
+
parts = {"train": train}
|
65 |
+
for metrics in history:
|
66 |
+
for key, sub in metrics.items():
|
67 |
+
part = parts.get(key, {})
|
68 |
+
if 'duration' in sub:
|
69 |
+
# Convert to minutes for readability.
|
70 |
+
sub['duration'] = sub['duration'] / 60.
|
71 |
+
part.update(sub)
|
72 |
+
parts[key] = part
|
73 |
+
ping = get_sheep_ping(sheep)
|
74 |
+
if ping is not None:
|
75 |
+
for name in self.stages():
|
76 |
+
if name not in parts:
|
77 |
+
parts[name] = {}
|
78 |
+
# Add the ping to each part for convenience.
|
79 |
+
parts[name]['ping'] = ping
|
80 |
+
return parts
|
audiocraft/grids/audiogen/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""AudioGen grids."""
|
audiocraft/grids/audiogen/audiogen_base_16khz.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from ..musicgen._explorers import LMExplorer
|
8 |
+
from ...environment import AudioCraftEnvironment
|
9 |
+
|
10 |
+
|
11 |
+
@LMExplorer
|
12 |
+
def explorer(launcher):
|
13 |
+
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
14 |
+
launcher.slurm_(gpus=64, partition=partitions)
|
15 |
+
launcher.bind_(solver='audiogen/audiogen_base_16khz')
|
16 |
+
# replace this by the desired environmental sound dataset
|
17 |
+
launcher.bind_(dset='internal/sounds_16khz')
|
18 |
+
|
19 |
+
fsdp = {'autocast': False, 'fsdp.use': True}
|
20 |
+
medium = {'model/lm/model_scale': 'medium'}
|
21 |
+
|
22 |
+
launcher.bind_(fsdp)
|
23 |
+
launcher(medium)
|
audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Evaluation with objective metrics for the pretrained AudioGen models.
|
9 |
+
This grid takes signature from the training grid and runs evaluation-only stage.
|
10 |
+
|
11 |
+
When running the grid for the first time, please use:
|
12 |
+
REGEN=1 dora grid audiogen.audiogen_pretrained_16khz_eval
|
13 |
+
and re-use the REGEN=1 option when the grid is changed to force regenerating it.
|
14 |
+
|
15 |
+
Note that you need the proper metrics external libraries setup to use all
|
16 |
+
the objective metrics activated in this grid. Refer to the README for more information.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import os
|
20 |
+
|
21 |
+
from ..musicgen._explorers import GenerationEvalExplorer
|
22 |
+
from ...environment import AudioCraftEnvironment
|
23 |
+
from ... import train
|
24 |
+
|
25 |
+
|
26 |
+
def eval(launcher, batch_size: int = 32):
|
27 |
+
opts = {
|
28 |
+
'dset': 'audio/audiocaps_16khz',
|
29 |
+
'solver/audiogen/evaluation': 'objective_eval',
|
30 |
+
'execute_only': 'evaluate',
|
31 |
+
'+dataset.evaluate.batch_size': batch_size,
|
32 |
+
'+metrics.fad.tf.batch_size': 32,
|
33 |
+
}
|
34 |
+
# binary for FAD computation: replace this path with your own path
|
35 |
+
metrics_opts = {
|
36 |
+
'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research'
|
37 |
+
}
|
38 |
+
opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.}
|
39 |
+
opt2 = {'transformer_lm.two_step_cfg': True}
|
40 |
+
|
41 |
+
sub = launcher.bind(opts)
|
42 |
+
sub.bind_(metrics_opts)
|
43 |
+
|
44 |
+
# base objective metrics
|
45 |
+
sub(opt1, opt2)
|
46 |
+
|
47 |
+
|
48 |
+
@GenerationEvalExplorer
|
49 |
+
def explorer(launcher):
|
50 |
+
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
51 |
+
launcher.slurm_(gpus=4, partition=partitions)
|
52 |
+
|
53 |
+
if 'REGEN' not in os.environ:
|
54 |
+
folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1]
|
55 |
+
with launcher.job_array():
|
56 |
+
for sig in folder.iterdir():
|
57 |
+
if not sig.is_symlink():
|
58 |
+
continue
|
59 |
+
xp = train.main.get_xp_from_sig(sig.name)
|
60 |
+
launcher(xp.argv)
|
61 |
+
return
|
62 |
+
|
63 |
+
audiogen_base = launcher.bind(solver="audiogen/audiogen_base_16khz")
|
64 |
+
audiogen_base.bind_({'autocast': False, 'fsdp.use': True})
|
65 |
+
|
66 |
+
audiogen_base_medium = audiogen_base.bind({'continue_from': '//pretrained/facebook/audiogen-medium'})
|
67 |
+
audiogen_base_medium.bind_({'model/lm/model_scale': 'medium'})
|
68 |
+
eval(audiogen_base_medium, batch_size=128)
|
audiocraft/grids/compression/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""EnCodec grids."""
|
audiocraft/grids/compression/_explorers.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import treetable as tt
|
8 |
+
|
9 |
+
from .._base_explorers import BaseExplorer
|
10 |
+
|
11 |
+
|
12 |
+
class CompressionExplorer(BaseExplorer):
|
13 |
+
eval_metrics = ["sisnr", "visqol"]
|
14 |
+
|
15 |
+
def stages(self):
|
16 |
+
return ["train", "valid", "evaluate"]
|
17 |
+
|
18 |
+
def get_grid_meta(self):
|
19 |
+
"""Returns the list of Meta information to display for each XP/job.
|
20 |
+
"""
|
21 |
+
return [
|
22 |
+
tt.leaf("index", align=">"),
|
23 |
+
tt.leaf("name", wrap=140),
|
24 |
+
tt.leaf("state"),
|
25 |
+
tt.leaf("sig", align=">"),
|
26 |
+
]
|
27 |
+
|
28 |
+
def get_grid_metrics(self):
|
29 |
+
"""Return the metrics that should be displayed in the tracking table.
|
30 |
+
"""
|
31 |
+
return [
|
32 |
+
tt.group(
|
33 |
+
"train",
|
34 |
+
[
|
35 |
+
tt.leaf("epoch"),
|
36 |
+
tt.leaf("bandwidth", ".2f"),
|
37 |
+
tt.leaf("adv", ".4f"),
|
38 |
+
tt.leaf("d_loss", ".4f"),
|
39 |
+
],
|
40 |
+
align=">",
|
41 |
+
),
|
42 |
+
tt.group(
|
43 |
+
"valid",
|
44 |
+
[
|
45 |
+
tt.leaf("bandwidth", ".2f"),
|
46 |
+
tt.leaf("adv", ".4f"),
|
47 |
+
tt.leaf("msspec", ".4f"),
|
48 |
+
tt.leaf("sisnr", ".2f"),
|
49 |
+
],
|
50 |
+
align=">",
|
51 |
+
),
|
52 |
+
tt.group(
|
53 |
+
"evaluate", [tt.leaf(name, ".3f") for name in self.eval_metrics], align=">"
|
54 |
+
),
|
55 |
+
]
|
audiocraft/grids/compression/debug.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Grid search file, simply list all the exp you want in `explorer`.
|
9 |
+
Any new exp added there will be scheduled.
|
10 |
+
You can cancel and experiment by commenting its line.
|
11 |
+
|
12 |
+
This grid is a minimal example for debugging compression task
|
13 |
+
and how to override parameters directly in a grid.
|
14 |
+
Learn more about dora grids: https://github.com/facebookresearch/dora
|
15 |
+
"""
|
16 |
+
|
17 |
+
from ._explorers import CompressionExplorer
|
18 |
+
from ...environment import AudioCraftEnvironment
|
19 |
+
|
20 |
+
|
21 |
+
@CompressionExplorer
|
22 |
+
def explorer(launcher):
|
23 |
+
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
24 |
+
launcher.slurm_(gpus=2, partition=partitions)
|
25 |
+
launcher.bind_(solver='compression/debug')
|
26 |
+
|
27 |
+
with launcher.job_array():
|
28 |
+
# base debug task using config from solver=compression/debug
|
29 |
+
launcher()
|
30 |
+
# we can override parameters in the grid to launch additional xps
|
31 |
+
launcher({'rvq.bins': 2048, 'rvq.n_q': 4})
|
audiocraft/grids/compression/encodec_audiogen_16khz.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Grid search file, simply list all the exp you want in `explorer`.
|
9 |
+
Any new exp added there will be scheduled.
|
10 |
+
You can cancel and experiment by commenting its line.
|
11 |
+
|
12 |
+
This grid shows how to train the new AudioGen EnCodec model at 16 kHz.
|
13 |
+
"""
|
14 |
+
|
15 |
+
from ._explorers import CompressionExplorer
|
16 |
+
from ...environment import AudioCraftEnvironment
|
17 |
+
|
18 |
+
|
19 |
+
@CompressionExplorer
|
20 |
+
def explorer(launcher):
|
21 |
+
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
22 |
+
launcher.slurm_(gpus=8, partition=partitions)
|
23 |
+
# use configuration for AudioGen's EnCodec model trained on monophonic audio sampled at 16 kHz
|
24 |
+
# AudioGen's EnCodec is trained with a total stride of 320 leading to a frame rate of 50 hz
|
25 |
+
launcher.bind_(solver='compression/encodec_audiogen_16khz')
|
26 |
+
# replace this by the desired sound dataset
|
27 |
+
launcher.bind_(dset='internal/sounds_16khz')
|
28 |
+
# launch xp
|
29 |
+
launcher()
|
audiocraft/grids/compression/encodec_base_24khz.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Grid search file, simply list all the exp you want in `explorer`.
|
9 |
+
Any new exp added there will be scheduled.
|
10 |
+
You can cancel and experiment by commenting its line.
|
11 |
+
|
12 |
+
This grid shows how to train a base causal EnCodec model at 24 kHz.
|
13 |
+
"""
|
14 |
+
|
15 |
+
from ._explorers import CompressionExplorer
|
16 |
+
from ...environment import AudioCraftEnvironment
|
17 |
+
|
18 |
+
|
19 |
+
@CompressionExplorer
|
20 |
+
def explorer(launcher):
|
21 |
+
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
22 |
+
launcher.slurm_(gpus=8, partition=partitions)
|
23 |
+
# base causal EnCodec trained on monophonic audio sampled at 24 kHz
|
24 |
+
launcher.bind_(solver='compression/encodec_base_24khz')
|
25 |
+
# replace this by the desired dataset
|
26 |
+
launcher.bind_(dset='audio/example')
|
27 |
+
# launch xp
|
28 |
+
launcher()
|
audiocraft/grids/compression/encodec_musicgen_32khz.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Grid search file, simply list all the exp you want in `explorer`.
|
9 |
+
Any new exp added there will be scheduled.
|
10 |
+
You can cancel and experiment by commenting its line.
|
11 |
+
|
12 |
+
This grid shows how to train a MusicGen EnCodec model at 32 kHz.
|
13 |
+
"""
|
14 |
+
|
15 |
+
from ._explorers import CompressionExplorer
|
16 |
+
from ...environment import AudioCraftEnvironment
|
17 |
+
|
18 |
+
|
19 |
+
@CompressionExplorer
|
20 |
+
def explorer(launcher):
|
21 |
+
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
22 |
+
launcher.slurm_(gpus=8, partition=partitions)
|
23 |
+
# use configuration for MusicGen's EnCodec model trained on monophonic audio sampled at 32 kHz
|
24 |
+
# MusicGen's EnCodec is trained with a total stride of 640 leading to a frame rate of 50 hz
|
25 |
+
launcher.bind_(solver='compression/encodec_musicgen_32khz')
|
26 |
+
# replace this by the desired music dataset
|
27 |
+
launcher.bind_(dset='internal/music_400k_32khz')
|
28 |
+
# launch xp
|
29 |
+
launcher()
|
30 |
+
launcher({
|
31 |
+
'metrics.visqol.bin': '/data/home/jadecopet/local/usr/opt/visqol',
|
32 |
+
'label': 'visqol',
|
33 |
+
'evaluate.metrics.visqol': True
|
34 |
+
})
|
audiocraft/grids/diffusion/4_bands_base_32khz.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Training of the 4 diffusion models described in
|
9 |
+
"From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion"
|
10 |
+
(paper link).
|
11 |
+
"""
|
12 |
+
|
13 |
+
from ._explorers import DiffusionExplorer
|
14 |
+
|
15 |
+
|
16 |
+
@DiffusionExplorer
|
17 |
+
def explorer(launcher):
|
18 |
+
launcher.slurm_(gpus=4, partition='learnfair')
|
19 |
+
|
20 |
+
launcher.bind_({'solver': 'diffusion/default',
|
21 |
+
'dset': 'internal/music_10k_32khz'})
|
22 |
+
|
23 |
+
with launcher.job_array():
|
24 |
+
launcher({'filter.use': True, 'filter.idx_band': 0, "processor.use": False, 'processor.power_std': 0.4})
|
25 |
+
launcher({'filter.use': True, 'filter.idx_band': 1, "processor.use": False, 'processor.power_std': 0.4})
|
26 |
+
launcher({'filter.use': True, 'filter.idx_band': 2, "processor.use": True, 'processor.power_std': 0.4})
|
27 |
+
launcher({'filter.use': True, 'filter.idx_band': 3, "processor.use": True, 'processor.power_std': 0.75})
|
audiocraft/grids/diffusion/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Diffusion grids."""
|
audiocraft/grids/diffusion/_explorers.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import treetable as tt
|
8 |
+
|
9 |
+
from .._base_explorers import BaseExplorer
|
10 |
+
|
11 |
+
|
12 |
+
class DiffusionExplorer(BaseExplorer):
|
13 |
+
eval_metrics = ["sisnr", "visqol"]
|
14 |
+
|
15 |
+
def stages(self):
|
16 |
+
return ["train", "valid", "valid_ema", "evaluate", "evaluate_ema"]
|
17 |
+
|
18 |
+
def get_grid_meta(self):
|
19 |
+
"""Returns the list of Meta information to display for each XP/job.
|
20 |
+
"""
|
21 |
+
return [
|
22 |
+
tt.leaf("index", align=">"),
|
23 |
+
tt.leaf("name", wrap=140),
|
24 |
+
tt.leaf("state"),
|
25 |
+
tt.leaf("sig", align=">"),
|
26 |
+
]
|
27 |
+
|
28 |
+
def get_grid_metrics(self):
|
29 |
+
"""Return the metrics that should be displayed in the tracking table.
|
30 |
+
"""
|
31 |
+
return [
|
32 |
+
tt.group(
|
33 |
+
"train",
|
34 |
+
[
|
35 |
+
tt.leaf("epoch"),
|
36 |
+
tt.leaf("loss", ".3%"),
|
37 |
+
],
|
38 |
+
align=">",
|
39 |
+
),
|
40 |
+
tt.group(
|
41 |
+
"valid",
|
42 |
+
[
|
43 |
+
tt.leaf("loss", ".3%"),
|
44 |
+
# tt.leaf("loss_0", ".3%"),
|
45 |
+
],
|
46 |
+
align=">",
|
47 |
+
),
|
48 |
+
tt.group(
|
49 |
+
"valid_ema",
|
50 |
+
[
|
51 |
+
tt.leaf("loss", ".3%"),
|
52 |
+
# tt.leaf("loss_0", ".3%"),
|
53 |
+
],
|
54 |
+
align=">",
|
55 |
+
),
|
56 |
+
tt.group(
|
57 |
+
"evaluate", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"),
|
58 |
+
tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"),
|
59 |
+
tt.leaf("rvm_3", ".4f"), ], align=">"
|
60 |
+
),
|
61 |
+
tt.group(
|
62 |
+
"evaluate_ema", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"),
|
63 |
+
tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"),
|
64 |
+
tt.leaf("rvm_3", ".4f")], align=">"
|
65 |
+
),
|
66 |
+
]
|
audiocraft/grids/musicgen/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""MusicGen grids."""
|
audiocraft/grids/musicgen/_explorers.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
import treetable as tt
|
10 |
+
|
11 |
+
from .._base_explorers import BaseExplorer
|
12 |
+
|
13 |
+
|
14 |
+
class LMExplorer(BaseExplorer):
|
15 |
+
eval_metrics: tp.List[str] = []
|
16 |
+
|
17 |
+
def stages(self) -> tp.List[str]:
|
18 |
+
return ['train', 'valid']
|
19 |
+
|
20 |
+
def get_grid_metrics(self):
|
21 |
+
"""Return the metrics that should be displayed in the tracking table."""
|
22 |
+
return [
|
23 |
+
tt.group(
|
24 |
+
'train',
|
25 |
+
[
|
26 |
+
tt.leaf('epoch'),
|
27 |
+
tt.leaf('duration', '.1f'), # duration in minutes
|
28 |
+
tt.leaf('ping'),
|
29 |
+
tt.leaf('ce', '.4f'), # cross entropy
|
30 |
+
tt.leaf("ppl", '.3f'), # perplexity
|
31 |
+
],
|
32 |
+
align='>',
|
33 |
+
),
|
34 |
+
tt.group(
|
35 |
+
'valid',
|
36 |
+
[
|
37 |
+
tt.leaf('ce', '.4f'),
|
38 |
+
tt.leaf('ppl', '.3f'),
|
39 |
+
tt.leaf('best_ppl', '.3f'),
|
40 |
+
],
|
41 |
+
align='>',
|
42 |
+
),
|
43 |
+
]
|
44 |
+
|
45 |
+
def process_sheep(self, sheep, history):
|
46 |
+
parts = super().process_sheep(sheep, history)
|
47 |
+
|
48 |
+
track_by = {'ppl': 'lower'} # values should be in ['lower', 'higher']
|
49 |
+
best_metrics = {k: (1 if v == 'lower' else -1) * float('inf') for k, v in track_by.items()}
|
50 |
+
|
51 |
+
def comparator(mode, a, b):
|
52 |
+
return a < b if mode == 'lower' else a > b
|
53 |
+
|
54 |
+
for metrics in history:
|
55 |
+
for key, sub in metrics.items():
|
56 |
+
for metric in track_by:
|
57 |
+
# for the validation set, keep track of best metrics (ppl in this example)
|
58 |
+
# this is so we can conveniently compare metrics between runs in the grid
|
59 |
+
if key == 'valid' and metric in sub and comparator(
|
60 |
+
track_by[metric], sub[metric], best_metrics[metric]
|
61 |
+
):
|
62 |
+
best_metrics[metric] = sub[metric]
|
63 |
+
|
64 |
+
if 'valid' in parts:
|
65 |
+
parts['valid'].update({f'best_{k}': v for k, v in best_metrics.items()})
|
66 |
+
return parts
|
67 |
+
|
68 |
+
|
69 |
+
class GenerationEvalExplorer(BaseExplorer):
|
70 |
+
eval_metrics: tp.List[str] = []
|
71 |
+
|
72 |
+
def stages(self) -> tp.List[str]:
|
73 |
+
return ['evaluate']
|
74 |
+
|
75 |
+
def get_grid_metrics(self):
|
76 |
+
"""Return the metrics that should be displayed in the tracking table."""
|
77 |
+
return [
|
78 |
+
tt.group(
|
79 |
+
'evaluate',
|
80 |
+
[
|
81 |
+
tt.leaf('epoch', '.3f'),
|
82 |
+
tt.leaf('duration', '.1f'),
|
83 |
+
tt.leaf('ping'),
|
84 |
+
tt.leaf('ce', '.4f'),
|
85 |
+
tt.leaf('ppl', '.3f'),
|
86 |
+
tt.leaf('fad', '.3f'),
|
87 |
+
tt.leaf('kld', '.3f'),
|
88 |
+
tt.leaf('text_consistency', '.3f'),
|
89 |
+
tt.leaf('chroma_cosine', '.3f'),
|
90 |
+
],
|
91 |
+
align='>',
|
92 |
+
),
|
93 |
+
]
|
audiocraft/grids/musicgen/musicgen_base_32khz.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from ._explorers import LMExplorer
|
8 |
+
from ...environment import AudioCraftEnvironment
|
9 |
+
|
10 |
+
|
11 |
+
@LMExplorer
|
12 |
+
def explorer(launcher):
|
13 |
+
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
14 |
+
launcher.slurm_(gpus=32, partition=partitions)
|
15 |
+
launcher.bind_(solver='musicgen/musicgen_base_32khz')
|
16 |
+
# replace this by the desired music dataset
|
17 |
+
launcher.bind_(dset='internal/music_400k_32khz')
|
18 |
+
|
19 |
+
fsdp = {'autocast': False, 'fsdp.use': True}
|
20 |
+
medium = {'model/lm/model_scale': 'medium'}
|
21 |
+
large = {'model/lm/model_scale': 'large'}
|
22 |
+
|
23 |
+
cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
|
24 |
+
wd_low = {'conditioners.description.t5.word_dropout': 0.2}
|
25 |
+
|
26 |
+
adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
|
27 |
+
|
28 |
+
launcher.bind_(fsdp)
|
29 |
+
|
30 |
+
launcher.slurm_(gpus=32).bind_(label='32gpus')
|
31 |
+
with launcher.job_array():
|
32 |
+
sub = launcher.bind()
|
33 |
+
sub()
|
34 |
+
|
35 |
+
launcher.slurm_(gpus=64).bind_(label='64gpus')
|
36 |
+
with launcher.job_array():
|
37 |
+
sub = launcher.bind()
|
38 |
+
sub(medium, adam)
|
39 |
+
|
40 |
+
launcher.slurm_(gpus=96).bind_(label='96gpus')
|
41 |
+
with launcher.job_array():
|
42 |
+
sub = launcher.bind()
|
43 |
+
sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
|
audiocraft/grids/musicgen/musicgen_base_cached_32khz.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from ._explorers import LMExplorer
|
8 |
+
from ...environment import AudioCraftEnvironment
|
9 |
+
|
10 |
+
|
11 |
+
@LMExplorer
|
12 |
+
def explorer(launcher):
|
13 |
+
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
14 |
+
launcher.slurm_(gpus=32, partition=partitions)
|
15 |
+
launcher.bind_(solver='musicgen/musicgen_base_32khz')
|
16 |
+
# replace this by the desired music dataset
|
17 |
+
launcher.bind_(dset='internal/music_400k_32khz')
|
18 |
+
|
19 |
+
fsdp = {'autocast': False, 'fsdp.use': True}
|
20 |
+
medium = {'model/lm/model_scale': 'medium'}
|
21 |
+
large = {'model/lm/model_scale': 'large'}
|
22 |
+
|
23 |
+
cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
|
24 |
+
wd_low = {'conditioners.description.t5.word_dropout': 0.2}
|
25 |
+
|
26 |
+
adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
|
27 |
+
|
28 |
+
# BEGINNING OF CACHE WRITING JOBS.
|
29 |
+
cache_write = {
|
30 |
+
'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k',
|
31 |
+
'cache.write': True,
|
32 |
+
'generate.every': 500,
|
33 |
+
'evaluate.every': 500,
|
34 |
+
'logging.log_updates': 50,
|
35 |
+
}
|
36 |
+
|
37 |
+
cache_sub = launcher.bind({'model/lm/model_scale': 'xsmall', 'conditioner': 'none'})
|
38 |
+
cache_sub.bind_({'deadlock.use': True})
|
39 |
+
cache_sub.slurm_(gpus=8)
|
40 |
+
with launcher.job_array():
|
41 |
+
num_shards = 10 # total number of jobs running in parallel.
|
42 |
+
for shard in range(0, num_shards):
|
43 |
+
launcher(cache_write, {'cache.write_num_shards': num_shards, 'cache.write_shard': shard})
|
44 |
+
|
45 |
+
# REMOVE THE FOLLOWING RETURN STATEMENT ONCE THE ABOVE JOBS ARE DONE,
|
46 |
+
# OR SUFFICIENTLY AHEAD.
|
47 |
+
return
|
48 |
+
|
49 |
+
cache = {
|
50 |
+
'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k',
|
51 |
+
}
|
52 |
+
launcher.bind_(fsdp, cache)
|
53 |
+
|
54 |
+
launcher.slurm_(gpus=32).bind_(label='32gpus')
|
55 |
+
with launcher.job_array():
|
56 |
+
sub = launcher.bind()
|
57 |
+
sub()
|
58 |
+
|
59 |
+
launcher.slurm_(gpus=64).bind_(label='64gpus')
|
60 |
+
with launcher.job_array():
|
61 |
+
sub = launcher.bind()
|
62 |
+
sub(medium, adam)
|
63 |
+
|
64 |
+
launcher.slurm_(gpus=96).bind_(label='96gpus')
|
65 |
+
with launcher.job_array():
|
66 |
+
sub = launcher.bind()
|
67 |
+
sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
|