Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .github/actions/audiocraft_build/action.yml +29 -0
- .github/workflows/audiocraft_docs.yml +32 -0
- .github/workflows/audiocraft_linter.yml +17 -0
- .github/workflows/audiocraft_tests.yml +17 -0
- .gitignore +55 -0
- CHANGELOG.md +23 -0
- CODE_OF_CONDUCT.md +80 -0
- CONTRIBUTING.md +35 -0
- LICENSE +21 -0
- LICENSE_weights +157 -0
- MANIFEST.in +8 -0
- MODEL_CARD.md +81 -0
- Makefile +21 -0
- Qma6mgf0vpW1.mp4 +0 -0
- README.md +125 -1
- app.py +367 -0
- assets/bach.mp3 +0 -0
- assets/bolero_ravel.mp3 +0 -0
- audiocraft/__init__.py +10 -0
- audiocraft/__pycache__/__init__.cpython-310.pyc +0 -0
- audiocraft/data/__init__.py +8 -0
- audiocraft/data/__pycache__/__init__.cpython-310.pyc +0 -0
- audiocraft/data/__pycache__/audio.cpython-310.pyc +0 -0
- audiocraft/data/__pycache__/audio_dataset.cpython-310.pyc +0 -0
- audiocraft/data/__pycache__/audio_utils.cpython-310.pyc +0 -0
- audiocraft/data/__pycache__/zip.cpython-310.pyc +0 -0
- audiocraft/data/audio.py +215 -0
- audiocraft/data/audio_dataset.py +525 -0
- audiocraft/data/audio_utils.py +174 -0
- audiocraft/data/zip.py +74 -0
- audiocraft/models/__init__.py +10 -0
- audiocraft/models/__pycache__/__init__.cpython-310.pyc +0 -0
- audiocraft/models/__pycache__/builders.cpython-310.pyc +0 -0
- audiocraft/models/__pycache__/encodec.cpython-310.pyc +0 -0
- audiocraft/models/__pycache__/lm.cpython-310.pyc +0 -0
- audiocraft/models/__pycache__/loaders.cpython-310.pyc +0 -0
- audiocraft/models/__pycache__/musicgen.cpython-310.pyc +0 -0
- audiocraft/models/builders.py +218 -0
- audiocraft/models/encodec.py +302 -0
- audiocraft/models/lm.py +527 -0
- audiocraft/models/loaders.py +90 -0
- audiocraft/models/musicgen.py +361 -0
- audiocraft/modules/__init__.py +20 -0
- audiocraft/modules/__pycache__/__init__.cpython-310.pyc +0 -0
- audiocraft/modules/__pycache__/activations.cpython-310.pyc +0 -0
- audiocraft/modules/__pycache__/codebooks_patterns.cpython-310.pyc +0 -0
- audiocraft/modules/__pycache__/conditioners.cpython-310.pyc +0 -0
- audiocraft/modules/__pycache__/conv.cpython-310.pyc +0 -0
- audiocraft/modules/__pycache__/lstm.cpython-310.pyc +0 -0
- audiocraft/modules/__pycache__/rope.cpython-310.pyc +0 -0
.github/actions/audiocraft_build/action.yml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: audiocraft_build
|
2 |
+
description: 'Build audiocraft env.'
|
3 |
+
runs:
|
4 |
+
using: "composite"
|
5 |
+
steps:
|
6 |
+
- uses: actions/setup-python@v2
|
7 |
+
with:
|
8 |
+
python-version: 3.8
|
9 |
+
- uses: actions/cache@v2
|
10 |
+
id: cache
|
11 |
+
with:
|
12 |
+
path: env
|
13 |
+
key: audiocraft_env-${{ hashFiles('**/requirements.txt') }}
|
14 |
+
|
15 |
+
- if: ${{ steps.cache.outputs.cache-hit != 'true' }}
|
16 |
+
name: Install dependencies
|
17 |
+
shell: bash
|
18 |
+
run: |
|
19 |
+
sudo apt-get update
|
20 |
+
sudo apt-get install libsndfile1-dev ffmpeg
|
21 |
+
python3 -m venv env
|
22 |
+
. env/bin/activate
|
23 |
+
python -m pip install --upgrade pip
|
24 |
+
pip install -e '.[dev]'
|
25 |
+
- name: System Dependencies
|
26 |
+
shell: bash
|
27 |
+
run: |
|
28 |
+
sudo apt-get update
|
29 |
+
sudo apt-get install libsndfile1-dev ffmpeg
|
.github/workflows/audiocraft_docs.yml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: audiocraft_docs
|
2 |
+
on:
|
3 |
+
push:
|
4 |
+
branches: [ main ]
|
5 |
+
|
6 |
+
jobs:
|
7 |
+
run_docs:
|
8 |
+
name: Run docs
|
9 |
+
runs-on: ubuntu-latest
|
10 |
+
steps:
|
11 |
+
- uses: actions/checkout@v2
|
12 |
+
- uses: ./.github/actions/audiocraft_build
|
13 |
+
- name: Config git
|
14 |
+
run: |
|
15 |
+
git config --global user.email "defossez@fb.com"
|
16 |
+
git config --global user.name "Alexandre Défossez (autodoc)"
|
17 |
+
|
18 |
+
- name: Reset branch
|
19 |
+
run: |
|
20 |
+
git branch -f gh-docs main
|
21 |
+
git checkout gh-docs
|
22 |
+
|
23 |
+
- name: Make docs
|
24 |
+
run: |
|
25 |
+
. env/bin/activate
|
26 |
+
make docs
|
27 |
+
git add -f docs
|
28 |
+
git commit -m docs
|
29 |
+
|
30 |
+
- name: Push branch
|
31 |
+
run: |
|
32 |
+
git push -f -u origin gh-docs
|
.github/workflows/audiocraft_linter.yml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: audiocraft_linter
|
2 |
+
on:
|
3 |
+
push:
|
4 |
+
branches: [ main ]
|
5 |
+
pull_request:
|
6 |
+
branches: [ main ]
|
7 |
+
|
8 |
+
jobs:
|
9 |
+
run_linter:
|
10 |
+
name: Run linter
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
steps:
|
13 |
+
- uses: actions/checkout@v2
|
14 |
+
- uses: ./.github/actions/audiocraft_build
|
15 |
+
- run: |
|
16 |
+
. env/bin/activate
|
17 |
+
make linter
|
.github/workflows/audiocraft_tests.yml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: audiocraft_tests
|
2 |
+
on:
|
3 |
+
push:
|
4 |
+
branches: [ main ]
|
5 |
+
pull_request:
|
6 |
+
branches: [ main ]
|
7 |
+
|
8 |
+
jobs:
|
9 |
+
run_tests:
|
10 |
+
name: Run tests
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
steps:
|
13 |
+
- uses: actions/checkout@v2
|
14 |
+
- uses: ./.github/actions/audiocraft_build
|
15 |
+
- run: |
|
16 |
+
. env/bin/activate
|
17 |
+
make tests
|
.gitignore
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# macOS dir files
|
10 |
+
.DS_Store
|
11 |
+
|
12 |
+
# Distribution / packaging
|
13 |
+
.Python
|
14 |
+
env/
|
15 |
+
build/
|
16 |
+
develop-eggs/
|
17 |
+
dist/
|
18 |
+
downloads/
|
19 |
+
eggs/
|
20 |
+
.eggs/
|
21 |
+
lib/
|
22 |
+
lib64/
|
23 |
+
parts/
|
24 |
+
sdist/
|
25 |
+
var/
|
26 |
+
wheels/
|
27 |
+
*.egg-info/
|
28 |
+
.installed.cfg
|
29 |
+
*.egg
|
30 |
+
.ipynb_checkpoints
|
31 |
+
|
32 |
+
# Tests and linter
|
33 |
+
.pytest_cache/
|
34 |
+
.mypy_cache/
|
35 |
+
.coverage
|
36 |
+
|
37 |
+
# docs
|
38 |
+
/docs
|
39 |
+
|
40 |
+
# dotenv
|
41 |
+
.env
|
42 |
+
.envrc
|
43 |
+
|
44 |
+
# virtualenv
|
45 |
+
.venv
|
46 |
+
venv/
|
47 |
+
ENV/
|
48 |
+
|
49 |
+
# personal notebooks & scripts
|
50 |
+
*/local_scripts
|
51 |
+
*/notes
|
52 |
+
.vscode/
|
53 |
+
/notebooks
|
54 |
+
/local_scripts
|
55 |
+
/notes
|
CHANGELOG.md
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Changelog
|
2 |
+
|
3 |
+
All notable changes to this project will be documented in this file.
|
4 |
+
|
5 |
+
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
|
6 |
+
|
7 |
+
## [0.0.2a] - TBD
|
8 |
+
|
9 |
+
Improved demo, fixed top p (thanks @jnordberg).
|
10 |
+
|
11 |
+
Compressor tanh on output to avoid clipping with some style (especially piano).
|
12 |
+
Now repeating the conditioning periodically if it is too short.
|
13 |
+
|
14 |
+
More options when launching Gradio app locally (thanks @ashleykleynhans).
|
15 |
+
|
16 |
+
Testing out PyTorch 2.0 memory efficient attention.
|
17 |
+
|
18 |
+
Added extended generation (infinite length) by slowly moving the windows.
|
19 |
+
Note that other implementations exist: https://github.com/camenduru/MusicGen-colab.
|
20 |
+
|
21 |
+
## [0.0.1] - 2023-06-09
|
22 |
+
|
23 |
+
Initial release, with model evaluation only.
|
CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code of Conduct
|
2 |
+
|
3 |
+
## Our Pledge
|
4 |
+
|
5 |
+
In the interest of fostering an open and welcoming environment, we as
|
6 |
+
contributors and maintainers pledge to make participation in our project and
|
7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
9 |
+
level of experience, education, socio-economic status, nationality, personal
|
10 |
+
appearance, race, religion, or sexual identity and orientation.
|
11 |
+
|
12 |
+
## Our Standards
|
13 |
+
|
14 |
+
Examples of behavior that contributes to creating a positive environment
|
15 |
+
include:
|
16 |
+
|
17 |
+
* Using welcoming and inclusive language
|
18 |
+
* Being respectful of differing viewpoints and experiences
|
19 |
+
* Gracefully accepting constructive criticism
|
20 |
+
* Focusing on what is best for the community
|
21 |
+
* Showing empathy towards other community members
|
22 |
+
|
23 |
+
Examples of unacceptable behavior by participants include:
|
24 |
+
|
25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
26 |
+
advances
|
27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
28 |
+
* Public or private harassment
|
29 |
+
* Publishing others' private information, such as a physical or electronic
|
30 |
+
address, without explicit permission
|
31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
32 |
+
professional setting
|
33 |
+
|
34 |
+
## Our Responsibilities
|
35 |
+
|
36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
38 |
+
response to any instances of unacceptable behavior.
|
39 |
+
|
40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
44 |
+
threatening, offensive, or harmful.
|
45 |
+
|
46 |
+
## Scope
|
47 |
+
|
48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
49 |
+
an individual is representing the project or its community in public spaces.
|
50 |
+
Examples of representing a project or community include using an official
|
51 |
+
project e-mail address, posting via an official social media account, or acting
|
52 |
+
as an appointed representative at an online or offline event. Representation of
|
53 |
+
a project may be further defined and clarified by project maintainers.
|
54 |
+
|
55 |
+
This Code of Conduct also applies outside the project spaces when there is a
|
56 |
+
reasonable belief that an individual's behavior may have a negative impact on
|
57 |
+
the project or its community.
|
58 |
+
|
59 |
+
## Enforcement
|
60 |
+
|
61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
62 |
+
reported by contacting the project team at <opensource-conduct@fb.com>. All
|
63 |
+
complaints will be reviewed and investigated and will result in a response that
|
64 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
65 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
66 |
+
Further details of specific enforcement policies may be posted separately.
|
67 |
+
|
68 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
69 |
+
faith may face temporary or permanent repercussions as determined by other
|
70 |
+
members of the project's leadership.
|
71 |
+
|
72 |
+
## Attribution
|
73 |
+
|
74 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
75 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
76 |
+
|
77 |
+
[homepage]: https://www.contributor-covenant.org
|
78 |
+
|
79 |
+
For answers to common questions about this code of conduct, see
|
80 |
+
https://www.contributor-covenant.org/faq
|
CONTRIBUTING.md
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributing to Audiocraft
|
2 |
+
|
3 |
+
We want to make contributing to this project as easy and transparent as
|
4 |
+
possible.
|
5 |
+
|
6 |
+
## Pull Requests
|
7 |
+
|
8 |
+
Audiocraft is the implementation of a research paper.
|
9 |
+
Therefore, we do not plan on accepting many pull requests for new features.
|
10 |
+
We certainly welcome them for bug fixes.
|
11 |
+
|
12 |
+
1. Fork the repo and create your branch from `main`.
|
13 |
+
2. If you've added code that should be tested, add tests.
|
14 |
+
3. If you've changed APIs, update the documentation.
|
15 |
+
4. Ensure the test suite passes.
|
16 |
+
5. Make sure your code lints.
|
17 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
18 |
+
|
19 |
+
## Contributor License Agreement ("CLA")
|
20 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
21 |
+
to do this once to work on any of Meta's open source projects.
|
22 |
+
|
23 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
24 |
+
|
25 |
+
## Issues
|
26 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
27 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
28 |
+
|
29 |
+
Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
30 |
+
disclosure of security bugs. In those cases, please go through the process
|
31 |
+
outlined on that page and do not file a public issue.
|
32 |
+
|
33 |
+
## License
|
34 |
+
By contributing to encodec, you agree that your contributions will be licensed
|
35 |
+
under the LICENSE file in the root directory of this source tree.
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
LICENSE_weights
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Attribution-NonCommercial-NoDerivatives 4.0 International
|
2 |
+
|
3 |
+
> *Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.*
|
4 |
+
>
|
5 |
+
> ### Using Creative Commons Public Licenses
|
6 |
+
>
|
7 |
+
> Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
|
8 |
+
>
|
9 |
+
> * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors).
|
10 |
+
>
|
11 |
+
> * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees).
|
12 |
+
|
13 |
+
## Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International Public License
|
14 |
+
|
15 |
+
By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
|
16 |
+
|
17 |
+
### Section 1 – Definitions.
|
18 |
+
|
19 |
+
a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
|
20 |
+
|
21 |
+
b. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
|
22 |
+
|
23 |
+
e. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
|
24 |
+
|
25 |
+
f. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
|
26 |
+
|
27 |
+
h. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
|
28 |
+
|
29 |
+
i. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
|
30 |
+
|
31 |
+
h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License.
|
32 |
+
|
33 |
+
i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
|
34 |
+
|
35 |
+
j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
|
36 |
+
|
37 |
+
k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
|
38 |
+
|
39 |
+
l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
|
40 |
+
|
41 |
+
### Section 2 – Scope.
|
42 |
+
|
43 |
+
a. ___License grant.___
|
44 |
+
|
45 |
+
1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
|
46 |
+
|
47 |
+
A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
|
48 |
+
|
49 |
+
B. produce and reproduce, but not Share, Adapted Material for NonCommercial purposes only.
|
50 |
+
|
51 |
+
2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
|
52 |
+
|
53 |
+
3. __Term.__ The term of this Public License is specified in Section 6(a).
|
54 |
+
|
55 |
+
4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
|
56 |
+
|
57 |
+
5. __Downstream recipients.__
|
58 |
+
|
59 |
+
A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
|
60 |
+
|
61 |
+
B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
|
62 |
+
|
63 |
+
6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
|
64 |
+
|
65 |
+
b. ___Other rights.___
|
66 |
+
|
67 |
+
1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
|
68 |
+
|
69 |
+
2. Patent and trademark rights are not licensed under this Public License.
|
70 |
+
|
71 |
+
3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
|
72 |
+
|
73 |
+
### Section 3 – License Conditions.
|
74 |
+
|
75 |
+
Your exercise of the Licensed Rights is expressly made subject to the following conditions.
|
76 |
+
|
77 |
+
a. ___Attribution.___
|
78 |
+
|
79 |
+
1. If You Share the Licensed Material, You must:
|
80 |
+
|
81 |
+
A. retain the following if it is supplied by the Licensor with the Licensed Material:
|
82 |
+
|
83 |
+
i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
|
84 |
+
|
85 |
+
ii. a copyright notice;
|
86 |
+
|
87 |
+
iii. a notice that refers to this Public License;
|
88 |
+
|
89 |
+
iv. a notice that refers to the disclaimer of warranties;
|
90 |
+
|
91 |
+
v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
|
92 |
+
|
93 |
+
B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
|
94 |
+
|
95 |
+
C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
|
96 |
+
|
97 |
+
For the avoidance of doubt, You do not have permission under this Public License to Share Adapted Material.
|
98 |
+
|
99 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
|
100 |
+
|
101 |
+
3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
|
102 |
+
|
103 |
+
### Section 4 – Sui Generis Database Rights.
|
104 |
+
|
105 |
+
Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
|
106 |
+
|
107 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only and provided You do not Share Adapted Material;
|
108 |
+
|
109 |
+
b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and
|
110 |
+
|
111 |
+
c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
|
112 |
+
|
113 |
+
For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
|
114 |
+
|
115 |
+
### Section 5 – Disclaimer of Warranties and Limitation of Liability.
|
116 |
+
|
117 |
+
a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__
|
118 |
+
|
119 |
+
b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__
|
120 |
+
|
121 |
+
c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
|
122 |
+
|
123 |
+
### Section 6 – Term and Termination.
|
124 |
+
|
125 |
+
a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
|
126 |
+
|
127 |
+
b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
|
128 |
+
|
129 |
+
1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
|
130 |
+
|
131 |
+
2. upon express reinstatement by the Licensor.
|
132 |
+
|
133 |
+
For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
|
134 |
+
|
135 |
+
c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
|
136 |
+
|
137 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
|
138 |
+
|
139 |
+
### Section 7 – Other Terms and Conditions.
|
140 |
+
|
141 |
+
a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
|
142 |
+
|
143 |
+
b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
|
144 |
+
|
145 |
+
### Section 8 – Interpretation.
|
146 |
+
|
147 |
+
a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
|
148 |
+
|
149 |
+
b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
|
150 |
+
|
151 |
+
c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
|
152 |
+
|
153 |
+
d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
|
154 |
+
|
155 |
+
> Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
|
156 |
+
>
|
157 |
+
> Creative Commons may be contacted at [creativecommons.org](http://creativecommons.org).
|
MANIFEST.in
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
include Makefile
|
2 |
+
include LICENSE
|
3 |
+
include LICENSE_weights
|
4 |
+
include *.md
|
5 |
+
include *.ini
|
6 |
+
include requirements.txt
|
7 |
+
include audiocraft/py.typed
|
8 |
+
include assets/*.mp3
|
MODEL_CARD.md
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MusicGen Model Card
|
2 |
+
|
3 |
+
## Model details
|
4 |
+
|
5 |
+
**Organization developing the model:** The FAIR team of Meta AI.
|
6 |
+
|
7 |
+
**Model date:** MusicGen was trained between April 2023 and May 2023.
|
8 |
+
|
9 |
+
**Model version:** This is the version 1 of the model.
|
10 |
+
|
11 |
+
**Model type:** MusicGen consists of an EnCodec model for audio tokenization, an auto-regressive language model based on the transformer architecture for music modeling. The model comes in different sizes: 300M, 1.5B and 3.3B parameters ; and two variants: a model trained for text-to-music generation task and a model trained for melody-guided music generation.
|
12 |
+
|
13 |
+
**Paper or resources for more information:** More information can be found in the paper [Simple and Controllable Music Generation][arxiv].
|
14 |
+
|
15 |
+
**Citation details** See [our paper][arxiv]
|
16 |
+
|
17 |
+
**License** Code is released under MIT, model weights are released under CC-BY-NC 4.0.
|
18 |
+
|
19 |
+
**Where to send questions or comments about the model:** Questions and comments about MusicGen can be sent via the [Github repository](https://github.com/facebookresearch/audiocraft) of the project, or by opening an issue.
|
20 |
+
|
21 |
+
## Intended use
|
22 |
+
**Primary intended use:** The primary use of MusicGen is research on AI-based music generation, including:
|
23 |
+
|
24 |
+
- Research efforts, such as probing and better understanding the limitations of generative models to further improve the state of science
|
25 |
+
- Generation of music guided by text or melody to understand current abilities of generative AI models by machine learning amateurs
|
26 |
+
|
27 |
+
**Primary intended users:** The primary intended users of the model are researchers in audio, machine learning and artificial intelligence, as well as amateur seeking to better understand those models.
|
28 |
+
|
29 |
+
**Out-of-scope use cases** The model should not be used on downstream applications without further risk evaluation and mitigation. The model should not be used to intentionally create or disseminate music pieces that create hostile or alienating environments for people. This includes generating music that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
|
30 |
+
|
31 |
+
## Metrics
|
32 |
+
|
33 |
+
**Models performance measures:** We used the following objective measure to evaluate the model on a standard music benchmark:
|
34 |
+
|
35 |
+
- Frechet Audio Distance computed on features extracted from a pre-trained audio classifier (VGGish)
|
36 |
+
- Kullback-Leibler Divergence on label distributions extracted from a pre-trained audio classifier (PaSST)
|
37 |
+
- CLAP Score between audio embedding and text embedding extracted from a pre-trained CLAP model
|
38 |
+
|
39 |
+
Additionally, we run qualitative studies with human participants, evaluating the performance of the model with the following axes:
|
40 |
+
|
41 |
+
- Overall quality of the music samples;
|
42 |
+
- Text relevance to the provided text input;
|
43 |
+
- Adherence to the melody for melody-guided music generation.
|
44 |
+
|
45 |
+
More details on performance measures and human studies can be found in the paper.
|
46 |
+
|
47 |
+
**Decision thresholds:** Not applicable.
|
48 |
+
|
49 |
+
## Evaluation datasets
|
50 |
+
|
51 |
+
The model was evaluated on the [MusicCaps benchmark](https://www.kaggle.com/datasets/googleai/musiccaps) and on an in-domain held-out evaluation set, with no artist overlap with the training set.
|
52 |
+
|
53 |
+
## Training datasets
|
54 |
+
|
55 |
+
The model was trained on licensed data using the following sources: the [Meta Music Initiative Sound Collection](https://www.fb.com/sound), [Shutterstock music collection](https://www.shutterstock.com/music) and the [Pond5 music collection](https://www.pond5.com/). See the paper for more details about the training set and corresponding preprocessing.
|
56 |
+
|
57 |
+
## Quantitative analysis
|
58 |
+
|
59 |
+
More information can be found in the paper [Simple and Controllable Music Generation][arxiv], in the Experimental Setup section.
|
60 |
+
|
61 |
+
## Limitations and biases
|
62 |
+
|
63 |
+
**Data:** The data sources used to train the model are created by music professionals and covered by legal agreements with the right holders. The model is trained on 20K hours of data, we believe that scaling the model on larger datasets can further improve the performance of the model.
|
64 |
+
|
65 |
+
**Mitigations:** Vocals have been removed from the data source using corresponding tags, and then using using a state-of-the-art music source separation method, namely using the open source [Hybrid Transformer for Music Source Separation](https://github.com/facebookresearch/demucs) (HT-Demucs).
|
66 |
+
|
67 |
+
**Limitations:**
|
68 |
+
|
69 |
+
- The model is not able to generate realistic vocals.
|
70 |
+
- The model has been trained with English descriptions and will not perform as well in other languages.
|
71 |
+
- The model does not perform equally well for all music styles and cultures.
|
72 |
+
- The model sometimes generates end of songs, collapsing to silence.
|
73 |
+
- It is sometimes difficult to assess what types of text descriptions provide the best generations. Prompt engineering may be required to obtain satisfying results.
|
74 |
+
|
75 |
+
**Biases:** The source of data is potentially lacking diversity and all music cultures are not equally represented in the dataset. The model may not perform equally well on the wide variety of music genres that exists. The generated samples from the model will reflect the biases from the training data. Further work on this model should include methods for balanced and just representations of cultures, for example, by scaling the training data to be both diverse and inclusive.
|
76 |
+
|
77 |
+
**Risks and harms:** Biases and limitations of the model may lead to generation of samples that may be considered as biased, inappropriate or offensive. We believe that providing the code to reproduce the research and train new models will allow to broaden the application to new and more representative data.
|
78 |
+
|
79 |
+
**Use cases:** Users must be aware of the biases, limitations and risks of the model. MusicGen is a model developed for artificial intelligence research on controllable music generation. As such, it should not be used for downstream applications without further investigation and mitigation of risks.
|
80 |
+
|
81 |
+
[arxiv]: https://arxiv.org/abs/2306.05284
|
Makefile
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
default: linter tests
|
2 |
+
|
3 |
+
install:
|
4 |
+
pip install -U pip
|
5 |
+
pip install -U -e '.[dev]'
|
6 |
+
|
7 |
+
linter:
|
8 |
+
flake8 audiocraft && mypy audiocraft
|
9 |
+
flake8 tests && mypy tests
|
10 |
+
|
11 |
+
tests:
|
12 |
+
coverage run -m pytest tests
|
13 |
+
coverage report --include 'audiocraft/*'
|
14 |
+
|
15 |
+
docs:
|
16 |
+
pdoc3 --html -o docs -f audiocraft
|
17 |
+
|
18 |
+
dist:
|
19 |
+
python setup.py sdist
|
20 |
+
|
21 |
+
.PHONY: linter tests docs dist
|
Qma6mgf0vpW1.mp4
ADDED
Binary file (174 kB). View file
|
|
README.md
CHANGED
@@ -1,6 +1,130 @@
|
|
1 |
---
|
2 |
title: music
|
3 |
-
app_file:
|
4 |
sdk: gradio
|
5 |
sdk_version: 3.39.0
|
6 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
title: music
|
3 |
+
app_file: app.py
|
4 |
sdk: gradio
|
5 |
sdk_version: 3.39.0
|
6 |
---
|
7 |
+
# Audiocraft
|
8 |
+
![docs badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_docs/badge.svg)
|
9 |
+
![linter badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_linter/badge.svg)
|
10 |
+
![tests badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_tests/badge.svg)
|
11 |
+
|
12 |
+
Audiocraft is a PyTorch library for deep learning research on audio generation. At the moment, it contains the code for MusicGen, a state-of-the-art controllable text-to-music model.
|
13 |
+
|
14 |
+
## MusicGen
|
15 |
+
|
16 |
+
Audiocraft provides the code and models for MusicGen, [a simple and controllable model for music generation][arxiv]. MusicGen is a single stage auto-regressive
|
17 |
+
Transformer model trained over a 32kHz <a href="https://github.com/facebookresearch/encodec">EnCodec tokenizer</a> with 4 codebooks sampled at 50 Hz. Unlike existing methods like [MusicLM](https://arxiv.org/abs/2301.11325), MusicGen doesn't require a self-supervised semantic representation, and it generates
|
18 |
+
all 4 codebooks in one pass. By introducing a small delay between the codebooks, we show we can predict
|
19 |
+
them in parallel, thus having only 50 auto-regressive steps per second of audio.
|
20 |
+
Check out our [sample page][musicgen_samples] or test the available demo!
|
21 |
+
|
22 |
+
<a target="_blank" href="https://colab.research.google.com/drive/1-Xe9NCdIs2sCUbiSmwHXozK6AAhMm7_i?usp=sharing">
|
23 |
+
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
24 |
+
</a>
|
25 |
+
<a target="_blank" href="https://huggingface.co/spaces/facebook/MusicGen">
|
26 |
+
<img src="https://huggingface.co/datasets/huggingface/badges/raw/main/open-in-hf-spaces-sm.svg" alt="Open in HugginFace"/>
|
27 |
+
</a>
|
28 |
+
<br>
|
29 |
+
|
30 |
+
We use 20K hours of licensed music to train MusicGen. Specifically, we rely on an internal dataset of 10K high-quality music tracks, and on the ShutterStock and Pond5 music data.
|
31 |
+
|
32 |
+
## Installation
|
33 |
+
Audiocraft requires Python 3.9, PyTorch 2.0.0, and a GPU with at least 16 GB of memory (for the medium-sized model). To install Audiocraft, you can run the following:
|
34 |
+
|
35 |
+
```shell
|
36 |
+
# Best to make sure you have torch installed first, in particular before installing xformers.
|
37 |
+
# Don't run this if you already have PyTorch installed.
|
38 |
+
pip install 'torch>=2.0'
|
39 |
+
# Then proceed to one of the following
|
40 |
+
pip install -U audiocraft # stable release
|
41 |
+
pip install -U git+https://git@github.com/facebookresearch/audiocraft#egg=audiocraft # bleeding edge
|
42 |
+
pip install -e . # or if you cloned the repo locally
|
43 |
+
```
|
44 |
+
|
45 |
+
## Usage
|
46 |
+
We offer a number of way to interact with MusicGen:
|
47 |
+
1. A demo is also available on the [`facebook/MusicGen` HuggingFace Space](https://huggingface.co/spaces/facebook/MusicGen) (huge thanks to all the HF team for their support).
|
48 |
+
2. You can run the extended demo on a Colab: [colab notebook](https://colab.research.google.com/drive/1fxGqfg96RBUvGxZ1XXN07s3DthrKUl4-?usp=sharing).
|
49 |
+
3. You can use the gradio demo locally by running `python app.py`.
|
50 |
+
4. You can play with MusicGen by running the jupyter notebook at [`demo.ipynb`](./demo.ipynb) locally (if you have a GPU).
|
51 |
+
5. Finally, checkout [@camenduru Colab page](https://github.com/camenduru/MusicGen-colab) which is regularly
|
52 |
+
updated with contributions from @camenduru and the community.
|
53 |
+
|
54 |
+
## API
|
55 |
+
|
56 |
+
We provide a simple API and 4 pre-trained models. The pre trained models are:
|
57 |
+
- `small`: 300M model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-small)
|
58 |
+
- `medium`: 1.5B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-medium)
|
59 |
+
- `melody`: 1.5B model, text to music and text+melody to music - [🤗 Hub](https://huggingface.co/facebook/musicgen-melody)
|
60 |
+
- `large`: 3.3B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-large)
|
61 |
+
|
62 |
+
We observe the best trade-off between quality and compute with the `medium` or `melody` model.
|
63 |
+
In order to use MusicGen locally **you must have a GPU**. We recommend 16GB of memory, but smaller
|
64 |
+
GPUs will be able to generate short sequences, or longer sequences with the `small` model.
|
65 |
+
|
66 |
+
**Note**: Please make sure to have [ffmpeg](https://ffmpeg.org/download.html) installed when using newer version of `torchaudio`.
|
67 |
+
You can install it with:
|
68 |
+
```
|
69 |
+
apt-get install ffmpeg
|
70 |
+
```
|
71 |
+
|
72 |
+
See after a quick example for using the API.
|
73 |
+
|
74 |
+
```python
|
75 |
+
import torchaudio
|
76 |
+
from audiocraft.models import MusicGen
|
77 |
+
from audiocraft.data.audio import audio_write
|
78 |
+
|
79 |
+
model = MusicGen.get_pretrained('melody')
|
80 |
+
model.set_generation_params(duration=8) # generate 8 seconds.
|
81 |
+
wav = model.generate_unconditional(4) # generates 4 unconditional audio samples
|
82 |
+
descriptions = ['happy rock', 'energetic EDM', 'sad jazz']
|
83 |
+
wav = model.generate(descriptions) # generates 3 samples.
|
84 |
+
|
85 |
+
melody, sr = torchaudio.load('./assets/bach.mp3')
|
86 |
+
# generates using the melody from the given audio and the provided descriptions.
|
87 |
+
wav = model.generate_with_chroma(descriptions, melody[None].expand(3, -1, -1), sr)
|
88 |
+
|
89 |
+
for idx, one_wav in enumerate(wav):
|
90 |
+
# Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
|
91 |
+
audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)
|
92 |
+
```
|
93 |
+
|
94 |
+
|
95 |
+
## Model Card
|
96 |
+
|
97 |
+
See [the model card page](./MODEL_CARD.md).
|
98 |
+
|
99 |
+
## FAQ
|
100 |
+
|
101 |
+
#### Will the training code be released?
|
102 |
+
|
103 |
+
Yes. We will soon release the training code for MusicGen and EnCodec.
|
104 |
+
|
105 |
+
|
106 |
+
#### I need help on Windows
|
107 |
+
|
108 |
+
@FurkanGozukara made a complete tutorial for [Audiocraft/MusicGen on Windows](https://youtu.be/v-YpvPkhdO4)
|
109 |
+
|
110 |
+
#### I need help for running the demo on Colab
|
111 |
+
|
112 |
+
Check [@camenduru tutorial on Youtube](https://www.youtube.com/watch?v=EGfxuTy9Eeo).
|
113 |
+
|
114 |
+
|
115 |
+
## Citation
|
116 |
+
```
|
117 |
+
@article{copet2023simple,
|
118 |
+
title={Simple and Controllable Music Generation},
|
119 |
+
author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez},
|
120 |
+
year={2023},
|
121 |
+
journal={arXiv preprint arXiv:2306.05284},
|
122 |
+
}
|
123 |
+
```
|
124 |
+
|
125 |
+
## License
|
126 |
+
* The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE).
|
127 |
+
* The weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights).
|
128 |
+
|
129 |
+
[arxiv]: https://arxiv.org/abs/2306.05284
|
130 |
+
[musicgen_samples]: https://ai.honu.io/papers/musicgen/
|
app.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Updated to account for UI changes from https://github.com/rkfg/audiocraft/blob/long/app.py
|
8 |
+
# also released under the MIT license.
|
9 |
+
|
10 |
+
import argparse
|
11 |
+
from concurrent.futures import ProcessPoolExecutor
|
12 |
+
import os
|
13 |
+
import subprocess as sp
|
14 |
+
from tempfile import NamedTemporaryFile
|
15 |
+
import time
|
16 |
+
import warnings
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import gradio as gr
|
20 |
+
|
21 |
+
from audiocraft.data.audio_utils import convert_audio
|
22 |
+
from audiocraft.data.audio import audio_write
|
23 |
+
from audiocraft.models import MusicGen
|
24 |
+
import subprocess, random, string
|
25 |
+
|
26 |
+
|
27 |
+
MODEL = None # Last used model
|
28 |
+
IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '')
|
29 |
+
MAX_BATCH_SIZE = 12
|
30 |
+
BATCHED_DURATION = 15
|
31 |
+
INTERRUPTING = False
|
32 |
+
# We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
|
33 |
+
_old_call = sp.call
|
34 |
+
|
35 |
+
def generate_random_string(length):
|
36 |
+
characters = string.ascii_letters + string.digits
|
37 |
+
return ''.join(random.choice(characters) for _ in range(length))
|
38 |
+
|
39 |
+
def resize_video(input_path, output_path, target_width, target_height):
|
40 |
+
ffmpeg_cmd = [
|
41 |
+
'ffmpeg',
|
42 |
+
'-y',
|
43 |
+
'-i', input_path,
|
44 |
+
'-vf', f'scale={target_width}:{target_height}',
|
45 |
+
'-c:a', 'copy',
|
46 |
+
output_path
|
47 |
+
]
|
48 |
+
subprocess.run(ffmpeg_cmd)
|
49 |
+
|
50 |
+
def _call_nostderr(*args, **kwargs):
|
51 |
+
# Avoid ffmpeg vomitting on the logs.
|
52 |
+
kwargs['stderr'] = sp.DEVNULL
|
53 |
+
kwargs['stdout'] = sp.DEVNULL
|
54 |
+
_old_call(*args, **kwargs)
|
55 |
+
|
56 |
+
|
57 |
+
sp.call = _call_nostderr
|
58 |
+
# Preallocating the pool of processes.
|
59 |
+
pool = ProcessPoolExecutor(4)
|
60 |
+
pool.__enter__()
|
61 |
+
|
62 |
+
|
63 |
+
def interrupt():
|
64 |
+
global INTERRUPTING
|
65 |
+
INTERRUPTING = True
|
66 |
+
|
67 |
+
|
68 |
+
def make_waveform(*args, **kwargs):
|
69 |
+
# Further remove some warnings.
|
70 |
+
be = time.time()
|
71 |
+
with warnings.catch_warnings():
|
72 |
+
warnings.simplefilter('ignore')
|
73 |
+
waveform_video = gr.make_waveform(*args, **kwargs)
|
74 |
+
out = f"{generate_random_string(12)}.mp4"
|
75 |
+
resize_video(waveform_video, out, 900, 300)
|
76 |
+
print("Make a video took", time.time() - be)
|
77 |
+
return out
|
78 |
+
|
79 |
+
|
80 |
+
def load_model(version='melody'):
|
81 |
+
global MODEL
|
82 |
+
print("Loading model", version)
|
83 |
+
if MODEL is None or MODEL.name != version:
|
84 |
+
MODEL = MusicGen.get_pretrained(version)
|
85 |
+
|
86 |
+
|
87 |
+
def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
|
88 |
+
MODEL.set_generation_params(duration=duration, **gen_kwargs)
|
89 |
+
print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
|
90 |
+
be = time.time()
|
91 |
+
processed_melodies = []
|
92 |
+
target_sr = 32000
|
93 |
+
target_ac = 1
|
94 |
+
for melody in melodies:
|
95 |
+
if melody is None:
|
96 |
+
processed_melodies.append(None)
|
97 |
+
else:
|
98 |
+
sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
|
99 |
+
if melody.dim() == 1:
|
100 |
+
melody = melody[None]
|
101 |
+
melody = melody[..., :int(sr * duration)]
|
102 |
+
melody = convert_audio(melody, sr, target_sr, target_ac)
|
103 |
+
processed_melodies.append(melody)
|
104 |
+
|
105 |
+
if any(m is not None for m in processed_melodies):
|
106 |
+
outputs = MODEL.generate_with_chroma(
|
107 |
+
descriptions=texts,
|
108 |
+
melody_wavs=processed_melodies,
|
109 |
+
melody_sample_rate=target_sr,
|
110 |
+
progress=progress,
|
111 |
+
)
|
112 |
+
else:
|
113 |
+
outputs = MODEL.generate(texts, progress=progress)
|
114 |
+
|
115 |
+
outputs = outputs.detach().cpu().float()
|
116 |
+
out_files = []
|
117 |
+
for output in outputs:
|
118 |
+
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
119 |
+
audio_write(
|
120 |
+
file.name, output, MODEL.sample_rate, strategy="loudness",
|
121 |
+
loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
|
122 |
+
out_files.append(pool.submit(make_waveform, file.name, bg_color="#21b0fe" , bars_color=('#fe218b', '#fed700'), fg_alpha=1.0, bar_count=75))
|
123 |
+
res = [out_file.result() for out_file in out_files]
|
124 |
+
print("batch finished", len(texts), time.time() - be)
|
125 |
+
return res
|
126 |
+
|
127 |
+
|
128 |
+
def predict_batched(texts, melodies):
|
129 |
+
max_text_length = 512
|
130 |
+
texts = [text[:max_text_length] for text in texts]
|
131 |
+
load_model('melody')
|
132 |
+
res = _do_predictions(texts, melodies, BATCHED_DURATION)
|
133 |
+
return [res]
|
134 |
+
|
135 |
+
|
136 |
+
def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()):
|
137 |
+
global INTERRUPTING
|
138 |
+
INTERRUPTING = False
|
139 |
+
if temperature < 0:
|
140 |
+
raise gr.Error("Temperature must be >= 0.")
|
141 |
+
if topk < 0:
|
142 |
+
raise gr.Error("Topk must be non-negative.")
|
143 |
+
if topp < 0:
|
144 |
+
raise gr.Error("Topp must be non-negative.")
|
145 |
+
|
146 |
+
topk = int(topk)
|
147 |
+
load_model(model)
|
148 |
+
|
149 |
+
def _progress(generated, to_generate):
|
150 |
+
progress((generated, to_generate))
|
151 |
+
if INTERRUPTING:
|
152 |
+
raise gr.Error("Interrupted.")
|
153 |
+
MODEL.set_custom_progress_callback(_progress)
|
154 |
+
|
155 |
+
outs = _do_predictions(
|
156 |
+
[text], [melody], duration, progress=True,
|
157 |
+
top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
|
158 |
+
return outs[0]
|
159 |
+
|
160 |
+
|
161 |
+
def ui_full(launch_kwargs):
|
162 |
+
with gr.Blocks() as interface:
|
163 |
+
gr.Markdown(
|
164 |
+
"""
|
165 |
+
# MusicGen
|
166 |
+
This is your private demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
|
167 |
+
presented at: ["Simple and Controllable Music Generation"](https://arxiv.org/abs/2306.05284)
|
168 |
+
"""
|
169 |
+
)
|
170 |
+
with gr.Row():
|
171 |
+
with gr.Column():
|
172 |
+
with gr.Row():
|
173 |
+
text = gr.Text(label="Input Text", interactive=True)
|
174 |
+
melody = gr.Audio(source="upload", type="numpy", label="Melody Condition (optional)", interactive=True)
|
175 |
+
with gr.Row():
|
176 |
+
submit = gr.Button("Submit")
|
177 |
+
# Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
|
178 |
+
_ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
|
179 |
+
with gr.Row():
|
180 |
+
model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
|
181 |
+
with gr.Row():
|
182 |
+
duration = gr.Slider(minimum=1, maximum=120, value=10, label="Duration", interactive=True)
|
183 |
+
with gr.Row():
|
184 |
+
topk = gr.Number(label="Top-k", value=250, interactive=True)
|
185 |
+
topp = gr.Number(label="Top-p", value=0, interactive=True)
|
186 |
+
temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
|
187 |
+
cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
|
188 |
+
with gr.Column():
|
189 |
+
output = gr.Video(label="Generated Music")
|
190 |
+
submit.click(predict_full, inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef], outputs=[output])
|
191 |
+
gr.Examples(
|
192 |
+
fn=predict_full,
|
193 |
+
examples=[
|
194 |
+
[
|
195 |
+
"An 80s driving pop song with heavy drums and synth pads in the background",
|
196 |
+
"./assets/bach.mp3",
|
197 |
+
"melody"
|
198 |
+
],
|
199 |
+
[
|
200 |
+
"A cheerful country song with acoustic guitars",
|
201 |
+
"./assets/bolero_ravel.mp3",
|
202 |
+
"melody"
|
203 |
+
],
|
204 |
+
[
|
205 |
+
"90s rock song with electric guitar and heavy drums",
|
206 |
+
None,
|
207 |
+
"medium"
|
208 |
+
],
|
209 |
+
[
|
210 |
+
"a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
|
211 |
+
"./assets/bach.mp3",
|
212 |
+
"melody"
|
213 |
+
],
|
214 |
+
[
|
215 |
+
"lofi slow bpm electro chill with organic samples",
|
216 |
+
None,
|
217 |
+
"medium",
|
218 |
+
],
|
219 |
+
],
|
220 |
+
inputs=[text, melody, model],
|
221 |
+
outputs=[output]
|
222 |
+
)
|
223 |
+
gr.Markdown(
|
224 |
+
"""
|
225 |
+
### More details
|
226 |
+
|
227 |
+
The model will generate a short music extract based on the description you provided.
|
228 |
+
The model can generate up to 30 seconds of audio in one pass. It is now possible
|
229 |
+
to extend the generation by feeding back the end of the previous chunk of audio.
|
230 |
+
This can take a long time, and the model might lose consistency. The model might also
|
231 |
+
decide at arbitrary positions that the song ends.
|
232 |
+
|
233 |
+
**WARNING:** Choosing long durations will take a long time to generate (2min might take ~10min). An overlap of 12 seconds
|
234 |
+
is kept with the previously generated chunk, and 18 "new" seconds are generated each time.
|
235 |
+
|
236 |
+
We present 4 model variations:
|
237 |
+
1. Melody -- a music generation model capable of generating music condition on text and melody inputs. **Note**, you can also use text only.
|
238 |
+
2. Small -- a 300M transformer decoder conditioned on text only.
|
239 |
+
3. Medium -- a 1.5B transformer decoder conditioned on text only.
|
240 |
+
4. Large -- a 3.3B transformer decoder conditioned on text only (might OOM for the longest sequences.)
|
241 |
+
|
242 |
+
When using `melody`, ou can optionaly provide a reference audio from
|
243 |
+
which a broad melody will be extracted. The model will then try to follow both the description and melody provided.
|
244 |
+
|
245 |
+
You can also use your own GPU or a Google Colab by following the instructions on our repo.
|
246 |
+
See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
|
247 |
+
for more details.
|
248 |
+
"""
|
249 |
+
)
|
250 |
+
|
251 |
+
interface.queue().launch(**launch_kwargs)
|
252 |
+
|
253 |
+
|
254 |
+
def ui_batched(launch_kwargs):
|
255 |
+
with gr.Blocks() as demo:
|
256 |
+
gr.Markdown(
|
257 |
+
"""
|
258 |
+
# MusicGen
|
259 |
+
|
260 |
+
This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
|
261 |
+
presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
|
262 |
+
<br/>
|
263 |
+
<a href="https://huggingface.co/spaces/facebook/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
|
264 |
+
<img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
|
265 |
+
for longer sequences, more control and no queue.</p>
|
266 |
+
"""
|
267 |
+
)
|
268 |
+
with gr.Row():
|
269 |
+
with gr.Column():
|
270 |
+
with gr.Row():
|
271 |
+
text = gr.Text(label="Describe your music", lines=2, interactive=True)
|
272 |
+
melody = gr.Audio(source="upload", type="numpy", label="Condition on a melody (optional)", interactive=True)
|
273 |
+
with gr.Row():
|
274 |
+
submit = gr.Button("Generate")
|
275 |
+
with gr.Column():
|
276 |
+
output = gr.Video(label="Generated Music")
|
277 |
+
submit.click(predict_batched, inputs=[text, melody], outputs=[output], batch=True, max_batch_size=MAX_BATCH_SIZE)
|
278 |
+
gr.Examples(
|
279 |
+
fn=predict_batched,
|
280 |
+
examples=[
|
281 |
+
[
|
282 |
+
"An 80s driving pop song with heavy drums and synth pads in the background",
|
283 |
+
"./assets/bach.mp3",
|
284 |
+
],
|
285 |
+
[
|
286 |
+
"A cheerful country song with acoustic guitars",
|
287 |
+
"./assets/bolero_ravel.mp3",
|
288 |
+
],
|
289 |
+
[
|
290 |
+
"90s rock song with electric guitar and heavy drums",
|
291 |
+
None,
|
292 |
+
],
|
293 |
+
[
|
294 |
+
"a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
|
295 |
+
"./assets/bach.mp3",
|
296 |
+
],
|
297 |
+
[
|
298 |
+
"lofi slow bpm electro chill with organic samples",
|
299 |
+
None,
|
300 |
+
],
|
301 |
+
],
|
302 |
+
inputs=[text, melody],
|
303 |
+
outputs=[output]
|
304 |
+
)
|
305 |
+
gr.Markdown("""
|
306 |
+
### More details
|
307 |
+
|
308 |
+
The model will generate 12 seconds of audio based on the description you provided.
|
309 |
+
You can optionaly provide a reference audio from which a broad melody will be extracted.
|
310 |
+
The model will then try to follow both the description and melody provided.
|
311 |
+
All samples are generated with the `melody` model.
|
312 |
+
|
313 |
+
You can also use your own GPU or a Google Colab by following the instructions on our repo.
|
314 |
+
|
315 |
+
See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
|
316 |
+
for more details.
|
317 |
+
""")
|
318 |
+
|
319 |
+
demo.queue(max_size=8 * 4).launch(**launch_kwargs)
|
320 |
+
|
321 |
+
|
322 |
+
if __name__ == "__main__":
|
323 |
+
parser = argparse.ArgumentParser()
|
324 |
+
parser.add_argument(
|
325 |
+
'--listen',
|
326 |
+
type=str,
|
327 |
+
default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1',
|
328 |
+
help='IP to listen on for connections to Gradio',
|
329 |
+
)
|
330 |
+
parser.add_argument(
|
331 |
+
'--username', type=str, default='', help='Username for authentication'
|
332 |
+
)
|
333 |
+
parser.add_argument(
|
334 |
+
'--password', type=str, default='', help='Password for authentication'
|
335 |
+
)
|
336 |
+
parser.add_argument(
|
337 |
+
'--server_port',
|
338 |
+
type=int,
|
339 |
+
default=0,
|
340 |
+
help='Port to run the server listener on',
|
341 |
+
)
|
342 |
+
parser.add_argument(
|
343 |
+
'--inbrowser', action='store_true', help='Open in browser'
|
344 |
+
)
|
345 |
+
parser.add_argument(
|
346 |
+
'--share', action='store_true', help='Share the gradio UI'
|
347 |
+
)
|
348 |
+
|
349 |
+
args = parser.parse_args()
|
350 |
+
|
351 |
+
launch_kwargs = {}
|
352 |
+
launch_kwargs['server_name'] = args.listen
|
353 |
+
|
354 |
+
if args.username and args.password:
|
355 |
+
launch_kwargs['auth'] = (args.username, args.password)
|
356 |
+
if args.server_port:
|
357 |
+
launch_kwargs['server_port'] = args.server_port
|
358 |
+
if args.inbrowser:
|
359 |
+
launch_kwargs['inbrowser'] = args.inbrowser
|
360 |
+
if args.share:
|
361 |
+
launch_kwargs['share'] = args.share
|
362 |
+
|
363 |
+
# Show the interface
|
364 |
+
if IS_BATCHED:
|
365 |
+
ui_batched(launch_kwargs)
|
366 |
+
else:
|
367 |
+
ui_full(launch_kwargs)
|
assets/bach.mp3
ADDED
Binary file (160 kB). View file
|
|
assets/bolero_ravel.mp3
ADDED
Binary file (161 kB). View file
|
|
audiocraft/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# flake8: noqa
|
8 |
+
from . import data, modules, models
|
9 |
+
|
10 |
+
__version__ = '0.0.2a2'
|
audiocraft/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (227 Bytes). View file
|
|
audiocraft/data/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# flake8: noqa
|
8 |
+
from . import audio, audio_dataset
|
audiocraft/data/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (194 Bytes). View file
|
|
audiocraft/data/__pycache__/audio.cpython-310.pyc
ADDED
Binary file (7.53 kB). View file
|
|
audiocraft/data/__pycache__/audio_dataset.cpython-310.pyc
ADDED
Binary file (19.1 kB). View file
|
|
audiocraft/data/__pycache__/audio_utils.cpython-310.pyc
ADDED
Binary file (6.23 kB). View file
|
|
audiocraft/data/__pycache__/zip.cpython-310.pyc
ADDED
Binary file (2.54 kB). View file
|
|
audiocraft/data/audio.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Audio IO methods are defined in this module (info, read, write),
|
9 |
+
We rely on av library for faster read when possible, otherwise on torchaudio.
|
10 |
+
"""
|
11 |
+
|
12 |
+
from dataclasses import dataclass
|
13 |
+
from pathlib import Path
|
14 |
+
import logging
|
15 |
+
import typing as tp
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import soundfile
|
19 |
+
import torch
|
20 |
+
from torch.nn import functional as F
|
21 |
+
import torchaudio as ta
|
22 |
+
|
23 |
+
import av
|
24 |
+
|
25 |
+
from .audio_utils import f32_pcm, i16_pcm, normalize_audio
|
26 |
+
|
27 |
+
|
28 |
+
_av_initialized = False
|
29 |
+
|
30 |
+
|
31 |
+
def _init_av():
|
32 |
+
global _av_initialized
|
33 |
+
if _av_initialized:
|
34 |
+
return
|
35 |
+
logger = logging.getLogger('libav.mp3')
|
36 |
+
logger.setLevel(logging.ERROR)
|
37 |
+
_av_initialized = True
|
38 |
+
|
39 |
+
|
40 |
+
@dataclass(frozen=True)
|
41 |
+
class AudioFileInfo:
|
42 |
+
sample_rate: int
|
43 |
+
duration: float
|
44 |
+
channels: int
|
45 |
+
|
46 |
+
|
47 |
+
def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
|
48 |
+
_init_av()
|
49 |
+
with av.open(str(filepath)) as af:
|
50 |
+
stream = af.streams.audio[0]
|
51 |
+
sample_rate = stream.codec_context.sample_rate
|
52 |
+
duration = float(stream.duration * stream.time_base)
|
53 |
+
channels = stream.channels
|
54 |
+
return AudioFileInfo(sample_rate, duration, channels)
|
55 |
+
|
56 |
+
|
57 |
+
def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
|
58 |
+
info = soundfile.info(filepath)
|
59 |
+
return AudioFileInfo(info.samplerate, info.duration, info.channels)
|
60 |
+
|
61 |
+
|
62 |
+
def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
|
63 |
+
# torchaudio no longer returns useful duration informations for some formats like mp3s.
|
64 |
+
filepath = Path(filepath)
|
65 |
+
if filepath.suffix in ['.flac', '.ogg']: # TODO: Validate .ogg can be safely read with av_info
|
66 |
+
# ffmpeg has some weird issue with flac.
|
67 |
+
return _soundfile_info(filepath)
|
68 |
+
else:
|
69 |
+
return _av_info(filepath)
|
70 |
+
|
71 |
+
|
72 |
+
def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]:
|
73 |
+
"""FFMPEG-based audio file reading using PyAV bindings.
|
74 |
+
Soundfile cannot read mp3 and av_read is more efficient than torchaudio.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
filepath (str or Path): Path to audio file to read.
|
78 |
+
seek_time (float): Time at which to start reading in the file.
|
79 |
+
duration (float): Duration to read from the file. If set to -1, the whole file is read.
|
80 |
+
Returns:
|
81 |
+
Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate
|
82 |
+
"""
|
83 |
+
_init_av()
|
84 |
+
with av.open(str(filepath)) as af:
|
85 |
+
stream = af.streams.audio[0]
|
86 |
+
sr = stream.codec_context.sample_rate
|
87 |
+
num_frames = int(sr * duration) if duration >= 0 else -1
|
88 |
+
frame_offset = int(sr * seek_time)
|
89 |
+
# we need a small negative offset otherwise we get some edge artifact
|
90 |
+
# from the mp3 decoder.
|
91 |
+
af.seek(int(max(0, (seek_time - 0.1)) / stream.time_base), stream=stream)
|
92 |
+
frames = []
|
93 |
+
length = 0
|
94 |
+
for frame in af.decode(streams=stream.index):
|
95 |
+
current_offset = int(frame.rate * frame.pts * frame.time_base)
|
96 |
+
strip = max(0, frame_offset - current_offset)
|
97 |
+
buf = torch.from_numpy(frame.to_ndarray())
|
98 |
+
if buf.shape[0] != stream.channels:
|
99 |
+
buf = buf.view(-1, stream.channels).t()
|
100 |
+
buf = buf[:, strip:]
|
101 |
+
frames.append(buf)
|
102 |
+
length += buf.shape[1]
|
103 |
+
if num_frames > 0 and length >= num_frames:
|
104 |
+
break
|
105 |
+
assert frames
|
106 |
+
# If the above assert fails, it is likely because we seeked past the end of file point,
|
107 |
+
# in which case ffmpeg returns a single frame with only zeros, and a weird timestamp.
|
108 |
+
# This will need proper debugging, in due time.
|
109 |
+
wav = torch.cat(frames, dim=1)
|
110 |
+
assert wav.shape[0] == stream.channels
|
111 |
+
if num_frames > 0:
|
112 |
+
wav = wav[:, :num_frames]
|
113 |
+
return f32_pcm(wav), sr
|
114 |
+
|
115 |
+
|
116 |
+
def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
|
117 |
+
duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
|
118 |
+
"""Read audio by picking the most appropriate backend tool based on the audio format.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
filepath (str or Path): Path to audio file to read.
|
122 |
+
seek_time (float): Time at which to start reading in the file.
|
123 |
+
duration (float): Duration to read from the file. If set to -1, the whole file is read.
|
124 |
+
pad (bool): Pad output audio if not reaching expected duration.
|
125 |
+
Returns:
|
126 |
+
Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate.
|
127 |
+
"""
|
128 |
+
fp = Path(filepath)
|
129 |
+
if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
|
130 |
+
# There is some bug with ffmpeg and reading flac
|
131 |
+
info = _soundfile_info(filepath)
|
132 |
+
frames = -1 if duration <= 0 else int(duration * info.sample_rate)
|
133 |
+
frame_offset = int(seek_time * info.sample_rate)
|
134 |
+
wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32)
|
135 |
+
assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}"
|
136 |
+
wav = torch.from_numpy(wav).t().contiguous()
|
137 |
+
if len(wav.shape) == 1:
|
138 |
+
wav = torch.unsqueeze(wav, 0)
|
139 |
+
elif (
|
140 |
+
fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats()
|
141 |
+
and duration <= 0 and seek_time == 0
|
142 |
+
):
|
143 |
+
# Torchaudio is faster if we load an entire file at once.
|
144 |
+
wav, sr = ta.load(fp)
|
145 |
+
else:
|
146 |
+
wav, sr = _av_read(filepath, seek_time, duration)
|
147 |
+
if pad and duration > 0:
|
148 |
+
expected_frames = int(duration * sr)
|
149 |
+
wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
|
150 |
+
return wav, sr
|
151 |
+
|
152 |
+
|
153 |
+
def audio_write(stem_name: tp.Union[str, Path],
|
154 |
+
wav: torch.Tensor, sample_rate: int,
|
155 |
+
format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
|
156 |
+
strategy: str = 'peak', peak_clip_headroom_db: float = 1,
|
157 |
+
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
|
158 |
+
loudness_compressor: bool = False,
|
159 |
+
log_clipping: bool = True, make_parent_dir: bool = True,
|
160 |
+
add_suffix: bool = True) -> Path:
|
161 |
+
"""Convenience function for saving audio to disk. Returns the filename the audio was written to.
|
162 |
+
|
163 |
+
Args:
|
164 |
+
stem_name (str or Path): Filename without extension which will be added automatically.
|
165 |
+
format (str): Either "wav" or "mp3".
|
166 |
+
mp3_rate (int): kbps when using mp3s.
|
167 |
+
normalize (bool): if `True` (default), normalizes according to the prescribed
|
168 |
+
strategy (see after). If `False`, the strategy is only used in case clipping
|
169 |
+
would happen.
|
170 |
+
strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
|
171 |
+
i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
|
172 |
+
with extra headroom to avoid clipping. 'clip' just clips.
|
173 |
+
peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
|
174 |
+
rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
|
175 |
+
than the `peak_clip` one to avoid further clipping.
|
176 |
+
loudness_headroom_db (float): Target loudness for loudness normalization.
|
177 |
+
loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
|
178 |
+
when strategy is 'loudness'log_clipping (bool): If True, basic logging on stderr when clipping still
|
179 |
+
occurs despite strategy (only for 'rms').
|
180 |
+
make_parent_dir (bool): Make parent directory if it doesn't exist.
|
181 |
+
Returns:
|
182 |
+
Path: Path of the saved audio.
|
183 |
+
"""
|
184 |
+
assert wav.dtype.is_floating_point, "wav is not floating point"
|
185 |
+
if wav.dim() == 1:
|
186 |
+
wav = wav[None]
|
187 |
+
elif wav.dim() > 2:
|
188 |
+
raise ValueError("Input wav should be at most 2 dimension.")
|
189 |
+
assert wav.isfinite().all()
|
190 |
+
wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
|
191 |
+
rms_headroom_db, loudness_headroom_db, log_clipping=log_clipping,
|
192 |
+
sample_rate=sample_rate, stem_name=str(stem_name))
|
193 |
+
kwargs: dict = {}
|
194 |
+
if format == 'mp3':
|
195 |
+
suffix = '.mp3'
|
196 |
+
kwargs.update({"compression": mp3_rate})
|
197 |
+
elif format == 'wav':
|
198 |
+
wav = i16_pcm(wav)
|
199 |
+
suffix = '.wav'
|
200 |
+
kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16})
|
201 |
+
else:
|
202 |
+
raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
|
203 |
+
if not add_suffix:
|
204 |
+
suffix = ''
|
205 |
+
path = Path(str(stem_name) + suffix)
|
206 |
+
if make_parent_dir:
|
207 |
+
path.parent.mkdir(exist_ok=True, parents=True)
|
208 |
+
try:
|
209 |
+
ta.save(path, wav, sample_rate, **kwargs)
|
210 |
+
except Exception:
|
211 |
+
if path.exists():
|
212 |
+
# we do not want to leave half written files around.
|
213 |
+
path.unlink()
|
214 |
+
raise
|
215 |
+
return path
|
audiocraft/data/audio_dataset.py
ADDED
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import copy
|
9 |
+
from concurrent.futures import ThreadPoolExecutor, Future
|
10 |
+
from dataclasses import dataclass, fields
|
11 |
+
from contextlib import ExitStack
|
12 |
+
import gzip
|
13 |
+
import json
|
14 |
+
import logging
|
15 |
+
import os
|
16 |
+
from pathlib import Path
|
17 |
+
import random
|
18 |
+
import sys
|
19 |
+
import typing as tp
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn.functional as F
|
23 |
+
|
24 |
+
from .audio import audio_read, audio_info
|
25 |
+
from .audio_utils import convert_audio
|
26 |
+
from .zip import PathInZip
|
27 |
+
|
28 |
+
try:
|
29 |
+
import dora
|
30 |
+
except ImportError:
|
31 |
+
dora = None # type: ignore
|
32 |
+
|
33 |
+
|
34 |
+
@dataclass(order=True)
|
35 |
+
class BaseInfo:
|
36 |
+
|
37 |
+
@classmethod
|
38 |
+
def _dict2fields(cls, dictionary: dict):
|
39 |
+
return {
|
40 |
+
field.name: dictionary[field.name]
|
41 |
+
for field in fields(cls) if field.name in dictionary
|
42 |
+
}
|
43 |
+
|
44 |
+
@classmethod
|
45 |
+
def from_dict(cls, dictionary: dict):
|
46 |
+
_dictionary = cls._dict2fields(dictionary)
|
47 |
+
return cls(**_dictionary)
|
48 |
+
|
49 |
+
def to_dict(self):
|
50 |
+
return {
|
51 |
+
field.name: self.__getattribute__(field.name)
|
52 |
+
for field in fields(self)
|
53 |
+
}
|
54 |
+
|
55 |
+
|
56 |
+
@dataclass(order=True)
|
57 |
+
class AudioMeta(BaseInfo):
|
58 |
+
path: str
|
59 |
+
duration: float
|
60 |
+
sample_rate: int
|
61 |
+
amplitude: tp.Optional[float] = None
|
62 |
+
weight: tp.Optional[float] = None
|
63 |
+
# info_path is used to load additional information about the audio file that is stored in zip files.
|
64 |
+
info_path: tp.Optional[PathInZip] = None
|
65 |
+
|
66 |
+
@classmethod
|
67 |
+
def from_dict(cls, dictionary: dict):
|
68 |
+
base = cls._dict2fields(dictionary)
|
69 |
+
if 'info_path' in base and base['info_path'] is not None:
|
70 |
+
base['info_path'] = PathInZip(base['info_path'])
|
71 |
+
return cls(**base)
|
72 |
+
|
73 |
+
def to_dict(self):
|
74 |
+
d = super().to_dict()
|
75 |
+
if d['info_path'] is not None:
|
76 |
+
d['info_path'] = str(d['info_path'])
|
77 |
+
return d
|
78 |
+
|
79 |
+
|
80 |
+
@dataclass(order=True)
|
81 |
+
class SegmentInfo(BaseInfo):
|
82 |
+
meta: AudioMeta
|
83 |
+
seek_time: float
|
84 |
+
n_frames: int # actual number of frames without padding
|
85 |
+
total_frames: int # total number of frames, padding included
|
86 |
+
sample_rate: int # actual sample rate
|
87 |
+
|
88 |
+
|
89 |
+
DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
|
90 |
+
|
91 |
+
logger = logging.getLogger(__name__)
|
92 |
+
|
93 |
+
|
94 |
+
def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta:
|
95 |
+
"""AudioMeta from a path to an audio file.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
file_path (str): Resolved path of valid audio file.
|
99 |
+
minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
|
100 |
+
Returns:
|
101 |
+
AudioMeta: Audio file path and its metadata.
|
102 |
+
"""
|
103 |
+
info = audio_info(file_path)
|
104 |
+
amplitude: tp.Optional[float] = None
|
105 |
+
if not minimal:
|
106 |
+
wav, sr = audio_read(file_path)
|
107 |
+
amplitude = wav.abs().max().item()
|
108 |
+
return AudioMeta(file_path, info.duration, info.sample_rate, amplitude)
|
109 |
+
|
110 |
+
|
111 |
+
def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta:
|
112 |
+
"""If Dora is available as a dependency, try to resolve potential relative paths
|
113 |
+
in list of AudioMeta. This method is expected to be used when loading meta from file.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
m (AudioMeta): Audio meta to resolve.
|
117 |
+
fast (bool): If True, uses a really fast check for determining if a file is already absolute or not.
|
118 |
+
Only valid on Linux/Mac.
|
119 |
+
Returns:
|
120 |
+
AudioMeta: Audio meta with resolved path.
|
121 |
+
"""
|
122 |
+
def is_abs(m):
|
123 |
+
if fast:
|
124 |
+
return str(m)[0] == '/'
|
125 |
+
else:
|
126 |
+
os.path.isabs(str(m))
|
127 |
+
|
128 |
+
if not dora:
|
129 |
+
return m
|
130 |
+
|
131 |
+
if not is_abs(m.path):
|
132 |
+
m.path = dora.git_save.to_absolute_path(m.path)
|
133 |
+
if m.info_path is not None and not is_abs(m.info_path.zip_path):
|
134 |
+
m.info_path.zip_path = dora.git_save.to_absolute_path(m.path)
|
135 |
+
return m
|
136 |
+
|
137 |
+
|
138 |
+
def find_audio_files(path: tp.Union[Path, str],
|
139 |
+
exts: tp.List[str] = DEFAULT_EXTS,
|
140 |
+
resolve: bool = True,
|
141 |
+
minimal: bool = True,
|
142 |
+
progress: bool = False,
|
143 |
+
workers: int = 0) -> tp.List[AudioMeta]:
|
144 |
+
"""Build a list of AudioMeta from a given path,
|
145 |
+
collecting relevant audio files and fetching meta info.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
path (str or Path): Path to folder containing audio files.
|
149 |
+
exts (list of str): List of file extensions to consider for audio files.
|
150 |
+
minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
|
151 |
+
progress (bool): Whether to log progress on audio files collection.
|
152 |
+
workers (int): number of parallel workers, if 0, use only the current thread.
|
153 |
+
Returns:
|
154 |
+
List[AudioMeta]: List of audio file path and its metadata.
|
155 |
+
"""
|
156 |
+
audio_files = []
|
157 |
+
futures: tp.List[Future] = []
|
158 |
+
pool: tp.Optional[ThreadPoolExecutor] = None
|
159 |
+
with ExitStack() as stack:
|
160 |
+
if workers > 0:
|
161 |
+
pool = ThreadPoolExecutor(workers)
|
162 |
+
stack.enter_context(pool)
|
163 |
+
|
164 |
+
if progress:
|
165 |
+
print("Finding audio files...")
|
166 |
+
for root, folders, files in os.walk(path, followlinks=True):
|
167 |
+
for file in files:
|
168 |
+
full_path = Path(root) / file
|
169 |
+
if full_path.suffix.lower() in exts:
|
170 |
+
audio_files.append(full_path)
|
171 |
+
if pool is not None:
|
172 |
+
futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal))
|
173 |
+
if progress:
|
174 |
+
print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr)
|
175 |
+
|
176 |
+
if progress:
|
177 |
+
print("Getting audio metadata...")
|
178 |
+
meta: tp.List[AudioMeta] = []
|
179 |
+
for idx, file_path in enumerate(audio_files):
|
180 |
+
try:
|
181 |
+
if pool is None:
|
182 |
+
m = _get_audio_meta(str(file_path), minimal)
|
183 |
+
else:
|
184 |
+
m = futures[idx].result()
|
185 |
+
if resolve:
|
186 |
+
m = _resolve_audio_meta(m)
|
187 |
+
except Exception as err:
|
188 |
+
print("Error with", str(file_path), err, file=sys.stderr)
|
189 |
+
continue
|
190 |
+
meta.append(m)
|
191 |
+
if progress:
|
192 |
+
print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr)
|
193 |
+
meta.sort()
|
194 |
+
return meta
|
195 |
+
|
196 |
+
|
197 |
+
def load_audio_meta(path: tp.Union[str, Path],
|
198 |
+
resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]:
|
199 |
+
"""Load list of AudioMeta from an optionally compressed json file.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
path (str or Path): Path to JSON file.
|
203 |
+
resolve (bool): Whether to resolve the path from AudioMeta (default=True).
|
204 |
+
fast (bool): activates some tricks to make things faster.
|
205 |
+
Returns:
|
206 |
+
List[AudioMeta]: List of audio file path and its total duration.
|
207 |
+
"""
|
208 |
+
open_fn = gzip.open if str(path).lower().endswith('.gz') else open
|
209 |
+
with open_fn(path, 'rb') as fp: # type: ignore
|
210 |
+
lines = fp.readlines()
|
211 |
+
meta = []
|
212 |
+
for line in lines:
|
213 |
+
d = json.loads(line)
|
214 |
+
m = AudioMeta.from_dict(d)
|
215 |
+
if resolve:
|
216 |
+
m = _resolve_audio_meta(m, fast=fast)
|
217 |
+
meta.append(m)
|
218 |
+
return meta
|
219 |
+
|
220 |
+
|
221 |
+
def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]):
|
222 |
+
"""Save the audio metadata to the file pointer as json.
|
223 |
+
|
224 |
+
Args:
|
225 |
+
path (str or Path): Path to JSON file.
|
226 |
+
metadata (list of BaseAudioMeta): List of audio meta to save.
|
227 |
+
"""
|
228 |
+
Path(path).parent.mkdir(exist_ok=True, parents=True)
|
229 |
+
open_fn = gzip.open if str(path).lower().endswith('.gz') else open
|
230 |
+
with open_fn(path, 'wb') as fp: # type: ignore
|
231 |
+
for m in meta:
|
232 |
+
json_str = json.dumps(m.to_dict()) + '\n'
|
233 |
+
json_bytes = json_str.encode('utf-8')
|
234 |
+
fp.write(json_bytes)
|
235 |
+
|
236 |
+
|
237 |
+
class AudioDataset:
|
238 |
+
"""Base audio dataset.
|
239 |
+
|
240 |
+
The dataset takes a list of AudioMeta and create a dataset composed of segments of audio
|
241 |
+
and potentially additional information, by creating random segments from the list of audio
|
242 |
+
files referenced in the metadata and applying minimal data pre-processing such as resampling,
|
243 |
+
mixing of channels, padding, etc.
|
244 |
+
|
245 |
+
If no segment_duration value is provided, the AudioDataset will return the full wav for each
|
246 |
+
audio file. Otherwise, it will randomly sample audio files and create a segment of the specified
|
247 |
+
duration, applying padding if required.
|
248 |
+
|
249 |
+
By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True
|
250 |
+
allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
|
251 |
+
original audio meta.
|
252 |
+
|
253 |
+
Args:
|
254 |
+
meta (tp.List[AudioMeta]): List of audio files metadata.
|
255 |
+
segment_duration (float): Optional segment duration of audio to load.
|
256 |
+
If not specified, the dataset will load the full audio segment from the file.
|
257 |
+
shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
|
258 |
+
sample_rate (int): Target sample rate of the loaded audio samples.
|
259 |
+
channels (int): Target number of channels of the loaded audio samples.
|
260 |
+
sample_on_duration (bool): Set to `True` to sample segments with probability
|
261 |
+
dependent on audio file duration. This is only used if `segment_duration` is provided.
|
262 |
+
sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of
|
263 |
+
`AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product
|
264 |
+
of the file duration and file weight. This is only used if `segment_duration` is provided.
|
265 |
+
min_segment_ratio (float): Minimum segment ratio to use when the audio file
|
266 |
+
is shorter than the desired segment.
|
267 |
+
max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
|
268 |
+
return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
|
269 |
+
min_audio_duration (tp.Optional[float], optional): Minimum audio file duration, in seconds, if provided
|
270 |
+
audio shorter than this will be filtered out.
|
271 |
+
max_audio_duration (tp.Optional[float], optional): Maximal audio file duration in seconds, if provided
|
272 |
+
audio longer than this will be filtered out.
|
273 |
+
"""
|
274 |
+
def __init__(self,
|
275 |
+
meta: tp.List[AudioMeta],
|
276 |
+
segment_duration: tp.Optional[float] = None,
|
277 |
+
shuffle: bool = True,
|
278 |
+
num_samples: int = 10_000,
|
279 |
+
sample_rate: int = 48_000,
|
280 |
+
channels: int = 2,
|
281 |
+
pad: bool = True,
|
282 |
+
sample_on_duration: bool = True,
|
283 |
+
sample_on_weight: bool = True,
|
284 |
+
min_segment_ratio: float = 0.5,
|
285 |
+
max_read_retry: int = 10,
|
286 |
+
return_info: bool = False,
|
287 |
+
min_audio_duration: tp.Optional[float] = None,
|
288 |
+
max_audio_duration: tp.Optional[float] = None
|
289 |
+
):
|
290 |
+
assert len(meta) > 0, 'No audio meta provided to AudioDataset. Please check loading of audio meta.'
|
291 |
+
assert segment_duration is None or segment_duration > 0
|
292 |
+
assert segment_duration is None or min_segment_ratio >= 0
|
293 |
+
logging.debug(f'sample_on_duration: {sample_on_duration}')
|
294 |
+
logging.debug(f'sample_on_weight: {sample_on_weight}')
|
295 |
+
logging.debug(f'pad: {pad}')
|
296 |
+
logging.debug(f'min_segment_ratio: {min_segment_ratio}')
|
297 |
+
|
298 |
+
self.segment_duration = segment_duration
|
299 |
+
self.min_segment_ratio = min_segment_ratio
|
300 |
+
self.max_audio_duration = max_audio_duration
|
301 |
+
self.min_audio_duration = min_audio_duration
|
302 |
+
if self.min_audio_duration is not None and self.max_audio_duration is not None:
|
303 |
+
assert self.min_audio_duration <= self.max_audio_duration
|
304 |
+
self.meta: tp.List[AudioMeta] = self._filter_duration(meta)
|
305 |
+
assert len(self.meta) # Fail fast if all data has been filtered.
|
306 |
+
self.total_duration = sum(d.duration for d in self.meta)
|
307 |
+
|
308 |
+
if segment_duration is None:
|
309 |
+
num_samples = len(self.meta)
|
310 |
+
self.num_samples = num_samples
|
311 |
+
self.shuffle = shuffle
|
312 |
+
self.sample_rate = sample_rate
|
313 |
+
self.channels = channels
|
314 |
+
self.pad = pad
|
315 |
+
self.sample_on_weight = sample_on_weight
|
316 |
+
self.sample_on_duration = sample_on_duration
|
317 |
+
self.sampling_probabilities = self._get_sampling_probabilities()
|
318 |
+
self.max_read_retry = max_read_retry
|
319 |
+
self.return_info = return_info
|
320 |
+
|
321 |
+
def __len__(self):
|
322 |
+
return self.num_samples
|
323 |
+
|
324 |
+
def _get_sampling_probabilities(self, normalized: bool = True):
|
325 |
+
"""Return the sampling probabilities for each file inside `self.meta`.
|
326 |
+
"""
|
327 |
+
scores: tp.List[float] = []
|
328 |
+
for file_meta in self.meta:
|
329 |
+
score = 1.
|
330 |
+
if self.sample_on_weight and file_meta.weight is not None:
|
331 |
+
score *= file_meta.weight
|
332 |
+
if self.sample_on_duration:
|
333 |
+
score *= file_meta.duration
|
334 |
+
scores.append(score)
|
335 |
+
probabilities = torch.tensor(scores)
|
336 |
+
if normalized:
|
337 |
+
probabilities /= probabilities.sum()
|
338 |
+
return probabilities
|
339 |
+
|
340 |
+
def sample_file(self, rng: torch.Generator) -> AudioMeta:
|
341 |
+
"""Sample a given file from `self.meta`. Can be overriden in subclasses.
|
342 |
+
This is only called if `segment_duration` is not None.
|
343 |
+
|
344 |
+
You must use the provided random number generator `rng` for reproducibility.
|
345 |
+
"""
|
346 |
+
if not self.sample_on_weight and not self.sample_on_duration:
|
347 |
+
file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
|
348 |
+
else:
|
349 |
+
file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item())
|
350 |
+
|
351 |
+
return self.meta[file_index]
|
352 |
+
|
353 |
+
def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
|
354 |
+
if self.segment_duration is None:
|
355 |
+
file_meta = self.meta[index]
|
356 |
+
out, sr = audio_read(file_meta.path)
|
357 |
+
out = convert_audio(out, sr, self.sample_rate, self.channels)
|
358 |
+
n_frames = out.shape[-1]
|
359 |
+
segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
|
360 |
+
sample_rate=self.sample_rate)
|
361 |
+
else:
|
362 |
+
rng = torch.Generator()
|
363 |
+
if self.shuffle:
|
364 |
+
# We use index, plus extra randomness
|
365 |
+
rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
|
366 |
+
else:
|
367 |
+
# We only use index
|
368 |
+
rng.manual_seed(index)
|
369 |
+
|
370 |
+
for retry in range(self.max_read_retry):
|
371 |
+
file_meta = self.sample_file(rng)
|
372 |
+
# We add some variance in the file position even if audio file is smaller than segment
|
373 |
+
# without ending up with empty segments
|
374 |
+
max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
|
375 |
+
seek_time = torch.rand(1, generator=rng).item() * max_seek
|
376 |
+
try:
|
377 |
+
out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False)
|
378 |
+
out = convert_audio(out, sr, self.sample_rate, self.channels)
|
379 |
+
n_frames = out.shape[-1]
|
380 |
+
target_frames = int(self.segment_duration * self.sample_rate)
|
381 |
+
if self.pad:
|
382 |
+
out = F.pad(out, (0, target_frames - n_frames))
|
383 |
+
segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
|
384 |
+
sample_rate=self.sample_rate)
|
385 |
+
except Exception as exc:
|
386 |
+
logger.warning("Error opening file %s: %r", file_meta.path, exc)
|
387 |
+
if retry == self.max_read_retry - 1:
|
388 |
+
raise
|
389 |
+
else:
|
390 |
+
break
|
391 |
+
|
392 |
+
if self.return_info:
|
393 |
+
# Returns the wav and additional information on the wave segment
|
394 |
+
return out, segment_info
|
395 |
+
else:
|
396 |
+
return out
|
397 |
+
|
398 |
+
def collater(self, samples):
|
399 |
+
"""The collater function has to be provided to the dataloader
|
400 |
+
if AudioDataset has return_info=True in order to properly collate
|
401 |
+
the samples of a batch.
|
402 |
+
"""
|
403 |
+
if self.segment_duration is None and len(samples) > 1:
|
404 |
+
assert self.pad, "Must allow padding when batching examples of different durations."
|
405 |
+
|
406 |
+
# In this case the audio reaching the collater is of variable length as segment_duration=None.
|
407 |
+
to_pad = self.segment_duration is None and self.pad
|
408 |
+
if to_pad:
|
409 |
+
max_len = max([wav.shape[-1] for wav, _ in samples])
|
410 |
+
|
411 |
+
def _pad_wav(wav):
|
412 |
+
return F.pad(wav, (0, max_len - wav.shape[-1]))
|
413 |
+
|
414 |
+
if self.return_info:
|
415 |
+
if len(samples) > 0:
|
416 |
+
assert len(samples[0]) == 2
|
417 |
+
assert isinstance(samples[0][0], torch.Tensor)
|
418 |
+
assert isinstance(samples[0][1], SegmentInfo)
|
419 |
+
|
420 |
+
wavs = [wav for wav, _ in samples]
|
421 |
+
segment_infos = [copy.deepcopy(info) for _, info in samples]
|
422 |
+
|
423 |
+
if to_pad:
|
424 |
+
# Each wav could be of a different duration as they are not segmented.
|
425 |
+
for i in range(len(samples)):
|
426 |
+
# Determines the total legth of the signal with padding, so we update here as we pad.
|
427 |
+
segment_infos[i].total_frames = max_len
|
428 |
+
wavs[i] = _pad_wav(wavs[i])
|
429 |
+
|
430 |
+
wav = torch.stack(wavs)
|
431 |
+
return wav, segment_infos
|
432 |
+
else:
|
433 |
+
assert isinstance(samples[0], torch.Tensor)
|
434 |
+
if to_pad:
|
435 |
+
samples = [_pad_wav(s) for s in samples]
|
436 |
+
return torch.stack(samples)
|
437 |
+
|
438 |
+
def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
|
439 |
+
"""Filters out audio files with short durations.
|
440 |
+
Removes from meta files that have durations that will not allow to samples examples from them.
|
441 |
+
"""
|
442 |
+
orig_len = len(meta)
|
443 |
+
|
444 |
+
# Filter data that is too short.
|
445 |
+
if self.min_audio_duration is not None:
|
446 |
+
meta = [m for m in meta if m.duration >= self.min_audio_duration]
|
447 |
+
|
448 |
+
# Filter data that is too long.
|
449 |
+
if self.max_audio_duration is not None:
|
450 |
+
meta = [m for m in meta if m.duration <= self.max_audio_duration]
|
451 |
+
|
452 |
+
filtered_len = len(meta)
|
453 |
+
removed_percentage = 100*(1-float(filtered_len)/orig_len)
|
454 |
+
msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage
|
455 |
+
if removed_percentage < 10:
|
456 |
+
logging.debug(msg)
|
457 |
+
else:
|
458 |
+
logging.warning(msg)
|
459 |
+
return meta
|
460 |
+
|
461 |
+
@classmethod
|
462 |
+
def from_meta(cls, root: tp.Union[str, Path], **kwargs):
|
463 |
+
"""Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.
|
464 |
+
|
465 |
+
Args:
|
466 |
+
root (str or Path): Path to root folder containing audio files.
|
467 |
+
kwargs: Additional keyword arguments for the AudioDataset.
|
468 |
+
"""
|
469 |
+
root = Path(root)
|
470 |
+
if root.is_dir():
|
471 |
+
if (root / 'data.jsonl').exists():
|
472 |
+
root = root / 'data.jsonl'
|
473 |
+
elif (root / 'data.jsonl.gz').exists():
|
474 |
+
root = root / 'data.jsonl.gz'
|
475 |
+
else:
|
476 |
+
raise ValueError("Don't know where to read metadata from in the dir. "
|
477 |
+
"Expecting either a data.jsonl or data.jsonl.gz file but none found.")
|
478 |
+
meta = load_audio_meta(root)
|
479 |
+
return cls(meta, **kwargs)
|
480 |
+
|
481 |
+
@classmethod
|
482 |
+
def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True,
|
483 |
+
exts: tp.List[str] = DEFAULT_EXTS, **kwargs):
|
484 |
+
"""Instantiate AudioDataset from a path containing (possibly nested) audio files.
|
485 |
+
|
486 |
+
Args:
|
487 |
+
root (str or Path): Path to root folder containing audio files.
|
488 |
+
minimal_meta (bool): Whether to only load minimal metadata or not.
|
489 |
+
exts (list of str): Extensions for audio files.
|
490 |
+
kwargs: Additional keyword arguments for the AudioDataset.
|
491 |
+
"""
|
492 |
+
root = Path(root)
|
493 |
+
if root.is_file():
|
494 |
+
meta = load_audio_meta(root, resolve=True)
|
495 |
+
else:
|
496 |
+
meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True)
|
497 |
+
return cls(meta, **kwargs)
|
498 |
+
|
499 |
+
|
500 |
+
def main():
|
501 |
+
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
|
502 |
+
parser = argparse.ArgumentParser(
|
503 |
+
prog='audio_dataset',
|
504 |
+
description='Generate .jsonl files by scanning a folder.')
|
505 |
+
parser.add_argument('root', help='Root folder with all the audio files')
|
506 |
+
parser.add_argument('output_meta_file',
|
507 |
+
help='Output file to store the metadata, ')
|
508 |
+
parser.add_argument('--complete',
|
509 |
+
action='store_false', dest='minimal', default=True,
|
510 |
+
help='Retrieve all metadata, even the one that are expansive '
|
511 |
+
'to compute (e.g. normalization).')
|
512 |
+
parser.add_argument('--resolve',
|
513 |
+
action='store_true', default=False,
|
514 |
+
help='Resolve the paths to be absolute and with no symlinks.')
|
515 |
+
parser.add_argument('--workers',
|
516 |
+
default=10, type=int,
|
517 |
+
help='Number of workers.')
|
518 |
+
args = parser.parse_args()
|
519 |
+
meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True,
|
520 |
+
resolve=args.resolve, minimal=args.minimal, workers=args.workers)
|
521 |
+
save_audio_meta(args.output_meta_file, meta)
|
522 |
+
|
523 |
+
|
524 |
+
if __name__ == '__main__':
|
525 |
+
main()
|
audiocraft/data/audio_utils.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import sys
|
8 |
+
import typing as tp
|
9 |
+
|
10 |
+
import julius
|
11 |
+
import torch
|
12 |
+
import torchaudio
|
13 |
+
|
14 |
+
|
15 |
+
def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor:
|
16 |
+
"""Convert audio to the given number of channels.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
wav (torch.Tensor): Audio wave of shape [B, C, T].
|
20 |
+
channels (int): Expected number of channels as output.
|
21 |
+
Returns:
|
22 |
+
torch.Tensor: Downmixed or unchanged audio wave [B, C, T].
|
23 |
+
"""
|
24 |
+
*shape, src_channels, length = wav.shape
|
25 |
+
if src_channels == channels:
|
26 |
+
pass
|
27 |
+
elif channels == 1:
|
28 |
+
# Case 1:
|
29 |
+
# The caller asked 1-channel audio, and the stream has multiple
|
30 |
+
# channels, downmix all channels.
|
31 |
+
wav = wav.mean(dim=-2, keepdim=True)
|
32 |
+
elif src_channels == 1:
|
33 |
+
# Case 2:
|
34 |
+
# The caller asked for multiple channels, but the input file has
|
35 |
+
# a single channel, replicate the audio over all channels.
|
36 |
+
wav = wav.expand(*shape, channels, length)
|
37 |
+
elif src_channels >= channels:
|
38 |
+
# Case 3:
|
39 |
+
# The caller asked for multiple channels, and the input file has
|
40 |
+
# more channels than requested. In that case return the first channels.
|
41 |
+
wav = wav[..., :channels, :]
|
42 |
+
else:
|
43 |
+
# Case 4: What is a reasonable choice here?
|
44 |
+
raise ValueError('The audio file has less channels than requested but is not mono.')
|
45 |
+
return wav
|
46 |
+
|
47 |
+
|
48 |
+
def convert_audio(wav: torch.Tensor, from_rate: float,
|
49 |
+
to_rate: float, to_channels: int) -> torch.Tensor:
|
50 |
+
"""Convert audio to new sample rate and number of audio channels.
|
51 |
+
"""
|
52 |
+
wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
|
53 |
+
wav = convert_audio_channels(wav, to_channels)
|
54 |
+
return wav
|
55 |
+
|
56 |
+
|
57 |
+
def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
|
58 |
+
loudness_compressor: bool = False, energy_floor: float = 2e-3):
|
59 |
+
"""Normalize an input signal to a user loudness in dB LKFS.
|
60 |
+
Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
wav (torch.Tensor): Input multichannel audio data.
|
64 |
+
sample_rate (int): Sample rate.
|
65 |
+
loudness_headroom_db (float): Target loudness of the output in dB LUFS.
|
66 |
+
loudness_compressor (bool): Uses tanh for soft clipping.
|
67 |
+
energy_floor (float): anything below that RMS level will not be rescaled.
|
68 |
+
Returns:
|
69 |
+
output (torch.Tensor): Loudness normalized output data.
|
70 |
+
"""
|
71 |
+
energy = wav.pow(2).mean().sqrt().item()
|
72 |
+
if energy < energy_floor:
|
73 |
+
return wav
|
74 |
+
transform = torchaudio.transforms.Loudness(sample_rate)
|
75 |
+
input_loudness_db = transform(wav).item()
|
76 |
+
# calculate the gain needed to scale to the desired loudness level
|
77 |
+
delta_loudness = -loudness_headroom_db - input_loudness_db
|
78 |
+
gain = 10.0 ** (delta_loudness / 20.0)
|
79 |
+
output = gain * wav
|
80 |
+
if loudness_compressor:
|
81 |
+
output = torch.tanh(output)
|
82 |
+
assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
|
83 |
+
return output
|
84 |
+
|
85 |
+
|
86 |
+
def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None:
|
87 |
+
"""Utility function to clip the audio with logging if specified."""
|
88 |
+
max_scale = wav.abs().max()
|
89 |
+
if log_clipping and max_scale > 1:
|
90 |
+
clamp_prob = (wav.abs() > 1).float().mean().item()
|
91 |
+
print(f"CLIPPING {stem_name or ''} happening with proba (a bit of clipping is okay):",
|
92 |
+
clamp_prob, "maximum scale: ", max_scale.item(), file=sys.stderr)
|
93 |
+
wav.clamp_(-1, 1)
|
94 |
+
|
95 |
+
|
96 |
+
def normalize_audio(wav: torch.Tensor, normalize: bool = True,
|
97 |
+
strategy: str = 'peak', peak_clip_headroom_db: float = 1,
|
98 |
+
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
|
99 |
+
loudness_compressor: bool = False, log_clipping: bool = False,
|
100 |
+
sample_rate: tp.Optional[int] = None,
|
101 |
+
stem_name: tp.Optional[str] = None) -> torch.Tensor:
|
102 |
+
"""Normalize the audio according to the prescribed strategy (see after).
|
103 |
+
|
104 |
+
Args:
|
105 |
+
wav (torch.Tensor): Audio data.
|
106 |
+
normalize (bool): if `True` (default), normalizes according to the prescribed
|
107 |
+
strategy (see after). If `False`, the strategy is only used in case clipping
|
108 |
+
would happen.
|
109 |
+
strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
|
110 |
+
i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
|
111 |
+
with extra headroom to avoid clipping. 'clip' just clips.
|
112 |
+
peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
|
113 |
+
rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
|
114 |
+
than the `peak_clip` one to avoid further clipping.
|
115 |
+
loudness_headroom_db (float): Target loudness for loudness normalization.
|
116 |
+
loudness_compressor (bool): If True, uses tanh based soft clipping.
|
117 |
+
log_clipping (bool): If True, basic logging on stderr when clipping still
|
118 |
+
occurs despite strategy (only for 'rms').
|
119 |
+
sample_rate (int): Sample rate for the audio data (required for loudness).
|
120 |
+
stem_name (Optional[str]): Stem name for clipping logging.
|
121 |
+
Returns:
|
122 |
+
torch.Tensor: Normalized audio.
|
123 |
+
"""
|
124 |
+
scale_peak = 10 ** (-peak_clip_headroom_db / 20)
|
125 |
+
scale_rms = 10 ** (-rms_headroom_db / 20)
|
126 |
+
if strategy == 'peak':
|
127 |
+
rescaling = (scale_peak / wav.abs().max())
|
128 |
+
if normalize or rescaling < 1:
|
129 |
+
wav = wav * rescaling
|
130 |
+
elif strategy == 'clip':
|
131 |
+
wav = wav.clamp(-scale_peak, scale_peak)
|
132 |
+
elif strategy == 'rms':
|
133 |
+
mono = wav.mean(dim=0)
|
134 |
+
rescaling = scale_rms / mono.pow(2).mean().sqrt()
|
135 |
+
if normalize or rescaling < 1:
|
136 |
+
wav = wav * rescaling
|
137 |
+
_clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
|
138 |
+
elif strategy == 'loudness':
|
139 |
+
assert sample_rate is not None, "Loudness normalization requires sample rate."
|
140 |
+
wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor)
|
141 |
+
_clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
|
142 |
+
else:
|
143 |
+
assert wav.abs().max() < 1
|
144 |
+
assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'"
|
145 |
+
return wav
|
146 |
+
|
147 |
+
|
148 |
+
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
149 |
+
"""Convert audio to float 32 bits PCM format.
|
150 |
+
"""
|
151 |
+
if wav.dtype.is_floating_point:
|
152 |
+
return wav
|
153 |
+
else:
|
154 |
+
assert wav.dtype == torch.int16
|
155 |
+
return wav.float() / 2**15
|
156 |
+
|
157 |
+
|
158 |
+
def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
|
159 |
+
"""Convert audio to int 16 bits PCM format.
|
160 |
+
|
161 |
+
..Warning:: There exist many formula for doing this convertion. None are perfect
|
162 |
+
due to the asymetry of the int16 range. One either have possible clipping, DC offset,
|
163 |
+
or inconsistancies with f32_pcm. If the given wav doesn't have enough headroom,
|
164 |
+
it is possible that `i16_pcm(f32_pcm)) != Identity`.
|
165 |
+
"""
|
166 |
+
if wav.dtype.is_floating_point:
|
167 |
+
assert wav.abs().max() <= 1
|
168 |
+
candidate = (wav * 2 ** 15).round()
|
169 |
+
if candidate.max() >= 2 ** 15: # clipping would occur
|
170 |
+
candidate = (wav * (2 ** 15 - 1)).round()
|
171 |
+
return candidate.short()
|
172 |
+
else:
|
173 |
+
assert wav.dtype == torch.int16
|
174 |
+
return wav
|
audiocraft/data/zip.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import typing
|
8 |
+
import zipfile
|
9 |
+
|
10 |
+
from dataclasses import dataclass
|
11 |
+
from functools import lru_cache
|
12 |
+
from typing_extensions import Literal
|
13 |
+
|
14 |
+
|
15 |
+
DEFAULT_SIZE = 32
|
16 |
+
MODE = Literal['r', 'w', 'x', 'a']
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass(order=True)
|
20 |
+
class PathInZip:
|
21 |
+
"""Class for holding a path of file within a zip file.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
path: The convention is <path_to_zip>:<relative_path_inside_zip>
|
25 |
+
Let's assume there is a zip file /some/location/foo.zip
|
26 |
+
and inside of it is a json file located at /data/file1.json,
|
27 |
+
Then we expect path = "/some/location/foo.zip:/data/file1.json"
|
28 |
+
"""
|
29 |
+
|
30 |
+
INFO_PATH_SEP = ':'
|
31 |
+
zip_path: str
|
32 |
+
file_path: str
|
33 |
+
|
34 |
+
def __init__(self, path: str) -> None:
|
35 |
+
split_path = path.split(self.INFO_PATH_SEP)
|
36 |
+
assert len(split_path) == 2
|
37 |
+
self.zip_path, self.file_path = split_path
|
38 |
+
|
39 |
+
@classmethod
|
40 |
+
def from_paths(cls, zip_path: str, file_path: str):
|
41 |
+
return cls(zip_path + cls.INFO_PATH_SEP + file_path)
|
42 |
+
|
43 |
+
def __str__(self) -> str:
|
44 |
+
return self.zip_path + self.INFO_PATH_SEP + self.file_path
|
45 |
+
|
46 |
+
|
47 |
+
def _open_zip(path: str, mode: MODE = 'r'):
|
48 |
+
return zipfile.ZipFile(path, mode)
|
49 |
+
|
50 |
+
|
51 |
+
_cached_open_zip = lru_cache(DEFAULT_SIZE)(_open_zip)
|
52 |
+
|
53 |
+
|
54 |
+
def set_zip_cache_size(max_size: int):
|
55 |
+
"""Sets the maximal LRU caching for zip file opening.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
max_size: the maximal LRU cache.
|
59 |
+
"""
|
60 |
+
global _cached_open_zip
|
61 |
+
_cached_open_zip = lru_cache(max_size)(_open_zip)
|
62 |
+
|
63 |
+
|
64 |
+
def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
|
65 |
+
"""Opens a file stored inside a zip and returns a file-like object.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
path_in_zip: A PathInZip object representing the file to return a file-like object of.
|
69 |
+
mode: The mode in which to open the file with.
|
70 |
+
Returns:
|
71 |
+
A file-like object for PathInZip.
|
72 |
+
"""
|
73 |
+
zf = _cached_open_zip(path_in_zip.zip_path)
|
74 |
+
return zf.open(path_in_zip.file_path)
|
audiocraft/models/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# flake8: noqa
|
8 |
+
from .musicgen import MusicGen
|
9 |
+
from .lm import LMModel
|
10 |
+
from .encodec import CompressionModel, EncodecModel
|
audiocraft/models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (288 Bytes). View file
|
|
audiocraft/models/__pycache__/builders.cpython-310.pyc
ADDED
Binary file (6.84 kB). View file
|
|
audiocraft/models/__pycache__/encodec.cpython-310.pyc
ADDED
Binary file (10.9 kB). View file
|
|
audiocraft/models/__pycache__/lm.cpython-310.pyc
ADDED
Binary file (18.7 kB). View file
|
|
audiocraft/models/__pycache__/loaders.cpython-310.pyc
ADDED
Binary file (2.83 kB). View file
|
|
audiocraft/models/__pycache__/musicgen.cpython-310.pyc
ADDED
Binary file (13.6 kB). View file
|
|
audiocraft/models/builders.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
All the functions to build the relevant models and modules
|
9 |
+
from the Hydra config.
|
10 |
+
"""
|
11 |
+
|
12 |
+
import typing as tp
|
13 |
+
import warnings
|
14 |
+
|
15 |
+
import audiocraft
|
16 |
+
import omegaconf
|
17 |
+
import torch
|
18 |
+
|
19 |
+
from .encodec import CompressionModel, EncodecModel, FlattenedCompressionModel # noqa
|
20 |
+
from .lm import LMModel
|
21 |
+
from ..modules.codebooks_patterns import (
|
22 |
+
CodebooksPatternProvider,
|
23 |
+
DelayedPatternProvider,
|
24 |
+
ParallelPatternProvider,
|
25 |
+
UnrolledPatternProvider,
|
26 |
+
VALLEPattern,
|
27 |
+
MusicLMPattern,
|
28 |
+
)
|
29 |
+
from ..modules.conditioners import (
|
30 |
+
BaseConditioner,
|
31 |
+
ConditioningProvider,
|
32 |
+
LUTConditioner,
|
33 |
+
T5Conditioner,
|
34 |
+
ConditionFuser,
|
35 |
+
ChromaStemConditioner,
|
36 |
+
)
|
37 |
+
from .. import quantization as qt
|
38 |
+
from ..utils.utils import dict_from_config
|
39 |
+
|
40 |
+
|
41 |
+
def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer:
|
42 |
+
klass = {
|
43 |
+
'no_quant': qt.DummyQuantizer,
|
44 |
+
'rvq': qt.ResidualVectorQuantizer
|
45 |
+
}[quantizer]
|
46 |
+
kwargs = dict_from_config(getattr(cfg, quantizer))
|
47 |
+
if quantizer != 'no_quant':
|
48 |
+
kwargs['dimension'] = dimension
|
49 |
+
return klass(**kwargs)
|
50 |
+
|
51 |
+
|
52 |
+
def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
|
53 |
+
if encoder_name == 'seanet':
|
54 |
+
kwargs = dict_from_config(getattr(cfg, 'seanet'))
|
55 |
+
encoder_override_kwargs = kwargs.pop('encoder')
|
56 |
+
decoder_override_kwargs = kwargs.pop('decoder')
|
57 |
+
encoder_kwargs = {**kwargs, **encoder_override_kwargs}
|
58 |
+
decoder_kwargs = {**kwargs, **decoder_override_kwargs}
|
59 |
+
encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs)
|
60 |
+
decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs)
|
61 |
+
return encoder, decoder
|
62 |
+
else:
|
63 |
+
raise KeyError(f'Unexpected compression model {cfg.compression_model}')
|
64 |
+
|
65 |
+
|
66 |
+
def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
|
67 |
+
"""Instantiate a compression model.
|
68 |
+
"""
|
69 |
+
if cfg.compression_model == 'encodec':
|
70 |
+
kwargs = dict_from_config(getattr(cfg, 'encodec'))
|
71 |
+
encoder_name = kwargs.pop('autoencoder')
|
72 |
+
quantizer_name = kwargs.pop('quantizer')
|
73 |
+
encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
|
74 |
+
quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
|
75 |
+
frame_rate = kwargs['sample_rate'] // encoder.hop_length
|
76 |
+
renormalize = kwargs.pop('renormalize', None)
|
77 |
+
renorm = kwargs.pop('renorm')
|
78 |
+
if renormalize is None:
|
79 |
+
renormalize = renorm is not None
|
80 |
+
warnings.warn("You are using a deprecated EnCodec model. Please migrate to new renormalization.")
|
81 |
+
return EncodecModel(encoder, decoder, quantizer,
|
82 |
+
frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
|
83 |
+
else:
|
84 |
+
raise KeyError(f'Unexpected compression model {cfg.compression_model}')
|
85 |
+
|
86 |
+
|
87 |
+
def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
|
88 |
+
"""Instantiate a transformer LM.
|
89 |
+
"""
|
90 |
+
if cfg.lm_model == 'transformer_lm':
|
91 |
+
kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
|
92 |
+
n_q = kwargs['n_q']
|
93 |
+
q_modeling = kwargs.pop('q_modeling', None)
|
94 |
+
codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
|
95 |
+
attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
|
96 |
+
cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
|
97 |
+
cfg_prob, cfg_coef = cls_free_guidance["training_dropout"], cls_free_guidance["inference_coef"]
|
98 |
+
fuser = get_condition_fuser(cfg)
|
99 |
+
condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
|
100 |
+
if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programatically
|
101 |
+
kwargs['cross_attention'] = True
|
102 |
+
if codebooks_pattern_cfg.modeling is None:
|
103 |
+
assert q_modeling is not None, \
|
104 |
+
'LM model should either have a codebook pattern defined or transformer_lm.q_modeling'
|
105 |
+
codebooks_pattern_cfg = omegaconf.OmegaConf.create(
|
106 |
+
{'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
|
107 |
+
)
|
108 |
+
pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
|
109 |
+
return LMModel(
|
110 |
+
pattern_provider=pattern_provider,
|
111 |
+
condition_provider=condition_provider,
|
112 |
+
fuser=fuser,
|
113 |
+
cfg_dropout=cfg_prob,
|
114 |
+
cfg_coef=cfg_coef,
|
115 |
+
attribute_dropout=attribute_dropout,
|
116 |
+
dtype=getattr(torch, cfg.dtype),
|
117 |
+
device=cfg.device,
|
118 |
+
**kwargs
|
119 |
+
).to(cfg.device)
|
120 |
+
else:
|
121 |
+
raise KeyError(f'Unexpected LM model {cfg.lm_model}')
|
122 |
+
|
123 |
+
|
124 |
+
def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider:
|
125 |
+
"""Instantiate a conditioning model.
|
126 |
+
"""
|
127 |
+
device = cfg.device
|
128 |
+
duration = cfg.dataset.segment_duration
|
129 |
+
cfg = getattr(cfg, "conditioners")
|
130 |
+
cfg = omegaconf.OmegaConf.create({}) if cfg is None else cfg
|
131 |
+
conditioners: tp.Dict[str, BaseConditioner] = {}
|
132 |
+
with omegaconf.open_dict(cfg):
|
133 |
+
condition_provider_args = cfg.pop('args', {})
|
134 |
+
for cond, cond_cfg in cfg.items():
|
135 |
+
model_type = cond_cfg["model"]
|
136 |
+
model_args = cond_cfg[model_type]
|
137 |
+
if model_type == "t5":
|
138 |
+
conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
|
139 |
+
elif model_type == "lut":
|
140 |
+
conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
|
141 |
+
elif model_type == "chroma_stem":
|
142 |
+
model_args.pop('cache_path', None)
|
143 |
+
conditioners[str(cond)] = ChromaStemConditioner(
|
144 |
+
output_dim=output_dim,
|
145 |
+
duration=duration,
|
146 |
+
device=device,
|
147 |
+
**model_args
|
148 |
+
)
|
149 |
+
else:
|
150 |
+
raise ValueError(f"unrecognized conditioning model: {model_type}")
|
151 |
+
conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
|
152 |
+
return conditioner
|
153 |
+
|
154 |
+
|
155 |
+
def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
|
156 |
+
"""Instantiate a condition fuser object.
|
157 |
+
"""
|
158 |
+
fuser_cfg = getattr(cfg, "fuser")
|
159 |
+
fuser_methods = ["sum", "cross", "prepend", "input_interpolate"]
|
160 |
+
fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
|
161 |
+
kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
|
162 |
+
fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
|
163 |
+
return fuser
|
164 |
+
|
165 |
+
|
166 |
+
def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
|
167 |
+
"""Instantiate a codebooks pattern provider object.
|
168 |
+
"""
|
169 |
+
pattern_providers = {
|
170 |
+
'parallel': ParallelPatternProvider,
|
171 |
+
'delay': DelayedPatternProvider,
|
172 |
+
'unroll': UnrolledPatternProvider,
|
173 |
+
'valle': VALLEPattern,
|
174 |
+
'musiclm': MusicLMPattern,
|
175 |
+
}
|
176 |
+
name = cfg.modeling
|
177 |
+
kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
|
178 |
+
klass = pattern_providers[name]
|
179 |
+
return klass(n_q, **kwargs)
|
180 |
+
|
181 |
+
|
182 |
+
def get_debug_compression_model(device='cpu'):
|
183 |
+
"""Instantiate a debug compression model to be used for unit tests.
|
184 |
+
"""
|
185 |
+
seanet_kwargs = {
|
186 |
+
'n_filters': 4,
|
187 |
+
'n_residual_layers': 1,
|
188 |
+
'dimension': 32,
|
189 |
+
'ratios': [10, 8, 16] # 25 Hz at 32kHz
|
190 |
+
}
|
191 |
+
encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
|
192 |
+
decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
|
193 |
+
quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4)
|
194 |
+
init_x = torch.randn(8, 32, 128)
|
195 |
+
quantizer(init_x, 1) # initialize kmeans etc.
|
196 |
+
compression_model = EncodecModel(
|
197 |
+
encoder, decoder, quantizer,
|
198 |
+
frame_rate=25, sample_rate=32000, channels=1).to(device)
|
199 |
+
return compression_model.eval()
|
200 |
+
|
201 |
+
|
202 |
+
def get_debug_lm_model(device='cpu'):
|
203 |
+
"""Instantiate a debug LM to be used for unit tests.
|
204 |
+
"""
|
205 |
+
pattern = DelayedPatternProvider(n_q=4)
|
206 |
+
dim = 16
|
207 |
+
providers = {
|
208 |
+
'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"),
|
209 |
+
}
|
210 |
+
condition_provider = ConditioningProvider(providers)
|
211 |
+
fuser = ConditionFuser(
|
212 |
+
{'cross': ['description'], 'prepend': [],
|
213 |
+
'sum': [], 'input_interpolate': []})
|
214 |
+
lm = LMModel(
|
215 |
+
pattern, condition_provider, fuser,
|
216 |
+
n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2,
|
217 |
+
cross_attention=True, causal=True)
|
218 |
+
return lm.to(device).eval()
|
audiocraft/models/encodec.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from abc import ABC, abstractmethod
|
8 |
+
import typing as tp
|
9 |
+
|
10 |
+
from einops import rearrange
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
from .. import quantization as qt
|
15 |
+
|
16 |
+
|
17 |
+
class CompressionModel(ABC, nn.Module):
|
18 |
+
|
19 |
+
@abstractmethod
|
20 |
+
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
21 |
+
...
|
22 |
+
|
23 |
+
@abstractmethod
|
24 |
+
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
25 |
+
"""See `EncodecModel.encode`"""
|
26 |
+
...
|
27 |
+
|
28 |
+
@abstractmethod
|
29 |
+
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
30 |
+
"""See `EncodecModel.decode`"""
|
31 |
+
...
|
32 |
+
|
33 |
+
@property
|
34 |
+
@abstractmethod
|
35 |
+
def channels(self) -> int:
|
36 |
+
...
|
37 |
+
|
38 |
+
@property
|
39 |
+
@abstractmethod
|
40 |
+
def frame_rate(self) -> int:
|
41 |
+
...
|
42 |
+
|
43 |
+
@property
|
44 |
+
@abstractmethod
|
45 |
+
def sample_rate(self) -> int:
|
46 |
+
...
|
47 |
+
|
48 |
+
@property
|
49 |
+
@abstractmethod
|
50 |
+
def cardinality(self) -> int:
|
51 |
+
...
|
52 |
+
|
53 |
+
@property
|
54 |
+
@abstractmethod
|
55 |
+
def num_codebooks(self) -> int:
|
56 |
+
...
|
57 |
+
|
58 |
+
@property
|
59 |
+
@abstractmethod
|
60 |
+
def total_codebooks(self) -> int:
|
61 |
+
...
|
62 |
+
|
63 |
+
@abstractmethod
|
64 |
+
def set_num_codebooks(self, n: int):
|
65 |
+
"""Set the active number of codebooks used by the quantizer.
|
66 |
+
"""
|
67 |
+
...
|
68 |
+
|
69 |
+
|
70 |
+
class EncodecModel(CompressionModel):
|
71 |
+
"""Encodec model operating on the raw waveform.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
encoder (nn.Module): Encoder network.
|
75 |
+
decoder (nn.Module): Decoder network.
|
76 |
+
quantizer (qt.BaseQuantizer): Quantizer network.
|
77 |
+
frame_rate (int): Frame rate for the latent representation.
|
78 |
+
sample_rate (int): Audio sample rate.
|
79 |
+
channels (int): Number of audio channels.
|
80 |
+
causal (bool): Whether to use a causal version of the model.
|
81 |
+
renormalize (bool): Whether to renormalize the audio before running the model.
|
82 |
+
"""
|
83 |
+
# we need assignement to override the property in the abstract class,
|
84 |
+
# I couldn't find a better way...
|
85 |
+
frame_rate: int = 0
|
86 |
+
sample_rate: int = 0
|
87 |
+
channels: int = 0
|
88 |
+
|
89 |
+
def __init__(self,
|
90 |
+
encoder: nn.Module,
|
91 |
+
decoder: nn.Module,
|
92 |
+
quantizer: qt.BaseQuantizer,
|
93 |
+
frame_rate: int,
|
94 |
+
sample_rate: int,
|
95 |
+
channels: int,
|
96 |
+
causal: bool = False,
|
97 |
+
renormalize: bool = False):
|
98 |
+
super().__init__()
|
99 |
+
self.encoder = encoder
|
100 |
+
self.decoder = decoder
|
101 |
+
self.quantizer = quantizer
|
102 |
+
self.frame_rate = frame_rate
|
103 |
+
self.sample_rate = sample_rate
|
104 |
+
self.channels = channels
|
105 |
+
self.renormalize = renormalize
|
106 |
+
self.causal = causal
|
107 |
+
if self.causal:
|
108 |
+
# we force disabling here to avoid handling linear overlap of segments
|
109 |
+
# as supported in original EnCodec codebase.
|
110 |
+
assert not self.renormalize, 'Causal model does not support renormalize'
|
111 |
+
|
112 |
+
@property
|
113 |
+
def total_codebooks(self):
|
114 |
+
"""Total number of quantizer codebooks available.
|
115 |
+
"""
|
116 |
+
return self.quantizer.total_codebooks
|
117 |
+
|
118 |
+
@property
|
119 |
+
def num_codebooks(self):
|
120 |
+
"""Active number of codebooks used by the quantizer.
|
121 |
+
"""
|
122 |
+
return self.quantizer.num_codebooks
|
123 |
+
|
124 |
+
def set_num_codebooks(self, n: int):
|
125 |
+
"""Set the active number of codebooks used by the quantizer.
|
126 |
+
"""
|
127 |
+
self.quantizer.set_num_codebooks(n)
|
128 |
+
|
129 |
+
@property
|
130 |
+
def cardinality(self):
|
131 |
+
"""Cardinality of each codebook.
|
132 |
+
"""
|
133 |
+
return self.quantizer.bins
|
134 |
+
|
135 |
+
def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
136 |
+
scale: tp.Optional[torch.Tensor]
|
137 |
+
if self.renormalize:
|
138 |
+
mono = x.mean(dim=1, keepdim=True)
|
139 |
+
volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
|
140 |
+
scale = 1e-8 + volume
|
141 |
+
x = x / scale
|
142 |
+
scale = scale.view(-1, 1)
|
143 |
+
else:
|
144 |
+
scale = None
|
145 |
+
return x, scale
|
146 |
+
|
147 |
+
def postprocess(self,
|
148 |
+
x: torch.Tensor,
|
149 |
+
scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
|
150 |
+
if scale is not None:
|
151 |
+
assert self.renormalize
|
152 |
+
x = x * scale.view(-1, 1, 1)
|
153 |
+
return x
|
154 |
+
|
155 |
+
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
156 |
+
assert x.dim() == 3
|
157 |
+
length = x.shape[-1]
|
158 |
+
x, scale = self.preprocess(x)
|
159 |
+
|
160 |
+
emb = self.encoder(x)
|
161 |
+
q_res = self.quantizer(emb, self.frame_rate)
|
162 |
+
out = self.decoder(q_res.x)
|
163 |
+
|
164 |
+
# remove extra padding added by the encoder and decoder
|
165 |
+
assert out.shape[-1] >= length, (out.shape[-1], length)
|
166 |
+
out = out[..., :length]
|
167 |
+
|
168 |
+
q_res.x = self.postprocess(out, scale)
|
169 |
+
|
170 |
+
return q_res
|
171 |
+
|
172 |
+
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
173 |
+
"""Encode the given input tensor to quantized representation along with scale parameter.
|
174 |
+
|
175 |
+
Args:
|
176 |
+
x (torch.Tensor): Float tensor of shape [B, C, T]
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
codes, scale (tp.Tuple[torch.Tensor, torch.Tensor]): Tuple composed of:
|
180 |
+
codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
|
181 |
+
scale a float tensor containing the scale for audio renormalizealization.
|
182 |
+
"""
|
183 |
+
assert x.dim() == 3
|
184 |
+
x, scale = self.preprocess(x)
|
185 |
+
emb = self.encoder(x)
|
186 |
+
codes = self.quantizer.encode(emb)
|
187 |
+
return codes, scale
|
188 |
+
|
189 |
+
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
190 |
+
"""Decode the given codes to a reconstructed representation, using the scale to perform
|
191 |
+
audio denormalization if needed.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
codes (torch.Tensor): Int tensor of shape [B, K, T]
|
195 |
+
scale (tp.Optional[torch.Tensor]): Float tensor containing the scale value.
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
|
199 |
+
"""
|
200 |
+
emb = self.quantizer.decode(codes)
|
201 |
+
out = self.decoder(emb)
|
202 |
+
out = self.postprocess(out, scale)
|
203 |
+
# out contains extra padding added by the encoder and decoder
|
204 |
+
return out
|
205 |
+
|
206 |
+
|
207 |
+
class FlattenedCompressionModel(CompressionModel):
|
208 |
+
"""Wraps a CompressionModel and flatten its codebooks, e.g.
|
209 |
+
instead of returning [B, K, T], return [B, S, T * (K // S)] with
|
210 |
+
S the number of codebooks per step, and `K // S` the number of 'virtual steps'
|
211 |
+
for each real time step.
|
212 |
+
|
213 |
+
Args:
|
214 |
+
model (CompressionModel): compression model to wrap.
|
215 |
+
codebooks_per_step (int): number of codebooks to keep per step,
|
216 |
+
this must divide the number of codebooks provided by the wrapped model.
|
217 |
+
extend_cardinality (bool): if True, and for instance if codebooks_per_step = 1,
|
218 |
+
if each codebook has a cardinality N, then the first codebook will
|
219 |
+
use the range [0, N - 1], and the second [N, 2 N - 1] etc.
|
220 |
+
On decoding, this can lead to potentially invalid sequences.
|
221 |
+
Any invalid entry will be silently remapped to the proper range
|
222 |
+
with a modulo.
|
223 |
+
"""
|
224 |
+
def __init__(self, model: CompressionModel, codebooks_per_step: int = 1,
|
225 |
+
extend_cardinality: bool = True):
|
226 |
+
super().__init__()
|
227 |
+
self.model = model
|
228 |
+
self.codebooks_per_step = codebooks_per_step
|
229 |
+
self.extend_cardinality = extend_cardinality
|
230 |
+
|
231 |
+
@property
|
232 |
+
def total_codebooks(self):
|
233 |
+
return self.model.total_codebooks
|
234 |
+
|
235 |
+
@property
|
236 |
+
def num_codebooks(self):
|
237 |
+
"""Active number of codebooks used by the quantizer.
|
238 |
+
|
239 |
+
..Warning:: this reports the number of codebooks after the flattening
|
240 |
+
of the codebooks!
|
241 |
+
"""
|
242 |
+
assert self.model.num_codebooks % self.codebooks_per_step == 0
|
243 |
+
return self.codebooks_per_step
|
244 |
+
|
245 |
+
def set_num_codebooks(self, n: int):
|
246 |
+
"""Set the active number of codebooks used by the quantizer.
|
247 |
+
|
248 |
+
..Warning:: this sets the number of codebooks **before** the flattening
|
249 |
+
of the codebooks.
|
250 |
+
"""
|
251 |
+
assert n % self.codebooks_per_step == 0
|
252 |
+
self.model.set_num_codebooks(n)
|
253 |
+
|
254 |
+
@property
|
255 |
+
def num_virtual_steps(self) -> int:
|
256 |
+
"""Return the number of virtual steps, e.g. one real step
|
257 |
+
will be split into that many steps.
|
258 |
+
"""
|
259 |
+
return self.model.num_codebooks // self.codebooks_per_step
|
260 |
+
|
261 |
+
@property
|
262 |
+
def frame_rate(self) -> int:
|
263 |
+
return self.model.frame_rate * self.num_virtual_steps
|
264 |
+
|
265 |
+
@property
|
266 |
+
def sample_rate(self) -> int:
|
267 |
+
return self.model.sample_rate
|
268 |
+
|
269 |
+
@property
|
270 |
+
def channels(self) -> int:
|
271 |
+
return self.model.channels
|
272 |
+
|
273 |
+
@property
|
274 |
+
def cardinality(self):
|
275 |
+
"""Cardinality of each codebook.
|
276 |
+
"""
|
277 |
+
if self.extend_cardinality:
|
278 |
+
return self.model.cardinality * self.num_virtual_steps
|
279 |
+
else:
|
280 |
+
return self.model.cardinality
|
281 |
+
|
282 |
+
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
283 |
+
raise NotImplementedError("Not supported, use encode and decode.")
|
284 |
+
|
285 |
+
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
286 |
+
indices, scales = self.model.encode(x)
|
287 |
+
B, K, T = indices.shape
|
288 |
+
indices = rearrange(indices, 'b (k v) t -> b k t v', k=self.codebooks_per_step)
|
289 |
+
if self.extend_cardinality:
|
290 |
+
for virtual_step in range(1, self.num_virtual_steps):
|
291 |
+
indices[..., virtual_step] += self.model.cardinality * virtual_step
|
292 |
+
indices = rearrange(indices, 'b k t v -> b k (t v)')
|
293 |
+
return (indices, scales)
|
294 |
+
|
295 |
+
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
296 |
+
B, K, T = codes.shape
|
297 |
+
assert T % self.num_virtual_steps == 0
|
298 |
+
codes = rearrange(codes, 'b k (t v) -> b (k v) t', v=self.num_virtual_steps)
|
299 |
+
# We silently ignore potential errors from the LM when
|
300 |
+
# using extend_cardinality.
|
301 |
+
codes = codes % self.model.cardinality
|
302 |
+
return self.model.decode(codes, scale)
|
audiocraft/models/lm.py
ADDED
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from functools import partial
|
9 |
+
import logging
|
10 |
+
import math
|
11 |
+
import typing as tp
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from torch import nn
|
15 |
+
|
16 |
+
from ..utils import utils
|
17 |
+
from ..modules.streaming import StreamingModule, State
|
18 |
+
from ..modules.transformer import StreamingTransformer, create_norm_fn
|
19 |
+
from ..modules.conditioners import (
|
20 |
+
ConditionFuser,
|
21 |
+
ClassifierFreeGuidanceDropout,
|
22 |
+
AttributeDropout,
|
23 |
+
ConditioningProvider,
|
24 |
+
ConditioningAttributes,
|
25 |
+
ConditionType,
|
26 |
+
)
|
27 |
+
from ..modules.codebooks_patterns import CodebooksPatternProvider
|
28 |
+
from ..modules.activations import get_activation_fn
|
29 |
+
|
30 |
+
|
31 |
+
logger = logging.getLogger(__name__)
|
32 |
+
ConditionTensors = tp.Dict[str, ConditionType]
|
33 |
+
CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
|
34 |
+
|
35 |
+
|
36 |
+
def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None):
|
37 |
+
"""LM layer initialization.
|
38 |
+
Inspired from xlformers: https://github.com/fairinternal/xlformers
|
39 |
+
|
40 |
+
Args:
|
41 |
+
method (str): Method name for init function. Valid options are:
|
42 |
+
'gaussian', 'uniform'.
|
43 |
+
input_dim (int): Input dimension of the initialized module.
|
44 |
+
init_depth (Optional[int]): Optional init depth value used to rescale
|
45 |
+
the standard deviation if defined.
|
46 |
+
"""
|
47 |
+
# Compute std
|
48 |
+
std = 1 / math.sqrt(input_dim)
|
49 |
+
# Rescale with depth
|
50 |
+
if init_depth is not None:
|
51 |
+
std = std / math.sqrt(2 * init_depth)
|
52 |
+
|
53 |
+
if method == 'gaussian':
|
54 |
+
return partial(
|
55 |
+
torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std
|
56 |
+
)
|
57 |
+
elif method == 'uniform':
|
58 |
+
bound = math.sqrt(3) * std # ensure the standard deviation is `std`
|
59 |
+
return partial(torch.nn.init.uniform_, a=-bound, b=bound)
|
60 |
+
else:
|
61 |
+
raise ValueError("Unsupported layer initialization method")
|
62 |
+
|
63 |
+
|
64 |
+
def init_layer(m: nn.Module,
|
65 |
+
method: str,
|
66 |
+
init_depth: tp.Optional[int] = None,
|
67 |
+
zero_bias_init: bool = False):
|
68 |
+
"""Wrapper around ``get_init_fn`` for proper initialization of LM modules.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
m (nn.Module): Module to initialize.
|
72 |
+
method (str): Method name for the init function.
|
73 |
+
init_depth (Optional[int]): Optional init depth value used to rescale
|
74 |
+
the standard deviation if defined.
|
75 |
+
zero_bias_init (bool): Whether to initialize the bias to 0 or not.
|
76 |
+
"""
|
77 |
+
if isinstance(m, nn.Linear):
|
78 |
+
init_fn = get_init_fn(method, m.in_features, init_depth=init_depth)
|
79 |
+
if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
|
80 |
+
weight = m.weight.float()
|
81 |
+
init_fn(weight)
|
82 |
+
m.weight.data[:] = weight.half()
|
83 |
+
else:
|
84 |
+
init_fn(m.weight)
|
85 |
+
if zero_bias_init and m.bias is not None:
|
86 |
+
nn.init.constant_(m.bias, 0)
|
87 |
+
elif isinstance(m, nn.Embedding):
|
88 |
+
init_fn = get_init_fn(method, m.embedding_dim, init_depth=None)
|
89 |
+
if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
|
90 |
+
weight = m.weight.float()
|
91 |
+
init_fn(weight)
|
92 |
+
m.weight.data[:] = weight.half()
|
93 |
+
else:
|
94 |
+
init_fn(m.weight)
|
95 |
+
|
96 |
+
|
97 |
+
class ScaledEmbedding(nn.Embedding):
|
98 |
+
"""Boost learning rate for embeddings (with `scale`).
|
99 |
+
"""
|
100 |
+
def __init__(self, *args, lr=None, **kwargs):
|
101 |
+
super().__init__(*args, **kwargs)
|
102 |
+
self.lr = lr
|
103 |
+
|
104 |
+
def make_optim_group(self):
|
105 |
+
group = {"params": list(self.parameters())}
|
106 |
+
if self.lr is not None:
|
107 |
+
group["lr"] = self.lr
|
108 |
+
return group
|
109 |
+
|
110 |
+
|
111 |
+
@dataclass
|
112 |
+
class LMOutput:
|
113 |
+
# The logits are already re-aligned with the input codes
|
114 |
+
# hence no extra shift is required, e.g. when computing CE
|
115 |
+
logits: torch.Tensor # [B, K, T, card]
|
116 |
+
mask: torch.Tensor # [B, K, T]
|
117 |
+
|
118 |
+
|
119 |
+
class LMModel(StreamingModule):
|
120 |
+
"""Transformer-based language model on multiple streams of codes.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving.
|
124 |
+
condition_provider (MusicConditioningProvider): Conditioning provider from metadata.
|
125 |
+
fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input.
|
126 |
+
n_q (int): Number of parallel streams to model.
|
127 |
+
card (int): Cardinality, vocabulary size.
|
128 |
+
dim (int): Dimension of the transformer encoder.
|
129 |
+
num_heads (int): Number of heads for the transformer encoder.
|
130 |
+
hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
|
131 |
+
norm (str): Normalization method.
|
132 |
+
norm_first (bool): Use pre-norm instead of post-norm.
|
133 |
+
emb_lr (Optional[float]): Embedding-specific learning rate.
|
134 |
+
bias_proj (bool): Use bias for output projections.
|
135 |
+
weight_init (Optional[str]): Method for weight initialization.
|
136 |
+
depthwise_init (Optional[str]): Method for depthwise weight initialization.
|
137 |
+
zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros.
|
138 |
+
cfg_dropout (float): Classifier-free guidance dropout.
|
139 |
+
cfg_coef (float): Classifier-free guidance coefficient.
|
140 |
+
attribute_dropout (dict): Attribute dropout probabilities.
|
141 |
+
two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
|
142 |
+
**kwargs: Additional parameters for the transformer encoder.
|
143 |
+
"""
|
144 |
+
def __init__(self, pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider,
|
145 |
+
fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8,
|
146 |
+
hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False,
|
147 |
+
emb_lr: tp.Optional[float] = None, bias_proj: bool = True,
|
148 |
+
weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
|
149 |
+
zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0,
|
150 |
+
attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, two_step_cfg: bool = False,
|
151 |
+
**kwargs):
|
152 |
+
super().__init__()
|
153 |
+
self.cfg_coef = cfg_coef
|
154 |
+
self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout)
|
155 |
+
self.att_dropout = AttributeDropout(p=attribute_dropout)
|
156 |
+
self.condition_provider = condition_provider
|
157 |
+
self.fuser = fuser
|
158 |
+
self.card = card
|
159 |
+
embed_dim = self.card + 1
|
160 |
+
self.n_q = n_q
|
161 |
+
self.dim = dim
|
162 |
+
self.pattern_provider = pattern_provider
|
163 |
+
self.two_step_cfg = two_step_cfg
|
164 |
+
self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)])
|
165 |
+
if 'activation' in kwargs:
|
166 |
+
kwargs['activation'] = get_activation_fn(kwargs['activation'])
|
167 |
+
self.transformer = StreamingTransformer(
|
168 |
+
d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim),
|
169 |
+
norm=norm, norm_first=norm_first, **kwargs)
|
170 |
+
self.out_norm: tp.Optional[nn.Module] = None
|
171 |
+
if norm_first:
|
172 |
+
self.out_norm = create_norm_fn(norm, dim)
|
173 |
+
self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)])
|
174 |
+
self._init_weights(weight_init, depthwise_init, zero_bias_init)
|
175 |
+
self._fsdp: tp.Optional[nn.Module]
|
176 |
+
self.__dict__['_fsdp'] = None
|
177 |
+
|
178 |
+
def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool):
|
179 |
+
"""Initialization of the transformer module weights.
|
180 |
+
|
181 |
+
Args:
|
182 |
+
weight_init (Optional[str]): Weight initialization strategy. See ``get_init_fn`` for valid options.
|
183 |
+
depthwise_init (Optional[str]): Depwthwise initialization strategy. The following options are valid:
|
184 |
+
'current' where the depth corresponds to the current layer index or 'global' where the total number
|
185 |
+
of layer is used as depth. If not set, no depthwise initialization strategy is used.
|
186 |
+
zero_bias_init (bool): Whether to initalize bias to zero or not.
|
187 |
+
"""
|
188 |
+
assert depthwise_init is None or depthwise_init in ['current', 'global']
|
189 |
+
assert depthwise_init is None or weight_init is not None, \
|
190 |
+
"If 'depthwise_init' is defined, a 'weight_init' method should be provided."
|
191 |
+
assert not zero_bias_init or weight_init is not None, \
|
192 |
+
"If 'zero_bias_init', a 'weight_init' method should be provided"
|
193 |
+
|
194 |
+
if weight_init is None:
|
195 |
+
return
|
196 |
+
|
197 |
+
for emb_layer in self.emb:
|
198 |
+
init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
|
199 |
+
|
200 |
+
for layer_idx, tr_layer in enumerate(self.transformer.layers):
|
201 |
+
depth = None
|
202 |
+
if depthwise_init == 'current':
|
203 |
+
depth = layer_idx + 1
|
204 |
+
elif depthwise_init == 'global':
|
205 |
+
depth = len(self.transformer.layers)
|
206 |
+
init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init)
|
207 |
+
tr_layer.apply(init_fn)
|
208 |
+
|
209 |
+
for linear in self.linears:
|
210 |
+
init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
|
211 |
+
|
212 |
+
@property
|
213 |
+
def special_token_id(self) -> int:
|
214 |
+
return self.card
|
215 |
+
|
216 |
+
@property
|
217 |
+
def num_codebooks(self) -> int:
|
218 |
+
return self.n_q
|
219 |
+
|
220 |
+
def forward(self, sequence: torch.Tensor,
|
221 |
+
conditions: tp.List[ConditioningAttributes],
|
222 |
+
condition_tensors: tp.Optional[ConditionTensors] = None) -> torch.Tensor:
|
223 |
+
"""Apply language model on sequence and conditions.
|
224 |
+
Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
|
225 |
+
S the sequence steps, return the logits with shape [B, card, K, S].
|
226 |
+
|
227 |
+
Args:
|
228 |
+
indices (torch.Tensor): indices of the codes to model.
|
229 |
+
conditions (list[ConditioningAttributes]): conditionings to use when modeling
|
230 |
+
the given codes. Note that when evaluating multiple time with the same conditioning
|
231 |
+
you should pre-compute those and pass them as `condition_tensors`.
|
232 |
+
condition_tensors (dict[str, ConditionType] or None): pre-computed conditioning
|
233 |
+
tensors, see `conditions`.
|
234 |
+
Returns:
|
235 |
+
torch.Tensor: Logits.
|
236 |
+
"""
|
237 |
+
B, K, S = sequence.shape
|
238 |
+
assert K == self.num_codebooks, 'Sequence shape must match the specified number of codebooks'
|
239 |
+
input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
|
240 |
+
if condition_tensors is None:
|
241 |
+
assert not self._is_streaming, "Conditions tensors should be precomputed when streaming."
|
242 |
+
# apply dropout modules
|
243 |
+
conditions = self.cfg_dropout(conditions)
|
244 |
+
conditions = self.att_dropout(conditions)
|
245 |
+
tokenized = self.condition_provider.tokenize(conditions)
|
246 |
+
# encode conditions and fuse, both have a streaming cache to not recompute when generating.
|
247 |
+
condition_tensors = self.condition_provider(tokenized)
|
248 |
+
else:
|
249 |
+
assert not conditions, "Shouldn't pass both conditions and condition_tensors."
|
250 |
+
|
251 |
+
input_, cross_attention_input = self.fuser(input_, condition_tensors)
|
252 |
+
|
253 |
+
out = self.transformer(input_, cross_attention_src=cross_attention_input)
|
254 |
+
if self.out_norm:
|
255 |
+
out = self.out_norm(out)
|
256 |
+
logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card]
|
257 |
+
|
258 |
+
# remove the prefix from the model outputs
|
259 |
+
if len(self.fuser.fuse2cond['prepend']) > 0:
|
260 |
+
logits = logits[:, :, -S:]
|
261 |
+
|
262 |
+
return logits # [B, K, S, card]
|
263 |
+
|
264 |
+
def compute_predictions(
|
265 |
+
self, codes: torch.Tensor,
|
266 |
+
conditions: tp.List[ConditioningAttributes],
|
267 |
+
condition_tensors: tp.Optional[ConditionTensors] = None) -> LMOutput:
|
268 |
+
"""Given an input tensor of codes [B, K, T] and list of conditions, runs the model
|
269 |
+
forward using the specified codes interleaving pattern.
|
270 |
+
|
271 |
+
Args:
|
272 |
+
codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
|
273 |
+
K the number of codebooks and T the number of timesteps.
|
274 |
+
conditions (list[ConditioningAttributes]): conditionings to use when modeling
|
275 |
+
the given codes. Note that when evaluating multiple time with the same conditioning
|
276 |
+
you should pre-compute those and pass them as `condition_tensors`.
|
277 |
+
condition_tensors (dict[str, ConditionType] or None): pre-computed conditioning
|
278 |
+
tensors, see `conditions`.
|
279 |
+
Returns:
|
280 |
+
LMOutput: Language model outputs
|
281 |
+
logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
|
282 |
+
i.e. the first item corresponds to logits to predict the first code, meaning that
|
283 |
+
no additional shifting of codes and logits is required.
|
284 |
+
mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions.
|
285 |
+
Given the specified interleaving strategies, parts of the logits and codes should
|
286 |
+
not be considered as valid predictions because of invalid context.
|
287 |
+
"""
|
288 |
+
B, K, T = codes.shape
|
289 |
+
codes = codes.contiguous()
|
290 |
+
# map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
|
291 |
+
pattern = self.pattern_provider.get_pattern(T)
|
292 |
+
sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
|
293 |
+
codes, self.special_token_id, keep_only_valid_steps=True
|
294 |
+
)
|
295 |
+
# apply model on pattern sequence
|
296 |
+
model = self if self._fsdp is None else self._fsdp
|
297 |
+
logits = model(sequence_codes, conditions, condition_tensors) # [B, K, S, card]
|
298 |
+
# map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
|
299 |
+
# and provide the corresponding mask over invalid positions of tokens
|
300 |
+
logits = logits.permute(0, 3, 1, 2) # [B, card, K, S]
|
301 |
+
# note: we use nans as special token to make it obvious if we feed unexpected logits
|
302 |
+
logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
|
303 |
+
logits, float('nan'), keep_only_valid_steps=True
|
304 |
+
)
|
305 |
+
logits = logits.permute(0, 2, 3, 1) # [B, K, T, card]
|
306 |
+
logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T]
|
307 |
+
return LMOutput(logits, logits_mask)
|
308 |
+
|
309 |
+
def _sample_next_token(self,
|
310 |
+
sequence: torch.Tensor,
|
311 |
+
cfg_conditions: CFGConditions,
|
312 |
+
unconditional_state: State,
|
313 |
+
use_sampling: bool = False,
|
314 |
+
temp: float = 1.0,
|
315 |
+
top_k: int = 0,
|
316 |
+
top_p: float = 0.0,
|
317 |
+
cfg_coef: tp.Optional[float] = None) -> torch.Tensor:
|
318 |
+
"""Sample next token from the model given a sequence and a set of conditions. The model supports
|
319 |
+
multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
|
320 |
+
|
321 |
+
Args:
|
322 |
+
sequence (torch.Tensor): Current sequence of shape [B, K, S]
|
323 |
+
with K corresponding to the number of codebooks and S the number of sequence steps.
|
324 |
+
S = 1 in streaming mode, except for the first step that contains a bigger prompt.
|
325 |
+
condition_tensors (Dict[str, ConditionType): Set of conditions. If CFG is used,
|
326 |
+
should be twice the batch size, being the concatenation of the conditions + null conditions.
|
327 |
+
use_sampling (bool): Whether to use a sampling strategy or not.
|
328 |
+
temp (float): Sampling temperature.
|
329 |
+
top_k (int): K for "top-k" sampling.
|
330 |
+
top_p (float): P for "top-p" sampling.
|
331 |
+
cfg_coef (float): classifier free guidance coefficient
|
332 |
+
Returns:
|
333 |
+
next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
|
334 |
+
"""
|
335 |
+
B = sequence.shape[0]
|
336 |
+
cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
|
337 |
+
model = self if self._fsdp is None else self._fsdp
|
338 |
+
if self.two_step_cfg and cfg_conditions != {}:
|
339 |
+
assert isinstance(cfg_conditions, tuple)
|
340 |
+
condition_tensors, null_condition_tensors = cfg_conditions
|
341 |
+
cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
|
342 |
+
state = self.get_streaming_state()
|
343 |
+
self.set_streaming_state(unconditional_state)
|
344 |
+
uncond_logits = model(sequence, conditions=[], condition_tensors=null_condition_tensors)
|
345 |
+
unconditional_state.update(self.get_streaming_state())
|
346 |
+
self.set_streaming_state(state)
|
347 |
+
logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef
|
348 |
+
else:
|
349 |
+
assert isinstance(cfg_conditions, dict)
|
350 |
+
condition_tensors = cfg_conditions
|
351 |
+
if condition_tensors:
|
352 |
+
# Preparing for CFG, predicting both conditional and unconditional logits.
|
353 |
+
sequence = torch.cat([sequence, sequence], dim=0)
|
354 |
+
all_logits = model(
|
355 |
+
sequence,
|
356 |
+
conditions=[], condition_tensors=condition_tensors)
|
357 |
+
if condition_tensors:
|
358 |
+
cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card]
|
359 |
+
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef
|
360 |
+
else:
|
361 |
+
logits = all_logits
|
362 |
+
|
363 |
+
logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
|
364 |
+
logits = logits[..., -1] # [B x K x card]
|
365 |
+
|
366 |
+
# Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
|
367 |
+
if use_sampling and temp > 0.0:
|
368 |
+
probs = torch.softmax(logits / temp, dim=-1)
|
369 |
+
if top_p > 0.0:
|
370 |
+
next_token = utils.sample_top_p(probs, p=top_p)
|
371 |
+
elif top_k > 0:
|
372 |
+
next_token = utils.sample_top_k(probs, k=top_k)
|
373 |
+
else:
|
374 |
+
next_token = utils.multinomial(probs, num_samples=1)
|
375 |
+
else:
|
376 |
+
next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
377 |
+
|
378 |
+
return next_token
|
379 |
+
|
380 |
+
@torch.no_grad()
|
381 |
+
def generate(self,
|
382 |
+
prompt: tp.Optional[torch.Tensor] = None,
|
383 |
+
conditions: tp.List[ConditioningAttributes] = [],
|
384 |
+
num_samples: tp.Optional[int] = None,
|
385 |
+
max_gen_len: int = 256,
|
386 |
+
use_sampling: bool = True,
|
387 |
+
temp: float = 1.0,
|
388 |
+
top_k: int = 250,
|
389 |
+
top_p: float = 0.0,
|
390 |
+
cfg_coef: tp.Optional[float] = None,
|
391 |
+
two_step_cfg: bool = False,
|
392 |
+
remove_prompts: bool = False,
|
393 |
+
check: bool = False,
|
394 |
+
callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor:
|
395 |
+
"""Generate tokens sampling from the model given a prompt or unconditionally. Generation can
|
396 |
+
be perform in a greedy fashion or using sampling with top K and top P strategies.
|
397 |
+
|
398 |
+
Args:
|
399 |
+
prompt (Optional[torch.Tensor]): Prompt tokens of shape [B, K, T].
|
400 |
+
conditions_tensors (Dict[str, torch.Tensor]): Set of conditions or None.
|
401 |
+
num_samples (int or None): Number of samples to generate when no prompt and no conditions are given.
|
402 |
+
max_gen_len (int): Maximum generation length.
|
403 |
+
use_sampling (bool): Whether to use a sampling strategy or not.
|
404 |
+
temp (float): Sampling temperature.
|
405 |
+
top_k (int): K for "top-k" sampling.
|
406 |
+
top_p (float): P for "top-p" sampling.
|
407 |
+
remove_prompts (bool): Whether to remove prompts from generation or not.
|
408 |
+
Returns:
|
409 |
+
torch.Tensor: Generated tokens.
|
410 |
+
"""
|
411 |
+
assert not self.training, "generation shouldn't be used in training mode."
|
412 |
+
first_param = next(iter(self.parameters()))
|
413 |
+
device = first_param.device
|
414 |
+
|
415 |
+
# Checking all input shapes are consistents.
|
416 |
+
possible_num_samples = []
|
417 |
+
if num_samples is not None:
|
418 |
+
possible_num_samples.append(num_samples)
|
419 |
+
elif prompt is not None:
|
420 |
+
possible_num_samples.append(prompt.shape[0])
|
421 |
+
elif conditions:
|
422 |
+
possible_num_samples.append(len(conditions))
|
423 |
+
else:
|
424 |
+
possible_num_samples.append(1)
|
425 |
+
assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsitent inputs shapes"
|
426 |
+
num_samples = possible_num_samples[0]
|
427 |
+
|
428 |
+
# below we create set of conditions: one conditional and one unconditional
|
429 |
+
# to do that we merge the regular condition together with the null condition
|
430 |
+
# we then do 1 forward pass instead of 2.
|
431 |
+
# the reason for that is two-fold:
|
432 |
+
# 1. it is about x2 faster than doing 2 forward passes
|
433 |
+
# 2. avoid the streaming API treating the 2 passes as part of different time steps
|
434 |
+
# We also support doing two different passes, in particular to ensure that
|
435 |
+
# the padding structure is exactly the same between train anf test.
|
436 |
+
# With a batch size of 1, this can be slower though.
|
437 |
+
cfg_conditions: CFGConditions
|
438 |
+
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
|
439 |
+
if conditions:
|
440 |
+
null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
|
441 |
+
if two_step_cfg:
|
442 |
+
cfg_conditions = (
|
443 |
+
self.condition_provider(self.condition_provider.tokenize(conditions)),
|
444 |
+
self.condition_provider(self.condition_provider.tokenize(null_conditions)),
|
445 |
+
)
|
446 |
+
else:
|
447 |
+
conditions = conditions + null_conditions
|
448 |
+
tokenized = self.condition_provider.tokenize(conditions)
|
449 |
+
cfg_conditions = self.condition_provider(tokenized)
|
450 |
+
else:
|
451 |
+
cfg_conditions = {}
|
452 |
+
|
453 |
+
if prompt is None:
|
454 |
+
assert num_samples > 0
|
455 |
+
prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
|
456 |
+
|
457 |
+
B, K, T = prompt.shape
|
458 |
+
start_offset = T
|
459 |
+
assert start_offset < max_gen_len
|
460 |
+
|
461 |
+
pattern = self.pattern_provider.get_pattern(max_gen_len)
|
462 |
+
# this token is used as default value for codes that are not generated yet
|
463 |
+
unknown_token = -1
|
464 |
+
|
465 |
+
# we generate codes up to the max_gen_len that will be mapped to the pattern sequence
|
466 |
+
gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
|
467 |
+
# filling the gen_codes with the prompt if needed
|
468 |
+
gen_codes[..., :start_offset] = prompt
|
469 |
+
# create the gen_sequence with proper interleaving from the pattern: [B, K, S]
|
470 |
+
gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
|
471 |
+
# retrieve the start_offset in the sequence:
|
472 |
+
# it is the first sequence step that contains the `start_offset` timestep
|
473 |
+
start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
|
474 |
+
assert start_offset_sequence is not None
|
475 |
+
|
476 |
+
with self.streaming():
|
477 |
+
unconditional_state = self.get_streaming_state()
|
478 |
+
prev_offset = 0
|
479 |
+
gen_sequence_len = gen_sequence.shape[-1] # gen_sequence shape is [B, K, S]
|
480 |
+
for offset in range(start_offset_sequence, gen_sequence_len):
|
481 |
+
# get current sequence (note that the streaming API is providing the caching over previous offsets)
|
482 |
+
curr_sequence = gen_sequence[..., prev_offset:offset]
|
483 |
+
curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
|
484 |
+
if check:
|
485 |
+
# check coherence between mask and sequence
|
486 |
+
assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all()
|
487 |
+
# should never happen as gen_sequence is filled progressively
|
488 |
+
assert not (curr_sequence == unknown_token).any()
|
489 |
+
# sample next token from the model, next token shape is [B, K, 1]
|
490 |
+
next_token = self._sample_next_token(
|
491 |
+
curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
|
492 |
+
cfg_coef=cfg_coef)
|
493 |
+
# ensure the tokens that should be masked are properly set to special_token_id
|
494 |
+
# as the model never output special_token_id
|
495 |
+
valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
|
496 |
+
next_token[~valid_mask] = self.special_token_id
|
497 |
+
# ensure we don't overwrite prompt tokens, we only write over unknown tokens
|
498 |
+
# (then mask tokens should be left as is as well, which is correct)
|
499 |
+
gen_sequence[..., offset:offset+1] = torch.where(
|
500 |
+
gen_sequence[..., offset:offset+1] == unknown_token,
|
501 |
+
next_token, gen_sequence[..., offset:offset+1]
|
502 |
+
)
|
503 |
+
prev_offset = offset
|
504 |
+
if callback is not None:
|
505 |
+
callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
|
506 |
+
unconditional_state.clear()
|
507 |
+
|
508 |
+
# ensure sequence has been entirely filled
|
509 |
+
assert not (gen_sequence == unknown_token).any()
|
510 |
+
# ensure gen_sequence pattern and mask are matching
|
511 |
+
# which means the gen_sequence is valid according to the pattern
|
512 |
+
assert (
|
513 |
+
gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id)
|
514 |
+
).all()
|
515 |
+
# get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps
|
516 |
+
out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
|
517 |
+
|
518 |
+
# sanity checks over the returned codes and corresponding masks
|
519 |
+
assert (out_codes[..., :max_gen_len] != unknown_token).all()
|
520 |
+
assert (out_mask[..., :max_gen_len] == 1).all()
|
521 |
+
|
522 |
+
out_start_offset = start_offset if remove_prompts else 0
|
523 |
+
out_codes = out_codes[..., out_start_offset:max_gen_len]
|
524 |
+
|
525 |
+
# ensure the returned codes are all valid
|
526 |
+
assert (out_codes >= 0).all() and (out_codes <= self.card).all()
|
527 |
+
return out_codes
|
audiocraft/models/loaders.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Utility functions to load from the checkpoints.
|
9 |
+
Each checkpoint is a torch.saved dict with the following keys:
|
10 |
+
- 'xp.cfg': the hydra config as dumped during training. This should be used
|
11 |
+
to rebuild the object using the audiocraft.models.builders functions,
|
12 |
+
- 'model_best_state': a readily loadable best state for the model, including
|
13 |
+
the conditioner. The model obtained from `xp.cfg` should be compatible
|
14 |
+
with this state dict. In the case of a LM, the encodec model would not be
|
15 |
+
bundled along but instead provided separately.
|
16 |
+
|
17 |
+
Those functions also support loading from a remote location with the Torch Hub API.
|
18 |
+
They also support overriding some parameters, in particular the device and dtype
|
19 |
+
of the returned model.
|
20 |
+
"""
|
21 |
+
|
22 |
+
from pathlib import Path
|
23 |
+
from huggingface_hub import hf_hub_download
|
24 |
+
import typing as tp
|
25 |
+
import os
|
26 |
+
|
27 |
+
from omegaconf import OmegaConf
|
28 |
+
import torch
|
29 |
+
|
30 |
+
from . import builders
|
31 |
+
|
32 |
+
|
33 |
+
HF_MODEL_CHECKPOINTS_MAP = {
|
34 |
+
"small": "GrandaddyShmax/musicgen-small",
|
35 |
+
"medium": "GrandaddyShmax/musicgen-medium",
|
36 |
+
"large": "GrandaddyShmax/musicgen-large",
|
37 |
+
"melody": "GrandaddyShmax/musicgen-melody",
|
38 |
+
}
|
39 |
+
|
40 |
+
|
41 |
+
def _get_state_dict(
|
42 |
+
file_or_url_or_id: tp.Union[Path, str],
|
43 |
+
filename: tp.Optional[str] = None,
|
44 |
+
device='cpu',
|
45 |
+
cache_dir: tp.Optional[str] = None,
|
46 |
+
):
|
47 |
+
# Return the state dict either from a file or url
|
48 |
+
file_or_url_or_id = str(file_or_url_or_id)
|
49 |
+
assert isinstance(file_or_url_or_id, str)
|
50 |
+
|
51 |
+
if os.path.isfile(file_or_url_or_id):
|
52 |
+
return torch.load(file_or_url_or_id, map_location=device)
|
53 |
+
|
54 |
+
elif file_or_url_or_id.startswith('https://'):
|
55 |
+
return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
|
56 |
+
|
57 |
+
elif file_or_url_or_id in HF_MODEL_CHECKPOINTS_MAP:
|
58 |
+
assert filename is not None, "filename needs to be defined if using HF checkpoints"
|
59 |
+
|
60 |
+
repo_id = HF_MODEL_CHECKPOINTS_MAP[file_or_url_or_id]
|
61 |
+
file = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir)
|
62 |
+
return torch.load(file, map_location=device)
|
63 |
+
|
64 |
+
else:
|
65 |
+
raise ValueError(f"{file_or_url_or_id} is not a valid name, path or link that can be loaded.")
|
66 |
+
|
67 |
+
|
68 |
+
def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
|
69 |
+
pkg = _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir)
|
70 |
+
cfg = OmegaConf.create(pkg['xp.cfg'])
|
71 |
+
cfg.device = str(device)
|
72 |
+
model = builders.get_compression_model(cfg)
|
73 |
+
model.load_state_dict(pkg['best_state'])
|
74 |
+
model.eval()
|
75 |
+
return model
|
76 |
+
|
77 |
+
|
78 |
+
def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
|
79 |
+
pkg = _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir)
|
80 |
+
cfg = OmegaConf.create(pkg['xp.cfg'])
|
81 |
+
cfg.device = str(device)
|
82 |
+
if cfg.device == 'cpu':
|
83 |
+
cfg.dtype = 'float32'
|
84 |
+
else:
|
85 |
+
cfg.dtype = 'float16'
|
86 |
+
model = builders.get_lm_model(cfg)
|
87 |
+
model.load_state_dict(pkg['best_state'])
|
88 |
+
model.eval()
|
89 |
+
model.cfg = cfg
|
90 |
+
return model
|
audiocraft/models/musicgen.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Main model for using MusicGen. This will combine all the required components
|
9 |
+
and provide easy access to the generation API.
|
10 |
+
"""
|
11 |
+
|
12 |
+
import os
|
13 |
+
import typing as tp
|
14 |
+
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from .encodec import CompressionModel
|
18 |
+
from .lm import LMModel
|
19 |
+
from .builders import get_debug_compression_model, get_debug_lm_model
|
20 |
+
from .loaders import load_compression_model, load_lm_model, HF_MODEL_CHECKPOINTS_MAP
|
21 |
+
from ..data.audio_utils import convert_audio
|
22 |
+
from ..modules.conditioners import ConditioningAttributes, WavCondition
|
23 |
+
from ..utils.autocast import TorchAutocast
|
24 |
+
|
25 |
+
|
26 |
+
MelodyList = tp.List[tp.Optional[torch.Tensor]]
|
27 |
+
MelodyType = tp.Union[torch.Tensor, MelodyList]
|
28 |
+
|
29 |
+
|
30 |
+
class MusicGen:
|
31 |
+
"""MusicGen main model with convenient generation API.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
name (str): name of the model.
|
35 |
+
compression_model (CompressionModel): Compression model
|
36 |
+
used to map audio to invertible discrete representations.
|
37 |
+
lm (LMModel): Language model over discrete representations.
|
38 |
+
"""
|
39 |
+
def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
|
40 |
+
max_duration: float = 30):
|
41 |
+
self.name = name
|
42 |
+
self.compression_model = compression_model
|
43 |
+
self.lm = lm
|
44 |
+
self.max_duration = max_duration
|
45 |
+
self.device = next(iter(lm.parameters())).device
|
46 |
+
self.generation_params: dict = {}
|
47 |
+
self.set_generation_params(duration=15) # 15 seconds by default
|
48 |
+
self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
|
49 |
+
if self.device.type == 'cpu':
|
50 |
+
self.autocast = TorchAutocast(enabled=False)
|
51 |
+
else:
|
52 |
+
self.autocast = TorchAutocast(
|
53 |
+
enabled=True, device_type=self.device.type, dtype=torch.float16)
|
54 |
+
|
55 |
+
@property
|
56 |
+
def frame_rate(self) -> int:
|
57 |
+
"""Roughly the number of AR steps per seconds."""
|
58 |
+
return self.compression_model.frame_rate
|
59 |
+
|
60 |
+
@property
|
61 |
+
def sample_rate(self) -> int:
|
62 |
+
"""Sample rate of the generated audio."""
|
63 |
+
return self.compression_model.sample_rate
|
64 |
+
|
65 |
+
@property
|
66 |
+
def audio_channels(self) -> int:
|
67 |
+
"""Audio channels of the generated audio."""
|
68 |
+
return self.compression_model.channels
|
69 |
+
|
70 |
+
@staticmethod
|
71 |
+
def get_pretrained(name: str = 'melody', device=None):
|
72 |
+
"""Return pretrained model, we provide four models:
|
73 |
+
- small (300M), text to music, # see: https://huggingface.co/facebook/musicgen-small
|
74 |
+
- medium (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-medium
|
75 |
+
- melody (1.5B) text to music and text+melody to music, # see: https://huggingface.co/facebook/musicgen-melody
|
76 |
+
- large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-large
|
77 |
+
"""
|
78 |
+
|
79 |
+
if device is None:
|
80 |
+
if torch.cuda.device_count():
|
81 |
+
device = 'cuda'
|
82 |
+
else:
|
83 |
+
device = 'cpu'
|
84 |
+
|
85 |
+
if name == 'debug':
|
86 |
+
# used only for unit tests
|
87 |
+
compression_model = get_debug_compression_model(device)
|
88 |
+
lm = get_debug_lm_model(device)
|
89 |
+
return MusicGen(name, compression_model, lm)
|
90 |
+
|
91 |
+
if name not in HF_MODEL_CHECKPOINTS_MAP:
|
92 |
+
raise ValueError(
|
93 |
+
f"{name} is not a valid checkpoint name. "
|
94 |
+
f"Choose one of {', '.join(HF_MODEL_CHECKPOINTS_MAP.keys())}"
|
95 |
+
)
|
96 |
+
|
97 |
+
cache_dir = os.environ.get('MUSICGEN_ROOT', None)
|
98 |
+
compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
|
99 |
+
lm = load_lm_model(name, device=device, cache_dir=cache_dir)
|
100 |
+
if name == 'melody':
|
101 |
+
lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
|
102 |
+
|
103 |
+
return MusicGen(name, compression_model, lm)
|
104 |
+
|
105 |
+
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
|
106 |
+
top_p: float = 0.0, temperature: float = 1.0,
|
107 |
+
duration: float = 30.0, cfg_coef: float = 3.0,
|
108 |
+
two_step_cfg: bool = False, extend_stride: float = 18):
|
109 |
+
"""Set the generation parameters for MusicGen.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
|
113 |
+
top_k (int, optional): top_k used for sampling. Defaults to 250.
|
114 |
+
top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
|
115 |
+
temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
|
116 |
+
duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
|
117 |
+
cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
|
118 |
+
two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
|
119 |
+
instead of batching together the two. This has some impact on how things
|
120 |
+
are padded but seems to have little impact in practice.
|
121 |
+
extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much
|
122 |
+
should we extend the audio each time. Larger values will mean less context is
|
123 |
+
preserved, and shorter value will require extra computations.
|
124 |
+
"""
|
125 |
+
assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
|
126 |
+
self.extend_stride = extend_stride
|
127 |
+
self.duration = duration
|
128 |
+
self.generation_params = {
|
129 |
+
'use_sampling': use_sampling,
|
130 |
+
'temp': temperature,
|
131 |
+
'top_k': top_k,
|
132 |
+
'top_p': top_p,
|
133 |
+
'cfg_coef': cfg_coef,
|
134 |
+
'two_step_cfg': two_step_cfg,
|
135 |
+
}
|
136 |
+
|
137 |
+
def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
|
138 |
+
"""Override the default progress callback."""
|
139 |
+
self._progress_callback = progress_callback
|
140 |
+
|
141 |
+
def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor:
|
142 |
+
"""Generate samples in an unconditional manner.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
num_samples (int): Number of samples to be generated.
|
146 |
+
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
147 |
+
"""
|
148 |
+
descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
|
149 |
+
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
|
150 |
+
return self._generate_tokens(attributes, prompt_tokens, progress)
|
151 |
+
|
152 |
+
def generate(self, descriptions: tp.List[str], progress: bool = False) -> torch.Tensor:
|
153 |
+
"""Generate samples conditioned on text.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
descriptions (tp.List[str]): A list of strings used as text conditioning.
|
157 |
+
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
158 |
+
"""
|
159 |
+
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
|
160 |
+
assert prompt_tokens is None
|
161 |
+
return self._generate_tokens(attributes, prompt_tokens, progress)
|
162 |
+
|
163 |
+
def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType,
|
164 |
+
melody_sample_rate: int, progress: bool = False) -> torch.Tensor:
|
165 |
+
"""Generate samples conditioned on text and melody.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
descriptions (tp.List[str]): A list of strings used as text conditioning.
|
169 |
+
melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as
|
170 |
+
melody conditioning. Should have shape [B, C, T] with B matching the description length,
|
171 |
+
C=1 or 2. It can be [C, T] if there is a single description. It can also be
|
172 |
+
a list of [C, T] tensors.
|
173 |
+
melody_sample_rate: (int): Sample rate of the melody waveforms.
|
174 |
+
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
175 |
+
"""
|
176 |
+
if isinstance(melody_wavs, torch.Tensor):
|
177 |
+
if melody_wavs.dim() == 2:
|
178 |
+
melody_wavs = melody_wavs[None]
|
179 |
+
if melody_wavs.dim() != 3:
|
180 |
+
raise ValueError("Melody wavs should have a shape [B, C, T].")
|
181 |
+
melody_wavs = list(melody_wavs)
|
182 |
+
else:
|
183 |
+
for melody in melody_wavs:
|
184 |
+
if melody is not None:
|
185 |
+
assert melody.dim() == 2, "One melody in the list has the wrong number of dims."
|
186 |
+
|
187 |
+
melody_wavs = [
|
188 |
+
convert_audio(wav, melody_sample_rate, self.sample_rate, self.audio_channels)
|
189 |
+
if wav is not None else None
|
190 |
+
for wav in melody_wavs]
|
191 |
+
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None,
|
192 |
+
melody_wavs=melody_wavs)
|
193 |
+
assert prompt_tokens is None
|
194 |
+
return self._generate_tokens(attributes, prompt_tokens, progress)
|
195 |
+
|
196 |
+
def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
|
197 |
+
descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
|
198 |
+
progress: bool = False) -> torch.Tensor:
|
199 |
+
"""Generate samples conditioned on audio prompts.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
prompt (torch.Tensor): A batch of waveforms used for continuation.
|
203 |
+
Prompt should be [B, C, T], or [C, T] if only one sample is generated.
|
204 |
+
prompt_sample_rate (int): Sampling rate of the given audio waveforms.
|
205 |
+
descriptions (tp.List[str], optional): A list of strings used as text conditioning. Defaults to None.
|
206 |
+
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
207 |
+
"""
|
208 |
+
if prompt.dim() == 2:
|
209 |
+
prompt = prompt[None]
|
210 |
+
if prompt.dim() != 3:
|
211 |
+
raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
|
212 |
+
prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
|
213 |
+
if descriptions is None:
|
214 |
+
descriptions = [None] * len(prompt)
|
215 |
+
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
|
216 |
+
assert prompt_tokens is not None
|
217 |
+
return self._generate_tokens(attributes, prompt_tokens, progress)
|
218 |
+
|
219 |
+
@torch.no_grad()
|
220 |
+
def _prepare_tokens_and_attributes(
|
221 |
+
self,
|
222 |
+
descriptions: tp.Sequence[tp.Optional[str]],
|
223 |
+
prompt: tp.Optional[torch.Tensor],
|
224 |
+
melody_wavs: tp.Optional[MelodyList] = None,
|
225 |
+
) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
|
226 |
+
"""Prepare model inputs.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
descriptions (tp.List[str]): A list of strings used as text conditioning.
|
230 |
+
prompt (torch.Tensor): A batch of waveforms used for continuation.
|
231 |
+
melody_wavs (tp.Optional[torch.Tensor], optional): A batch of waveforms
|
232 |
+
used as melody conditioning. Defaults to None.
|
233 |
+
"""
|
234 |
+
attributes = [
|
235 |
+
ConditioningAttributes(text={'description': description})
|
236 |
+
for description in descriptions]
|
237 |
+
|
238 |
+
if melody_wavs is None:
|
239 |
+
for attr in attributes:
|
240 |
+
attr.wav['self_wav'] = WavCondition(
|
241 |
+
torch.zeros((1, 1), device=self.device),
|
242 |
+
torch.tensor([0], device=self.device),
|
243 |
+
path='null_wav') # type: ignore
|
244 |
+
else:
|
245 |
+
if self.name != "melody":
|
246 |
+
raise RuntimeError("This model doesn't support melody conditioning. "
|
247 |
+
"Use the `melody` model.")
|
248 |
+
assert len(melody_wavs) == len(descriptions), \
|
249 |
+
f"number of melody wavs must match number of descriptions! " \
|
250 |
+
f"got melody len={len(melody_wavs)}, and descriptions len={len(descriptions)}"
|
251 |
+
for attr, melody in zip(attributes, melody_wavs):
|
252 |
+
if melody is None:
|
253 |
+
attr.wav['self_wav'] = WavCondition(
|
254 |
+
torch.zeros((1, 1), device=self.device),
|
255 |
+
torch.tensor([0], device=self.device),
|
256 |
+
path='null_wav') # type: ignore
|
257 |
+
else:
|
258 |
+
attr.wav['self_wav'] = WavCondition(
|
259 |
+
melody.to(device=self.device),
|
260 |
+
torch.tensor([melody.shape[-1]], device=self.device))
|
261 |
+
|
262 |
+
if prompt is not None:
|
263 |
+
if descriptions is not None:
|
264 |
+
assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
|
265 |
+
prompt = prompt.to(self.device)
|
266 |
+
prompt_tokens, scale = self.compression_model.encode(prompt)
|
267 |
+
assert scale is None
|
268 |
+
else:
|
269 |
+
prompt_tokens = None
|
270 |
+
return attributes, prompt_tokens
|
271 |
+
|
272 |
+
def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
|
273 |
+
prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
|
274 |
+
"""Generate discrete audio tokens given audio prompt and/or conditions.
|
275 |
+
|
276 |
+
Args:
|
277 |
+
attributes (tp.List[ConditioningAttributes]): Conditions used for generation (text/melody).
|
278 |
+
prompt_tokens (tp.Optional[torch.Tensor]): Audio prompt used for continuation.
|
279 |
+
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
280 |
+
Returns:
|
281 |
+
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
|
282 |
+
"""
|
283 |
+
total_gen_len = int(self.duration * self.frame_rate)
|
284 |
+
max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
|
285 |
+
current_gen_offset: int = 0
|
286 |
+
|
287 |
+
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
288 |
+
generated_tokens += current_gen_offset
|
289 |
+
if self._progress_callback is not None:
|
290 |
+
# Note that total_gen_len might be quite wrong depending on the
|
291 |
+
# codebook pattern used, but with delay it is almost accurate.
|
292 |
+
self._progress_callback(generated_tokens, total_gen_len)
|
293 |
+
else:
|
294 |
+
print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
|
295 |
+
|
296 |
+
if prompt_tokens is not None:
|
297 |
+
assert max_prompt_len >= prompt_tokens.shape[-1], \
|
298 |
+
"Prompt is longer than audio to generate"
|
299 |
+
|
300 |
+
callback = None
|
301 |
+
if progress:
|
302 |
+
callback = _progress_callback
|
303 |
+
|
304 |
+
if self.duration <= self.max_duration:
|
305 |
+
# generate by sampling from LM, simple case.
|
306 |
+
with self.autocast:
|
307 |
+
gen_tokens = self.lm.generate(
|
308 |
+
prompt_tokens, attributes,
|
309 |
+
callback=callback, max_gen_len=total_gen_len, **self.generation_params)
|
310 |
+
|
311 |
+
else:
|
312 |
+
# now this gets a bit messier, we need to handle prompts,
|
313 |
+
# melody conditioning etc.
|
314 |
+
ref_wavs = [attr.wav['self_wav'] for attr in attributes]
|
315 |
+
all_tokens = []
|
316 |
+
if prompt_tokens is None:
|
317 |
+
prompt_length = 0
|
318 |
+
else:
|
319 |
+
all_tokens.append(prompt_tokens)
|
320 |
+
prompt_length = prompt_tokens.shape[-1]
|
321 |
+
|
322 |
+
stride_tokens = int(self.frame_rate * self.extend_stride)
|
323 |
+
|
324 |
+
while current_gen_offset + prompt_length < total_gen_len:
|
325 |
+
time_offset = current_gen_offset / self.frame_rate
|
326 |
+
chunk_duration = min(self.duration - time_offset, self.max_duration)
|
327 |
+
max_gen_len = int(chunk_duration * self.frame_rate)
|
328 |
+
for attr, ref_wav in zip(attributes, ref_wavs):
|
329 |
+
wav_length = ref_wav.length.item()
|
330 |
+
if wav_length == 0:
|
331 |
+
continue
|
332 |
+
# We will extend the wav periodically if it not long enough.
|
333 |
+
# we have to do it here rather than in conditioners.py as otherwise
|
334 |
+
# we wouldn't have the full wav.
|
335 |
+
initial_position = int(time_offset * self.sample_rate)
|
336 |
+
wav_target_length = int(self.max_duration * self.sample_rate)
|
337 |
+
print(initial_position / self.sample_rate, wav_target_length / self.sample_rate)
|
338 |
+
positions = torch.arange(initial_position,
|
339 |
+
initial_position + wav_target_length, device=self.device)
|
340 |
+
attr.wav['self_wav'] = WavCondition(
|
341 |
+
ref_wav[0][:, positions % wav_length],
|
342 |
+
torch.full_like(ref_wav[1], wav_target_length))
|
343 |
+
with self.autocast:
|
344 |
+
gen_tokens = self.lm.generate(
|
345 |
+
prompt_tokens, attributes,
|
346 |
+
callback=callback, max_gen_len=max_gen_len, **self.generation_params)
|
347 |
+
if prompt_tokens is None:
|
348 |
+
all_tokens.append(gen_tokens)
|
349 |
+
else:
|
350 |
+
all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
|
351 |
+
prompt_tokens = gen_tokens[:, :, stride_tokens:]
|
352 |
+
prompt_length = prompt_tokens.shape[-1]
|
353 |
+
current_gen_offset += stride_tokens
|
354 |
+
|
355 |
+
gen_tokens = torch.cat(all_tokens, dim=-1)
|
356 |
+
|
357 |
+
# generate audio
|
358 |
+
assert gen_tokens.dim() == 3
|
359 |
+
with torch.no_grad():
|
360 |
+
gen_audio = self.compression_model.decode(gen_tokens, None)
|
361 |
+
return gen_audio
|
audiocraft/modules/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# flake8: noqa
|
8 |
+
from .conv import (
|
9 |
+
NormConv1d,
|
10 |
+
NormConv2d,
|
11 |
+
NormConvTranspose1d,
|
12 |
+
NormConvTranspose2d,
|
13 |
+
StreamableConv1d,
|
14 |
+
StreamableConvTranspose1d,
|
15 |
+
pad_for_conv1d,
|
16 |
+
pad1d,
|
17 |
+
unpad1d,
|
18 |
+
)
|
19 |
+
from .lstm import StreamableLSTM
|
20 |
+
from .seanet import SEANetEncoder, SEANetDecoder
|
audiocraft/modules/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (496 Bytes). View file
|
|
audiocraft/modules/__pycache__/activations.cpython-310.pyc
ADDED
Binary file (3.94 kB). View file
|
|
audiocraft/modules/__pycache__/codebooks_patterns.cpython-310.pyc
ADDED
Binary file (25.1 kB). View file
|
|
audiocraft/modules/__pycache__/conditioners.cpython-310.pyc
ADDED
Binary file (41.5 kB). View file
|
|
audiocraft/modules/__pycache__/conv.cpython-310.pyc
ADDED
Binary file (9.04 kB). View file
|
|
audiocraft/modules/__pycache__/lstm.cpython-310.pyc
ADDED
Binary file (1.02 kB). View file
|
|
audiocraft/modules/__pycache__/rope.cpython-310.pyc
ADDED
Binary file (5.16 kB). View file
|
|