trttung1610 commited on
Commit
26246bd
1 Parent(s): 42fc3a0

Upload 233 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +63 -0
  2. CHANGELOG.md +28 -0
  3. CODE_OF_CONDUCT.md +80 -0
  4. CONTRIBUTING.md +35 -0
  5. Dockerfile +26 -0
  6. LICENSE +21 -0
  7. LICENSE_weights +399 -0
  8. MANIFEST.in +9 -0
  9. Makefile +40 -0
  10. README.md +7 -6
  11. app.py +463 -0
  12. app_v2.py +1839 -0
  13. assets/a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3 +0 -0
  14. assets/bach.mp3 +0 -0
  15. assets/bolero_ravel.mp3 +0 -0
  16. assets/sirens_and_a_humming_engine_approach_and_pass.mp3 +0 -0
  17. audiocraft/__init__.py +26 -0
  18. audiocraft/adversarial/__init__.py +22 -0
  19. audiocraft/adversarial/discriminators/__init__.py +10 -0
  20. audiocraft/adversarial/discriminators/base.py +34 -0
  21. audiocraft/adversarial/discriminators/mpd.py +106 -0
  22. audiocraft/adversarial/discriminators/msd.py +126 -0
  23. audiocraft/adversarial/discriminators/msstftd.py +134 -0
  24. audiocraft/adversarial/losses.py +228 -0
  25. audiocraft/data/__init__.py +10 -0
  26. audiocraft/data/audio.py +216 -0
  27. audiocraft/data/audio_dataset.py +587 -0
  28. audiocraft/data/audio_utils.py +177 -0
  29. audiocraft/data/info_audio_dataset.py +110 -0
  30. audiocraft/data/music_dataset.py +270 -0
  31. audiocraft/data/sound_dataset.py +330 -0
  32. audiocraft/data/zip.py +76 -0
  33. audiocraft/environment.py +176 -0
  34. audiocraft/grids/__init__.py +6 -0
  35. audiocraft/grids/_base_explorers.py +80 -0
  36. audiocraft/grids/audiogen/__init__.py +6 -0
  37. audiocraft/grids/audiogen/audiogen_base_16khz.py +23 -0
  38. audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py +68 -0
  39. audiocraft/grids/compression/__init__.py +6 -0
  40. audiocraft/grids/compression/_explorers.py +55 -0
  41. audiocraft/grids/compression/debug.py +31 -0
  42. audiocraft/grids/compression/encodec_audiogen_16khz.py +29 -0
  43. audiocraft/grids/compression/encodec_base_24khz.py +28 -0
  44. audiocraft/grids/compression/encodec_musicgen_32khz.py +34 -0
  45. audiocraft/grids/diffusion/4_bands_base_32khz.py +27 -0
  46. audiocraft/grids/diffusion/__init__.py +6 -0
  47. audiocraft/grids/diffusion/_explorers.py +66 -0
  48. audiocraft/grids/musicgen/__init__.py +6 -0
  49. audiocraft/grids/musicgen/_explorers.py +93 -0
  50. audiocraft/grids/musicgen/musicgen_base_32khz.py +43 -0
.gitignore ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # macOS dir files
10
+ .DS_Store
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ env/
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ .ipynb_checkpoints
31
+
32
+ # Tests and linter
33
+ .pytest_cache/
34
+ .mypy_cache/
35
+ .coverage
36
+
37
+ # docs
38
+ /api_docs
39
+
40
+ # dotenv
41
+ .env
42
+ .envrc
43
+
44
+ # virtualenv
45
+ .venv
46
+ venv/
47
+ ENV/
48
+
49
+ # egs with manifest files
50
+ egs/*
51
+ !egs/example
52
+ # local datasets
53
+ dataset/*
54
+ !dataset/example
55
+
56
+ # personal notebooks & scripts
57
+ */local_scripts
58
+ */notes
59
+ .vscode/
60
+ /notebooks
61
+ /local_scripts
62
+ /notes
63
+ /cache
CHANGELOG.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Changelog
2
+
3
+ All notable changes to this project will be documented in this file.
4
+
5
+ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
6
+
7
+ ## [1.0.0] - 2023-08-02
8
+
9
+ Major revision, added training code for EnCodec, AudioGen, MusicGen, and MultiBandDiffusion.
10
+ Added pretrained model for AudioGen and MultiBandDiffusion.
11
+
12
+ ## [0.0.2] - 2023-08-01
13
+
14
+ Improved demo, fixed top p (thanks @jnordberg).
15
+
16
+ Compressor tanh on output to avoid clipping with some style (especially piano).
17
+ Now repeating the conditioning periodically if it is too short.
18
+
19
+ More options when launching Gradio app locally (thanks @ashleykleynhans).
20
+
21
+ Testing out PyTorch 2.0 memory efficient attention.
22
+
23
+ Added extended generation (infinite length) by slowly moving the windows.
24
+ Note that other implementations exist: https://github.com/camenduru/MusicGen-colab.
25
+
26
+ ## [0.0.1] - 2023-06-09
27
+
28
+ Initial release, with model evaluation only.
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ This Code of Conduct also applies outside the project spaces when there is a
56
+ reasonable belief that an individual's behavior may have a negative impact on
57
+ the project or its community.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported by contacting the project team at <opensource-conduct@fb.com>. All
63
+ complaints will be reviewed and investigated and will result in a response that
64
+ is deemed necessary and appropriate to the circumstances. The project team is
65
+ obligated to maintain confidentiality with regard to the reporter of an incident.
66
+ Further details of specific enforcement policies may be posted separately.
67
+
68
+ Project maintainers who do not follow or enforce the Code of Conduct in good
69
+ faith may face temporary or permanent repercussions as determined by other
70
+ members of the project's leadership.
71
+
72
+ ## Attribution
73
+
74
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76
+
77
+ [homepage]: https://www.contributor-covenant.org
78
+
79
+ For answers to common questions about this code of conduct, see
80
+ https://www.contributor-covenant.org/faq
CONTRIBUTING.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to AudioCraft
2
+
3
+ We want to make contributing to this project as easy and transparent as
4
+ possible.
5
+
6
+ ## Pull Requests
7
+
8
+ AudioCraft is the implementation of a research paper.
9
+ Therefore, we do not plan on accepting many pull requests for new features.
10
+ We certainly welcome them for bug fixes.
11
+
12
+ 1. Fork the repo and create your branch from `main`.
13
+ 2. If you've added code that should be tested, add tests.
14
+ 3. If you've changed APIs, update the documentation.
15
+ 4. Ensure the test suite passes.
16
+ 5. Make sure your code lints.
17
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
18
+
19
+ ## Contributor License Agreement ("CLA")
20
+ In order to accept your pull request, we need you to submit a CLA. You only need
21
+ to do this once to work on any of Meta's open source projects.
22
+
23
+ Complete your CLA here: <https://code.facebook.com/cla>
24
+
25
+ ## Issues
26
+ We use GitHub issues to track public bugs. Please ensure your description is
27
+ clear and has sufficient instructions to be able to reproduce the issue.
28
+
29
+ Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
30
+ disclosure of security bugs. In those cases, please go through the process
31
+ outlined on that page and do not file a public issue.
32
+
33
+ ## License
34
+ By contributing to encodec, you agree that your contributions will be licensed
35
+ under the LICENSE file in the root directory of this source tree.
Dockerfile ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.8.0-base-ubuntu22.04
2
+
3
+ ENV DEBIAN_FRONTEND=noninteractive \
4
+ PYTHONUNBUFFERED=1 \
5
+ PYTHONIOENCODING=UTF-8
6
+ RUN --mount=type=cache,target=/var/cache/apt --mount=type=cache,target=/var/lib/apt apt update &&\
7
+ apt install -y \
8
+ wget \
9
+ git \
10
+ pkg-config \
11
+ python3 \
12
+ python3-pip \
13
+ python-is-python3 \
14
+ ffmpeg \
15
+ libnvrtc11.2 \
16
+ libtcmalloc-minimal4
17
+
18
+ RUN useradd -m -u 1000 ac
19
+ RUN --mount=type=cache,target=/root/.cache python -m pip install --upgrade pip wheel
20
+ ENV TORCH_COMMAND="pip install torch==2.0.1+cu118 torchaudio --extra-index-url https://download.pytorch.org/whl/cu118"
21
+ RUN --mount=type=cache,target=/root/.cache python -m $TORCH_COMMAND
22
+ RUN ln -s /usr/lib/x86_64-linux-gnu/libnvrtc.so.11.2 /usr/lib/x86_64-linux-gnu/libnvrtc.so
23
+ USER 1000
24
+ RUN mkdir ~/.cache
25
+ RUN --mount=type=cache,target=/home/ac/.cache --mount=source=.,target=/home/ac/audiocraft python -m pip install -r /home/ac/audiocraft/requirements.txt
26
+ WORKDIR /home/ac/audiocraft
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Meta Platforms, Inc. and affiliates.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
LICENSE_weights ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More_considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial 4.0 International Public
58
+ License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial 4.0 International Public License ("Public
63
+ License"). To the extent this Public License may be interpreted as a
64
+ contract, You are granted the Licensed Rights in consideration of Your
65
+ acceptance of these terms and conditions, and the Licensor grants You
66
+ such rights in consideration of benefits the Licensor receives from
67
+ making the Licensed Material available under these terms and
68
+ conditions.
69
+
70
+ Section 1 -- Definitions.
71
+
72
+ a. Adapted Material means material subject to Copyright and Similar
73
+ Rights that is derived from or based upon the Licensed Material
74
+ and in which the Licensed Material is translated, altered,
75
+ arranged, transformed, or otherwise modified in a manner requiring
76
+ permission under the Copyright and Similar Rights held by the
77
+ Licensor. For purposes of this Public License, where the Licensed
78
+ Material is a musical work, performance, or sound recording,
79
+ Adapted Material is always produced where the Licensed Material is
80
+ synched in timed relation with a moving image.
81
+
82
+ b. Adapter's License means the license You apply to Your Copyright
83
+ and Similar Rights in Your contributions to Adapted Material in
84
+ accordance with the terms and conditions of this Public License.
85
+
86
+ c. Copyright and Similar Rights means copyright and/or similar rights
87
+ closely related to copyright including, without limitation,
88
+ performance, broadcast, sound recording, and Sui Generis Database
89
+ Rights, without regard to how the rights are labeled or
90
+ categorized. For purposes of this Public License, the rights
91
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
92
+ Rights.
93
+ d. Effective Technological Measures means those measures that, in the
94
+ absence of proper authority, may not be circumvented under laws
95
+ fulfilling obligations under Article 11 of the WIPO Copyright
96
+ Treaty adopted on December 20, 1996, and/or similar international
97
+ agreements.
98
+
99
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
100
+ any other exception or limitation to Copyright and Similar Rights
101
+ that applies to Your use of the Licensed Material.
102
+
103
+ f. Licensed Material means the artistic or literary work, database,
104
+ or other material to which the Licensor applied this Public
105
+ License.
106
+
107
+ g. Licensed Rights means the rights granted to You subject to the
108
+ terms and conditions of this Public License, which are limited to
109
+ all Copyright and Similar Rights that apply to Your use of the
110
+ Licensed Material and that the Licensor has authority to license.
111
+
112
+ h. Licensor means the individual(s) or entity(ies) granting rights
113
+ under this Public License.
114
+
115
+ i. NonCommercial means not primarily intended for or directed towards
116
+ commercial advantage or monetary compensation. For purposes of
117
+ this Public License, the exchange of the Licensed Material for
118
+ other material subject to Copyright and Similar Rights by digital
119
+ file-sharing or similar means is NonCommercial provided there is
120
+ no payment of monetary compensation in connection with the
121
+ exchange.
122
+
123
+ j. Share means to provide material to the public by any means or
124
+ process that requires permission under the Licensed Rights, such
125
+ as reproduction, public display, public performance, distribution,
126
+ dissemination, communication, or importation, and to make material
127
+ available to the public including in ways that members of the
128
+ public may access the material from a place and at a time
129
+ individually chosen by them.
130
+
131
+ k. Sui Generis Database Rights means rights other than copyright
132
+ resulting from Directive 96/9/EC of the European Parliament and of
133
+ the Council of 11 March 1996 on the legal protection of databases,
134
+ as amended and/or succeeded, as well as other essentially
135
+ equivalent rights anywhere in the world.
136
+
137
+ l. You means the individual or entity exercising the Licensed Rights
138
+ under this Public License. Your has a corresponding meaning.
139
+
140
+ Section 2 -- Scope.
141
+
142
+ a. License grant.
143
+
144
+ 1. Subject to the terms and conditions of this Public License,
145
+ the Licensor hereby grants You a worldwide, royalty-free,
146
+ non-sublicensable, non-exclusive, irrevocable license to
147
+ exercise the Licensed Rights in the Licensed Material to:
148
+
149
+ a. reproduce and Share the Licensed Material, in whole or
150
+ in part, for NonCommercial purposes only; and
151
+
152
+ b. produce, reproduce, and Share Adapted Material for
153
+ NonCommercial purposes only.
154
+
155
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
156
+ Exceptions and Limitations apply to Your use, this Public
157
+ License does not apply, and You do not need to comply with
158
+ its terms and conditions.
159
+
160
+ 3. Term. The term of this Public License is specified in Section
161
+ 6(a).
162
+
163
+ 4. Media and formats; technical modifications allowed. The
164
+ Licensor authorizes You to exercise the Licensed Rights in
165
+ all media and formats whether now known or hereafter created,
166
+ and to make technical modifications necessary to do so. The
167
+ Licensor waives and/or agrees not to assert any right or
168
+ authority to forbid You from making technical modifications
169
+ necessary to exercise the Licensed Rights, including
170
+ technical modifications necessary to circumvent Effective
171
+ Technological Measures. For purposes of this Public License,
172
+ simply making modifications authorized by this Section 2(a)
173
+ (4) never produces Adapted Material.
174
+
175
+ 5. Downstream recipients.
176
+
177
+ a. Offer from the Licensor -- Licensed Material. Every
178
+ recipient of the Licensed Material automatically
179
+ receives an offer from the Licensor to exercise the
180
+ Licensed Rights under the terms and conditions of this
181
+ Public License.
182
+
183
+ b. No downstream restrictions. You may not offer or impose
184
+ any additional or different terms or conditions on, or
185
+ apply any Effective Technological Measures to, the
186
+ Licensed Material if doing so restricts exercise of the
187
+ Licensed Rights by any recipient of the Licensed
188
+ Material.
189
+
190
+ 6. No endorsement. Nothing in this Public License constitutes or
191
+ may be construed as permission to assert or imply that You
192
+ are, or that Your use of the Licensed Material is, connected
193
+ with, or sponsored, endorsed, or granted official status by,
194
+ the Licensor or others designated to receive attribution as
195
+ provided in Section 3(a)(1)(A)(i).
196
+
197
+ b. Other rights.
198
+
199
+ 1. Moral rights, such as the right of integrity, are not
200
+ licensed under this Public License, nor are publicity,
201
+ privacy, and/or other similar personality rights; however, to
202
+ the extent possible, the Licensor waives and/or agrees not to
203
+ assert any such rights held by the Licensor to the limited
204
+ extent necessary to allow You to exercise the Licensed
205
+ Rights, but not otherwise.
206
+
207
+ 2. Patent and trademark rights are not licensed under this
208
+ Public License.
209
+
210
+ 3. To the extent possible, the Licensor waives any right to
211
+ collect royalties from You for the exercise of the Licensed
212
+ Rights, whether directly or through a collecting society
213
+ under any voluntary or waivable statutory or compulsory
214
+ licensing scheme. In all other cases the Licensor expressly
215
+ reserves any right to collect such royalties, including when
216
+ the Licensed Material is used other than for NonCommercial
217
+ purposes.
218
+
219
+ Section 3 -- License Conditions.
220
+
221
+ Your exercise of the Licensed Rights is expressly made subject to the
222
+ following conditions.
223
+
224
+ a. Attribution.
225
+
226
+ 1. If You Share the Licensed Material (including in modified
227
+ form), You must:
228
+
229
+ a. retain the following if it is supplied by the Licensor
230
+ with the Licensed Material:
231
+
232
+ i. identification of the creator(s) of the Licensed
233
+ Material and any others designated to receive
234
+ attribution, in any reasonable manner requested by
235
+ the Licensor (including by pseudonym if
236
+ designated);
237
+
238
+ ii. a copyright notice;
239
+
240
+ iii. a notice that refers to this Public License;
241
+
242
+ iv. a notice that refers to the disclaimer of
243
+ warranties;
244
+
245
+ v. a URI or hyperlink to the Licensed Material to the
246
+ extent reasonably practicable;
247
+
248
+ b. indicate if You modified the Licensed Material and
249
+ retain an indication of any previous modifications; and
250
+
251
+ c. indicate the Licensed Material is licensed under this
252
+ Public License, and include the text of, or the URI or
253
+ hyperlink to, this Public License.
254
+
255
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
256
+ reasonable manner based on the medium, means, and context in
257
+ which You Share the Licensed Material. For example, it may be
258
+ reasonable to satisfy the conditions by providing a URI or
259
+ hyperlink to a resource that includes the required
260
+ information.
261
+
262
+ 3. If requested by the Licensor, You must remove any of the
263
+ information required by Section 3(a)(1)(A) to the extent
264
+ reasonably practicable.
265
+
266
+ 4. If You Share Adapted Material You produce, the Adapter's
267
+ License You apply must not prevent recipients of the Adapted
268
+ Material from complying with this Public License.
269
+
270
+ Section 4 -- Sui Generis Database Rights.
271
+
272
+ Where the Licensed Rights include Sui Generis Database Rights that
273
+ apply to Your use of the Licensed Material:
274
+
275
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
276
+ to extract, reuse, reproduce, and Share all or a substantial
277
+ portion of the contents of the database for NonCommercial purposes
278
+ only;
279
+
280
+ b. if You include all or a substantial portion of the database
281
+ contents in a database in which You have Sui Generis Database
282
+ Rights, then the database in which You have Sui Generis Database
283
+ Rights (but not its individual contents) is Adapted Material; and
284
+
285
+ c. You must comply with the conditions in Section 3(a) if You Share
286
+ all or a substantial portion of the contents of the database.
287
+
288
+ For the avoidance of doubt, this Section 4 supplements and does not
289
+ replace Your obligations under this Public License where the Licensed
290
+ Rights include other Copyright and Similar Rights.
291
+
292
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
293
+
294
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
295
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
296
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
297
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
298
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
299
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
300
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
301
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
302
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
303
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
304
+
305
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
306
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
307
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
308
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
309
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
310
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
311
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
312
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
313
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
314
+
315
+ c. The disclaimer of warranties and limitation of liability provided
316
+ above shall be interpreted in a manner that, to the extent
317
+ possible, most closely approximates an absolute disclaimer and
318
+ waiver of all liability.
319
+
320
+ Section 6 -- Term and Termination.
321
+
322
+ a. This Public License applies for the term of the Copyright and
323
+ Similar Rights licensed here. However, if You fail to comply with
324
+ this Public License, then Your rights under this Public License
325
+ terminate automatically.
326
+
327
+ b. Where Your right to use the Licensed Material has terminated under
328
+ Section 6(a), it reinstates:
329
+
330
+ 1. automatically as of the date the violation is cured, provided
331
+ it is cured within 30 days of Your discovery of the
332
+ violation; or
333
+
334
+ 2. upon express reinstatement by the Licensor.
335
+
336
+ For the avoidance of doubt, this Section 6(b) does not affect any
337
+ right the Licensor may have to seek remedies for Your violations
338
+ of this Public License.
339
+
340
+ c. For the avoidance of doubt, the Licensor may also offer the
341
+ Licensed Material under separate terms or conditions or stop
342
+ distributing the Licensed Material at any time; however, doing so
343
+ will not terminate this Public License.
344
+
345
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
346
+ License.
347
+
348
+ Section 7 -- Other Terms and Conditions.
349
+
350
+ a. The Licensor shall not be bound by any additional or different
351
+ terms or conditions communicated by You unless expressly agreed.
352
+
353
+ b. Any arrangements, understandings, or agreements regarding the
354
+ Licensed Material not stated herein are separate from and
355
+ independent of the terms and conditions of this Public License.
356
+
357
+ Section 8 -- Interpretation.
358
+
359
+ a. For the avoidance of doubt, this Public License does not, and
360
+ shall not be interpreted to, reduce, limit, restrict, or impose
361
+ conditions on any use of the Licensed Material that could lawfully
362
+ be made without permission under this Public License.
363
+
364
+ b. To the extent possible, if any provision of this Public License is
365
+ deemed unenforceable, it shall be automatically reformed to the
366
+ minimum extent necessary to make it enforceable. If the provision
367
+ cannot be reformed, it shall be severed from this Public License
368
+ without affecting the enforceability of the remaining terms and
369
+ conditions.
370
+
371
+ c. No term or condition of this Public License will be waived and no
372
+ failure to comply consented to unless expressly agreed to by the
373
+ Licensor.
374
+
375
+ d. Nothing in this Public License constitutes or may be interpreted
376
+ as a limitation upon, or waiver of, any privileges and immunities
377
+ that apply to the Licensor or You, including from the legal
378
+ processes of any jurisdiction or authority.
379
+
380
+ =======================================================================
381
+
382
+ Creative Commons is not a party to its public
383
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
384
+ its public licenses to material it publishes and in those instances
385
+ will be considered the “Licensor.” The text of the Creative Commons
386
+ public licenses is dedicated to the public domain under the CC0 Public
387
+ Domain Dedication. Except for the limited purpose of indicating that
388
+ material is shared under a Creative Commons public license or as
389
+ otherwise permitted by the Creative Commons policies published at
390
+ creativecommons.org/policies, Creative Commons does not authorize the
391
+ use of the trademark "Creative Commons" or any other trademark or logo
392
+ of Creative Commons without its prior written consent including,
393
+ without limitation, in connection with any unauthorized modifications
394
+ to any of its public licenses or any other arrangements,
395
+ understandings, or agreements concerning use of licensed material. For
396
+ the avoidance of doubt, this paragraph does not form part of the
397
+ public licenses.
398
+
399
+ Creative Commons may be contacted at creativecommons.org.
MANIFEST.in ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ include Makefile
2
+ include LICENSE
3
+ include LICENSE_weights
4
+ include *.md
5
+ include *.ini
6
+ include requirements.txt
7
+ include audiocraft/py.typed
8
+ include assets/*.mp3
9
+ recursive-include conf *.yaml
Makefile ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ INTEG=AUDIOCRAFT_DORA_DIR="/tmp/magma_$(USER)" python3 -m dora -v run --clear device=cpu dataset.num_workers=0 optim.epochs=1 \
2
+ dataset.train.num_samples=10 dataset.valid.num_samples=10 \
3
+ dataset.evaluate.num_samples=10 dataset.generate.num_samples=2 sample_rate=16000 \
4
+ logging.level=DEBUG
5
+ INTEG_COMPRESSION = $(INTEG) solver=compression/debug rvq.n_q=2 rvq.bins=48 checkpoint.save_last=true # SIG is 616d7b3c
6
+ INTEG_MUSICGEN = $(INTEG) solver=musicgen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \
7
+ transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 616d7b3c
8
+ INTEG_AUDIOGEN = $(INTEG) solver=audiogen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \
9
+ transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 616d7b3c
10
+ INTEG_MBD = $(INTEG) solver=diffusion/debug dset=audio/example \
11
+ checkpoint.save_last=false # Using compression model from 616d7b3c
12
+
13
+ default: linter tests
14
+
15
+ install:
16
+ pip install -U pip
17
+ pip install -U -e '.[dev]'
18
+
19
+ linter:
20
+ flake8 audiocraft && mypy audiocraft
21
+ flake8 tests && mypy tests
22
+
23
+ tests:
24
+ coverage run -m pytest tests
25
+ coverage report
26
+
27
+ tests_integ:
28
+ $(INTEG_COMPRESSION)
29
+ $(INTEG_MBD)
30
+ $(INTEG_MUSICGEN)
31
+ $(INTEG_AUDIOGEN)
32
+
33
+
34
+ api_docs:
35
+ pdoc3 --html -o api_docs -f audiocraft
36
+
37
+ dist:
38
+ python setup.py sdist
39
+
40
+ .PHONY: linter tests api_docs dist
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: Musicgen
3
- emoji: 🏢
4
- colorFrom: purple
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 3.40.1
8
  app_file: app.py
9
- pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: AudioCraft Plus v2.0.0a (MusicGen + AudioGen)
3
+ emoji: 🎶
4
+ colorFrom: yellow
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 3.39.0
8
  app_file: app.py
9
+ pinned: true
10
+ license: mit
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Updated to account for UI changes from https://github.com/rkfg/audiocraft/blob/long/app.py
8
+ # also released under the MIT license.
9
+
10
+ import argparse
11
+ from concurrent.futures import ProcessPoolExecutor
12
+ import os
13
+ from pathlib import Path
14
+ import subprocess as sp
15
+ from tempfile import NamedTemporaryFile
16
+ import time
17
+ import typing as tp
18
+ import warnings
19
+
20
+ import torch
21
+ import gradio as gr
22
+
23
+ from audiocraft.data.audio_utils import convert_audio
24
+ from audiocraft.data.audio import audio_write
25
+ from audiocraft.models import MusicGen, MultiBandDiffusion
26
+
27
+
28
+ MODEL = None # Last used model
29
+ IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '')
30
+ print(IS_BATCHED)
31
+ MAX_BATCH_SIZE = 12
32
+ BATCHED_DURATION = 15
33
+ INTERRUPTING = False
34
+ MBD = None
35
+ # We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
36
+ _old_call = sp.call
37
+
38
+
39
+ def _call_nostderr(*args, **kwargs):
40
+ # Avoid ffmpeg vomiting on the logs.
41
+ kwargs['stderr'] = sp.DEVNULL
42
+ kwargs['stdout'] = sp.DEVNULL
43
+ _old_call(*args, **kwargs)
44
+
45
+
46
+ sp.call = _call_nostderr
47
+ # Preallocating the pool of processes.
48
+ pool = ProcessPoolExecutor(4)
49
+ pool.__enter__()
50
+
51
+
52
+ def interrupt():
53
+ global INTERRUPTING
54
+ INTERRUPTING = True
55
+
56
+
57
+ class FileCleaner:
58
+ def __init__(self, file_lifetime: float = 3600):
59
+ self.file_lifetime = file_lifetime
60
+ self.files = []
61
+
62
+ def add(self, path: tp.Union[str, Path]):
63
+ self._cleanup()
64
+ self.files.append((time.time(), Path(path)))
65
+
66
+ def _cleanup(self):
67
+ now = time.time()
68
+ for time_added, path in list(self.files):
69
+ if now - time_added > self.file_lifetime:
70
+ if path.exists():
71
+ path.unlink()
72
+ self.files.pop(0)
73
+ else:
74
+ break
75
+
76
+
77
+ file_cleaner = FileCleaner()
78
+
79
+
80
+ def make_waveform(*args, **kwargs):
81
+ # Further remove some warnings.
82
+ be = time.time()
83
+ with warnings.catch_warnings():
84
+ warnings.simplefilter('ignore')
85
+ out = gr.make_waveform(*args, **kwargs)
86
+ print("Make a video took", time.time() - be)
87
+ return out
88
+
89
+
90
+ def load_model(version='facebook/musicgen-melody'):
91
+ global MODEL
92
+ print("Loading model", version)
93
+ if MODEL is None or MODEL.name != version:
94
+ MODEL = MusicGen.get_pretrained(version)
95
+
96
+
97
+ def load_diffusion():
98
+ global MBD
99
+ if MBD is None:
100
+ print("loading MBD")
101
+ MBD = MultiBandDiffusion.get_mbd_musicgen()
102
+
103
+
104
+ def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
105
+ MODEL.set_generation_params(duration=duration, **gen_kwargs)
106
+ print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
107
+ be = time.time()
108
+ processed_melodies = []
109
+ target_sr = 32000
110
+ target_ac = 1
111
+ for melody in melodies:
112
+ if melody is None:
113
+ processed_melodies.append(None)
114
+ else:
115
+ sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
116
+ if melody.dim() == 1:
117
+ melody = melody[None]
118
+ melody = melody[..., :int(sr * duration)]
119
+ melody = convert_audio(melody, sr, target_sr, target_ac)
120
+ processed_melodies.append(melody)
121
+
122
+ if any(m is not None for m in processed_melodies):
123
+ outputs = MODEL.generate_with_chroma(
124
+ descriptions=texts,
125
+ melody_wavs=processed_melodies,
126
+ melody_sample_rate=target_sr,
127
+ progress=progress,
128
+ return_tokens=USE_DIFFUSION
129
+ )
130
+ else:
131
+ outputs = MODEL.generate(texts, progress=progress, return_tokens=USE_DIFFUSION)
132
+ if USE_DIFFUSION:
133
+ outputs_diffusion = MBD.tokens_to_wav(outputs[1])
134
+ outputs = torch.cat([outputs[0], outputs_diffusion], dim=0)
135
+ outputs = outputs.detach().cpu().float()
136
+ pending_videos = []
137
+ out_wavs = []
138
+ for output in outputs:
139
+ with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
140
+ audio_write(
141
+ file.name, output, MODEL.sample_rate, strategy="loudness",
142
+ loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
143
+ pending_videos.append(pool.submit(make_waveform, file.name))
144
+ out_wavs.append(file.name)
145
+ file_cleaner.add(file.name)
146
+ out_videos = [pending_video.result() for pending_video in pending_videos]
147
+ for video in out_videos:
148
+ file_cleaner.add(video)
149
+ print("batch finished", len(texts), time.time() - be)
150
+ print("Tempfiles currently stored: ", len(file_cleaner.files))
151
+ return out_videos, out_wavs
152
+
153
+
154
+ def predict_batched(texts, melodies):
155
+ max_text_length = 512
156
+ texts = [text[:max_text_length] for text in texts]
157
+ load_model('facebook/musicgen-melody')
158
+ res = _do_predictions(texts, melodies, BATCHED_DURATION)
159
+ return res
160
+
161
+
162
+ def predict_full(model, decoder, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()):
163
+ global INTERRUPTING
164
+ global USE_DIFFUSION
165
+ INTERRUPTING = False
166
+ if temperature < 0:
167
+ raise gr.Error("Temperature must be >= 0.")
168
+ if topk < 0:
169
+ raise gr.Error("Topk must be non-negative.")
170
+ if topp < 0:
171
+ raise gr.Error("Topp must be non-negative.")
172
+
173
+ topk = int(topk)
174
+ if decoder == "MultiBand_Diffusion":
175
+ USE_DIFFUSION = True
176
+ load_diffusion()
177
+ else:
178
+ USE_DIFFUSION = False
179
+ load_model(model)
180
+
181
+ def _progress(generated, to_generate):
182
+ progress((min(generated, to_generate), to_generate))
183
+ if INTERRUPTING:
184
+ raise gr.Error("Interrupted.")
185
+ MODEL.set_custom_progress_callback(_progress)
186
+
187
+ videos, wavs = _do_predictions(
188
+ [text], [melody], duration, progress=True,
189
+ top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
190
+ if USE_DIFFUSION:
191
+ return videos[0], wavs[0], videos[1], wavs[1]
192
+ return videos[0], wavs[0], None, None
193
+
194
+
195
+ def toggle_audio_src(choice):
196
+ if choice == "mic":
197
+ return gr.update(source="microphone", value=None, label="Microphone")
198
+ else:
199
+ return gr.update(source="upload", value=None, label="File")
200
+
201
+
202
+ def toggle_diffusion(choice):
203
+ if choice == "MultiBand_Diffusion":
204
+ return [gr.update(visible=True)] * 2
205
+ else:
206
+ return [gr.update(visible=False)] * 2
207
+
208
+
209
+ def ui_full(launch_kwargs):
210
+ with gr.Blocks() as interface:
211
+ gr.Markdown(
212
+ """
213
+ # MusicGen
214
+ This is your private demo for [MusicGen](https://github.com/facebookresearch/audiocraft),
215
+ a simple and controllable model for music generation
216
+ presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
217
+ """
218
+ )
219
+ with gr.Row():
220
+ with gr.Column():
221
+ with gr.Row():
222
+ text = gr.Text(label="Input Text", interactive=True)
223
+ with gr.Column():
224
+ radio = gr.Radio(["file", "mic"], value="file",
225
+ label="Condition on a melody (optional) File or Mic")
226
+ melody = gr.Audio(source="upload", type="numpy", label="File",
227
+ interactive=True, elem_id="melody-input")
228
+ with gr.Row():
229
+ submit = gr.Button("Submit")
230
+ # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
231
+ _ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
232
+ with gr.Row():
233
+ model = gr.Radio(["facebook/musicgen-melody", "facebook/musicgen-medium", "facebook/musicgen-small",
234
+ "facebook/musicgen-large"],
235
+ label="Model", value="facebook/musicgen-melody", interactive=True)
236
+ with gr.Row():
237
+ decoder = gr.Radio(["Default", "MultiBand_Diffusion"],
238
+ label="Decoder", value="Default", interactive=True)
239
+ with gr.Row():
240
+ duration = gr.Slider(minimum=1, maximum=120, value=10, label="Duration", interactive=True)
241
+ with gr.Row():
242
+ topk = gr.Number(label="Top-k", value=250, interactive=True)
243
+ topp = gr.Number(label="Top-p", value=0, interactive=True)
244
+ temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
245
+ cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
246
+ with gr.Column():
247
+ output = gr.Video(label="Generated Music")
248
+ audio_output = gr.Audio(label="Generated Music (wav)", type='filepath')
249
+ diffusion_output = gr.Video(label="MultiBand Diffusion Decoder")
250
+ audio_diffusion = gr.Audio(label="MultiBand Diffusion Decoder (wav)", type='filepath')
251
+ submit.click(toggle_diffusion, decoder, [diffusion_output, audio_diffusion], queue=False,
252
+ show_progress=False).then(predict_full, inputs=[model, decoder, text, melody, duration, topk, topp,
253
+ temperature, cfg_coef],
254
+ outputs=[output, audio_output, diffusion_output, audio_diffusion])
255
+ radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
256
+
257
+ gr.Examples(
258
+ fn=predict_full,
259
+ examples=[
260
+ [
261
+ "An 80s driving pop song with heavy drums and synth pads in the background",
262
+ "./assets/bach.mp3",
263
+ "facebook/musicgen-melody",
264
+ "Default"
265
+ ],
266
+ [
267
+ "A cheerful country song with acoustic guitars",
268
+ "./assets/bolero_ravel.mp3",
269
+ "facebook/musicgen-melody",
270
+ "Default"
271
+ ],
272
+ [
273
+ "90s rock song with electric guitar and heavy drums",
274
+ None,
275
+ "facebook/musicgen-medium",
276
+ "Default"
277
+ ],
278
+ [
279
+ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
280
+ "./assets/bach.mp3",
281
+ "facebook/musicgen-melody",
282
+ "Default"
283
+ ],
284
+ [
285
+ "lofi slow bpm electro chill with organic samples",
286
+ None,
287
+ "facebook/musicgen-medium",
288
+ "Default"
289
+ ],
290
+ [
291
+ "Punk rock with loud drum and power guitar",
292
+ None,
293
+ "facebook/musicgen-medium",
294
+ "MultiBand_Diffusion"
295
+ ],
296
+ ],
297
+ inputs=[text, melody, model, decoder],
298
+ outputs=[output]
299
+ )
300
+ gr.Markdown(
301
+ """
302
+ ### More details
303
+
304
+ The model will generate a short music extract based on the description you provided.
305
+ The model can generate up to 30 seconds of audio in one pass. It is now possible
306
+ to extend the generation by feeding back the end of the previous chunk of audio.
307
+ This can take a long time, and the model might lose consistency. The model might also
308
+ decide at arbitrary positions that the song ends.
309
+
310
+ **WARNING:** Choosing long durations will take a long time to generate (2min might take ~10min).
311
+ An overlap of 12 seconds is kept with the previously generated chunk, and 18 "new" seconds
312
+ are generated each time.
313
+
314
+ We present 4 model variations:
315
+ 1. facebook/musicgen-melody -- a music generation model capable of generating music condition
316
+ on text and melody inputs. **Note**, you can also use text only.
317
+ 2. facebook/musicgen-small -- a 300M transformer decoder conditioned on text only.
318
+ 3. facebook/musicgen-medium -- a 1.5B transformer decoder conditioned on text only.
319
+ 4. facebook/musicgen-large -- a 3.3B transformer decoder conditioned on text only.
320
+
321
+ We also present two way of decoding the audio tokens
322
+ 1. Use the default GAN based compression model
323
+ 2. Use MultiBand Diffusion from (paper linknano )
324
+
325
+ When using `facebook/musicgen-melody`, you can optionally provide a reference audio from
326
+ which a broad melody will be extracted. The model will then try to follow both
327
+ the description and melody provided.
328
+
329
+ You can also use your own GPU or a Google Colab by following the instructions on our repo.
330
+ See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
331
+ for more details.
332
+ """
333
+ )
334
+
335
+ interface.queue().launch(**launch_kwargs)
336
+
337
+
338
+ def ui_batched(launch_kwargs):
339
+ with gr.Blocks() as demo:
340
+ gr.Markdown(
341
+ """
342
+ # MusicGen
343
+
344
+ This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft),
345
+ a simple and controllable model for music generation
346
+ presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
347
+ <br/>
348
+ <a href="https://huggingface.co/spaces/facebook/MusicGen?duplicate=true"
349
+ style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
350
+ <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;"
351
+ src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
352
+ for longer sequences, more control and no queue.</p>
353
+ """
354
+ )
355
+ with gr.Row():
356
+ with gr.Column():
357
+ with gr.Row():
358
+ text = gr.Text(label="Describe your music", lines=2, interactive=True)
359
+ with gr.Column():
360
+ radio = gr.Radio(["file", "mic"], value="file",
361
+ label="Condition on a melody (optional) File or Mic")
362
+ melody = gr.Audio(source="upload", type="numpy", label="File",
363
+ interactive=True, elem_id="melody-input")
364
+ with gr.Row():
365
+ submit = gr.Button("Generate")
366
+ with gr.Column():
367
+ output = gr.Video(label="Generated Music")
368
+ audio_output = gr.Audio(label="Generated Music (wav)", type='filepath')
369
+ submit.click(predict_batched, inputs=[text, melody],
370
+ outputs=[output, audio_output], batch=True, max_batch_size=MAX_BATCH_SIZE)
371
+ radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
372
+ gr.Examples(
373
+ fn=predict_batched,
374
+ examples=[
375
+ [
376
+ "An 80s driving pop song with heavy drums and synth pads in the background",
377
+ "./assets/bach.mp3",
378
+ ],
379
+ [
380
+ "A cheerful country song with acoustic guitars",
381
+ "./assets/bolero_ravel.mp3",
382
+ ],
383
+ [
384
+ "90s rock song with electric guitar and heavy drums",
385
+ None,
386
+ ],
387
+ [
388
+ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
389
+ "./assets/bach.mp3",
390
+ ],
391
+ [
392
+ "lofi slow bpm electro chill with organic samples",
393
+ None,
394
+ ],
395
+ ],
396
+ inputs=[text, melody],
397
+ outputs=[output]
398
+ )
399
+ gr.Markdown("""
400
+ ### More details
401
+
402
+ The model will generate 12 seconds of audio based on the description you provided.
403
+ You can optionally provide a reference audio from which a broad melody will be extracted.
404
+ The model will then try to follow both the description and melody provided.
405
+ All samples are generated with the `melody` model.
406
+
407
+ You can also use your own GPU or a Google Colab by following the instructions on our repo.
408
+
409
+ See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
410
+ for more details.
411
+ """)
412
+
413
+ demo.queue(max_size=8 * 4).launch(**launch_kwargs)
414
+
415
+
416
+ if __name__ == "__main__":
417
+ parser = argparse.ArgumentParser()
418
+ parser.add_argument(
419
+ '--listen',
420
+ type=str,
421
+ default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1',
422
+ help='IP to listen on for connections to Gradio',
423
+ )
424
+ parser.add_argument(
425
+ '--username', type=str, default='', help='Username for authentication'
426
+ )
427
+ parser.add_argument(
428
+ '--password', type=str, default='', help='Password for authentication'
429
+ )
430
+ parser.add_argument(
431
+ '--server_port',
432
+ type=int,
433
+ default=0,
434
+ help='Port to run the server listener on',
435
+ )
436
+ parser.add_argument(
437
+ '--inbrowser', action='store_true', help='Open in browser'
438
+ )
439
+ parser.add_argument(
440
+ '--share', action='store_true', help='Share the gradio UI'
441
+ )
442
+
443
+ args = parser.parse_args()
444
+
445
+ launch_kwargs = {}
446
+ launch_kwargs['server_name'] = args.listen
447
+
448
+ if args.username and args.password:
449
+ launch_kwargs['auth'] = (args.username, args.password)
450
+ if args.server_port:
451
+ launch_kwargs['server_port'] = args.server_port
452
+ if args.inbrowser:
453
+ launch_kwargs['inbrowser'] = args.inbrowser
454
+ if args.share:
455
+ launch_kwargs['share'] = args.share
456
+
457
+ # Show the interface
458
+ if IS_BATCHED:
459
+ global USE_DIFFUSION
460
+ USE_DIFFUSION = False
461
+ ui_batched(launch_kwargs)
462
+ else:
463
+ ui_full(launch_kwargs)
app_v2.py ADDED
@@ -0,0 +1,1839 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Updated to account for UI changes from https://github.com/rkfg/audiocraft/blob/long/app.py
8
+ # also released under the MIT license.
9
+
10
+ import argparse
11
+ from concurrent.futures import ProcessPoolExecutor
12
+ import os
13
+ from pathlib import Path
14
+ import subprocess as sp
15
+ from tempfile import NamedTemporaryFile
16
+ import time
17
+ import warnings
18
+ import glob
19
+ import re
20
+ from PIL import Image
21
+ from pydub import AudioSegment
22
+ from datetime import datetime
23
+
24
+ import json
25
+ import shutil
26
+ import taglib
27
+ import torch
28
+ import torchaudio
29
+ import gradio as gr
30
+ import numpy as np
31
+ import typing as tp
32
+
33
+ from audiocraft.data.audio_utils import convert_audio
34
+ from audiocraft.data.audio import audio_write
35
+ from audiocraft.models import AudioGen, MusicGen, MultiBandDiffusion
36
+ from audiocraft.utils import ui
37
+ import random, string
38
+
39
+ version = "2.0.0a"
40
+
41
+ theme = gr.themes.Base(
42
+ primary_hue="lime",
43
+ secondary_hue="lime",
44
+ neutral_hue="neutral",
45
+ ).set(
46
+ button_primary_background_fill_hover='*primary_500',
47
+ button_primary_background_fill_hover_dark='*primary_500',
48
+ button_secondary_background_fill_hover='*primary_500',
49
+ button_secondary_background_fill_hover_dark='*primary_500'
50
+ )
51
+
52
+ MODEL = None # Last used model
53
+ MODELS = None
54
+ UNLOAD_MODEL = False
55
+ MOVE_TO_CPU = False
56
+ IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '')
57
+ print(IS_BATCHED)
58
+ MAX_BATCH_SIZE = 12
59
+ BATCHED_DURATION = 15
60
+ INTERRUPTING = False
61
+ MBD = None
62
+ # We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
63
+ _old_call = sp.call
64
+
65
+
66
+ def generate_random_string(length):
67
+ characters = string.ascii_letters + string.digits
68
+ return ''.join(random.choice(characters) for _ in range(length))
69
+
70
+
71
+ def resize_video(input_path, output_path, target_width, target_height):
72
+ ffmpeg_cmd = [
73
+ 'ffmpeg',
74
+ '-y',
75
+ '-i', input_path,
76
+ '-vf', f'scale={target_width}:{target_height}',
77
+ '-c:a', 'copy',
78
+ output_path
79
+ ]
80
+ sp.run(ffmpeg_cmd)
81
+
82
+
83
+ def _call_nostderr(*args, **kwargs):
84
+ # Avoid ffmpeg vomiting on the logs.
85
+ kwargs['stderr'] = sp.DEVNULL
86
+ kwargs['stdout'] = sp.DEVNULL
87
+ _old_call(*args, **kwargs)
88
+
89
+
90
+ sp.call = _call_nostderr
91
+ # Preallocating the pool of processes.
92
+ pool = ProcessPoolExecutor(4)
93
+ pool.__enter__()
94
+
95
+
96
+ def interrupt():
97
+ global INTERRUPTING
98
+ INTERRUPTING = True
99
+
100
+
101
+ class FileCleaner:
102
+ def __init__(self, file_lifetime: float = 3600):
103
+ self.file_lifetime = file_lifetime
104
+ self.files = []
105
+
106
+ def add(self, path: tp.Union[str, Path]):
107
+ self._cleanup()
108
+ self.files.append((time.time(), Path(path)))
109
+
110
+ def _cleanup(self):
111
+ now = time.time()
112
+ for time_added, path in list(self.files):
113
+ if now - time_added > self.file_lifetime:
114
+ if path.exists():
115
+ path.unlink()
116
+ self.files.pop(0)
117
+ else:
118
+ break
119
+
120
+
121
+ file_cleaner = FileCleaner()
122
+
123
+
124
+ def make_waveform(*args, **kwargs):
125
+ # Further remove some warnings.
126
+ be = time.time()
127
+ with warnings.catch_warnings():
128
+ warnings.simplefilter('ignore')
129
+ height = kwargs.pop('height')
130
+ width = kwargs.pop('width')
131
+ if height < 256:
132
+ height = 256
133
+ if width < 256:
134
+ width = 256
135
+ waveform_video = gr.make_waveform(*args, **kwargs)
136
+ out = f"{generate_random_string(12)}.mp4"
137
+ image = kwargs.get('bg_image', None)
138
+ if image is None:
139
+ resize_video(waveform_video, out, 900, 300)
140
+ else:
141
+ resize_video(waveform_video, out, width, height)
142
+ print("Make a video took", time.time() - be)
143
+ return out
144
+
145
+
146
+ def load_model(version='GrandaddyShmax/musicgen-melody', custom_model=None, base_model='GrandaddyShmax/musicgen-medium', gen_type="music"):
147
+ global MODEL, MODELS
148
+ print("Loading model", version)
149
+ if MODELS is None:
150
+ if version == 'GrandaddyShmax/musicgen-custom':
151
+ MODEL = MusicGen.get_pretrained(base_model)
152
+ file_path = os.path.abspath("models/" + str(custom_model) + ".pt")
153
+ MODEL.lm.load_state_dict(torch.load(file_path))
154
+ else:
155
+ if gen_type == "music":
156
+ MODEL = MusicGen.get_pretrained(version)
157
+ elif gen_type == "audio":
158
+ MODEL = AudioGen.get_pretrained(version)
159
+
160
+ return
161
+
162
+ else:
163
+ t1 = time.monotonic()
164
+ if MODEL is not None:
165
+ MODEL.to('cpu') # move to cache
166
+ print("Previous model moved to CPU in %.2fs" % (time.monotonic() - t1))
167
+ t1 = time.monotonic()
168
+ if version != 'GrandaddyShmax/musicgen-custom' and MODELS.get(version) is None:
169
+ print("Loading model %s from disk" % version)
170
+ if gen_type == "music":
171
+ result = MusicGen.get_pretrained(version)
172
+ elif gen_type == "audio":
173
+ result = AudioGen.get_pretrained(version)
174
+ MODELS[version] = result
175
+ print("Model loaded in %.2fs" % (time.monotonic() - t1))
176
+ MODEL = result
177
+ return
178
+ result = MODELS[version].to('cuda')
179
+ print("Cached model loaded in %.2fs" % (time.monotonic() - t1))
180
+ MODEL = result
181
+
182
+ def get_audio_info(audio_path):
183
+ if audio_path is not None:
184
+ if audio_path.name.endswith(".wav") or audio_path.name.endswith(".mp4") or audio_path.name.endswith(".json"):
185
+ if not audio_path.name.endswith(".json"):
186
+ with taglib.File(audio_path.name, save_on_exit=False) as song:
187
+ if 'COMMENT' not in song.tags:
188
+ return "No tags found. Either the file is not generated by MusicGen+ V1.2.7 and higher or the tags are corrupted. (Discord removes metadata from mp4 and wav files, so you can't use them)"
189
+ json_string = song.tags['COMMENT'][0]
190
+ data = json.loads(json_string)
191
+ global_prompt = str("\nGlobal Prompt: " + (data['global_prompt'] if data['global_prompt'] != "" else "none")) if 'global_prompt' in data else ""
192
+ bpm = str("\nBPM: " + data['bpm']) if 'bpm' in data else ""
193
+ key = str("\nKey: " + data['key']) if 'key' in data else ""
194
+ scale = str("\nScale: " + data['scale']) if 'scale' in data else ""
195
+ prompts = str("\nPrompts: " + (data['texts'] if data['texts'] != "['']" else "none")) if 'texts' in data else ""
196
+ duration = str("\nDuration: " + data['duration']) if 'duration' in data else ""
197
+ overlap = str("\nOverlap: " + data['overlap']) if 'overlap' in data else ""
198
+ seed = str("\nSeed: " + data['seed']) if 'seed' in data else ""
199
+ audio_mode = str("\nAudio Mode: " + data['audio_mode']) if 'audio_mode' in data else ""
200
+ input_length = str("\nInput Length: " + data['input_length']) if 'input_length' in data else ""
201
+ channel = str("\nChannel: " + data['channel']) if 'channel' in data else ""
202
+ sr_select = str("\nSample Rate: " + data['sr_select']) if 'sr_select' in data else ""
203
+ gen_type = str(data['generator'] + "gen-") if 'generator' in data else ""
204
+ model = str("\nModel: " + gen_type + data['model']) if 'model' in data else ""
205
+ custom_model = str("\nCustom Model: " + data['custom_model']) if 'custom_model' in data else ""
206
+ base_model = str("\nBase Model: " + data['base_model']) if 'base_model' in data else ""
207
+ decoder = str("\nDecoder: " + data['decoder']) if 'decoder' in data else ""
208
+ topk = str("\nTopk: " + data['topk']) if 'topk' in data else ""
209
+ topp = str("\nTopp: " + data['topp']) if 'topp' in data else ""
210
+ temperature = str("\nTemperature: " + data['temperature']) if 'temperature' in data else ""
211
+ cfg_coef = str("\nClassifier Free Guidance: " + data['cfg_coef']) if 'cfg_coef' in data else ""
212
+ version = str("Version: " + data['version']) if 'version' in data else "Version: Unknown"
213
+ info = str(version + global_prompt + bpm + key + scale + prompts + duration + overlap + seed + audio_mode + input_length + channel + sr_select + model + custom_model + base_model + decoder + topk + topp + temperature + cfg_coef)
214
+ if info == "":
215
+ return "No tags found. Either the file is not generated by MusicGen+ V1.2.7 and higher or the tags are corrupted. (Discord removes metadata from mp4 and wav files, so you can't use them)"
216
+ return info
217
+ else:
218
+ with open(audio_path.name) as json_file:
219
+ data = json.load(json_file)
220
+ #if 'global_prompt' not in data:
221
+ #return "No tags found. Either the file is not generated by MusicGen+ V1.2.8a and higher or the tags are corrupted."
222
+ global_prompt = str("\nGlobal Prompt: " + (data['global_prompt'] if data['global_prompt'] != "" else "none")) if 'global_prompt' in data else ""
223
+ bpm = str("\nBPM: " + data['bpm']) if 'bpm' in data else ""
224
+ key = str("\nKey: " + data['key']) if 'key' in data else ""
225
+ scale = str("\nScale: " + data['scale']) if 'scale' in data else ""
226
+ prompts = str("\nPrompts: " + (data['texts'] if data['texts'] != "['']" else "none")) if 'texts' in data else ""
227
+ duration = str("\nDuration: " + data['duration']) if 'duration' in data else ""
228
+ overlap = str("\nOverlap: " + data['overlap']) if 'overlap' in data else ""
229
+ seed = str("\nSeed: " + data['seed']) if 'seed' in data else ""
230
+ audio_mode = str("\nAudio Mode: " + data['audio_mode']) if 'audio_mode' in data else ""
231
+ input_length = str("\nInput Length: " + data['input_length']) if 'input_length' in data else ""
232
+ channel = str("\nChannel: " + data['channel']) if 'channel' in data else ""
233
+ sr_select = str("\nSample Rate: " + data['sr_select']) if 'sr_select' in data else ""
234
+ gen_type = str(data['generator'] + "gen-") if 'generator' in data else ""
235
+ model = str("\nModel: " + gen_type + data['model']) if 'model' in data else ""
236
+ custom_model = str("\nCustom Model: " + data['custom_model']) if 'custom_model' in data else ""
237
+ base_model = str("\nBase Model: " + data['base_model']) if 'base_model' in data else ""
238
+ decoder = str("\nDecoder: " + data['decoder']) if 'decoder' in data else ""
239
+ topk = str("\nTopk: " + data['topk']) if 'topk' in data else ""
240
+ topp = str("\nTopp: " + data['topp']) if 'topp' in data else ""
241
+ temperature = str("\nTemperature: " + data['temperature']) if 'temperature' in data else ""
242
+ cfg_coef = str("\nClassifier Free Guidance: " + data['cfg_coef']) if 'cfg_coef' in data else ""
243
+ version = str("Version: " + data['version']) if 'version' in data else "Version: Unknown"
244
+ info = str(version + global_prompt + bpm + key + scale + prompts + duration + overlap + seed + audio_mode + input_length + channel + sr_select + model + custom_model + base_model + decoder + topk + topp + temperature + cfg_coef)
245
+ if info == "":
246
+ return "No tags found. Either the file is not generated by MusicGen+ V1.2.7 and higher or the tags are corrupted."
247
+ return info
248
+ else:
249
+ return "Only .wav ,.mp4 and .json files are supported"
250
+ else:
251
+ return None
252
+
253
+
254
+ def info_to_params(audio_path):
255
+ if audio_path is not None:
256
+ if audio_path.name.endswith(".wav") or audio_path.name.endswith(".mp4") or audio_path.name.endswith(".json"):
257
+ if not audio_path.name.endswith(".json"):
258
+ with taglib.File(audio_path.name, save_on_exit=False) as song:
259
+ if 'COMMENT' not in song.tags:
260
+ return "Default", False, "", 120, "C", "Major", "large", None, "medium", 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, "sample", 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"
261
+ json_string = song.tags['COMMENT'][0]
262
+ data = json.loads(json_string)
263
+ struc_prompt = (False if data['bpm'] == "none" else True) if 'bpm' in data else False
264
+ global_prompt = data['global_prompt'] if 'global_prompt' in data else ""
265
+ bpm = (120 if data['bpm'] == "none" else int(data['bpm'])) if 'bpm' in data else 120
266
+ key = ("C" if data['key'] == "none" else data['key']) if 'key' in data else "C"
267
+ scale = ("Major" if data['scale'] == "none" else data['scale']) if 'scale' in data else "Major"
268
+ model = data['model'] if 'model' in data else "large"
269
+ custom_model = (data['custom_model'] if data['custom_model'] in get_available_models() else None) if 'custom_model' in data else None
270
+ base_model = data['base_model'] if 'base_model' in data else "medium"
271
+ decoder = data['decoder'] if 'decoder' in data else "Default"
272
+ if 'texts' not in data:
273
+ unique_prompts = 1
274
+ text = ["", "", "", "", "", "", "", "", "", ""]
275
+ repeat = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
276
+ else:
277
+ s = data['texts']
278
+ s = re.findall(r"'(.*?)'", s)
279
+ text = []
280
+ repeat = []
281
+ i = 0
282
+ for elem in s:
283
+ if elem.strip():
284
+ if i == 0 or elem != s[i-1]:
285
+ text.append(elem)
286
+ repeat.append(1)
287
+ else:
288
+ repeat[-1] += 1
289
+ i += 1
290
+ text.extend([""] * (10 - len(text)))
291
+ repeat.extend([1] * (10 - len(repeat)))
292
+ unique_prompts = len([t for t in text if t])
293
+ audio_mode = ("sample" if data['audio_mode'] == "none" else data['audio_mode']) if 'audio_mode' in data else "sample"
294
+ duration = int(data['duration']) if 'duration' in data else 10
295
+ topk = float(data['topk']) if 'topk' in data else 250
296
+ topp = float(data['topp']) if 'topp' in data else 0
297
+ temperature = float(data['temperature']) if 'temperature' in data else 1.0
298
+ cfg_coef = float(data['cfg_coef']) if 'cfg_coef' in data else 5.0
299
+ seed = int(data['seed']) if 'seed' in data else -1
300
+ overlap = int(data['overlap']) if 'overlap' in data else 12
301
+ channel = data['channel'] if 'channel' in data else "stereo"
302
+ sr_select = data['sr_select'] if 'sr_select' in data else "48000"
303
+ return decoder, struc_prompt, global_prompt, bpm, key, scale, model, custom_model, base_model, unique_prompts, text[0], text[1], text[2], text[3], text[4], text[5], text[6], text[7], text[8], text[9], repeat[0], repeat[1], repeat[2], repeat[3], repeat[4], repeat[5], repeat[6], repeat[7], repeat[8], repeat[9], audio_mode, duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select
304
+ else:
305
+ with open(audio_path.name) as json_file:
306
+ data = json.load(json_file)
307
+ struc_prompt = (False if data['bpm'] == "none" else True) if 'bpm' in data else False
308
+ global_prompt = data['global_prompt'] if 'global_prompt' in data else ""
309
+ bpm = (120 if data['bpm'] == "none" else int(data['bpm'])) if 'bpm' in data else 120
310
+ key = ("C" if data['key'] == "none" else data['key']) if 'key' in data else "C"
311
+ scale = ("Major" if data['scale'] == "none" else data['scale']) if 'scale' in data else "Major"
312
+ model = data['model'] if 'model' in data else "large"
313
+ custom_model = (data['custom_model'] if data['custom_model'] in get_available_models() else None) if 'custom_model' in data else None
314
+ base_model = data['base_model'] if 'base_model' in data else "medium"
315
+ decoder = data['decoder'] if 'decoder' in data else "Default"
316
+ if 'texts' not in data:
317
+ unique_prompts = 1
318
+ text = ["", "", "", "", "", "", "", "", "", ""]
319
+ repeat = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
320
+ else:
321
+ s = data['texts']
322
+ s = re.findall(r"'(.*?)'", s)
323
+ text = []
324
+ repeat = []
325
+ i = 0
326
+ for elem in s:
327
+ if elem.strip():
328
+ if i == 0 or elem != s[i-1]:
329
+ text.append(elem)
330
+ repeat.append(1)
331
+ else:
332
+ repeat[-1] += 1
333
+ i += 1
334
+ text.extend([""] * (10 - len(text)))
335
+ repeat.extend([1] * (10 - len(repeat)))
336
+ unique_prompts = len([t for t in text if t])
337
+ audio_mode = ("sample" if data['audio_mode'] == "none" else data['audio_mode']) if 'audio_mode' in data else "sample"
338
+ duration = int(data['duration']) if 'duration' in data else 10
339
+ topk = float(data['topk']) if 'topk' in data else 250
340
+ topp = float(data['topp']) if 'topp' in data else 0
341
+ temperature = float(data['temperature']) if 'temperature' in data else 1.0
342
+ cfg_coef = float(data['cfg_coef']) if 'cfg_coef' in data else 5.0
343
+ seed = int(data['seed']) if 'seed' in data else -1
344
+ overlap = int(data['overlap']) if 'overlap' in data else 12
345
+ channel = data['channel'] if 'channel' in data else "stereo"
346
+ sr_select = data['sr_select'] if 'sr_select' in data else "48000"
347
+ return decoder, struc_prompt, global_prompt, bpm, key, scale, model, custom_model, base_model, unique_prompts, text[0], text[1], text[2], text[3], text[4], text[5], text[6], text[7], text[8], text[9], repeat[0], repeat[1], repeat[2], repeat[3], repeat[4], repeat[5], repeat[6], repeat[7], repeat[8], repeat[9], audio_mode, duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select
348
+ else:
349
+ return "Default", False, "", 120, "C", "Major", "large", None, "medium", 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, "sample", 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"
350
+ else:
351
+ return "Default", False, "", 120, "C", "Major", "large", None, "medium", 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, "sample", 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"
352
+
353
+
354
+ def info_to_params_a(audio_path):
355
+ if audio_path is not None:
356
+ if audio_path.name.endswith(".wav") or audio_path.name.endswith(".mp4") or audio_path.name.endswith(".json"):
357
+ if not audio_path.name.endswith(".json"):
358
+ with taglib.File(audio_path.name, save_on_exit=False) as song:
359
+ if 'COMMENT' not in song.tags:
360
+ return "Default", False, "", 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"
361
+ json_string = song.tags['COMMENT'][0]
362
+ data = json.loads(json_string)
363
+ struc_prompt = (False if data['global_prompt'] == "" else True) if 'global_prompt' in data else False
364
+ global_prompt = data['global_prompt'] if 'global_prompt' in data else ""
365
+ decoder = data['decoder'] if 'decoder' in data else "Default"
366
+ if 'texts' not in data:
367
+ unique_prompts = 1
368
+ text = ["", "", "", "", "", "", "", "", "", ""]
369
+ repeat = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
370
+ else:
371
+ s = data['texts']
372
+ s = re.findall(r"'(.*?)'", s)
373
+ text = []
374
+ repeat = []
375
+ i = 0
376
+ for elem in s:
377
+ if elem.strip():
378
+ if i == 0 or elem != s[i-1]:
379
+ text.append(elem)
380
+ repeat.append(1)
381
+ else:
382
+ repeat[-1] += 1
383
+ i += 1
384
+ text.extend([""] * (10 - len(text)))
385
+ repeat.extend([1] * (10 - len(repeat)))
386
+ unique_prompts = len([t for t in text if t])
387
+ duration = int(data['duration']) if 'duration' in data else 10
388
+ topk = float(data['topk']) if 'topk' in data else 250
389
+ topp = float(data['topp']) if 'topp' in data else 0
390
+ temperature = float(data['temperature']) if 'temperature' in data else 1.0
391
+ cfg_coef = float(data['cfg_coef']) if 'cfg_coef' in data else 5.0
392
+ seed = int(data['seed']) if 'seed' in data else -1
393
+ overlap = int(data['overlap']) if 'overlap' in data else 12
394
+ channel = data['channel'] if 'channel' in data else "stereo"
395
+ sr_select = data['sr_select'] if 'sr_select' in data else "48000"
396
+ return decoder, struc_prompt, global_prompt, unique_prompts, text[0], text[1], text[2], text[3], text[4], text[5], text[6], text[7], text[8], text[9], repeat[0], repeat[1], repeat[2], repeat[3], repeat[4], repeat[5], repeat[6], repeat[7], repeat[8], repeat[9], duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select
397
+ else:
398
+ with open(audio_path.name) as json_file:
399
+ data = json.load(json_file)
400
+ struc_prompt = (False if data['global_prompt'] == "" else True) if 'global_prompt' in data else False
401
+ global_prompt = data['global_prompt'] if 'global_prompt' in data else ""
402
+ decoder = data['decoder'] if 'decoder' in data else "Default"
403
+ if 'texts' not in data:
404
+ unique_prompts = 1
405
+ text = ["", "", "", "", "", "", "", "", "", ""]
406
+ repeat = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
407
+ else:
408
+ s = data['texts']
409
+ s = re.findall(r"'(.*?)'", s)
410
+ text = []
411
+ repeat = []
412
+ i = 0
413
+ for elem in s:
414
+ if elem.strip():
415
+ if i == 0 or elem != s[i-1]:
416
+ text.append(elem)
417
+ repeat.append(1)
418
+ else:
419
+ repeat[-1] += 1
420
+ i += 1
421
+ text.extend([""] * (10 - len(text)))
422
+ repeat.extend([1] * (10 - len(repeat)))
423
+ unique_prompts = len([t for t in text if t])
424
+ duration = int(data['duration']) if 'duration' in data else 10
425
+ topk = float(data['topk']) if 'topk' in data else 250
426
+ topp = float(data['topp']) if 'topp' in data else 0
427
+ temperature = float(data['temperature']) if 'temperature' in data else 1.0
428
+ cfg_coef = float(data['cfg_coef']) if 'cfg_coef' in data else 5.0
429
+ seed = int(data['seed']) if 'seed' in data else -1
430
+ overlap = int(data['overlap']) if 'overlap' in data else 12
431
+ channel = data['channel'] if 'channel' in data else "stereo"
432
+ sr_select = data['sr_select'] if 'sr_select' in data else "48000"
433
+ return decoder, struc_prompt, global_prompt, unique_prompts, text[0], text[1], text[2], text[3], text[4], text[5], text[6], text[7], text[8], text[9], repeat[0], repeat[1], repeat[2], repeat[3], repeat[4], repeat[5], repeat[6], repeat[7], repeat[8], repeat[9], duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select
434
+
435
+ else:
436
+ return "Default", False, "", 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"
437
+ else:
438
+ return "Default", False, "", 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"
439
+
440
+
441
+ def make_pseudo_stereo (filename, sr_select, pan, delay):
442
+ if pan:
443
+ temp = AudioSegment.from_wav(filename)
444
+ if sr_select != "32000":
445
+ temp = temp.set_frame_rate(int(sr_select))
446
+ left = temp.pan(-0.5) - 5
447
+ right = temp.pan(0.6) - 5
448
+ temp = left.overlay(right, position=5)
449
+ temp.export(filename, format="wav")
450
+ if delay:
451
+ waveform, sample_rate = torchaudio.load(filename) # load mono WAV file
452
+ delay_seconds = 0.01 # set delay 10ms
453
+ delay_samples = int(delay_seconds * sample_rate) # Calculating delay value in number of samples
454
+ stereo_waveform = torch.stack([waveform[0], torch.cat((torch.zeros(delay_samples), waveform[0][:-delay_samples]))]) # Generate a stereo file with original mono audio and delayed version
455
+ torchaudio.save(filename, stereo_waveform, sample_rate)
456
+ return
457
+
458
+
459
+ def normalize_audio(audio_data):
460
+ audio_data = audio_data.astype(np.float32)
461
+ max_value = np.max(np.abs(audio_data))
462
+ audio_data /= max_value
463
+ return audio_data
464
+
465
+
466
+ def load_diffusion():
467
+ global MBD
468
+ if MBD is None:
469
+ print("loading MBD")
470
+ MBD = MultiBandDiffusion.get_mbd_musicgen()
471
+
472
+
473
+ def unload_diffusion():
474
+ global MBD
475
+ if MBD is not None:
476
+ print("unloading MBD")
477
+ MBD = None
478
+
479
+
480
+ def _do_predictions(gen_type, texts, melodies, sample, trim_start, trim_end, duration, image, height, width, background, bar1, bar2, channel, sr_select, progress=False, **gen_kwargs):
481
+ if gen_type == "music":
482
+ maximum_size = 29.5
483
+ elif gen_type == "audio":
484
+ maximum_size = 9.5
485
+ cut_size = 0
486
+ input_length = 0
487
+ sampleP = None
488
+ if sample is not None:
489
+ globalSR, sampleM = sample[0], sample[1]
490
+ sampleM = normalize_audio(sampleM)
491
+ sampleM = torch.from_numpy(sampleM).t()
492
+ if sampleM.dim() == 1:
493
+ sampleM = sampleM.unsqueeze(0)
494
+ sample_length = sampleM.shape[sampleM.dim() - 1] / globalSR
495
+ if trim_start >= sample_length:
496
+ trim_start = sample_length - 0.5
497
+ if trim_end >= sample_length:
498
+ trim_end = sample_length - 0.5
499
+ if trim_start + trim_end >= sample_length:
500
+ tmp = sample_length - 0.5
501
+ trim_start = tmp / 2
502
+ trim_end = tmp / 2
503
+ sampleM = sampleM[..., int(globalSR * trim_start):int(globalSR * (sample_length - trim_end))]
504
+ sample_length = sample_length - (trim_start + trim_end)
505
+ if sample_length > maximum_size:
506
+ cut_size = sample_length - maximum_size
507
+ sampleP = sampleM[..., :int(globalSR * cut_size)]
508
+ sampleM = sampleM[..., int(globalSR * cut_size):]
509
+ if sample_length >= duration:
510
+ duration = sample_length + 0.5
511
+ input_length = sample_length
512
+ global MODEL
513
+ MODEL.set_generation_params(duration=(duration - cut_size), **gen_kwargs)
514
+ print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies], [None if sample is None else (sample[0], sample[1].shape)])
515
+ be = time.time()
516
+ processed_melodies = []
517
+ if gen_type == "music":
518
+ target_sr = 32000
519
+ elif gen_type == "audio":
520
+ target_sr = 16000
521
+ target_ac = 1
522
+
523
+ for melody in melodies:
524
+ if melody is None:
525
+ processed_melodies.append(None)
526
+ else:
527
+ sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
528
+ if melody.dim() == 1:
529
+ melody = melody[None]
530
+ melody = melody[..., :int(sr * duration)]
531
+ melody = convert_audio(melody, sr, target_sr, target_ac)
532
+ processed_melodies.append(melody)
533
+
534
+ if sample is not None:
535
+ if sampleP is None:
536
+ if gen_type == "music":
537
+ outputs = MODEL.generate_continuation(
538
+ prompt=sampleM,
539
+ prompt_sample_rate=globalSR,
540
+ descriptions=texts,
541
+ progress=progress,
542
+ return_tokens=USE_DIFFUSION
543
+ )
544
+ elif gen_type == "audio":
545
+ outputs = MODEL.generate_continuation(
546
+ prompt=sampleM,
547
+ prompt_sample_rate=globalSR,
548
+ descriptions=texts,
549
+ progress=progress
550
+ )
551
+ else:
552
+ if sampleP.dim() > 1:
553
+ sampleP = convert_audio(sampleP, globalSR, target_sr, target_ac)
554
+ sampleP = sampleP.to(MODEL.device).float().unsqueeze(0)
555
+ if gen_type == "music":
556
+ outputs = MODEL.generate_continuation(
557
+ prompt=sampleM,
558
+ prompt_sample_rate=globalSR,
559
+ descriptions=texts,
560
+ progress=progress,
561
+ return_tokens=USE_DIFFUSION
562
+ )
563
+ elif gen_type == "audio":
564
+ outputs = MODEL.generate_continuation(
565
+ prompt=sampleM,
566
+ prompt_sample_rate=globalSR,
567
+ descriptions=texts,
568
+ progress=progress
569
+ )
570
+ outputs = torch.cat([sampleP, outputs], 2)
571
+
572
+ elif any(m is not None for m in processed_melodies):
573
+ if gen_type == "music":
574
+ outputs = MODEL.generate_with_chroma(
575
+ descriptions=texts,
576
+ melody_wavs=processed_melodies,
577
+ melody_sample_rate=target_sr,
578
+ progress=progress,
579
+ return_tokens=USE_DIFFUSION
580
+ )
581
+ elif gen_type == "audio":
582
+ outputs = MODEL.generate_with_chroma(
583
+ descriptions=texts,
584
+ melody_wavs=processed_melodies,
585
+ melody_sample_rate=target_sr,
586
+ progress=progress
587
+ )
588
+ else:
589
+ if gen_type == "music":
590
+ outputs = MODEL.generate(texts, progress=progress, return_tokens=USE_DIFFUSION)
591
+ elif gen_type == "audio":
592
+ outputs = MODEL.generate(texts, progress=progress)
593
+
594
+ if USE_DIFFUSION:
595
+ print("outputs: " + str(outputs))
596
+ outputs_diffusion = MBD.tokens_to_wav(outputs[1])
597
+ outputs = torch.cat([outputs[0], outputs_diffusion], dim=0)
598
+ outputs = outputs.detach().cpu().float()
599
+ backups = outputs
600
+ if channel == "stereo":
601
+ outputs = convert_audio(outputs, target_sr, int(sr_select), 2)
602
+ elif channel == "mono" and sr_select != "32000":
603
+ outputs = convert_audio(outputs, target_sr, int(sr_select), 1)
604
+ out_files = []
605
+ out_audios = []
606
+ out_backup = []
607
+ for output in outputs:
608
+ with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
609
+ audio_write(
610
+ file.name, output, (MODEL.sample_rate if channel == "stereo effect" else int(sr_select)), strategy="loudness",
611
+ loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
612
+
613
+ if channel == "stereo effect":
614
+ make_pseudo_stereo(file.name, sr_select, pan=True, delay=True);
615
+
616
+ out_files.append(pool.submit(make_waveform, file.name, bg_image=image, bg_color=background, bars_color=(bar1, bar2), fg_alpha=1.0, bar_count=75, height=height, width=width))
617
+ out_audios.append(file.name)
618
+ file_cleaner.add(file.name)
619
+ print(f'wav: {file.name}')
620
+ for backup in backups:
621
+ with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
622
+ audio_write(
623
+ file.name, backup, MODEL.sample_rate, strategy="loudness",
624
+ loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
625
+ out_backup.append(file.name)
626
+ file_cleaner.add(file.name)
627
+ res = [out_file.result() for out_file in out_files]
628
+ res_audio = out_audios
629
+ res_backup = out_backup
630
+ for file in res:
631
+ file_cleaner.add(file)
632
+ print(f'video: {file}')
633
+ print("batch finished", len(texts), time.time() - be)
634
+ print("Tempfiles currently stored: ", len(file_cleaner.files))
635
+ if MOVE_TO_CPU:
636
+ MODEL.to('cpu')
637
+ if UNLOAD_MODEL:
638
+ MODEL = None
639
+ torch.cuda.empty_cache()
640
+ torch.cuda.ipc_collect()
641
+ return res, res_audio, res_backup, input_length
642
+
643
+
644
+ def predict_batched(texts, melodies):
645
+ max_text_length = 512
646
+ texts = [text[:max_text_length] for text in texts]
647
+ load_model('melody')
648
+ res = _do_predictions(texts, melodies, BATCHED_DURATION)
649
+ return res
650
+
651
+
652
+ def add_tags(filename, tags):
653
+ json_string = None
654
+
655
+ data = {
656
+ "global_prompt": tags[0],
657
+ "bpm": tags[1],
658
+ "key": tags[2],
659
+ "scale": tags[3],
660
+ "texts": tags[4],
661
+ "duration": tags[5],
662
+ "overlap": tags[6],
663
+ "seed": tags[7],
664
+ "audio_mode": tags[8],
665
+ "input_length": tags[9],
666
+ "channel": tags[10],
667
+ "sr_select": tags[11],
668
+ "model": tags[12],
669
+ "custom_model": tags[13],
670
+ "base_model": tags[14],
671
+ "decoder": tags[15],
672
+ "topk": tags[16],
673
+ "topp": tags[17],
674
+ "temperature": tags[18],
675
+ "cfg_coef": tags[19],
676
+ "generator": tags[20],
677
+ "version": version
678
+ }
679
+
680
+ json_string = json.dumps(data)
681
+
682
+ if os.path.exists(filename):
683
+ with taglib.File(filename, save_on_exit=True) as song:
684
+ song.tags = {'COMMENT': json_string }
685
+
686
+ json_file = open(tags[7] + '.json', 'w')
687
+ json_file.write(json_string)
688
+ json_file.close()
689
+
690
+ return json_file.name;
691
+
692
+
693
+ def save_outputs(mp4, wav_tmp, tags, gen_type):
694
+ # mp4: .mp4 file name in root running folder of app.py
695
+ # wav_tmp: temporary wav file located in %TEMP% folder
696
+ # seed - used seed
697
+ # exanple BgnJtr4Pn1AJ.mp4, C:\Users\Alex\AppData\Local\Temp\tmp4ermrebs.wav, 195123182343465
698
+ # procedure read generated .mp4 and wav files, rename it by using seed as name,
699
+ # and will store it to ./output/today_date/wav and ./output/today_date/mp4 folders.
700
+ # if file with same seed number already exist its make postfix in name like seed(n)
701
+ # where is n - consiqunce number 1-2-3-4 and so on
702
+ # then we store generated mp4 and wav into destination folders.
703
+
704
+ current_date = datetime.now().strftime("%Y%m%d")
705
+ wav_directory = os.path.join(os.getcwd(), 'output', current_date, gen_type,'wav')
706
+ mp4_directory = os.path.join(os.getcwd(), 'output', current_date, gen_type,'mp4')
707
+ json_directory = os.path.join(os.getcwd(), 'output', current_date, gen_type,'json')
708
+ os.makedirs(wav_directory, exist_ok=True)
709
+ os.makedirs(mp4_directory, exist_ok=True)
710
+ os.makedirs(json_directory, exist_ok=True)
711
+
712
+ filename = str(tags[7]) + '.wav'
713
+ target = os.path.join(wav_directory, filename)
714
+ counter = 1
715
+ while os.path.exists(target):
716
+ filename = str(tags[7]) + f'({counter})' + '.wav'
717
+ target = os.path.join(wav_directory, filename)
718
+ counter += 1
719
+
720
+ shutil.copyfile(wav_tmp, target); # make copy of original file
721
+ json_file = add_tags(target, tags);
722
+
723
+ wav_target=target;
724
+ target=target.replace('wav', 'mp4');
725
+ mp4_target=target;
726
+
727
+ mp4=r'./' +mp4;
728
+ shutil.copyfile(mp4, target); # make copy of original file
729
+ _ = add_tags(target, tags);
730
+
731
+ target=target.replace('mp4', 'json'); # change the extension to json
732
+ json_target=target; # store the json target
733
+
734
+ with open(target, 'w') as f: # open a writable file object
735
+ shutil.copyfile(json_file, target); # make copy of original file
736
+
737
+ os.remove(json_file)
738
+
739
+ return wav_target, mp4_target, json_target;
740
+
741
+
742
+ def clear_cash():
743
+ # delete all temporary files genegated my system
744
+ current_date = datetime.now().date()
745
+ current_directory = os.getcwd()
746
+ files = glob.glob(os.path.join(current_directory, '*.mp4'))
747
+ for file in files:
748
+ creation_date = datetime.fromtimestamp(os.path.getctime(file)).date()
749
+ if creation_date == current_date:
750
+ os.remove(file)
751
+
752
+ temp_directory = os.environ.get('TEMP')
753
+ files = glob.glob(os.path.join(temp_directory, 'tmp*.mp4'))
754
+ for file in files:
755
+ creation_date = datetime.fromtimestamp(os.path.getctime(file)).date()
756
+ if creation_date == current_date:
757
+ os.remove(file)
758
+
759
+ files = glob.glob(os.path.join(temp_directory, 'tmp*.wav'))
760
+ for file in files:
761
+ creation_date = datetime.fromtimestamp(os.path.getctime(file)).date()
762
+ if creation_date == current_date:
763
+ os.remove(file)
764
+
765
+ files = glob.glob(os.path.join(temp_directory, 'tmp*.png'))
766
+ for file in files:
767
+ creation_date = datetime.fromtimestamp(os.path.getctime(file)).date()
768
+ if creation_date == current_date:
769
+ os.remove(file)
770
+ return
771
+
772
+
773
+ def s2t(seconds, seconds2):
774
+ # convert seconds to time format
775
+ # seconds - time in seconds
776
+ # return time in format 00:00
777
+ m, s = divmod(seconds, 60)
778
+ m2, s2 = divmod(seconds2, 60)
779
+ if seconds != 0 and seconds < seconds2:
780
+ s = s + 1
781
+ return ("%02d:%02d - %02d:%02d" % (m, s, m2, s2))
782
+
783
+
784
+ def calc_time(gen_type, s, duration, overlap, d0, d1, d2, d3, d4, d5, d6, d7, d8, d9):
785
+ # calculate the time of generation
786
+ # overlap - overlap in seconds
787
+ # d0-d9 - drag
788
+ # return time in seconds
789
+ d_amount = [int(d0), int(d1), int(d2), int(d3), int(d4), int(d5), int(d6), int(d7), int(d8), int(d9)]
790
+ calc = []
791
+ tracks = []
792
+ time = 0
793
+ s = s - 1
794
+ max_time = duration
795
+ max_limit = 0
796
+ if gen_type == "music":
797
+ max_limit = 30
798
+ elif gen_type == "audio":
799
+ max_limit = 10
800
+ track_add = max_limit - overlap
801
+ tracks.append(max_limit + ((d_amount[0] - 1) * track_add))
802
+ for i in range(1, 10):
803
+ tracks.append(d_amount[i] * track_add)
804
+
805
+ if tracks[0] >= max_time or s == 0:
806
+ calc.append(s2t(time, max_time))
807
+ time = max_time
808
+ else:
809
+ calc.append(s2t(time, tracks[0]))
810
+ time = tracks[0]
811
+
812
+ for i in range(1, 10):
813
+ if time + tracks[i] >= max_time or i == s:
814
+ calc.append(s2t(time, max_time))
815
+ time = max_time
816
+ else:
817
+ calc.append(s2t(time, time + tracks[i]))
818
+ time = time + tracks[i]
819
+
820
+ return calc[0], calc[1], calc[2], calc[3], calc[4], calc[5], calc[6], calc[7], calc[8], calc[9]
821
+
822
+
823
+ def predict_full(gen_type, model, decoder, custom_model, base_model, prompt_amount, struc_prompt, bpm, key, scale, global_prompt, p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, audio, mode, trim_start, trim_end, duration, topk, topp, temperature, cfg_coef, seed, overlap, image, height, width, background, bar1, bar2, channel, sr_select, progress=gr.Progress()):
824
+ global INTERRUPTING
825
+ global USE_DIFFUSION
826
+ INTERRUPTING = False
827
+
828
+ if gen_type == "audio":
829
+ custom_model = None
830
+ base_model = "medium"
831
+
832
+ if temperature < 0:
833
+ raise gr.Error("Temperature must be >= 0.")
834
+ if topk < 0:
835
+ raise gr.Error("Topk must be non-negative.")
836
+ if topp < 0:
837
+ raise gr.Error("Topp must be non-negative.")
838
+
839
+ if trim_start < 0:
840
+ trim_start = 0
841
+ if trim_end < 0:
842
+ trim_end = 0
843
+
844
+ topk = int(topk)
845
+
846
+ if decoder == "MultiBand_Diffusion":
847
+ USE_DIFFUSION = True
848
+ load_diffusion()
849
+ else:
850
+ USE_DIFFUSION = False
851
+ unload_diffusion()
852
+
853
+ if gen_type == "music":
854
+ model_shrt = model
855
+ model = "GrandaddyShmax/musicgen-" + model
856
+ elif gen_type == "audio":
857
+ model_shrt = model
858
+ model = "GrandaddyShmax/audiogen-" + model
859
+ base_model_shrt = base_model
860
+ base_model = "GrandaddyShmax/musicgen-" + base_model
861
+
862
+ if MODEL is None or MODEL.name != (model):
863
+ load_model(model, custom_model, base_model, gen_type)
864
+ else:
865
+ if MOVE_TO_CPU:
866
+ MODEL.to('cuda')
867
+
868
+ if seed < 0:
869
+ seed = random.randint(0, 0xffff_ffff_ffff)
870
+ torch.manual_seed(seed)
871
+
872
+ def _progress(generated, to_generate):
873
+ progress((min(generated, to_generate), to_generate))
874
+ if INTERRUPTING:
875
+ raise gr.Error("Interrupted.")
876
+ MODEL.set_custom_progress_callback(_progress)
877
+
878
+ audio_mode = "none"
879
+ melody = None
880
+ sample = None
881
+ if audio:
882
+ audio_mode = mode
883
+ if mode == "sample":
884
+ sample = audio
885
+ elif mode == "melody":
886
+ melody = audio
887
+
888
+ base_model = "none" if model != "custom" else base_model
889
+ custom_model = "none" if model != "custom" else custom_model
890
+
891
+ text_cat = [p0, p1, p2, p3, p4, p5, p6, p7, p8, p9]
892
+ drag_cat = [d0, d1, d2, d3, d4, d5, d6, d7, d8, d9]
893
+ texts = []
894
+ raw_texts = []
895
+ ind = 0
896
+ ind2 = 0
897
+ while ind < prompt_amount:
898
+ for ind2 in range(int(drag_cat[ind])):
899
+ if not struc_prompt:
900
+ texts.append(text_cat[ind])
901
+ global_prompt = "none"
902
+ bpm = "none"
903
+ key = "none"
904
+ scale = "none"
905
+ raw_texts.append(text_cat[ind])
906
+ else:
907
+ if gen_type == "music":
908
+ bpm_str = str(bpm) + " bpm"
909
+ key_str = ", " + str(key) + " " + str(scale)
910
+ global_str = (", " + str(global_prompt)) if str(global_prompt) != "" else ""
911
+ elif gen_type == "audio":
912
+ bpm_str = ""
913
+ key_str = ""
914
+ global_str = (str(global_prompt)) if str(global_prompt) != "" else ""
915
+ texts_str = (", " + str(text_cat[ind])) if str(text_cat[ind]) != "" else ""
916
+ texts.append(bpm_str + key_str + global_str + texts_str)
917
+ raw_texts.append(text_cat[ind])
918
+ ind2 = 0
919
+ ind = ind + 1
920
+
921
+ outs, outs_audio, outs_backup, input_length = _do_predictions(
922
+ gen_type, [texts], [melody], sample, trim_start, trim_end, duration, image, height, width, background, bar1, bar2, channel, sr_select, progress=True,
923
+ top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef, extend_stride=MODEL.max_duration-overlap)
924
+ tags = [str(global_prompt), str(bpm), str(key), str(scale), str(raw_texts), str(duration), str(overlap), str(seed), str(audio_mode), str(input_length), str(channel), str(sr_select), str(model_shrt), str(custom_model), str(base_model_shrt), str(decoder), str(topk), str(topp), str(temperature), str(cfg_coef), str(gen_type)]
925
+ wav_target, mp4_target, json_target = save_outputs(outs[0], outs_audio[0], tags, gen_type);
926
+ # Removes the temporary files.
927
+ for out in outs:
928
+ os.remove(out)
929
+ for out in outs_audio:
930
+ os.remove(out)
931
+
932
+ return mp4_target, wav_target, outs_backup[0], [mp4_target, wav_target, json_target], seed
933
+
934
+
935
+ max_textboxes = 10
936
+
937
+
938
+ def get_available_models():
939
+ return sorted([re.sub('.pt$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('.pt')])
940
+
941
+
942
+ def toggle_audio_src(choice):
943
+ if choice == "mic":
944
+ return gr.update(source="microphone", value=None, label="Microphone")
945
+ else:
946
+ return gr.update(source="upload", value=None, label="File")
947
+
948
+
949
+ def ui_full(launch_kwargs):
950
+ with gr.Blocks(title='AudioCraft Plus', theme=theme) as interface:
951
+ gr.Markdown(
952
+ """
953
+ # AudioCraft Plus - v2.0.0a
954
+
955
+ ### An All-in-One AudioCraft WebUI
956
+
957
+ #### **Disclaimer:** This will not run on CPU only. Its best to clone this App and run on GPU instance!
958
+ **Alternatively**, you can run this for free on a google colab:
959
+ https://colab.research.google.com/github/camenduru/MusicGen-colab/blob/main/MusicGen_ClownOfMadness_plus_colab.ipynb
960
+
961
+ **Or**, run this locally on your PC:
962
+ https://github.com/GrandaddyShmax/audiocraft_plus/tree/main
963
+
964
+ Thanks to: facebookresearch, Camenduru, rkfg, oobabooga, AlexHK and GrandaddyShmax
965
+ """
966
+ )
967
+ with gr.Tab("MusicGen"):
968
+ gr.Markdown(
969
+ """
970
+ ### MusicGen
971
+ """
972
+ )
973
+ with gr.Row():
974
+ with gr.Column():
975
+ with gr.Tab("Generation"):
976
+ with gr.Accordion("Structure Prompts", open=False):
977
+ with gr.Column():
978
+ with gr.Row():
979
+ struc_prompts = gr.Checkbox(label="Enable", value=False, interactive=True, container=False)
980
+ bpm = gr.Number(label="BPM", value=120, interactive=True, scale=1, precision=0)
981
+ key = gr.Dropdown(["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "Bb", "B"], label="Key", value="C", interactive=True)
982
+ scale = gr.Dropdown(["Major", "Minor"], label="Scale", value="Major", interactive=True)
983
+ with gr.Row():
984
+ global_prompt = gr.Text(label="Global Prompt", interactive=True, scale=3)
985
+ with gr.Row():
986
+ s = gr.Slider(1, max_textboxes, value=1, step=1, label="Prompts:", interactive=True, scale=2)
987
+ #s_mode = gr.Radio(["segmentation", "batch"], value="segmentation", interactive=True, scale=1, label="Generation Mode")
988
+ with gr.Column():
989
+ textboxes = []
990
+ prompts = []
991
+ repeats = []
992
+ calcs = []
993
+ with gr.Row():
994
+ text0 = gr.Text(label="Input Text", interactive=True, scale=4)
995
+ prompts.append(text0)
996
+ drag0 = gr.Number(label="Repeat", value=1, interactive=True, scale=1)
997
+ repeats.append(drag0)
998
+ calc0 = gr.Text(interactive=False, value="00:00 - 00:00", scale=1, label="Time")
999
+ calcs.append(calc0)
1000
+ for i in range(max_textboxes):
1001
+ with gr.Row(visible=False) as t:
1002
+ text = gr.Text(label="Input Text", interactive=True, scale=3)
1003
+ repeat = gr.Number(label="Repeat", minimum=1, value=1, interactive=True, scale=1)
1004
+ calc = gr.Text(interactive=False, value="00:00 - 00:00", scale=1, label="Time")
1005
+ textboxes.append(t)
1006
+ prompts.append(text)
1007
+ repeats.append(repeat)
1008
+ calcs.append(calc)
1009
+ to_calc = gr.Button("Calculate Timings", variant="secondary")
1010
+ with gr.Row():
1011
+ duration = gr.Slider(minimum=1, maximum=300, value=10, step=1, label="Duration", interactive=True)
1012
+ with gr.Row():
1013
+ overlap = gr.Slider(minimum=1, maximum=29, value=12, step=1, label="Overlap", interactive=True)
1014
+ with gr.Row():
1015
+ seed = gr.Number(label="Seed", value=-1, scale=4, precision=0, interactive=True)
1016
+ gr.Button('\U0001f3b2\ufe0f', scale=1).click(fn=lambda: -1, outputs=[seed], queue=False)
1017
+ reuse_seed = gr.Button('\u267b\ufe0f', scale=1)
1018
+
1019
+ with gr.Tab("Audio"):
1020
+ with gr.Row():
1021
+ with gr.Column():
1022
+ input_type = gr.Radio(["file", "mic"], value="file", label="Input Type (optional)", interactive=True)
1023
+ mode = gr.Radio(["melody", "sample"], label="Input Audio Mode (optional)", value="sample", interactive=True)
1024
+ with gr.Row():
1025
+ trim_start = gr.Number(label="Trim Start", value=0, interactive=True)
1026
+ trim_end = gr.Number(label="Trim End", value=0, interactive=True)
1027
+ audio = gr.Audio(source="upload", type="numpy", label="Input Audio (optional)", interactive=True)
1028
+
1029
+ with gr.Tab("Customization"):
1030
+ with gr.Row():
1031
+ with gr.Column():
1032
+ background = gr.ColorPicker(value="#0f0f0f", label="background color", interactive=True, scale=0)
1033
+ bar1 = gr.ColorPicker(value="#84cc16", label="bar color start", interactive=True, scale=0)
1034
+ bar2 = gr.ColorPicker(value="#10b981", label="bar color end", interactive=True, scale=0)
1035
+ with gr.Column():
1036
+ image = gr.Image(label="Background Image", type="filepath", interactive=True, scale=4)
1037
+ with gr.Row():
1038
+ height = gr.Number(label="Height", value=512, interactive=True)
1039
+ width = gr.Number(label="Width", value=768, interactive=True)
1040
+
1041
+ with gr.Tab("Settings"):
1042
+ with gr.Row():
1043
+ channel = gr.Radio(["mono", "stereo", "stereo effect"], label="Output Audio Channels", value="stereo", interactive=True, scale=1)
1044
+ sr_select = gr.Dropdown(["11025", "16000", "22050", "24000", "32000", "44100", "48000"], label="Output Audio Sample Rate", value="48000", interactive=True)
1045
+ with gr.Row():
1046
+ model = gr.Radio(["melody", "small", "medium", "large", "custom"], label="Model", value="large", interactive=True, scale=1)
1047
+ with gr.Column():
1048
+ dropdown = gr.Dropdown(choices=get_available_models(), value=("No models found" if len(get_available_models()) < 1 else get_available_models()[0]), label='Custom Model (models folder)', elem_classes='slim-dropdown', interactive=True)
1049
+ ui.create_refresh_button(dropdown, lambda: None, lambda: {'choices': get_available_models()}, 'refresh-button')
1050
+ basemodel = gr.Radio(["small", "medium", "melody", "large"], label="Base Model", value="medium", interactive=True, scale=1)
1051
+ with gr.Row():
1052
+ decoder = gr.Radio(["Default", "MultiBand_Diffusion"], label="Decoder", value="Default", interactive=True)
1053
+ with gr.Row():
1054
+ topk = gr.Number(label="Top-k", value=250, interactive=True)
1055
+ topp = gr.Number(label="Top-p", value=0, interactive=True)
1056
+ temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
1057
+ cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
1058
+ with gr.Row():
1059
+ submit = gr.Button("Generate", variant="primary")
1060
+ # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
1061
+ _ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
1062
+ with gr.Column() as c:
1063
+ with gr.Tab("Output"):
1064
+ output = gr.Video(label="Generated Music", scale=0)
1065
+ with gr.Row():
1066
+ audio_only = gr.Audio(type="numpy", label="Audio Only", interactive=False)
1067
+ backup_only = gr.Audio(type="numpy", label="Backup Audio", interactive=False, visible=False)
1068
+ send_audio = gr.Button("Send to Input Audio")
1069
+ seed_used = gr.Number(label='Seed used', value=-1, interactive=False)
1070
+ download = gr.File(label="Generated Files", interactive=False)
1071
+ with gr.Tab("Wiki"):
1072
+ gr.Markdown(
1073
+ """
1074
+ - **[Generate (button)]:**
1075
+ Generates the music with the given settings and prompts.
1076
+
1077
+ - **[Interrupt (button)]:**
1078
+ Stops the music generation as soon as it can, providing an incomplete output.
1079
+
1080
+ ---
1081
+
1082
+ ### Generation Tab:
1083
+
1084
+ #### Structure Prompts:
1085
+
1086
+ This feature helps reduce repetetive prompts by allowing you to set global prompts
1087
+ that will be used for all prompt segments.
1088
+
1089
+ - **[Structure Prompts (checkbox)]:**
1090
+ Enable/Disable the structure prompts feature.
1091
+
1092
+ - **[BPM (number)]:**
1093
+ Beats per minute of the generated music.
1094
+
1095
+ - **[Key (dropdown)]:**
1096
+ The key of the generated music.
1097
+
1098
+ - **[Scale (dropdown)]:**
1099
+ The scale of the generated music.
1100
+
1101
+ - **[Global Prompt (text)]:**
1102
+ Here write the prompt that you wish to be used for all prompt segments.
1103
+
1104
+ #### Multi-Prompt:
1105
+
1106
+ This feature allows you to control the music, adding variation to different time segments.
1107
+ You have up to 10 prompt segments. the first prompt will always be 30s long
1108
+ the other prompts will be [30s - overlap].
1109
+ for example if the overlap is 10s, each prompt segment will be 20s.
1110
+
1111
+ - **[Prompt Segments (number)]:**
1112
+ Amount of unique prompt to generate throughout the music generation.
1113
+
1114
+ - **[Prompt/Input Text (prompt)]:**
1115
+ Here describe the music you wish the model to generate.
1116
+
1117
+ - **[Repeat (number)]:**
1118
+ Write how many times this prompt will repeat (instead of wasting another prompt segment on the same prompt).
1119
+
1120
+ - **[Time (text)]:**
1121
+ The time of the prompt segment.
1122
+
1123
+ - **[Calculate Timings (button)]:**
1124
+ Calculates the timings of the prompt segments.
1125
+
1126
+ - **[Duration (number)]:**
1127
+ How long you want the generated music to be (in seconds).
1128
+
1129
+ - **[Overlap (number)]:**
1130
+ How much each new segment will reference the previous segment (in seconds).
1131
+ For example, if you choose 20s: Each new segment after the first one will reference the previous segment 20s
1132
+ and will generate only 10s of new music. The model can only process 30s of music.
1133
+
1134
+ - **[Seed (number)]:**
1135
+ Your generated music id. If you wish to generate the exact same music,
1136
+ place the exact seed with the exact prompts
1137
+ (This way you can also extend specific song that was generated short).
1138
+
1139
+ - **[Random Seed (button)]:**
1140
+ Gives "-1" as a seed, which counts as a random seed.
1141
+
1142
+ - **[Copy Previous Seed (button)]:**
1143
+ Copies the seed from the output seed (if you don't feel like doing it manualy).
1144
+
1145
+ ---
1146
+
1147
+ ### Audio Tab:
1148
+
1149
+ - **[Input Type (selection)]:**
1150
+ `File` mode allows you to upload an audio file to use as input
1151
+ `Mic` mode allows you to use your microphone as input
1152
+
1153
+ - **[Input Audio Mode (selection)]:**
1154
+ `Melody` mode only works with the melody model: it conditions the music generation to reference the melody
1155
+ `Sample` mode works with any model: it gives a music sample to the model to generate its continuation.
1156
+
1157
+ - **[Trim Start and Trim End (numbers)]:**
1158
+ `Trim Start` set how much you'd like to trim the input audio from the start
1159
+ `Trim End` same as the above but from the end
1160
+
1161
+ - **[Input Audio (audio file)]:**
1162
+ Input here the audio you wish to use with "melody" or "sample" mode.
1163
+
1164
+ ---
1165
+
1166
+ ### Customization Tab:
1167
+
1168
+ - **[Background Color (color)]:**
1169
+ Works only if you don't upload image. Color of the background of the waveform.
1170
+
1171
+ - **[Bar Color Start (color)]:**
1172
+ First color of the waveform bars.
1173
+
1174
+ - **[Bar Color End (color)]:**
1175
+ Second color of the waveform bars.
1176
+
1177
+ - **[Background Image (image)]:**
1178
+ Background image that you wish to be attached to the generated video along with the waveform.
1179
+
1180
+ - **[Height and Width (numbers)]:**
1181
+ Output video resolution, only works with image.
1182
+ (minimum height and width is 256).
1183
+
1184
+ ---
1185
+
1186
+ ### Settings Tab:
1187
+
1188
+ - **[Output Audio Channels (selection)]:**
1189
+ With this you can select the amount of channels that you wish for your output audio.
1190
+ `mono` is a straightforward single channel audio
1191
+ `stereo` is a dual channel audio but it will sound more or less like mono
1192
+ `stereo effect` this one is also dual channel but uses tricks to simulate a stereo audio.
1193
+
1194
+ - **[Output Audio Sample Rate (dropdown)]:**
1195
+ The output audio sample rate, the model default is 32000.
1196
+
1197
+ - **[Model (selection)]:**
1198
+ Here you can choose which model you wish to use:
1199
+ `melody` model is based on the medium model with a unique feature that lets you use melody conditioning
1200
+ `small` model is trained on 300M parameters
1201
+ `medium` model is trained on 1.5B parameters
1202
+ `large` model is trained on 3.3B parameters
1203
+ `custom` model runs the custom model that you provided.
1204
+
1205
+ - **[Custom Model (selection)]:**
1206
+ This dropdown will show you models that are placed in the `models` folder
1207
+ you must select `custom` in the model options in order to use it.
1208
+
1209
+ - **[Refresh (button)]:**
1210
+ Refreshes the dropdown list for custom model.
1211
+
1212
+ - **[Base Model (selection)]:**
1213
+ Choose here the model that your custom model is based on.
1214
+
1215
+ - **[Decoder (selection)]:**
1216
+ Choose here the decoder that you wish to use:
1217
+ `Default` is the default decoder
1218
+ `MultiBand_Diffusion` is a decoder that uses diffusion to generate the audio.
1219
+
1220
+ - **[Top-k (number)]:**
1221
+ is a parameter used in text generation models, including music generation models. It determines the number of most likely next tokens to consider at each step of the generation process. The model ranks all possible tokens based on their predicted probabilities, and then selects the top-k tokens from the ranked list. The model then samples from this reduced set of tokens to determine the next token in the generated sequence. A smaller value of k results in a more focused and deterministic output, while a larger value of k allows for more diversity in the generated music.
1222
+
1223
+ - **[Top-p (number)]:**
1224
+ also known as nucleus sampling or probabilistic sampling, is another method used for token selection during text generation. Instead of specifying a fixed number like top-k, top-p considers the cumulative probability distribution of the ranked tokens. It selects the smallest possible set of tokens whose cumulative probability exceeds a certain threshold (usually denoted as p). The model then samples from this set to choose the next token. This approach ensures that the generated output maintains a balance between diversity and coherence, as it allows for a varying number of tokens to be considered based on their probabilities.
1225
+
1226
+ - **[Temperature (number)]:**
1227
+ is a parameter that controls the randomness of the generated output. It is applied during the sampling process, where a higher temperature value results in more random and diverse outputs, while a lower temperature value leads to more deterministic and focused outputs. In the context of music generation, a higher temperature can introduce more variability and creativity into the generated music, but it may also lead to less coherent or structured compositions. On the other hand, a lower temperature can produce more repetitive and predictable music.
1228
+
1229
+ - **[Classifier Free Guidance (number)]:**
1230
+ refers to a technique used in some music generation models where a separate classifier network is trained to provide guidance or control over the generated music. This classifier is trained on labeled data to recognize specific musical characteristics or styles. During the generation process, the output of the generator model is evaluated by the classifier, and the generator is encouraged to produce music that aligns with the desired characteristics or style. This approach allows for more fine-grained control over the generated music, enabling users to specify certain attributes they want the model to capture.
1231
+ """
1232
+ )
1233
+ with gr.Tab("AudioGen"):
1234
+ gr.Markdown(
1235
+ """
1236
+ ### AudioGen
1237
+ """
1238
+ )
1239
+ with gr.Row():
1240
+ with gr.Column():
1241
+ with gr.Tab("Generation"):
1242
+ with gr.Accordion("Structure Prompts", open=False):
1243
+ with gr.Row():
1244
+ struc_prompts_a = gr.Checkbox(label="Enable", value=False, interactive=True, container=False)
1245
+ global_prompt_a = gr.Text(label="Global Prompt", interactive=True, scale=3)
1246
+ with gr.Row():
1247
+ s_a = gr.Slider(1, max_textboxes, value=1, step=1, label="Prompts:", interactive=True, scale=2)
1248
+ with gr.Column():
1249
+ textboxes_a = []
1250
+ prompts_a = []
1251
+ repeats_a = []
1252
+ calcs_a = []
1253
+ with gr.Row():
1254
+ text0_a = gr.Text(label="Input Text", interactive=True, scale=4)
1255
+ prompts_a.append(text0_a)
1256
+ drag0_a = gr.Number(label="Repeat", value=1, interactive=True, scale=1)
1257
+ repeats_a.append(drag0_a)
1258
+ calc0_a = gr.Text(interactive=False, value="00:00 - 00:00", scale=1, label="Time")
1259
+ calcs_a.append(calc0_a)
1260
+ for i in range(max_textboxes):
1261
+ with gr.Row(visible=False) as t_a:
1262
+ text_a = gr.Text(label="Input Text", interactive=True, scale=3)
1263
+ repeat_a = gr.Number(label="Repeat", minimum=1, value=1, interactive=True, scale=1)
1264
+ calc_a = gr.Text(interactive=False, value="00:00 - 00:00", scale=1, label="Time")
1265
+ textboxes_a.append(t_a)
1266
+ prompts_a.append(text_a)
1267
+ repeats_a.append(repeat_a)
1268
+ calcs_a.append(calc_a)
1269
+ to_calc_a = gr.Button("Calculate Timings", variant="secondary")
1270
+ with gr.Row():
1271
+ duration_a = gr.Slider(minimum=1, maximum=300, value=10, step=1, label="Duration", interactive=True)
1272
+ with gr.Row():
1273
+ overlap_a = gr.Slider(minimum=1, maximum=9, value=2, step=1, label="Overlap", interactive=True)
1274
+ with gr.Row():
1275
+ seed_a = gr.Number(label="Seed", value=-1, scale=4, precision=0, interactive=True)
1276
+ gr.Button('\U0001f3b2\ufe0f', scale=1).click(fn=lambda: -1, outputs=[seed_a], queue=False)
1277
+ reuse_seed_a = gr.Button('\u267b\ufe0f', scale=1)
1278
+
1279
+ with gr.Tab("Audio"):
1280
+ with gr.Row():
1281
+ with gr.Column():
1282
+ input_type_a = gr.Radio(["file", "mic"], value="file", label="Input Type (optional)", interactive=True)
1283
+ mode_a = gr.Radio(["sample"], label="Input Audio Mode (optional)", value="sample", interactive=False, visible=False)
1284
+ with gr.Row():
1285
+ trim_start_a = gr.Number(label="Trim Start", value=0, interactive=True)
1286
+ trim_end_a = gr.Number(label="Trim End", value=0, interactive=True)
1287
+ audio_a = gr.Audio(source="upload", type="numpy", label="Input Audio (optional)", interactive=True)
1288
+
1289
+ with gr.Tab("Customization"):
1290
+ with gr.Row():
1291
+ with gr.Column():
1292
+ background_a = gr.ColorPicker(value="#0f0f0f", label="background color", interactive=True, scale=0)
1293
+ bar1_a = gr.ColorPicker(value="#84cc16", label="bar color start", interactive=True, scale=0)
1294
+ bar2_a = gr.ColorPicker(value="#10b981", label="bar color end", interactive=True, scale=0)
1295
+ with gr.Column():
1296
+ image_a = gr.Image(label="Background Image", type="filepath", interactive=True, scale=4)
1297
+ with gr.Row():
1298
+ height_a = gr.Number(label="Height", value=512, interactive=True)
1299
+ width_a = gr.Number(label="Width", value=768, interactive=True)
1300
+
1301
+ with gr.Tab("Settings"):
1302
+ with gr.Row():
1303
+ channel_a = gr.Radio(["mono", "stereo", "stereo effect"], label="Output Audio Channels", value="stereo", interactive=True, scale=1)
1304
+ sr_select_a = gr.Dropdown(["11025", "16000", "22050", "24000", "32000", "44100", "48000"], label="Output Audio Sample Rate", value="48000", interactive=True)
1305
+ with gr.Row():
1306
+ model_a = gr.Radio(["medium"], label="Model", value="medium", interactive=False, visible=False)
1307
+ decoder_a = gr.Radio(["Default"], label="Decoder", value="Default", interactive=False, visible=False)
1308
+ with gr.Row():
1309
+ topk_a = gr.Number(label="Top-k", value=250, interactive=True)
1310
+ topp_a = gr.Number(label="Top-p", value=0, interactive=True)
1311
+ temperature_a = gr.Number(label="Temperature", value=1.0, interactive=True)
1312
+ cfg_coef_a = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
1313
+ with gr.Row():
1314
+ submit_a = gr.Button("Generate", variant="primary")
1315
+ _ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
1316
+ with gr.Column():
1317
+ with gr.Tab("Output"):
1318
+ output_a = gr.Video(label="Generated Audio", scale=0)
1319
+ with gr.Row():
1320
+ audio_only_a = gr.Audio(type="numpy", label="Audio Only", interactive=False)
1321
+ backup_only_a = gr.Audio(type="numpy", label="Backup Audio", interactive=False, visible=False)
1322
+ send_audio_a = gr.Button("Send to Input Audio")
1323
+ seed_used_a = gr.Number(label='Seed used', value=-1, interactive=False)
1324
+ download_a = gr.File(label="Generated Files", interactive=False)
1325
+ with gr.Tab("Wiki"):
1326
+ gr.Markdown(
1327
+ """
1328
+ - **[Generate (button)]:**
1329
+ Generates the audio with the given settings and prompts.
1330
+
1331
+ - **[Interrupt (button)]:**
1332
+ Stops the audio generation as soon as it can, providing an incomplete output.
1333
+
1334
+ ---
1335
+
1336
+ ### Generation Tab:
1337
+
1338
+ #### Structure Prompts:
1339
+
1340
+ This feature helps reduce repetetive prompts by allowing you to set global prompts
1341
+ that will be used for all prompt segments.
1342
+
1343
+ - **[Structure Prompts (checkbox)]:**
1344
+ Enable/Disable the structure prompts feature.
1345
+
1346
+ - **[Global Prompt (text)]:**
1347
+ Here write the prompt that you wish to be used for all prompt segments.
1348
+
1349
+ #### Multi-Prompt:
1350
+
1351
+ This feature allows you to control the audio, adding variation to different time segments.
1352
+ You have up to 10 prompt segments. the first prompt will always be 10s long
1353
+ the other prompts will be [10s - overlap].
1354
+ for example if the overlap is 2s, each prompt segment will be 8s.
1355
+
1356
+ - **[Prompt Segments (number)]:**
1357
+ Amount of unique prompt to generate throughout the audio generation.
1358
+
1359
+ - **[Prompt/Input Text (prompt)]:**
1360
+ Here describe the audio you wish the model to generate.
1361
+
1362
+ - **[Repeat (number)]:**
1363
+ Write how many times this prompt will repeat (instead of wasting another prompt segment on the same prompt).
1364
+
1365
+ - **[Time (text)]:**
1366
+ The time of the prompt segment.
1367
+
1368
+ - **[Calculate Timings (button)]:**
1369
+ Calculates the timings of the prompt segments.
1370
+
1371
+ - **[Duration (number)]:**
1372
+ How long you want the generated audio to be (in seconds).
1373
+
1374
+ - **[Overlap (number)]:**
1375
+ How much each new segment will reference the previous segment (in seconds).
1376
+ For example, if you choose 2s: Each new segment after the first one will reference the previous segment 2s
1377
+ and will generate only 8s of new audio. The model can only process 10s of music.
1378
+
1379
+ - **[Seed (number)]:**
1380
+ Your generated audio id. If you wish to generate the exact same audio,
1381
+ place the exact seed with the exact prompts
1382
+ (This way you can also extend specific song that was generated short).
1383
+
1384
+ - **[Random Seed (button)]:**
1385
+ Gives "-1" as a seed, which counts as a random seed.
1386
+
1387
+ - **[Copy Previous Seed (button)]:**
1388
+ Copies the seed from the output seed (if you don't feel like doing it manualy).
1389
+
1390
+ ---
1391
+
1392
+ ### Audio Tab:
1393
+
1394
+ - **[Input Type (selection)]:**
1395
+ `File` mode allows you to upload an audio file to use as input
1396
+ `Mic` mode allows you to use your microphone as input
1397
+
1398
+ - **[Trim Start and Trim End (numbers)]:**
1399
+ `Trim Start` set how much you'd like to trim the input audio from the start
1400
+ `Trim End` same as the above but from the end
1401
+
1402
+ - **[Input Audio (audio file)]:**
1403
+ Input here the audio you wish to use.
1404
+
1405
+ ---
1406
+
1407
+ ### Customization Tab:
1408
+
1409
+ - **[Background Color (color)]:**
1410
+ Works only if you don't upload image. Color of the background of the waveform.
1411
+
1412
+ - **[Bar Color Start (color)]:**
1413
+ First color of the waveform bars.
1414
+
1415
+ - **[Bar Color End (color)]:**
1416
+ Second color of the waveform bars.
1417
+
1418
+ - **[Background Image (image)]:**
1419
+ Background image that you wish to be attached to the generated video along with the waveform.
1420
+
1421
+ - **[Height and Width (numbers)]:**
1422
+ Output video resolution, only works with image.
1423
+ (minimum height and width is 256).
1424
+
1425
+ ---
1426
+
1427
+ ### Settings Tab:
1428
+
1429
+ - **[Output Audio Channels (selection)]:**
1430
+ With this you can select the amount of channels that you wish for your output audio.
1431
+ `mono` is a straightforward single channel audio
1432
+ `stereo` is a dual channel audio but it will sound more or less like mono
1433
+ `stereo effect` this one is also dual channel but uses tricks to simulate a stereo audio.
1434
+
1435
+ - **[Output Audio Sample Rate (dropdown)]:**
1436
+ The output audio sample rate, the model default is 32000.
1437
+
1438
+ - **[Top-k (number)]:**
1439
+ is a parameter used in text generation models, including music generation models. It determines the number of most likely next tokens to consider at each step of the generation process. The model ranks all possible tokens based on their predicted probabilities, and then selects the top-k tokens from the ranked list. The model then samples from this reduced set of tokens to determine the next token in the generated sequence. A smaller value of k results in a more focused and deterministic output, while a larger value of k allows for more diversity in the generated music.
1440
+
1441
+ - **[Top-p (number)]:**
1442
+ also known as nucleus sampling or probabilistic sampling, is another method used for token selection during text generation. Instead of specifying a fixed number like top-k, top-p considers the cumulative probability distribution of the ranked tokens. It selects the smallest possible set of tokens whose cumulative probability exceeds a certain threshold (usually denoted as p). The model then samples from this set to choose the next token. This approach ensures that the generated output maintains a balance between diversity and coherence, as it allows for a varying number of tokens to be considered based on their probabilities.
1443
+
1444
+ - **[Temperature (number)]:**
1445
+ is a parameter that controls the randomness of the generated output. It is applied during the sampling process, where a higher temperature value results in more random and diverse outputs, while a lower temperature value leads to more deterministic and focused outputs. In the context of music generation, a higher temperature can introduce more variability and creativity into the generated music, but it may also lead to less coherent or structured compositions. On the other hand, a lower temperature can produce more repetitive and predictable music.
1446
+
1447
+ - **[Classifier Free Guidance (number)]:**
1448
+ refers to a technique used in some music generation models where a separate classifier network is trained to provide guidance or control over the generated music. This classifier is trained on labeled data to recognize specific musical characteristics or styles. During the generation process, the output of the generator model is evaluated by the classifier, and the generator is encouraged to produce music that aligns with the desired characteristics or style. This approach allows for more fine-grained control over the generated music, enabling users to specify certain attributes they want the model to capture.
1449
+ """
1450
+ )
1451
+ with gr.Tab("Audio Info"):
1452
+ gr.Markdown(
1453
+ """
1454
+ ### Audio Info
1455
+ """
1456
+ )
1457
+ with gr.Row():
1458
+ with gr.Column():
1459
+ in_audio = gr.File(type="file", label="Input Any Audio", interactive=True)
1460
+ with gr.Row():
1461
+ send_gen = gr.Button("Send to MusicGen", variant="primary")
1462
+ send_gen_a = gr.Button("Send to AudioGen", variant="primary")
1463
+ with gr.Column():
1464
+ info = gr.Textbox(label="Audio Info", lines=10, interactive=False)
1465
+ with gr.Tab("Changelog"):
1466
+ gr.Markdown(
1467
+ """
1468
+ ## Changelog:
1469
+
1470
+ ### v2.0.0a
1471
+
1472
+ - Forgot to move all the update to app.py from temp2.py... oops
1473
+
1474
+
1475
+
1476
+ ### v2.0.0
1477
+
1478
+ - Changed name from MusicGen+ to AudioCraft Plus
1479
+
1480
+ - Complete overhaul of the repo "backend" with the latest changes from the main facebookresearch repo
1481
+
1482
+ - Added a new decoder: MultiBand_Diffusion
1483
+
1484
+ - Added AudioGen: a new tab for generating audio
1485
+
1486
+
1487
+
1488
+ ### v1.2.8c
1489
+
1490
+ - Implemented Reverse compatibility for audio info tab with previous versions
1491
+
1492
+
1493
+
1494
+ ### v1.2.8b
1495
+
1496
+ - Fixed the error when loading default models
1497
+
1498
+
1499
+
1500
+ ### v1.2.8a
1501
+
1502
+ - Adapted Audio info tab to work with the new structure prompts feature
1503
+
1504
+ - Now custom models actually work, make sure you select the correct base model
1505
+
1506
+
1507
+
1508
+ ### v1.2.8
1509
+
1510
+ - Now you will also recieve json file with metadata of generated audio
1511
+
1512
+ - Added error messages in Audio Info tab
1513
+
1514
+ - Added structure prompts: you can select bpm, key and global prompt for all prompts
1515
+
1516
+ - Added time display next to each prompt, can be calculated with "Calculate Timings" button
1517
+
1518
+
1519
+
1520
+ ### v1.2.7
1521
+
1522
+ - When sending generated audio to Input Audio, it will send a backup audio with default settings
1523
+ (best for continuos generation)
1524
+
1525
+ - Added Metadata to generated audio (Thanks to AlexHK ♥)
1526
+
1527
+ - Added Audio Info tab that will display the metadata of the input audio
1528
+
1529
+ - Added "send to Text2Audio" button in Audio Info tab
1530
+
1531
+ - Generated audio is now stored in the "output" folder (Thanks to AlexHK ♥)
1532
+
1533
+ - Added an output area with generated files and download buttons
1534
+
1535
+ - Enhanced Stereo effect (Thanks to AlexHK ♥)
1536
+
1537
+
1538
+
1539
+ ### v1.2.6
1540
+
1541
+ - Added option to generate in stereo (instead of only mono)
1542
+
1543
+ - Added dropdown for selecting output sample rate (model default is 32000)
1544
+
1545
+
1546
+
1547
+ ### v1.2.5a
1548
+
1549
+ - Added file cleaner (This comes from the main facebookresearch repo)
1550
+
1551
+ - Reorganized a little, moved audio to a seperate tab
1552
+
1553
+
1554
+
1555
+ ### v1.2.5
1556
+
1557
+ - Gave a unique lime theme to the webui
1558
+
1559
+ - Added additional output for audio only
1560
+
1561
+ - Added button to send generated audio to Input Audio
1562
+
1563
+ - Added option to trim Input Audio
1564
+
1565
+
1566
+
1567
+ ### v1.2.4
1568
+
1569
+ - Added mic input (This comes from the main facebookresearch repo)
1570
+
1571
+
1572
+
1573
+ ### v1.2.3
1574
+
1575
+ - Added option to change video size to fit the image you upload
1576
+
1577
+
1578
+
1579
+ ### v1.2.2
1580
+
1581
+ - Added Wiki, Changelog and About tabs
1582
+
1583
+
1584
+
1585
+ ### v1.2.1
1586
+
1587
+ - Added tabs and organized the entire interface
1588
+
1589
+ - Added option to attach image to the output video
1590
+
1591
+ - Added option to load fine-tuned models (Yet to be tested)
1592
+
1593
+
1594
+
1595
+ ### v1.2.0
1596
+
1597
+ - Added Multi-Prompt
1598
+
1599
+
1600
+
1601
+ ### v1.1.3
1602
+
1603
+ - Added customization options for generated waveform
1604
+
1605
+
1606
+
1607
+ ### v1.1.2
1608
+
1609
+ - Removed sample length limit: now you can input audio of any length as music sample
1610
+
1611
+
1612
+
1613
+ ### v1.1.1
1614
+
1615
+ - Improved music sample audio quality when using music continuation
1616
+
1617
+
1618
+
1619
+ ### v1.1.0
1620
+
1621
+ - Rebuilt the repo on top of the latest structure of the main MusicGen repo
1622
+
1623
+ - Improved Music continuation feature
1624
+
1625
+
1626
+
1627
+ ### v1.0.0 - Stable Version
1628
+
1629
+ - Added Music continuation
1630
+ """
1631
+ )
1632
+ with gr.Tab("About"):
1633
+ gen_type = gr.Text(value="music", interactive=False, visible=False)
1634
+ gen_type_a = gr.Text(value="audio", interactive=False, visible=False)
1635
+ gr.Markdown(
1636
+ """
1637
+ This is your private demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
1638
+ presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
1639
+
1640
+ ## MusicGen+ is an extended version of the original MusicGen by facebookresearch.
1641
+
1642
+ ### Repo: https://github.com/GrandaddyShmax/audiocraft_plus/tree/plus
1643
+
1644
+ ---
1645
+
1646
+ ### This project was possible thanks to:
1647
+
1648
+ #### GrandaddyShmax - https://github.com/GrandaddyShmax
1649
+
1650
+ #### Camenduru - https://github.com/camenduru
1651
+
1652
+ #### rkfg - https://github.com/rkfg
1653
+
1654
+ #### oobabooga - https://github.com/oobabooga
1655
+
1656
+ #### AlexHK - https://github.com/alanhk147
1657
+ """
1658
+ )
1659
+
1660
+ send_gen.click(info_to_params, inputs=[in_audio], outputs=[decoder, struc_prompts, global_prompt, bpm, key, scale, model, dropdown, basemodel, s, prompts[0], prompts[1], prompts[2], prompts[3], prompts[4], prompts[5], prompts[6], prompts[7], prompts[8], prompts[9], repeats[0], repeats[1], repeats[2], repeats[3], repeats[4], repeats[5], repeats[6], repeats[7], repeats[8], repeats[9], mode, duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select], queue=False)
1661
+ reuse_seed.click(fn=lambda x: x, inputs=[seed_used], outputs=[seed], queue=False)
1662
+ send_audio.click(fn=lambda x: x, inputs=[backup_only], outputs=[audio], queue=False)
1663
+ submit.click(predict_full, inputs=[gen_type, model, decoder, dropdown, basemodel, s, struc_prompts, bpm, key, scale, global_prompt, prompts[0], prompts[1], prompts[2], prompts[3], prompts[4], prompts[5], prompts[6], prompts[7], prompts[8], prompts[9], repeats[0], repeats[1], repeats[2], repeats[3], repeats[4], repeats[5], repeats[6], repeats[7], repeats[8], repeats[9], audio, mode, trim_start, trim_end, duration, topk, topp, temperature, cfg_coef, seed, overlap, image, height, width, background, bar1, bar2, channel, sr_select], outputs=[output, audio_only, backup_only, download, seed_used])
1664
+ input_type.change(toggle_audio_src, input_type, [audio], queue=False, show_progress=False)
1665
+ to_calc.click(calc_time, inputs=[gen_type, s, duration, overlap, repeats[0], repeats[1], repeats[2], repeats[3], repeats[4], repeats[5], repeats[6], repeats[7], repeats[8], repeats[9]], outputs=[calcs[0], calcs[1], calcs[2], calcs[3], calcs[4], calcs[5], calcs[6], calcs[7], calcs[8], calcs[9]], queue=False)
1666
+
1667
+ send_gen_a.click(info_to_params_a, inputs=[in_audio], outputs=[decoder_a, struc_prompts_a, global_prompt_a, s_a, prompts_a[0], prompts_a[1], prompts_a[2], prompts_a[3], prompts_a[4], prompts_a[5], prompts_a[6], prompts_a[7], prompts_a[8], prompts_a[9], repeats_a[0], repeats_a[1], repeats_a[2], repeats_a[3], repeats_a[4], repeats_a[5], repeats_a[6], repeats_a[7], repeats_a[8], repeats_a[9], duration_a, topk_a, topp_a, temperature_a, cfg_coef_a, seed_a, overlap_a, channel_a, sr_select_a], queue=False)
1668
+ reuse_seed_a.click(fn=lambda x: x, inputs=[seed_used_a], outputs=[seed_a], queue=False)
1669
+ send_audio_a.click(fn=lambda x: x, inputs=[backup_only_a], outputs=[audio_a], queue=False)
1670
+ submit_a.click(predict_full, inputs=[gen_type_a, model_a, decoder_a, dropdown, basemodel, s_a, struc_prompts_a, bpm, key, scale, global_prompt_a, prompts_a[0], prompts_a[1], prompts_a[2], prompts_a[3], prompts_a[4], prompts_a[5], prompts_a[6], prompts_a[7], prompts_a[8], prompts_a[9], repeats_a[0], repeats_a[1], repeats_a[2], repeats_a[3], repeats_a[4], repeats_a[5], repeats_a[6], repeats_a[7], repeats_a[8], repeats_a[9], audio_a, mode_a, trim_start_a, trim_end_a, duration_a, topk_a, topp_a, temperature_a, cfg_coef_a, seed_a, overlap_a, image_a, height_a, width_a, background_a, bar1_a, bar2_a, channel_a, sr_select_a], outputs=[output_a, audio_only_a, backup_only_a, download_a, seed_used_a])
1671
+ input_type_a.change(toggle_audio_src, input_type_a, [audio_a], queue=False, show_progress=False)
1672
+ to_calc_a.click(calc_time, inputs=[gen_type_a, s_a, duration_a, overlap_a, repeats_a[0], repeats_a[1], repeats_a[2], repeats_a[3], repeats_a[4], repeats_a[5], repeats_a[6], repeats_a[7], repeats_a[8], repeats_a[9]], outputs=[calcs_a[0], calcs_a[1], calcs_a[2], calcs_a[3], calcs_a[4], calcs_a[5], calcs_a[6], calcs_a[7], calcs_a[8], calcs_a[9]], queue=False)
1673
+
1674
+ in_audio.change(get_audio_info, in_audio, outputs=[info])
1675
+
1676
+ def variable_outputs(k):
1677
+ k = int(k) - 1
1678
+ return [gr.Textbox.update(visible=True)]*k + [gr.Textbox.update(visible=False)]*(max_textboxes-k)
1679
+ def get_size(image):
1680
+ if image is not None:
1681
+ img = Image.open(image)
1682
+ img_height = img.height
1683
+ img_width = img.width
1684
+ if (img_height%2) != 0:
1685
+ img_height = img_height + 1
1686
+ if (img_width%2) != 0:
1687
+ img_width = img_width + 1
1688
+ return img_height, img_width
1689
+ else:
1690
+ return 512, 768
1691
+
1692
+ image.change(get_size, image, outputs=[height, width])
1693
+ image_a.change(get_size, image_a, outputs=[height_a, width_a])
1694
+ s.change(variable_outputs, s, textboxes)
1695
+ s_a.change(variable_outputs, s_a, textboxes_a)
1696
+ interface.queue().launch(**launch_kwargs)
1697
+
1698
+
1699
+ def ui_batched(launch_kwargs):
1700
+ with gr.Blocks() as demo:
1701
+ gr.Markdown(
1702
+ """
1703
+ # MusicGen
1704
+
1705
+ This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft),
1706
+ a simple and controllable model for music generation
1707
+ presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
1708
+ <br/>
1709
+ <a href="https://huggingface.co/spaces/facebook/MusicGen?duplicate=true"
1710
+ style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
1711
+ <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;"
1712
+ src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
1713
+ for longer sequences, more control and no queue.</p>
1714
+ """
1715
+ )
1716
+ with gr.Row():
1717
+ with gr.Column():
1718
+ with gr.Row():
1719
+ text = gr.Text(label="Describe your music", lines=2, interactive=True)
1720
+ with gr.Column():
1721
+ radio = gr.Radio(["file", "mic"], value="file",
1722
+ label="Condition on a melody (optional) File or Mic")
1723
+ melody = gr.Audio(source="upload", type="numpy", label="File",
1724
+ interactive=True, elem_id="melody-input")
1725
+ with gr.Row():
1726
+ submit = gr.Button("Generate")
1727
+ with gr.Column():
1728
+ output = gr.Video(label="Generated Music")
1729
+ audio_output = gr.Audio(label="Generated Music (wav)", type='filepath')
1730
+ submit.click(predict_batched, inputs=[text, melody],
1731
+ outputs=[output, audio_output], batch=True, max_batch_size=MAX_BATCH_SIZE)
1732
+ radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
1733
+ gr.Examples(
1734
+ fn=predict_batched,
1735
+ examples=[
1736
+ [
1737
+ "An 80s driving pop song with heavy drums and synth pads in the background",
1738
+ "./assets/bach.mp3",
1739
+ ],
1740
+ [
1741
+ "A cheerful country song with acoustic guitars",
1742
+ "./assets/bolero_ravel.mp3",
1743
+ ],
1744
+ [
1745
+ "90s rock song with electric guitar and heavy drums",
1746
+ None,
1747
+ ],
1748
+ [
1749
+ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
1750
+ "./assets/bach.mp3",
1751
+ ],
1752
+ [
1753
+ "lofi slow bpm electro chill with organic samples",
1754
+ None,
1755
+ ],
1756
+ ],
1757
+ inputs=[text, melody],
1758
+ outputs=[output]
1759
+ )
1760
+ gr.Markdown("""
1761
+ ### More details
1762
+
1763
+ The model will generate 12 seconds of audio based on the description you provided.
1764
+ You can optionally provide a reference audio from which a broad melody will be extracted.
1765
+ The model will then try to follow both the description and melody provided.
1766
+ All samples are generated with the `melody` model.
1767
+
1768
+ You can also use your own GPU or a Google Colab by following the instructions on our repo.
1769
+
1770
+ See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
1771
+ for more details.
1772
+ """)
1773
+
1774
+ demo.queue(max_size=8 * 4).launch(**launch_kwargs)
1775
+
1776
+
1777
+ if __name__ == "__main__":
1778
+ parser = argparse.ArgumentParser()
1779
+ parser.add_argument(
1780
+ '--listen',
1781
+ type=str,
1782
+ default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1',
1783
+ help='IP to listen on for connections to Gradio',
1784
+ )
1785
+ parser.add_argument(
1786
+ '--username', type=str, default='', help='Username for authentication'
1787
+ )
1788
+ parser.add_argument(
1789
+ '--password', type=str, default='', help='Password for authentication'
1790
+ )
1791
+ parser.add_argument(
1792
+ '--server_port',
1793
+ type=int,
1794
+ default=0,
1795
+ help='Port to run the server listener on',
1796
+ )
1797
+ parser.add_argument(
1798
+ '--inbrowser', action='store_true', help='Open in browser'
1799
+ )
1800
+ parser.add_argument(
1801
+ '--share', action='store_true', help='Share the gradio UI'
1802
+ )
1803
+ parser.add_argument(
1804
+ '--unload_model', action='store_true', help='Unload the model after every generation to save GPU memory'
1805
+ )
1806
+
1807
+ parser.add_argument(
1808
+ '--unload_to_cpu', action='store_true', help='Move the model to main RAM after every generation to save GPU memory but reload faster than after full unload (see above)'
1809
+ )
1810
+
1811
+ parser.add_argument(
1812
+ '--cache', action='store_true', help='Cache models in RAM to quickly switch between them'
1813
+ )
1814
+
1815
+ args = parser.parse_args()
1816
+ UNLOAD_MODEL = args.unload_model
1817
+ MOVE_TO_CPU = args.unload_to_cpu
1818
+ if args.cache:
1819
+ MODELS = {}
1820
+
1821
+ launch_kwargs = {}
1822
+ launch_kwargs['server_name'] = args.listen
1823
+
1824
+ if args.username and args.password:
1825
+ launch_kwargs['auth'] = (args.username, args.password)
1826
+ if args.server_port:
1827
+ launch_kwargs['server_port'] = args.server_port
1828
+ if args.inbrowser:
1829
+ launch_kwargs['inbrowser'] = args.inbrowser
1830
+ if args.share:
1831
+ launch_kwargs['share'] = args.share
1832
+
1833
+ # Show the interface
1834
+ if IS_BATCHED:
1835
+ global USE_DIFFUSION
1836
+ USE_DIFFUSION = False
1837
+ ui_batched(launch_kwargs)
1838
+ else:
1839
+ ui_full(launch_kwargs)
assets/a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3 ADDED
Binary file (15.2 kB). View file
 
assets/bach.mp3 ADDED
Binary file (160 kB). View file
 
assets/bolero_ravel.mp3 ADDED
Binary file (161 kB). View file
 
assets/sirens_and_a_humming_engine_approach_and_pass.mp3 ADDED
Binary file (15.2 kB). View file
 
audiocraft/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ AudioCraft is a general framework for training audio generative models.
8
+ At the moment we provide the training code for:
9
+
10
+ - [MusicGen](https://arxiv.org/abs/2306.05284), a state-of-the-art
11
+ text-to-music and melody+text autoregressive generative model.
12
+ For the solver, see `audiocraft.solvers.musicgen.MusicGenSolver`, and for the model,
13
+ `audiocraft.models.musicgen.MusicGen`.
14
+ - [AudioGen](https://arxiv.org/abs/2209.15352), a state-of-the-art
15
+ text-to-general-audio generative model.
16
+ - [EnCodec](https://arxiv.org/abs/2210.13438), efficient and high fidelity
17
+ neural audio codec which provides an excellent tokenizer for autoregressive language models.
18
+ See `audiocraft.solvers.compression.CompressionSolver`, and `audiocraft.models.encodec.EncodecModel`.
19
+ - [MultiBandDiffusion](TODO), alternative diffusion-based decoder compatible with EnCodec that
20
+ improves the perceived quality and reduces the artifacts coming from adversarial decoders.
21
+ """
22
+
23
+ # flake8: noqa
24
+ from . import data, modules, models
25
+
26
+ __version__ = '1.0.0'
audiocraft/adversarial/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Adversarial losses and discriminator architectures."""
7
+
8
+ # flake8: noqa
9
+ from .discriminators import (
10
+ MultiPeriodDiscriminator,
11
+ MultiScaleDiscriminator,
12
+ MultiScaleSTFTDiscriminator
13
+ )
14
+ from .losses import (
15
+ AdversarialLoss,
16
+ AdvLossType,
17
+ get_adv_criterion,
18
+ get_fake_criterion,
19
+ get_real_criterion,
20
+ FeatLossType,
21
+ FeatureMatchingLoss
22
+ )
audiocraft/adversarial/discriminators/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # flake8: noqa
8
+ from .mpd import MultiPeriodDiscriminator
9
+ from .msd import MultiScaleDiscriminator
10
+ from .msstftd import MultiScaleSTFTDiscriminator
audiocraft/adversarial/discriminators/base.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from abc import ABC, abstractmethod
8
+ import typing as tp
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+
14
+ FeatureMapType = tp.List[torch.Tensor]
15
+ LogitsType = torch.Tensor
16
+ MultiDiscriminatorOutputType = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]
17
+
18
+
19
+ class MultiDiscriminator(ABC, nn.Module):
20
+ """Base implementation for discriminators composed of sub-discriminators acting at different scales.
21
+ """
22
+ def __init__(self):
23
+ super().__init__()
24
+
25
+ @abstractmethod
26
+ def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
27
+ ...
28
+
29
+ @property
30
+ @abstractmethod
31
+ def num_discriminators(self) -> int:
32
+ """Number of discriminators.
33
+ """
34
+ ...
audiocraft/adversarial/discriminators/mpd.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from ...modules import NormConv2d
14
+ from .base import MultiDiscriminator, MultiDiscriminatorOutputType
15
+
16
+
17
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
18
+ return int((kernel_size * dilation - dilation) / 2)
19
+
20
+
21
+ class PeriodDiscriminator(nn.Module):
22
+ """Period sub-discriminator.
23
+
24
+ Args:
25
+ period (int): Period between samples of audio.
26
+ in_channels (int): Number of input channels.
27
+ out_channels (int): Number of output channels.
28
+ n_layers (int): Number of convolutional layers.
29
+ kernel_sizes (list of int): Kernel sizes for convolutions.
30
+ stride (int): Stride for convolutions.
31
+ filters (int): Initial number of filters in convolutions.
32
+ filters_scale (int): Multiplier of number of filters as we increase depth.
33
+ max_filters (int): Maximum number of filters.
34
+ norm (str): Normalization method.
35
+ activation (str): Activation function.
36
+ activation_params (dict): Parameters to provide to the activation function.
37
+ """
38
+ def __init__(self, period: int, in_channels: int = 1, out_channels: int = 1,
39
+ n_layers: int = 5, kernel_sizes: tp.List[int] = [5, 3], stride: int = 3,
40
+ filters: int = 8, filters_scale: int = 4, max_filters: int = 1024,
41
+ norm: str = 'weight_norm', activation: str = 'LeakyReLU',
42
+ activation_params: dict = {'negative_slope': 0.2}):
43
+ super().__init__()
44
+ self.period = period
45
+ self.n_layers = n_layers
46
+ self.activation = getattr(torch.nn, activation)(**activation_params)
47
+ self.convs = nn.ModuleList()
48
+ in_chs = in_channels
49
+ for i in range(self.n_layers):
50
+ out_chs = min(filters * (filters_scale ** (i + 1)), max_filters)
51
+ eff_stride = 1 if i == self.n_layers - 1 else stride
52
+ self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_sizes[0], 1), stride=(eff_stride, 1),
53
+ padding=((kernel_sizes[0] - 1) // 2, 0), norm=norm))
54
+ in_chs = out_chs
55
+ self.conv_post = NormConv2d(in_chs, out_channels, kernel_size=(kernel_sizes[1], 1), stride=1,
56
+ padding=((kernel_sizes[1] - 1) // 2, 0), norm=norm)
57
+
58
+ def forward(self, x: torch.Tensor):
59
+ fmap = []
60
+ # 1d to 2d
61
+ b, c, t = x.shape
62
+ if t % self.period != 0: # pad first
63
+ n_pad = self.period - (t % self.period)
64
+ x = F.pad(x, (0, n_pad), 'reflect')
65
+ t = t + n_pad
66
+ x = x.view(b, c, t // self.period, self.period)
67
+
68
+ for conv in self.convs:
69
+ x = conv(x)
70
+ x = self.activation(x)
71
+ fmap.append(x)
72
+ x = self.conv_post(x)
73
+ fmap.append(x)
74
+ # x = torch.flatten(x, 1, -1)
75
+
76
+ return x, fmap
77
+
78
+
79
+ class MultiPeriodDiscriminator(MultiDiscriminator):
80
+ """Multi-Period (MPD) Discriminator.
81
+
82
+ Args:
83
+ in_channels (int): Number of input channels.
84
+ out_channels (int): Number of output channels.
85
+ periods (Sequence[int]): Periods between samples of audio for the sub-discriminators.
86
+ **kwargs: Additional args for `PeriodDiscriminator`
87
+ """
88
+ def __init__(self, in_channels: int = 1, out_channels: int = 1,
89
+ periods: tp.Sequence[int] = [2, 3, 5, 7, 11], **kwargs):
90
+ super().__init__()
91
+ self.discriminators = nn.ModuleList([
92
+ PeriodDiscriminator(p, in_channels, out_channels, **kwargs) for p in periods
93
+ ])
94
+
95
+ @property
96
+ def num_discriminators(self):
97
+ return len(self.discriminators)
98
+
99
+ def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
100
+ logits = []
101
+ fmaps = []
102
+ for disc in self.discriminators:
103
+ logit, fmap = disc(x)
104
+ logits.append(logit)
105
+ fmaps.append(fmap)
106
+ return logits, fmaps
audiocraft/adversarial/discriminators/msd.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from ...modules import NormConv1d
14
+ from .base import MultiDiscriminator, MultiDiscriminatorOutputType
15
+
16
+
17
+ class ScaleDiscriminator(nn.Module):
18
+ """Waveform sub-discriminator.
19
+
20
+ Args:
21
+ in_channels (int): Number of input channels.
22
+ out_channels (int): Number of output channels.
23
+ kernel_sizes (Sequence[int]): Kernel sizes for first and last convolutions.
24
+ filters (int): Number of initial filters for convolutions.
25
+ max_filters (int): Maximum number of filters.
26
+ downsample_scales (Sequence[int]): Scale for downsampling implemented as strided convolutions.
27
+ inner_kernel_sizes (Sequence[int] or None): Kernel sizes for inner convolutions.
28
+ groups (Sequence[int] or None): Groups for inner convolutions.
29
+ strides (Sequence[int] or None): Strides for inner convolutions.
30
+ paddings (Sequence[int] or None): Paddings for inner convolutions.
31
+ norm (str): Normalization method.
32
+ activation (str): Activation function.
33
+ activation_params (dict): Parameters to provide to the activation function.
34
+ pad (str): Padding for initial convolution.
35
+ pad_params (dict): Parameters to provide to the padding module.
36
+ """
37
+ def __init__(self, in_channels=1, out_channels=1, kernel_sizes: tp.Sequence[int] = [5, 3],
38
+ filters: int = 16, max_filters: int = 1024, downsample_scales: tp.Sequence[int] = [4, 4, 4, 4],
39
+ inner_kernel_sizes: tp.Optional[tp.Sequence[int]] = None, groups: tp.Optional[tp.Sequence[int]] = None,
40
+ strides: tp.Optional[tp.Sequence[int]] = None, paddings: tp.Optional[tp.Sequence[int]] = None,
41
+ norm: str = 'weight_norm', activation: str = 'LeakyReLU',
42
+ activation_params: dict = {'negative_slope': 0.2}, pad: str = 'ReflectionPad1d',
43
+ pad_params: dict = {}):
44
+ super().__init__()
45
+ assert len(kernel_sizes) == 2
46
+ assert kernel_sizes[0] % 2 == 1
47
+ assert kernel_sizes[1] % 2 == 1
48
+ assert (inner_kernel_sizes is None or len(inner_kernel_sizes) == len(downsample_scales))
49
+ assert (groups is None or len(groups) == len(downsample_scales))
50
+ assert (strides is None or len(strides) == len(downsample_scales))
51
+ assert (paddings is None or len(paddings) == len(downsample_scales))
52
+ self.activation = getattr(torch.nn, activation)(**activation_params)
53
+ self.convs = nn.ModuleList()
54
+ self.convs.append(
55
+ nn.Sequential(
56
+ getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params),
57
+ NormConv1d(in_channels, filters, kernel_size=np.prod(kernel_sizes), stride=1, norm=norm)
58
+ )
59
+ )
60
+
61
+ in_chs = filters
62
+ for i, downsample_scale in enumerate(downsample_scales):
63
+ out_chs = min(in_chs * downsample_scale, max_filters)
64
+ default_kernel_size = downsample_scale * 10 + 1
65
+ default_stride = downsample_scale
66
+ default_padding = (default_kernel_size - 1) // 2
67
+ default_groups = in_chs // 4
68
+ self.convs.append(
69
+ NormConv1d(in_chs, out_chs,
70
+ kernel_size=inner_kernel_sizes[i] if inner_kernel_sizes else default_kernel_size,
71
+ stride=strides[i] if strides else default_stride,
72
+ groups=groups[i] if groups else default_groups,
73
+ padding=paddings[i] if paddings else default_padding,
74
+ norm=norm))
75
+ in_chs = out_chs
76
+
77
+ out_chs = min(in_chs * 2, max_filters)
78
+ self.convs.append(NormConv1d(in_chs, out_chs, kernel_size=kernel_sizes[0], stride=1,
79
+ padding=(kernel_sizes[0] - 1) // 2, norm=norm))
80
+ self.conv_post = NormConv1d(out_chs, out_channels, kernel_size=kernel_sizes[1], stride=1,
81
+ padding=(kernel_sizes[1] - 1) // 2, norm=norm)
82
+
83
+ def forward(self, x: torch.Tensor):
84
+ fmap = []
85
+ for layer in self.convs:
86
+ x = layer(x)
87
+ x = self.activation(x)
88
+ fmap.append(x)
89
+ x = self.conv_post(x)
90
+ fmap.append(x)
91
+ # x = torch.flatten(x, 1, -1)
92
+ return x, fmap
93
+
94
+
95
+ class MultiScaleDiscriminator(MultiDiscriminator):
96
+ """Multi-Scale (MSD) Discriminator,
97
+
98
+ Args:
99
+ in_channels (int): Number of input channels.
100
+ out_channels (int): Number of output channels.
101
+ downsample_factor (int): Downsampling factor between the different scales.
102
+ scale_norms (Sequence[str]): Normalization for each sub-discriminator.
103
+ **kwargs: Additional args for ScaleDiscriminator.
104
+ """
105
+ def __init__(self, in_channels: int = 1, out_channels: int = 1, downsample_factor: int = 2,
106
+ scale_norms: tp.Sequence[str] = ['weight_norm', 'weight_norm', 'weight_norm'], **kwargs):
107
+ super().__init__()
108
+ self.discriminators = nn.ModuleList([
109
+ ScaleDiscriminator(in_channels, out_channels, norm=norm, **kwargs) for norm in scale_norms
110
+ ])
111
+ self.downsample = nn.AvgPool1d(downsample_factor * 2, downsample_factor, padding=downsample_factor)
112
+
113
+ @property
114
+ def num_discriminators(self):
115
+ return len(self.discriminators)
116
+
117
+ def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
118
+ logits = []
119
+ fmaps = []
120
+ for i, disc in enumerate(self.discriminators):
121
+ if i != 0:
122
+ self.downsample(x)
123
+ logit, fmap = disc(x)
124
+ logits.append(logit)
125
+ fmaps.append(fmap)
126
+ return logits, fmaps
audiocraft/adversarial/discriminators/msstftd.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ import torchaudio
10
+ import torch
11
+ from torch import nn
12
+ from einops import rearrange
13
+
14
+ from ...modules import NormConv2d
15
+ from .base import MultiDiscriminator, MultiDiscriminatorOutputType
16
+
17
+
18
+ def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)):
19
+ return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2)
20
+
21
+
22
+ class DiscriminatorSTFT(nn.Module):
23
+ """STFT sub-discriminator.
24
+
25
+ Args:
26
+ filters (int): Number of filters in convolutions.
27
+ in_channels (int): Number of input channels.
28
+ out_channels (int): Number of output channels.
29
+ n_fft (int): Size of FFT for each scale.
30
+ hop_length (int): Length of hop between STFT windows for each scale.
31
+ kernel_size (tuple of int): Inner Conv2d kernel sizes.
32
+ stride (tuple of int): Inner Conv2d strides.
33
+ dilations (list of int): Inner Conv2d dilation on the time dimension.
34
+ win_length (int): Window size for each scale.
35
+ normalized (bool): Whether to normalize by magnitude after stft.
36
+ norm (str): Normalization method.
37
+ activation (str): Activation function.
38
+ activation_params (dict): Parameters to provide to the activation function.
39
+ growth (int): Growth factor for the filters.
40
+ """
41
+ def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1,
42
+ n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024,
43
+ filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4],
44
+ stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm',
45
+ activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}):
46
+ super().__init__()
47
+ assert len(kernel_size) == 2
48
+ assert len(stride) == 2
49
+ self.filters = filters
50
+ self.in_channels = in_channels
51
+ self.out_channels = out_channels
52
+ self.n_fft = n_fft
53
+ self.hop_length = hop_length
54
+ self.win_length = win_length
55
+ self.normalized = normalized
56
+ self.activation = getattr(torch.nn, activation)(**activation_params)
57
+ self.spec_transform = torchaudio.transforms.Spectrogram(
58
+ n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window,
59
+ normalized=self.normalized, center=False, pad_mode=None, power=None)
60
+ spec_channels = 2 * self.in_channels
61
+ self.convs = nn.ModuleList()
62
+ self.convs.append(
63
+ NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size))
64
+ )
65
+ in_chs = min(filters_scale * self.filters, max_filters)
66
+ for i, dilation in enumerate(dilations):
67
+ out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
68
+ self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride,
69
+ dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)),
70
+ norm=norm))
71
+ in_chs = out_chs
72
+ out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters)
73
+ self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]),
74
+ padding=get_2d_padding((kernel_size[0], kernel_size[0])),
75
+ norm=norm))
76
+ self.conv_post = NormConv2d(out_chs, self.out_channels,
77
+ kernel_size=(kernel_size[0], kernel_size[0]),
78
+ padding=get_2d_padding((kernel_size[0], kernel_size[0])),
79
+ norm=norm)
80
+
81
+ def forward(self, x: torch.Tensor):
82
+ fmap = []
83
+ z = self.spec_transform(x) # [B, 2, Freq, Frames, 2]
84
+ z = torch.cat([z.real, z.imag], dim=1)
85
+ z = rearrange(z, 'b c w t -> b c t w')
86
+ for i, layer in enumerate(self.convs):
87
+ z = layer(z)
88
+ z = self.activation(z)
89
+ fmap.append(z)
90
+ z = self.conv_post(z)
91
+ return z, fmap
92
+
93
+
94
+ class MultiScaleSTFTDiscriminator(MultiDiscriminator):
95
+ """Multi-Scale STFT (MS-STFT) discriminator.
96
+
97
+ Args:
98
+ filters (int): Number of filters in convolutions.
99
+ in_channels (int): Number of input channels.
100
+ out_channels (int): Number of output channels.
101
+ sep_channels (bool): Separate channels to distinct samples for stereo support.
102
+ n_ffts (Sequence[int]): Size of FFT for each scale.
103
+ hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale.
104
+ win_lengths (Sequence[int]): Window size for each scale.
105
+ **kwargs: Additional args for STFTDiscriminator.
106
+ """
107
+ def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, sep_channels: bool = False,
108
+ n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128],
109
+ win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs):
110
+ super().__init__()
111
+ assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
112
+ self.sep_channels = sep_channels
113
+ self.discriminators = nn.ModuleList([
114
+ DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels,
115
+ n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs)
116
+ for i in range(len(n_ffts))
117
+ ])
118
+
119
+ @property
120
+ def num_discriminators(self):
121
+ return len(self.discriminators)
122
+
123
+ def _separate_channels(self, x: torch.Tensor) -> torch.Tensor:
124
+ B, C, T = x.shape
125
+ return x.view(-1, 1, T)
126
+
127
+ def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
128
+ logits = []
129
+ fmaps = []
130
+ for disc in self.discriminators:
131
+ logit, fmap = disc(x)
132
+ logits.append(logit)
133
+ fmaps.append(fmap)
134
+ return logits, fmaps
audiocraft/adversarial/losses.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Utility module to handle adversarial losses without requiring to mess up the main training loop.
9
+ """
10
+
11
+ import typing as tp
12
+
13
+ import flashy
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+
19
+ ADVERSARIAL_LOSSES = ['mse', 'hinge', 'hinge2']
20
+
21
+
22
+ AdvLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor], torch.Tensor]]
23
+ FeatLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]
24
+
25
+
26
+ class AdversarialLoss(nn.Module):
27
+ """Adversary training wrapper.
28
+
29
+ Args:
30
+ adversary (nn.Module): The adversary module will be used to estimate the logits given the fake and real samples.
31
+ We assume here the adversary output is ``Tuple[List[torch.Tensor], List[List[torch.Tensor]]]``
32
+ where the first item is a list of logits and the second item is a list of feature maps.
33
+ optimizer (torch.optim.Optimizer): Optimizer used for training the given module.
34
+ loss (AdvLossType): Loss function for generator training.
35
+ loss_real (AdvLossType): Loss function for adversarial training on logits from real samples.
36
+ loss_fake (AdvLossType): Loss function for adversarial training on logits from fake samples.
37
+ loss_feat (FeatLossType): Feature matching loss function for generator training.
38
+ normalize (bool): Whether to normalize by number of sub-discriminators.
39
+
40
+ Example of usage:
41
+ adv_loss = AdversarialLoss(adversaries, optimizer, loss, loss_real, loss_fake)
42
+ for real in loader:
43
+ noise = torch.randn(...)
44
+ fake = model(noise)
45
+ adv_loss.train_adv(fake, real)
46
+ loss, _ = adv_loss(fake, real)
47
+ loss.backward()
48
+ """
49
+ def __init__(self,
50
+ adversary: nn.Module,
51
+ optimizer: torch.optim.Optimizer,
52
+ loss: AdvLossType,
53
+ loss_real: AdvLossType,
54
+ loss_fake: AdvLossType,
55
+ loss_feat: tp.Optional[FeatLossType] = None,
56
+ normalize: bool = True):
57
+ super().__init__()
58
+ self.adversary: nn.Module = adversary
59
+ flashy.distrib.broadcast_model(self.adversary)
60
+ self.optimizer = optimizer
61
+ self.loss = loss
62
+ self.loss_real = loss_real
63
+ self.loss_fake = loss_fake
64
+ self.loss_feat = loss_feat
65
+ self.normalize = normalize
66
+
67
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
68
+ # Add the optimizer state dict inside our own.
69
+ super()._save_to_state_dict(destination, prefix, keep_vars)
70
+ destination[prefix + 'optimizer'] = self.optimizer.state_dict()
71
+ return destination
72
+
73
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
74
+ # Load optimizer state.
75
+ self.optimizer.load_state_dict(state_dict.pop(prefix + 'optimizer'))
76
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
77
+
78
+ def get_adversary_pred(self, x):
79
+ """Run adversary model, validating expected output format."""
80
+ logits, fmaps = self.adversary(x)
81
+ assert isinstance(logits, list) and all([isinstance(t, torch.Tensor) for t in logits]), \
82
+ f'Expecting a list of tensors as logits but {type(logits)} found.'
83
+ assert isinstance(fmaps, list), f'Expecting a list of features maps but {type(fmaps)} found.'
84
+ for fmap in fmaps:
85
+ assert isinstance(fmap, list) and all([isinstance(f, torch.Tensor) for f in fmap]), \
86
+ f'Expecting a list of tensors as feature maps but {type(fmap)} found.'
87
+ return logits, fmaps
88
+
89
+ def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
90
+ """Train the adversary with the given fake and real example.
91
+
92
+ We assume the adversary output is the following format: Tuple[List[torch.Tensor], List[List[torch.Tensor]]].
93
+ The first item being the logits and second item being a list of feature maps for each sub-discriminator.
94
+
95
+ This will automatically synchronize gradients (with `flashy.distrib.eager_sync_model`)
96
+ and call the optimizer.
97
+ """
98
+ loss = torch.tensor(0., device=fake.device)
99
+ all_logits_fake_is_fake, _ = self.get_adversary_pred(fake.detach())
100
+ all_logits_real_is_fake, _ = self.get_adversary_pred(real.detach())
101
+ n_sub_adversaries = len(all_logits_fake_is_fake)
102
+ for logit_fake_is_fake, logit_real_is_fake in zip(all_logits_fake_is_fake, all_logits_real_is_fake):
103
+ loss += self.loss_fake(logit_fake_is_fake) + self.loss_real(logit_real_is_fake)
104
+
105
+ if self.normalize:
106
+ loss /= n_sub_adversaries
107
+
108
+ self.optimizer.zero_grad()
109
+ with flashy.distrib.eager_sync_model(self.adversary):
110
+ loss.backward()
111
+ self.optimizer.step()
112
+
113
+ return loss
114
+
115
+ def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
116
+ """Return the loss for the generator, i.e. trying to fool the adversary,
117
+ and feature matching loss if provided.
118
+ """
119
+ adv = torch.tensor(0., device=fake.device)
120
+ feat = torch.tensor(0., device=fake.device)
121
+ with flashy.utils.readonly(self.adversary):
122
+ all_logits_fake_is_fake, all_fmap_fake = self.get_adversary_pred(fake)
123
+ all_logits_real_is_fake, all_fmap_real = self.get_adversary_pred(real)
124
+ n_sub_adversaries = len(all_logits_fake_is_fake)
125
+ for logit_fake_is_fake in all_logits_fake_is_fake:
126
+ adv += self.loss(logit_fake_is_fake)
127
+ if self.loss_feat:
128
+ for fmap_fake, fmap_real in zip(all_fmap_fake, all_fmap_real):
129
+ feat += self.loss_feat(fmap_fake, fmap_real)
130
+
131
+ if self.normalize:
132
+ adv /= n_sub_adversaries
133
+ feat /= n_sub_adversaries
134
+
135
+ return adv, feat
136
+
137
+
138
+ def get_adv_criterion(loss_type: str) -> tp.Callable:
139
+ assert loss_type in ADVERSARIAL_LOSSES
140
+ if loss_type == 'mse':
141
+ return mse_loss
142
+ elif loss_type == 'hinge':
143
+ return hinge_loss
144
+ elif loss_type == 'hinge2':
145
+ return hinge2_loss
146
+ raise ValueError('Unsupported loss')
147
+
148
+
149
+ def get_fake_criterion(loss_type: str) -> tp.Callable:
150
+ assert loss_type in ADVERSARIAL_LOSSES
151
+ if loss_type == 'mse':
152
+ return mse_fake_loss
153
+ elif loss_type in ['hinge', 'hinge2']:
154
+ return hinge_fake_loss
155
+ raise ValueError('Unsupported loss')
156
+
157
+
158
+ def get_real_criterion(loss_type: str) -> tp.Callable:
159
+ assert loss_type in ADVERSARIAL_LOSSES
160
+ if loss_type == 'mse':
161
+ return mse_real_loss
162
+ elif loss_type in ['hinge', 'hinge2']:
163
+ return hinge_real_loss
164
+ raise ValueError('Unsupported loss')
165
+
166
+
167
+ def mse_real_loss(x: torch.Tensor) -> torch.Tensor:
168
+ return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
169
+
170
+
171
+ def mse_fake_loss(x: torch.Tensor) -> torch.Tensor:
172
+ return F.mse_loss(x, torch.tensor(0., device=x.device).expand_as(x))
173
+
174
+
175
+ def hinge_real_loss(x: torch.Tensor) -> torch.Tensor:
176
+ return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
177
+
178
+
179
+ def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor:
180
+ return -torch.mean(torch.min(-x - 1, torch.tensor(0., device=x.device).expand_as(x)))
181
+
182
+
183
+ def mse_loss(x: torch.Tensor) -> torch.Tensor:
184
+ if x.numel() == 0:
185
+ return torch.tensor([0.0], device=x.device)
186
+ return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
187
+
188
+
189
+ def hinge_loss(x: torch.Tensor) -> torch.Tensor:
190
+ if x.numel() == 0:
191
+ return torch.tensor([0.0], device=x.device)
192
+ return -x.mean()
193
+
194
+
195
+ def hinge2_loss(x: torch.Tensor) -> torch.Tensor:
196
+ if x.numel() == 0:
197
+ return torch.tensor([0.0])
198
+ return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
199
+
200
+
201
+ class FeatureMatchingLoss(nn.Module):
202
+ """Feature matching loss for adversarial training.
203
+
204
+ Args:
205
+ loss (nn.Module): Loss to use for feature matching (default=torch.nn.L1).
206
+ normalize (bool): Whether to normalize the loss.
207
+ by number of feature maps.
208
+ """
209
+ def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: bool = True):
210
+ super().__init__()
211
+ self.loss = loss
212
+ self.normalize = normalize
213
+
214
+ def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List[torch.Tensor]) -> torch.Tensor:
215
+ assert len(fmap_fake) == len(fmap_real) and len(fmap_fake) > 0
216
+ feat_loss = torch.tensor(0., device=fmap_fake[0].device)
217
+ feat_scale = torch.tensor(0., device=fmap_fake[0].device)
218
+ n_fmaps = 0
219
+ for (feat_fake, feat_real) in zip(fmap_fake, fmap_real):
220
+ assert feat_fake.shape == feat_real.shape
221
+ n_fmaps += 1
222
+ feat_loss += self.loss(feat_fake, feat_real)
223
+ feat_scale += torch.mean(torch.abs(feat_real))
224
+
225
+ if self.normalize:
226
+ feat_loss /= n_fmaps
227
+
228
+ return feat_loss
audiocraft/data/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Audio loading and writing support. Datasets for raw audio
7
+ or also including some metadata."""
8
+
9
+ # flake8: noqa
10
+ from . import audio, audio_dataset, info_audio_dataset, music_dataset, sound_dataset
audiocraft/data/audio.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Audio IO methods are defined in this module (info, read, write),
9
+ We rely on av library for faster read when possible, otherwise on torchaudio.
10
+ """
11
+
12
+ from dataclasses import dataclass
13
+ from pathlib import Path
14
+ import logging
15
+ import typing as tp
16
+
17
+ import numpy as np
18
+ import soundfile
19
+ import torch
20
+ from torch.nn import functional as F
21
+ import torchaudio as ta
22
+
23
+ import av
24
+
25
+ from .audio_utils import f32_pcm, i16_pcm, normalize_audio
26
+
27
+
28
+ _av_initialized = False
29
+
30
+
31
+ def _init_av():
32
+ global _av_initialized
33
+ if _av_initialized:
34
+ return
35
+ logger = logging.getLogger('libav.mp3')
36
+ logger.setLevel(logging.ERROR)
37
+ _av_initialized = True
38
+
39
+
40
+ @dataclass(frozen=True)
41
+ class AudioFileInfo:
42
+ sample_rate: int
43
+ duration: float
44
+ channels: int
45
+
46
+
47
+ def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
48
+ _init_av()
49
+ with av.open(str(filepath)) as af:
50
+ stream = af.streams.audio[0]
51
+ sample_rate = stream.codec_context.sample_rate
52
+ duration = float(stream.duration * stream.time_base)
53
+ channels = stream.channels
54
+ return AudioFileInfo(sample_rate, duration, channels)
55
+
56
+
57
+ def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
58
+ info = soundfile.info(filepath)
59
+ return AudioFileInfo(info.samplerate, info.duration, info.channels)
60
+
61
+
62
+ def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
63
+ # torchaudio no longer returns useful duration informations for some formats like mp3s.
64
+ filepath = Path(filepath)
65
+ if filepath.suffix in ['.flac', '.ogg']: # TODO: Validate .ogg can be safely read with av_info
66
+ # ffmpeg has some weird issue with flac.
67
+ return _soundfile_info(filepath)
68
+ else:
69
+ return _av_info(filepath)
70
+
71
+
72
+ def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]:
73
+ """FFMPEG-based audio file reading using PyAV bindings.
74
+ Soundfile cannot read mp3 and av_read is more efficient than torchaudio.
75
+
76
+ Args:
77
+ filepath (str or Path): Path to audio file to read.
78
+ seek_time (float): Time at which to start reading in the file.
79
+ duration (float): Duration to read from the file. If set to -1, the whole file is read.
80
+ Returns:
81
+ tuple of torch.Tensor, int: Tuple containing audio data and sample rate
82
+ """
83
+ _init_av()
84
+ with av.open(str(filepath)) as af:
85
+ stream = af.streams.audio[0]
86
+ sr = stream.codec_context.sample_rate
87
+ num_frames = int(sr * duration) if duration >= 0 else -1
88
+ frame_offset = int(sr * seek_time)
89
+ # we need a small negative offset otherwise we get some edge artifact
90
+ # from the mp3 decoder.
91
+ af.seek(int(max(0, (seek_time - 0.1)) / stream.time_base), stream=stream)
92
+ frames = []
93
+ length = 0
94
+ for frame in af.decode(streams=stream.index):
95
+ current_offset = int(frame.rate * frame.pts * frame.time_base)
96
+ strip = max(0, frame_offset - current_offset)
97
+ buf = torch.from_numpy(frame.to_ndarray())
98
+ if buf.shape[0] != stream.channels:
99
+ buf = buf.view(-1, stream.channels).t()
100
+ buf = buf[:, strip:]
101
+ frames.append(buf)
102
+ length += buf.shape[1]
103
+ if num_frames > 0 and length >= num_frames:
104
+ break
105
+ assert frames
106
+ # If the above assert fails, it is likely because we seeked past the end of file point,
107
+ # in which case ffmpeg returns a single frame with only zeros, and a weird timestamp.
108
+ # This will need proper debugging, in due time.
109
+ wav = torch.cat(frames, dim=1)
110
+ assert wav.shape[0] == stream.channels
111
+ if num_frames > 0:
112
+ wav = wav[:, :num_frames]
113
+ return f32_pcm(wav), sr
114
+
115
+
116
+ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
117
+ duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
118
+ """Read audio by picking the most appropriate backend tool based on the audio format.
119
+
120
+ Args:
121
+ filepath (str or Path): Path to audio file to read.
122
+ seek_time (float): Time at which to start reading in the file.
123
+ duration (float): Duration to read from the file. If set to -1, the whole file is read.
124
+ pad (bool): Pad output audio if not reaching expected duration.
125
+ Returns:
126
+ tuple of torch.Tensor, int: Tuple containing audio data and sample rate.
127
+ """
128
+ fp = Path(filepath)
129
+ if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
130
+ # There is some bug with ffmpeg and reading flac
131
+ info = _soundfile_info(filepath)
132
+ frames = -1 if duration <= 0 else int(duration * info.sample_rate)
133
+ frame_offset = int(seek_time * info.sample_rate)
134
+ wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32)
135
+ assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}"
136
+ wav = torch.from_numpy(wav).t().contiguous()
137
+ if len(wav.shape) == 1:
138
+ wav = torch.unsqueeze(wav, 0)
139
+ elif (
140
+ fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats()
141
+ and duration <= 0 and seek_time == 0
142
+ ):
143
+ # Torchaudio is faster if we load an entire file at once.
144
+ wav, sr = ta.load(fp)
145
+ else:
146
+ wav, sr = _av_read(filepath, seek_time, duration)
147
+ if pad and duration > 0:
148
+ expected_frames = int(duration * sr)
149
+ wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
150
+ return wav, sr
151
+
152
+
153
+ def audio_write(stem_name: tp.Union[str, Path],
154
+ wav: torch.Tensor, sample_rate: int,
155
+ format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
156
+ strategy: str = 'peak', peak_clip_headroom_db: float = 1,
157
+ rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
158
+ loudness_compressor: bool = False,
159
+ log_clipping: bool = True, make_parent_dir: bool = True,
160
+ add_suffix: bool = True) -> Path:
161
+ """Convenience function for saving audio to disk. Returns the filename the audio was written to.
162
+
163
+ Args:
164
+ stem_name (str or Path): Filename without extension which will be added automatically.
165
+ format (str): Either "wav" or "mp3".
166
+ mp3_rate (int): kbps when using mp3s.
167
+ normalize (bool): if `True` (default), normalizes according to the prescribed
168
+ strategy (see after). If `False`, the strategy is only used in case clipping
169
+ would happen.
170
+ strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
171
+ i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
172
+ with extra headroom to avoid clipping. 'clip' just clips.
173
+ peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
174
+ rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
175
+ than the `peak_clip` one to avoid further clipping.
176
+ loudness_headroom_db (float): Target loudness for loudness normalization.
177
+ loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
178
+ when strategy is 'loudness' log_clipping (bool): If True, basic logging on stderr when clipping still
179
+ occurs despite strategy (only for 'rms').
180
+ make_parent_dir (bool): Make parent directory if it doesn't exist.
181
+ Returns:
182
+ Path: Path of the saved audio.
183
+ """
184
+ assert wav.dtype.is_floating_point, "wav is not floating point"
185
+ if wav.dim() == 1:
186
+ wav = wav[None]
187
+ elif wav.dim() > 2:
188
+ raise ValueError("Input wav should be at most 2 dimension.")
189
+ assert wav.isfinite().all()
190
+ wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
191
+ rms_headroom_db, loudness_headroom_db, loudness_compressor,
192
+ log_clipping=log_clipping, sample_rate=sample_rate,
193
+ stem_name=str(stem_name))
194
+ kwargs: dict = {}
195
+ if format == 'mp3':
196
+ suffix = '.mp3'
197
+ kwargs.update({"compression": mp3_rate})
198
+ elif format == 'wav':
199
+ wav = i16_pcm(wav)
200
+ suffix = '.wav'
201
+ kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16})
202
+ else:
203
+ raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
204
+ if not add_suffix:
205
+ suffix = ''
206
+ path = Path(str(stem_name) + suffix)
207
+ if make_parent_dir:
208
+ path.parent.mkdir(exist_ok=True, parents=True)
209
+ try:
210
+ ta.save(path, wav, sample_rate, **kwargs)
211
+ except Exception:
212
+ if path.exists():
213
+ # we do not want to leave half written files around.
214
+ path.unlink()
215
+ raise
216
+ return path
audiocraft/data/audio_dataset.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """AudioDataset support. In order to handle a larger number of files
7
+ without having to scan again the folders, we precompute some metadata
8
+ (filename, sample rate, duration), and use that to efficiently sample audio segments.
9
+ """
10
+ import argparse
11
+ import copy
12
+ from concurrent.futures import ThreadPoolExecutor, Future
13
+ from dataclasses import dataclass, fields
14
+ from contextlib import ExitStack
15
+ from functools import lru_cache
16
+ import gzip
17
+ import json
18
+ import logging
19
+ import os
20
+ from pathlib import Path
21
+ import random
22
+ import sys
23
+ import typing as tp
24
+
25
+ import torch
26
+ import torch.nn.functional as F
27
+
28
+ from .audio import audio_read, audio_info
29
+ from .audio_utils import convert_audio
30
+ from .zip import PathInZip
31
+
32
+ try:
33
+ import dora
34
+ except ImportError:
35
+ dora = None # type: ignore
36
+
37
+
38
+ @dataclass(order=True)
39
+ class BaseInfo:
40
+
41
+ @classmethod
42
+ def _dict2fields(cls, dictionary: dict):
43
+ return {
44
+ field.name: dictionary[field.name]
45
+ for field in fields(cls) if field.name in dictionary
46
+ }
47
+
48
+ @classmethod
49
+ def from_dict(cls, dictionary: dict):
50
+ _dictionary = cls._dict2fields(dictionary)
51
+ return cls(**_dictionary)
52
+
53
+ def to_dict(self):
54
+ return {
55
+ field.name: self.__getattribute__(field.name)
56
+ for field in fields(self)
57
+ }
58
+
59
+
60
+ @dataclass(order=True)
61
+ class AudioMeta(BaseInfo):
62
+ path: str
63
+ duration: float
64
+ sample_rate: int
65
+ amplitude: tp.Optional[float] = None
66
+ weight: tp.Optional[float] = None
67
+ # info_path is used to load additional information about the audio file that is stored in zip files.
68
+ info_path: tp.Optional[PathInZip] = None
69
+
70
+ @classmethod
71
+ def from_dict(cls, dictionary: dict):
72
+ base = cls._dict2fields(dictionary)
73
+ if 'info_path' in base and base['info_path'] is not None:
74
+ base['info_path'] = PathInZip(base['info_path'])
75
+ return cls(**base)
76
+
77
+ def to_dict(self):
78
+ d = super().to_dict()
79
+ if d['info_path'] is not None:
80
+ d['info_path'] = str(d['info_path'])
81
+ return d
82
+
83
+
84
+ @dataclass(order=True)
85
+ class SegmentInfo(BaseInfo):
86
+ meta: AudioMeta
87
+ seek_time: float
88
+ # The following values are given once the audio is processed, e.g.
89
+ # at the target sample rate and target number of channels.
90
+ n_frames: int # actual number of frames without padding
91
+ total_frames: int # total number of frames, padding included
92
+ sample_rate: int # actual sample rate
93
+ channels: int # number of audio channels.
94
+
95
+
96
+ DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
97
+
98
+ logger = logging.getLogger(__name__)
99
+
100
+
101
+ def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta:
102
+ """AudioMeta from a path to an audio file.
103
+
104
+ Args:
105
+ file_path (str): Resolved path of valid audio file.
106
+ minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
107
+ Returns:
108
+ AudioMeta: Audio file path and its metadata.
109
+ """
110
+ info = audio_info(file_path)
111
+ amplitude: tp.Optional[float] = None
112
+ if not minimal:
113
+ wav, sr = audio_read(file_path)
114
+ amplitude = wav.abs().max().item()
115
+ return AudioMeta(file_path, info.duration, info.sample_rate, amplitude)
116
+
117
+
118
+ def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta:
119
+ """If Dora is available as a dependency, try to resolve potential relative paths
120
+ in list of AudioMeta. This method is expected to be used when loading meta from file.
121
+
122
+ Args:
123
+ m (AudioMeta): Audio meta to resolve.
124
+ fast (bool): If True, uses a really fast check for determining if a file
125
+ is already absolute or not. Only valid on Linux/Mac.
126
+ Returns:
127
+ AudioMeta: Audio meta with resolved path.
128
+ """
129
+ def is_abs(m):
130
+ if fast:
131
+ return str(m)[0] == '/'
132
+ else:
133
+ os.path.isabs(str(m))
134
+
135
+ if not dora:
136
+ return m
137
+
138
+ if not is_abs(m.path):
139
+ m.path = dora.git_save.to_absolute_path(m.path)
140
+ if m.info_path is not None and not is_abs(m.info_path.zip_path):
141
+ m.info_path.zip_path = dora.git_save.to_absolute_path(m.path)
142
+ return m
143
+
144
+
145
+ def find_audio_files(path: tp.Union[Path, str],
146
+ exts: tp.List[str] = DEFAULT_EXTS,
147
+ resolve: bool = True,
148
+ minimal: bool = True,
149
+ progress: bool = False,
150
+ workers: int = 0) -> tp.List[AudioMeta]:
151
+ """Build a list of AudioMeta from a given path,
152
+ collecting relevant audio files and fetching meta info.
153
+
154
+ Args:
155
+ path (str or Path): Path to folder containing audio files.
156
+ exts (list of str): List of file extensions to consider for audio files.
157
+ minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
158
+ progress (bool): Whether to log progress on audio files collection.
159
+ workers (int): number of parallel workers, if 0, use only the current thread.
160
+ Returns:
161
+ list of AudioMeta: List of audio file path and its metadata.
162
+ """
163
+ audio_files = []
164
+ futures: tp.List[Future] = []
165
+ pool: tp.Optional[ThreadPoolExecutor] = None
166
+ with ExitStack() as stack:
167
+ if workers > 0:
168
+ pool = ThreadPoolExecutor(workers)
169
+ stack.enter_context(pool)
170
+
171
+ if progress:
172
+ print("Finding audio files...")
173
+ for root, folders, files in os.walk(path, followlinks=True):
174
+ for file in files:
175
+ full_path = Path(root) / file
176
+ if full_path.suffix.lower() in exts:
177
+ audio_files.append(full_path)
178
+ if pool is not None:
179
+ futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal))
180
+ if progress:
181
+ print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr)
182
+
183
+ if progress:
184
+ print("Getting audio metadata...")
185
+ meta: tp.List[AudioMeta] = []
186
+ for idx, file_path in enumerate(audio_files):
187
+ try:
188
+ if pool is None:
189
+ m = _get_audio_meta(str(file_path), minimal)
190
+ else:
191
+ m = futures[idx].result()
192
+ if resolve:
193
+ m = _resolve_audio_meta(m)
194
+ except Exception as err:
195
+ print("Error with", str(file_path), err, file=sys.stderr)
196
+ continue
197
+ meta.append(m)
198
+ if progress:
199
+ print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr)
200
+ meta.sort()
201
+ return meta
202
+
203
+
204
+ def load_audio_meta(path: tp.Union[str, Path],
205
+ resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]:
206
+ """Load list of AudioMeta from an optionally compressed json file.
207
+
208
+ Args:
209
+ path (str or Path): Path to JSON file.
210
+ resolve (bool): Whether to resolve the path from AudioMeta (default=True).
211
+ fast (bool): activates some tricks to make things faster.
212
+ Returns:
213
+ list of AudioMeta: List of audio file path and its total duration.
214
+ """
215
+ open_fn = gzip.open if str(path).lower().endswith('.gz') else open
216
+ with open_fn(path, 'rb') as fp: # type: ignore
217
+ lines = fp.readlines()
218
+ meta = []
219
+ for line in lines:
220
+ d = json.loads(line)
221
+ m = AudioMeta.from_dict(d)
222
+ if resolve:
223
+ m = _resolve_audio_meta(m, fast=fast)
224
+ meta.append(m)
225
+ return meta
226
+
227
+
228
+ def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]):
229
+ """Save the audio metadata to the file pointer as json.
230
+
231
+ Args:
232
+ path (str or Path): Path to JSON file.
233
+ metadata (list of BaseAudioMeta): List of audio meta to save.
234
+ """
235
+ Path(path).parent.mkdir(exist_ok=True, parents=True)
236
+ open_fn = gzip.open if str(path).lower().endswith('.gz') else open
237
+ with open_fn(path, 'wb') as fp: # type: ignore
238
+ for m in meta:
239
+ json_str = json.dumps(m.to_dict()) + '\n'
240
+ json_bytes = json_str.encode('utf-8')
241
+ fp.write(json_bytes)
242
+
243
+
244
+ class AudioDataset:
245
+ """Base audio dataset.
246
+
247
+ The dataset takes a list of AudioMeta and create a dataset composed of segments of audio
248
+ and potentially additional information, by creating random segments from the list of audio
249
+ files referenced in the metadata and applying minimal data pre-processing such as resampling,
250
+ mixing of channels, padding, etc.
251
+
252
+ If no segment_duration value is provided, the AudioDataset will return the full wav for each
253
+ audio file. Otherwise, it will randomly sample audio files and create a segment of the specified
254
+ duration, applying padding if required.
255
+
256
+ By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True
257
+ allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
258
+ original audio meta.
259
+
260
+ Note that you can call `start_epoch(epoch)` in order to get
261
+ a deterministic "randomization" for `shuffle=True`.
262
+ For a given epoch and dataset index, this will always return the same extract.
263
+ You can get back some diversity by setting the `shuffle_seed` param.
264
+
265
+ Args:
266
+ meta (list of AudioMeta): List of audio files metadata.
267
+ segment_duration (float, optional): Optional segment duration of audio to load.
268
+ If not specified, the dataset will load the full audio segment from the file.
269
+ shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
270
+ sample_rate (int): Target sample rate of the loaded audio samples.
271
+ channels (int): Target number of channels of the loaded audio samples.
272
+ sample_on_duration (bool): Set to `True` to sample segments with probability
273
+ dependent on audio file duration. This is only used if `segment_duration` is provided.
274
+ sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of
275
+ `AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product
276
+ of the file duration and file weight. This is only used if `segment_duration` is provided.
277
+ min_segment_ratio (float): Minimum segment ratio to use when the audio file
278
+ is shorter than the desired segment.
279
+ max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
280
+ return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
281
+ min_audio_duration (float, optional): Minimum audio file duration, in seconds, if provided
282
+ audio shorter than this will be filtered out.
283
+ max_audio_duration (float, optional): Maximal audio file duration in seconds, if provided
284
+ audio longer than this will be filtered out.
285
+ shuffle_seed (int): can be used to further randomize
286
+ load_wav (bool): if False, skip loading the wav but returns a tensor of 0
287
+ with the expected segment_duration (which must be provided if load_wav is False).
288
+ permutation_on_files (bool): only if `sample_on_weight` and `sample_on_duration`
289
+ are False. Will ensure a permutation on files when going through the dataset.
290
+ In that case the epoch number must be provided in order for the model
291
+ to continue the permutation across epochs. In that case, it is assumed
292
+ that `num_samples = total_batch_size * num_updates_per_epoch`, with
293
+ `total_batch_size` the overall batch size accounting for all gpus.
294
+ """
295
+ def __init__(self,
296
+ meta: tp.List[AudioMeta],
297
+ segment_duration: tp.Optional[float] = None,
298
+ shuffle: bool = True,
299
+ num_samples: int = 10_000,
300
+ sample_rate: int = 48_000,
301
+ channels: int = 2,
302
+ pad: bool = True,
303
+ sample_on_duration: bool = True,
304
+ sample_on_weight: bool = True,
305
+ min_segment_ratio: float = 0.5,
306
+ max_read_retry: int = 10,
307
+ return_info: bool = False,
308
+ min_audio_duration: tp.Optional[float] = None,
309
+ max_audio_duration: tp.Optional[float] = None,
310
+ shuffle_seed: int = 0,
311
+ load_wav: bool = True,
312
+ permutation_on_files: bool = False,
313
+ ):
314
+ assert len(meta) > 0, "No audio meta provided to AudioDataset. Please check loading of audio meta."
315
+ assert segment_duration is None or segment_duration > 0
316
+ assert segment_duration is None or min_segment_ratio >= 0
317
+ self.segment_duration = segment_duration
318
+ self.min_segment_ratio = min_segment_ratio
319
+ self.max_audio_duration = max_audio_duration
320
+ self.min_audio_duration = min_audio_duration
321
+ if self.min_audio_duration is not None and self.max_audio_duration is not None:
322
+ assert self.min_audio_duration <= self.max_audio_duration
323
+ self.meta: tp.List[AudioMeta] = self._filter_duration(meta)
324
+ assert len(self.meta) # Fail fast if all data has been filtered.
325
+ self.total_duration = sum(d.duration for d in self.meta)
326
+
327
+ if segment_duration is None:
328
+ num_samples = len(self.meta)
329
+ self.num_samples = num_samples
330
+ self.shuffle = shuffle
331
+ self.sample_rate = sample_rate
332
+ self.channels = channels
333
+ self.pad = pad
334
+ self.sample_on_weight = sample_on_weight
335
+ self.sample_on_duration = sample_on_duration
336
+ self.sampling_probabilities = self._get_sampling_probabilities()
337
+ self.max_read_retry = max_read_retry
338
+ self.return_info = return_info
339
+ self.shuffle_seed = shuffle_seed
340
+ self.current_epoch: tp.Optional[int] = None
341
+ self.load_wav = load_wav
342
+ if not load_wav:
343
+ assert segment_duration is not None
344
+ self.permutation_on_files = permutation_on_files
345
+ if permutation_on_files:
346
+ assert not self.sample_on_duration
347
+ assert not self.sample_on_weight
348
+ assert self.shuffle
349
+
350
+ def start_epoch(self, epoch: int):
351
+ self.current_epoch = epoch
352
+
353
+ def __len__(self):
354
+ return self.num_samples
355
+
356
+ def _get_sampling_probabilities(self, normalized: bool = True):
357
+ """Return the sampling probabilities for each file inside `self.meta`."""
358
+ scores: tp.List[float] = []
359
+ for file_meta in self.meta:
360
+ score = 1.
361
+ if self.sample_on_weight and file_meta.weight is not None:
362
+ score *= file_meta.weight
363
+ if self.sample_on_duration:
364
+ score *= file_meta.duration
365
+ scores.append(score)
366
+ probabilities = torch.tensor(scores)
367
+ if normalized:
368
+ probabilities /= probabilities.sum()
369
+ return probabilities
370
+
371
+ @staticmethod
372
+ @lru_cache(16)
373
+ def _get_file_permutation(num_files: int, permutation_index: int, base_seed: int):
374
+ # Used to keep the most recent files permutation in memory implicitely.
375
+ # will work unless someone is using a lot of Datasets in parallel.
376
+ rng = torch.Generator()
377
+ rng.manual_seed(base_seed + permutation_index)
378
+ return torch.randperm(num_files, generator=rng)
379
+
380
+ def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta:
381
+ """Sample a given file from `self.meta`. Can be overridden in subclasses.
382
+ This is only called if `segment_duration` is not None.
383
+
384
+ You must use the provided random number generator `rng` for reproducibility.
385
+ You can further make use of the index accessed.
386
+ """
387
+ if self.permutation_on_files:
388
+ assert self.current_epoch is not None
389
+ total_index = self.current_epoch * len(self) + index
390
+ permutation_index = total_index // len(self.meta)
391
+ relative_index = total_index % len(self.meta)
392
+ permutation = AudioDataset._get_file_permutation(
393
+ len(self.meta), permutation_index, self.shuffle_seed)
394
+ file_index = permutation[relative_index]
395
+ return self.meta[file_index]
396
+
397
+ if not self.sample_on_weight and not self.sample_on_duration:
398
+ file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
399
+ else:
400
+ file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item())
401
+
402
+ return self.meta[file_index]
403
+
404
+ def _audio_read(self, path: str, seek_time: float = 0, duration: float = -1):
405
+ # Override this method in subclass if needed.
406
+ if self.load_wav:
407
+ return audio_read(path, seek_time, duration, pad=False)
408
+ else:
409
+ assert self.segment_duration is not None
410
+ n_frames = int(self.sample_rate * self.segment_duration)
411
+ return torch.zeros(self.channels, n_frames), self.sample_rate
412
+
413
+ def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
414
+ if self.segment_duration is None:
415
+ file_meta = self.meta[index]
416
+ out, sr = audio_read(file_meta.path)
417
+ out = convert_audio(out, sr, self.sample_rate, self.channels)
418
+ n_frames = out.shape[-1]
419
+ segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
420
+ sample_rate=self.sample_rate, channels=out.shape[0])
421
+ else:
422
+ rng = torch.Generator()
423
+ if self.shuffle:
424
+ # We use index, plus extra randomness, either totally random if we don't know the epoch.
425
+ # otherwise we make use of the epoch number and optional shuffle_seed.
426
+ if self.current_epoch is None:
427
+ rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
428
+ else:
429
+ rng.manual_seed(index + self.num_samples * (self.current_epoch + self.shuffle_seed))
430
+ else:
431
+ # We only use index
432
+ rng.manual_seed(index)
433
+
434
+ for retry in range(self.max_read_retry):
435
+ file_meta = self.sample_file(index, rng)
436
+ # We add some variance in the file position even if audio file is smaller than segment
437
+ # without ending up with empty segments
438
+ max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
439
+ seek_time = torch.rand(1, generator=rng).item() * max_seek
440
+ try:
441
+ out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False)
442
+ out = convert_audio(out, sr, self.sample_rate, self.channels)
443
+ n_frames = out.shape[-1]
444
+ target_frames = int(self.segment_duration * self.sample_rate)
445
+ if self.pad:
446
+ out = F.pad(out, (0, target_frames - n_frames))
447
+ segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
448
+ sample_rate=self.sample_rate, channels=out.shape[0])
449
+ except Exception as exc:
450
+ logger.warning("Error opening file %s: %r", file_meta.path, exc)
451
+ if retry == self.max_read_retry - 1:
452
+ raise
453
+ else:
454
+ break
455
+
456
+ if self.return_info:
457
+ # Returns the wav and additional information on the wave segment
458
+ return out, segment_info
459
+ else:
460
+ return out
461
+
462
+ def collater(self, samples):
463
+ """The collater function has to be provided to the dataloader
464
+ if AudioDataset has return_info=True in order to properly collate
465
+ the samples of a batch.
466
+ """
467
+ if self.segment_duration is None and len(samples) > 1:
468
+ assert self.pad, "Must allow padding when batching examples of different durations."
469
+
470
+ # In this case the audio reaching the collater is of variable length as segment_duration=None.
471
+ to_pad = self.segment_duration is None and self.pad
472
+ if to_pad:
473
+ max_len = max([wav.shape[-1] for wav, _ in samples])
474
+
475
+ def _pad_wav(wav):
476
+ return F.pad(wav, (0, max_len - wav.shape[-1]))
477
+
478
+ if self.return_info:
479
+ if len(samples) > 0:
480
+ assert len(samples[0]) == 2
481
+ assert isinstance(samples[0][0], torch.Tensor)
482
+ assert isinstance(samples[0][1], SegmentInfo)
483
+
484
+ wavs = [wav for wav, _ in samples]
485
+ segment_infos = [copy.deepcopy(info) for _, info in samples]
486
+
487
+ if to_pad:
488
+ # Each wav could be of a different duration as they are not segmented.
489
+ for i in range(len(samples)):
490
+ # Determines the total length of the signal with padding, so we update here as we pad.
491
+ segment_infos[i].total_frames = max_len
492
+ wavs[i] = _pad_wav(wavs[i])
493
+
494
+ wav = torch.stack(wavs)
495
+ return wav, segment_infos
496
+ else:
497
+ assert isinstance(samples[0], torch.Tensor)
498
+ if to_pad:
499
+ samples = [_pad_wav(s) for s in samples]
500
+ return torch.stack(samples)
501
+
502
+ def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
503
+ """Filters out audio files with audio durations that will not allow to sample examples from them."""
504
+ orig_len = len(meta)
505
+
506
+ # Filter data that is too short.
507
+ if self.min_audio_duration is not None:
508
+ meta = [m for m in meta if m.duration >= self.min_audio_duration]
509
+
510
+ # Filter data that is too long.
511
+ if self.max_audio_duration is not None:
512
+ meta = [m for m in meta if m.duration <= self.max_audio_duration]
513
+
514
+ filtered_len = len(meta)
515
+ removed_percentage = 100*(1-float(filtered_len)/orig_len)
516
+ msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage
517
+ if removed_percentage < 10:
518
+ logging.debug(msg)
519
+ else:
520
+ logging.warning(msg)
521
+ return meta
522
+
523
+ @classmethod
524
+ def from_meta(cls, root: tp.Union[str, Path], **kwargs):
525
+ """Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.
526
+
527
+ Args:
528
+ root (str or Path): Path to root folder containing audio files.
529
+ kwargs: Additional keyword arguments for the AudioDataset.
530
+ """
531
+ root = Path(root)
532
+ if root.is_dir():
533
+ if (root / 'data.jsonl').exists():
534
+ root = root / 'data.jsonl'
535
+ elif (root / 'data.jsonl.gz').exists():
536
+ root = root / 'data.jsonl.gz'
537
+ else:
538
+ raise ValueError("Don't know where to read metadata from in the dir. "
539
+ "Expecting either a data.jsonl or data.jsonl.gz file but none found.")
540
+ meta = load_audio_meta(root)
541
+ return cls(meta, **kwargs)
542
+
543
+ @classmethod
544
+ def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True,
545
+ exts: tp.List[str] = DEFAULT_EXTS, **kwargs):
546
+ """Instantiate AudioDataset from a path containing (possibly nested) audio files.
547
+
548
+ Args:
549
+ root (str or Path): Path to root folder containing audio files.
550
+ minimal_meta (bool): Whether to only load minimal metadata or not.
551
+ exts (list of str): Extensions for audio files.
552
+ kwargs: Additional keyword arguments for the AudioDataset.
553
+ """
554
+ root = Path(root)
555
+ if root.is_file():
556
+ meta = load_audio_meta(root, resolve=True)
557
+ else:
558
+ meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True)
559
+ return cls(meta, **kwargs)
560
+
561
+
562
+ def main():
563
+ logging.basicConfig(stream=sys.stderr, level=logging.INFO)
564
+ parser = argparse.ArgumentParser(
565
+ prog='audio_dataset',
566
+ description='Generate .jsonl files by scanning a folder.')
567
+ parser.add_argument('root', help='Root folder with all the audio files')
568
+ parser.add_argument('output_meta_file',
569
+ help='Output file to store the metadata, ')
570
+ parser.add_argument('--complete',
571
+ action='store_false', dest='minimal', default=True,
572
+ help='Retrieve all metadata, even the one that are expansive '
573
+ 'to compute (e.g. normalization).')
574
+ parser.add_argument('--resolve',
575
+ action='store_true', default=False,
576
+ help='Resolve the paths to be absolute and with no symlinks.')
577
+ parser.add_argument('--workers',
578
+ default=10, type=int,
579
+ help='Number of workers.')
580
+ args = parser.parse_args()
581
+ meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True,
582
+ resolve=args.resolve, minimal=args.minimal, workers=args.workers)
583
+ save_audio_meta(args.output_meta_file, meta)
584
+
585
+
586
+ if __name__ == '__main__':
587
+ main()
audiocraft/data/audio_utils.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Various utilities for audio convertion (pcm format, sample rate and channels),
7
+ and volume normalization."""
8
+ import sys
9
+ import typing as tp
10
+
11
+ import julius
12
+ import torch
13
+ import torchaudio
14
+
15
+
16
+ def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor:
17
+ """Convert audio to the given number of channels.
18
+
19
+ Args:
20
+ wav (torch.Tensor): Audio wave of shape [B, C, T].
21
+ channels (int): Expected number of channels as output.
22
+ Returns:
23
+ torch.Tensor: Downmixed or unchanged audio wave [B, C, T].
24
+ """
25
+ *shape, src_channels, length = wav.shape
26
+ if src_channels == channels:
27
+ pass
28
+ elif channels == 1:
29
+ # Case 1:
30
+ # The caller asked 1-channel audio, and the stream has multiple
31
+ # channels, downmix all channels.
32
+ wav = wav.mean(dim=-2, keepdim=True)
33
+ elif src_channels == 1:
34
+ # Case 2:
35
+ # The caller asked for multiple channels, but the input file has
36
+ # a single channel, replicate the audio over all channels.
37
+ wav = wav.expand(*shape, channels, length)
38
+ elif src_channels >= channels:
39
+ # Case 3:
40
+ # The caller asked for multiple channels, and the input file has
41
+ # more channels than requested. In that case return the first channels.
42
+ wav = wav[..., :channels, :]
43
+ else:
44
+ # Case 4: What is a reasonable choice here?
45
+ raise ValueError('The audio file has less channels than requested but is not mono.')
46
+ return wav
47
+
48
+
49
+ def convert_audio(wav: torch.Tensor, from_rate: float,
50
+ to_rate: float, to_channels: int) -> torch.Tensor:
51
+ """Convert audio to new sample rate and number of audio channels."""
52
+ wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
53
+ wav = convert_audio_channels(wav, to_channels)
54
+ return wav
55
+
56
+
57
+ def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
58
+ loudness_compressor: bool = False, energy_floor: float = 2e-3):
59
+ """Normalize an input signal to a user loudness in dB LKFS.
60
+ Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
61
+
62
+ Args:
63
+ wav (torch.Tensor): Input multichannel audio data.
64
+ sample_rate (int): Sample rate.
65
+ loudness_headroom_db (float): Target loudness of the output in dB LUFS.
66
+ loudness_compressor (bool): Uses tanh for soft clipping.
67
+ energy_floor (float): anything below that RMS level will not be rescaled.
68
+ Returns:
69
+ torch.Tensor: Loudness normalized output data.
70
+ """
71
+ energy = wav.pow(2).mean().sqrt().item()
72
+ if energy < energy_floor:
73
+ return wav
74
+ transform = torchaudio.transforms.Loudness(sample_rate)
75
+ input_loudness_db = transform(wav).item()
76
+ # calculate the gain needed to scale to the desired loudness level
77
+ delta_loudness = -loudness_headroom_db - input_loudness_db
78
+ gain = 10.0 ** (delta_loudness / 20.0)
79
+ output = gain * wav
80
+ if loudness_compressor:
81
+ output = torch.tanh(output)
82
+ assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
83
+ return output
84
+
85
+
86
+ def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None:
87
+ """Utility function to clip the audio with logging if specified."""
88
+ max_scale = wav.abs().max()
89
+ if log_clipping and max_scale > 1:
90
+ clamp_prob = (wav.abs() > 1).float().mean().item()
91
+ print(f"CLIPPING {stem_name or ''} happening with proba (a bit of clipping is okay):",
92
+ clamp_prob, "maximum scale: ", max_scale.item(), file=sys.stderr)
93
+ #wav.clamp_(-1, 1)
94
+ wav = wav.clone().clamp_(-1, 1)
95
+
96
+
97
+ def normalize_audio(wav: torch.Tensor, normalize: bool = True,
98
+ strategy: str = 'peak', peak_clip_headroom_db: float = 1,
99
+ rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
100
+ loudness_compressor: bool = False, log_clipping: bool = False,
101
+ sample_rate: tp.Optional[int] = None,
102
+ stem_name: tp.Optional[str] = None) -> torch.Tensor:
103
+ """Normalize the audio according to the prescribed strategy (see after).
104
+
105
+ Args:
106
+ wav (torch.Tensor): Audio data.
107
+ normalize (bool): if `True` (default), normalizes according to the prescribed
108
+ strategy (see after). If `False`, the strategy is only used in case clipping
109
+ would happen.
110
+ strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
111
+ i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
112
+ with extra headroom to avoid clipping. 'clip' just clips.
113
+ peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
114
+ rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
115
+ than the `peak_clip` one to avoid further clipping.
116
+ loudness_headroom_db (float): Target loudness for loudness normalization.
117
+ loudness_compressor (bool): If True, uses tanh based soft clipping.
118
+ log_clipping (bool): If True, basic logging on stderr when clipping still
119
+ occurs despite strategy (only for 'rms').
120
+ sample_rate (int): Sample rate for the audio data (required for loudness).
121
+ stem_name (str, optional): Stem name for clipping logging.
122
+ Returns:
123
+ torch.Tensor: Normalized audio.
124
+ """
125
+ scale_peak = 10 ** (-peak_clip_headroom_db / 20)
126
+ scale_rms = 10 ** (-rms_headroom_db / 20)
127
+ if strategy == 'peak':
128
+ rescaling = (scale_peak / wav.abs().max())
129
+ if normalize or rescaling < 1:
130
+ wav = wav * rescaling
131
+ elif strategy == 'clip':
132
+ wav = wav.clamp(-scale_peak, scale_peak)
133
+ elif strategy == 'rms':
134
+ mono = wav.mean(dim=0)
135
+ rescaling = scale_rms / mono.pow(2).mean().sqrt()
136
+ if normalize or rescaling < 1:
137
+ wav = wav * rescaling
138
+ _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
139
+ elif strategy == 'loudness':
140
+ assert sample_rate is not None, "Loudness normalization requires sample rate."
141
+ wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor)
142
+ _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
143
+ else:
144
+ assert wav.abs().max() < 1
145
+ assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'"
146
+ return wav
147
+
148
+
149
+ def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
150
+ """Convert audio to float 32 bits PCM format.
151
+ """
152
+ if wav.dtype.is_floating_point:
153
+ return wav
154
+ elif wav.dtype == torch.int16:
155
+ return wav.float() / 2**15
156
+ elif wav.dtype == torch.int32:
157
+ return wav.float() / 2**31
158
+ raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
159
+
160
+
161
+ def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
162
+ """Convert audio to int 16 bits PCM format.
163
+
164
+ ..Warning:: There exist many formula for doing this conversion. None are perfect
165
+ due to the asymmetry of the int16 range. One either have possible clipping, DC offset,
166
+ or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom,
167
+ it is possible that `i16_pcm(f32_pcm)) != Identity`.
168
+ """
169
+ if wav.dtype.is_floating_point:
170
+ assert wav.abs().max() <= 1
171
+ candidate = (wav * 2 ** 15).round()
172
+ if candidate.max() >= 2 ** 15: # clipping would occur
173
+ candidate = (wav * (2 ** 15 - 1)).round()
174
+ return candidate.short()
175
+ else:
176
+ assert wav.dtype == torch.int16
177
+ return wav
audiocraft/data/info_audio_dataset.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Base classes for the datasets that also provide non-audio metadata,
7
+ e.g. description, text transcription etc.
8
+ """
9
+ from dataclasses import dataclass
10
+ import logging
11
+ import math
12
+ import re
13
+ import typing as tp
14
+
15
+ import torch
16
+
17
+ from .audio_dataset import AudioDataset, AudioMeta
18
+ from ..environment import AudioCraftEnvironment
19
+ from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def _clusterify_meta(meta: AudioMeta) -> AudioMeta:
26
+ """Monkey-patch meta to match cluster specificities."""
27
+ meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path)
28
+ if meta.info_path is not None:
29
+ meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path)
30
+ return meta
31
+
32
+
33
+ def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
34
+ """Monkey-patch all meta to match cluster specificities."""
35
+ return [_clusterify_meta(m) for m in meta]
36
+
37
+
38
+ @dataclass
39
+ class AudioInfo(SegmentWithAttributes):
40
+ """Dummy SegmentInfo with empty attributes.
41
+
42
+ The InfoAudioDataset is expected to return metadata that inherits
43
+ from SegmentWithAttributes class and can return conditioning attributes.
44
+
45
+ This basically guarantees all datasets will be compatible with current
46
+ solver that contain conditioners requiring this.
47
+ """
48
+ audio_tokens: tp.Optional[torch.Tensor] = None # populated when using cached batch for training a LM.
49
+
50
+ def to_condition_attributes(self) -> ConditioningAttributes:
51
+ return ConditioningAttributes()
52
+
53
+
54
+ class InfoAudioDataset(AudioDataset):
55
+ """AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform.
56
+
57
+ See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments.
58
+ """
59
+ def __init__(self, meta: tp.List[AudioMeta], **kwargs):
60
+ super().__init__(clusterify_all_meta(meta), **kwargs)
61
+
62
+ def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]:
63
+ if not self.return_info:
64
+ wav = super().__getitem__(index)
65
+ assert isinstance(wav, torch.Tensor)
66
+ return wav
67
+ wav, meta = super().__getitem__(index)
68
+ return wav, AudioInfo(**meta.to_dict())
69
+
70
+
71
+ def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]:
72
+ """Preprocess a single keyword or possible a list of keywords."""
73
+ if isinstance(value, list):
74
+ return get_keyword_list(value)
75
+ else:
76
+ return get_keyword(value)
77
+
78
+
79
+ def get_string(value: tp.Optional[str]) -> tp.Optional[str]:
80
+ """Preprocess a single keyword."""
81
+ if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
82
+ return None
83
+ else:
84
+ return value.strip()
85
+
86
+
87
+ def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]:
88
+ """Preprocess a single keyword."""
89
+ if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
90
+ return None
91
+ else:
92
+ return value.strip().lower()
93
+
94
+
95
+ def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]:
96
+ """Preprocess a list of keywords."""
97
+ if isinstance(values, str):
98
+ values = [v.strip() for v in re.split(r'[,\s]', values)]
99
+ elif isinstance(values, float) and math.isnan(values):
100
+ values = []
101
+ if not isinstance(values, list):
102
+ logger.debug(f"Unexpected keyword list {values}")
103
+ values = [str(values)]
104
+
105
+ kws = [get_keyword(v) for v in values]
106
+ kw_list = [k for k in kws if k is not None]
107
+ if len(kw_list) == 0:
108
+ return None
109
+ else:
110
+ return kw_list
audiocraft/data/music_dataset.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Dataset of music tracks with rich metadata.
7
+ """
8
+ from dataclasses import dataclass, field, fields, replace
9
+ import gzip
10
+ import json
11
+ import logging
12
+ from pathlib import Path
13
+ import random
14
+ import typing as tp
15
+
16
+ import torch
17
+
18
+ from .info_audio_dataset import (
19
+ InfoAudioDataset,
20
+ AudioInfo,
21
+ get_keyword_list,
22
+ get_keyword,
23
+ get_string
24
+ )
25
+ from ..modules.conditioners import (
26
+ ConditioningAttributes,
27
+ JointEmbedCondition,
28
+ WavCondition,
29
+ )
30
+ from ..utils.utils import warn_once
31
+
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ @dataclass
37
+ class MusicInfo(AudioInfo):
38
+ """Segment info augmented with music metadata.
39
+ """
40
+ # music-specific metadata
41
+ title: tp.Optional[str] = None
42
+ artist: tp.Optional[str] = None # anonymized artist id, used to ensure no overlap between splits
43
+ key: tp.Optional[str] = None
44
+ bpm: tp.Optional[float] = None
45
+ genre: tp.Optional[str] = None
46
+ moods: tp.Optional[list] = None
47
+ keywords: tp.Optional[list] = None
48
+ description: tp.Optional[str] = None
49
+ name: tp.Optional[str] = None
50
+ instrument: tp.Optional[str] = None
51
+ # original wav accompanying the metadata
52
+ self_wav: tp.Optional[WavCondition] = None
53
+ # dict mapping attributes names to tuple of wav, text and metadata
54
+ joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
55
+
56
+ @property
57
+ def has_music_meta(self) -> bool:
58
+ return self.name is not None
59
+
60
+ def to_condition_attributes(self) -> ConditioningAttributes:
61
+ out = ConditioningAttributes()
62
+ for _field in fields(self):
63
+ key, value = _field.name, getattr(self, _field.name)
64
+ if key == 'self_wav':
65
+ out.wav[key] = value
66
+ elif key == 'joint_embed':
67
+ for embed_attribute, embed_cond in value.items():
68
+ out.joint_embed[embed_attribute] = embed_cond
69
+ else:
70
+ if isinstance(value, list):
71
+ value = ' '.join(value)
72
+ out.text[key] = value
73
+ return out
74
+
75
+ @staticmethod
76
+ def attribute_getter(attribute):
77
+ if attribute == 'bpm':
78
+ preprocess_func = get_bpm
79
+ elif attribute == 'key':
80
+ preprocess_func = get_musical_key
81
+ elif attribute in ['moods', 'keywords']:
82
+ preprocess_func = get_keyword_list
83
+ elif attribute in ['genre', 'name', 'instrument']:
84
+ preprocess_func = get_keyword
85
+ elif attribute in ['title', 'artist', 'description']:
86
+ preprocess_func = get_string
87
+ else:
88
+ preprocess_func = None
89
+ return preprocess_func
90
+
91
+ @classmethod
92
+ def from_dict(cls, dictionary: dict, fields_required: bool = False):
93
+ _dictionary: tp.Dict[str, tp.Any] = {}
94
+
95
+ # allow a subset of attributes to not be loaded from the dictionary
96
+ # these attributes may be populated later
97
+ post_init_attributes = ['self_wav', 'joint_embed']
98
+ optional_fields = ['keywords']
99
+
100
+ for _field in fields(cls):
101
+ if _field.name in post_init_attributes:
102
+ continue
103
+ elif _field.name not in dictionary:
104
+ if fields_required and _field.name not in optional_fields:
105
+ raise KeyError(f"Unexpected missing key: {_field.name}")
106
+ else:
107
+ preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
108
+ value = dictionary[_field.name]
109
+ if preprocess_func:
110
+ value = preprocess_func(value)
111
+ _dictionary[_field.name] = value
112
+ return cls(**_dictionary)
113
+
114
+
115
+ def augment_music_info_description(music_info: MusicInfo, merge_text_p: float = 0.,
116
+ drop_desc_p: float = 0., drop_other_p: float = 0.) -> MusicInfo:
117
+ """Augment MusicInfo description with additional metadata fields and potential dropout.
118
+ Additional textual attributes are added given probability 'merge_text_conditions_p' and
119
+ the original textual description is dropped from the augmented description given probability drop_desc_p.
120
+
121
+ Args:
122
+ music_info (MusicInfo): The music metadata to augment.
123
+ merge_text_p (float): Probability of merging additional metadata to the description.
124
+ If provided value is 0, then no merging is performed.
125
+ drop_desc_p (float): Probability of dropping the original description on text merge.
126
+ if provided value is 0, then no drop out is performed.
127
+ drop_other_p (float): Probability of dropping the other fields used for text augmentation.
128
+ Returns:
129
+ MusicInfo: The MusicInfo with augmented textual description.
130
+ """
131
+ def is_valid_field(field_name: str, field_value: tp.Any) -> bool:
132
+ valid_field_name = field_name in ['key', 'bpm', 'genre', 'moods', 'instrument', 'keywords']
133
+ valid_field_value = field_value is not None and isinstance(field_value, (int, float, str, list))
134
+ keep_field = random.uniform(0, 1) < drop_other_p
135
+ return valid_field_name and valid_field_value and keep_field
136
+
137
+ def process_value(v: tp.Any) -> str:
138
+ if isinstance(v, (int, float, str)):
139
+ return str(v)
140
+ if isinstance(v, list):
141
+ return ", ".join(v)
142
+ else:
143
+ raise ValueError(f"Unknown type for text value! ({type(v), v})")
144
+
145
+ description = music_info.description
146
+
147
+ metadata_text = ""
148
+ if random.uniform(0, 1) < merge_text_p:
149
+ meta_pairs = [f'{_field.name}: {process_value(getattr(music_info, _field.name))}'
150
+ for _field in fields(music_info) if is_valid_field(_field.name, getattr(music_info, _field.name))]
151
+ random.shuffle(meta_pairs)
152
+ metadata_text = ". ".join(meta_pairs)
153
+ description = description if not random.uniform(0, 1) < drop_desc_p else None
154
+ logger.debug(f"Applying text augmentation on MMI info. description: {description}, metadata: {metadata_text}")
155
+
156
+ if description is None:
157
+ description = metadata_text if len(metadata_text) > 1 else None
158
+ else:
159
+ description = ". ".join([description.rstrip('.'), metadata_text])
160
+ description = description.strip() if description else None
161
+
162
+ music_info = replace(music_info)
163
+ music_info.description = description
164
+ return music_info
165
+
166
+
167
+ class Paraphraser:
168
+ def __init__(self, paraphrase_source: tp.Union[str, Path], paraphrase_p: float = 0.):
169
+ self.paraphrase_p = paraphrase_p
170
+ open_fn = gzip.open if str(paraphrase_source).lower().endswith('.gz') else open
171
+ with open_fn(paraphrase_source, 'rb') as f: # type: ignore
172
+ self.paraphrase_source = json.loads(f.read())
173
+ logger.info(f"loaded paraphrasing source from: {paraphrase_source}")
174
+
175
+ def sample_paraphrase(self, audio_path: str, description: str):
176
+ if random.random() >= self.paraphrase_p:
177
+ return description
178
+ info_path = Path(audio_path).with_suffix('.json')
179
+ if info_path not in self.paraphrase_source:
180
+ warn_once(logger, f"{info_path} not in paraphrase source!")
181
+ return description
182
+ new_desc = random.choice(self.paraphrase_source[info_path])
183
+ logger.debug(f"{description} -> {new_desc}")
184
+ return new_desc
185
+
186
+
187
+ class MusicDataset(InfoAudioDataset):
188
+ """Music dataset is an AudioDataset with music-related metadata.
189
+
190
+ Args:
191
+ info_fields_required (bool): Whether to enforce having required fields.
192
+ merge_text_p (float): Probability of merging additional metadata to the description.
193
+ drop_desc_p (float): Probability of dropping the original description on text merge.
194
+ drop_other_p (float): Probability of dropping the other fields used for text augmentation.
195
+ joint_embed_attributes (list[str]): A list of attributes for which joint embedding metadata is returned.
196
+ paraphrase_source (str, optional): Path to the .json or .json.gz file containing the
197
+ paraphrases for the description. The json should be a dict with keys are the
198
+ original info path (e.g. track_path.json) and each value is a list of possible
199
+ paraphrased.
200
+ paraphrase_p (float): probability of taking a paraphrase.
201
+
202
+ See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
203
+ """
204
+ def __init__(self, *args, info_fields_required: bool = True,
205
+ merge_text_p: float = 0., drop_desc_p: float = 0., drop_other_p: float = 0.,
206
+ joint_embed_attributes: tp.List[str] = [],
207
+ paraphrase_source: tp.Optional[str] = None, paraphrase_p: float = 0,
208
+ **kwargs):
209
+ kwargs['return_info'] = True # We require the info for each song of the dataset.
210
+ super().__init__(*args, **kwargs)
211
+ self.info_fields_required = info_fields_required
212
+ self.merge_text_p = merge_text_p
213
+ self.drop_desc_p = drop_desc_p
214
+ self.drop_other_p = drop_other_p
215
+ self.joint_embed_attributes = joint_embed_attributes
216
+ self.paraphraser = None
217
+ if paraphrase_source is not None:
218
+ self.paraphraser = Paraphraser(paraphrase_source, paraphrase_p)
219
+
220
+ def __getitem__(self, index):
221
+ wav, info = super().__getitem__(index)
222
+ info_data = info.to_dict()
223
+ music_info_path = Path(info.meta.path).with_suffix('.json')
224
+
225
+ if Path(music_info_path).exists():
226
+ with open(music_info_path, 'r') as json_file:
227
+ music_data = json.load(json_file)
228
+ music_data.update(info_data)
229
+ music_info = MusicInfo.from_dict(music_data, fields_required=self.info_fields_required)
230
+ if self.paraphraser is not None:
231
+ music_info.description = self.paraphraser.sample(music_info.meta.path, music_info.description)
232
+ if self.merge_text_p:
233
+ music_info = augment_music_info_description(
234
+ music_info, self.merge_text_p, self.drop_desc_p, self.drop_other_p)
235
+ else:
236
+ music_info = MusicInfo.from_dict(info_data, fields_required=False)
237
+
238
+ music_info.self_wav = WavCondition(
239
+ wav=wav[None], length=torch.tensor([info.n_frames]),
240
+ sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
241
+
242
+ for att in self.joint_embed_attributes:
243
+ att_value = getattr(music_info, att)
244
+ joint_embed_cond = JointEmbedCondition(
245
+ wav[None], [att_value], torch.tensor([info.n_frames]),
246
+ sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
247
+ music_info.joint_embed[att] = joint_embed_cond
248
+
249
+ return wav, music_info
250
+
251
+
252
+ def get_musical_key(value: tp.Optional[str]) -> tp.Optional[str]:
253
+ """Preprocess key keywords, discarding them if there are multiple key defined."""
254
+ if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
255
+ return None
256
+ elif ',' in value:
257
+ # For now, we discard when multiple keys are defined separated with comas
258
+ return None
259
+ else:
260
+ return value.strip().lower()
261
+
262
+
263
+ def get_bpm(value: tp.Optional[str]) -> tp.Optional[float]:
264
+ """Preprocess to a float."""
265
+ if value is None:
266
+ return None
267
+ try:
268
+ return float(value)
269
+ except ValueError:
270
+ return None
audiocraft/data/sound_dataset.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Dataset of audio with a simple description.
7
+ """
8
+
9
+ from dataclasses import dataclass, fields, replace
10
+ import json
11
+ from pathlib import Path
12
+ import random
13
+ import typing as tp
14
+
15
+ import numpy as np
16
+ import torch
17
+
18
+ from .info_audio_dataset import (
19
+ InfoAudioDataset,
20
+ get_keyword_or_keyword_list
21
+ )
22
+ from ..modules.conditioners import (
23
+ ConditioningAttributes,
24
+ SegmentWithAttributes,
25
+ WavCondition,
26
+ )
27
+
28
+
29
+ EPS = torch.finfo(torch.float32).eps
30
+ TARGET_LEVEL_LOWER = -35
31
+ TARGET_LEVEL_UPPER = -15
32
+
33
+
34
+ @dataclass
35
+ class SoundInfo(SegmentWithAttributes):
36
+ """Segment info augmented with Sound metadata.
37
+ """
38
+ description: tp.Optional[str] = None
39
+ self_wav: tp.Optional[torch.Tensor] = None
40
+
41
+ @property
42
+ def has_sound_meta(self) -> bool:
43
+ return self.description is not None
44
+
45
+ def to_condition_attributes(self) -> ConditioningAttributes:
46
+ out = ConditioningAttributes()
47
+
48
+ for _field in fields(self):
49
+ key, value = _field.name, getattr(self, _field.name)
50
+ if key == 'self_wav':
51
+ out.wav[key] = value
52
+ else:
53
+ out.text[key] = value
54
+ return out
55
+
56
+ @staticmethod
57
+ def attribute_getter(attribute):
58
+ if attribute == 'description':
59
+ preprocess_func = get_keyword_or_keyword_list
60
+ else:
61
+ preprocess_func = None
62
+ return preprocess_func
63
+
64
+ @classmethod
65
+ def from_dict(cls, dictionary: dict, fields_required: bool = False):
66
+ _dictionary: tp.Dict[str, tp.Any] = {}
67
+
68
+ # allow a subset of attributes to not be loaded from the dictionary
69
+ # these attributes may be populated later
70
+ post_init_attributes = ['self_wav']
71
+
72
+ for _field in fields(cls):
73
+ if _field.name in post_init_attributes:
74
+ continue
75
+ elif _field.name not in dictionary:
76
+ if fields_required:
77
+ raise KeyError(f"Unexpected missing key: {_field.name}")
78
+ else:
79
+ preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
80
+ value = dictionary[_field.name]
81
+ if preprocess_func:
82
+ value = preprocess_func(value)
83
+ _dictionary[_field.name] = value
84
+ return cls(**_dictionary)
85
+
86
+
87
+ class SoundDataset(InfoAudioDataset):
88
+ """Sound audio dataset: Audio dataset with environmental sound-specific metadata.
89
+
90
+ Args:
91
+ info_fields_required (bool): Whether all the mandatory metadata fields should be in the loaded metadata.
92
+ external_metadata_source (tp.Optional[str]): Folder containing JSON metadata for the corresponding dataset.
93
+ The metadata files contained in this folder are expected to match the stem of the audio file with
94
+ a json extension.
95
+ aug_p (float): Probability of performing audio mixing augmentation on the batch.
96
+ mix_p (float): Proportion of batch items that are mixed together when applying audio mixing augmentation.
97
+ mix_snr_low (int): Lowerbound for SNR value sampled for mixing augmentation.
98
+ mix_snr_high (int): Upperbound for SNR value sampled for mixing augmentation.
99
+ mix_min_overlap (float): Minimum overlap between audio files when performing mixing augmentation.
100
+ kwargs: Additional arguments for AudioDataset.
101
+
102
+ See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
103
+ """
104
+ def __init__(
105
+ self,
106
+ *args,
107
+ info_fields_required: bool = True,
108
+ external_metadata_source: tp.Optional[str] = None,
109
+ aug_p: float = 0.,
110
+ mix_p: float = 0.,
111
+ mix_snr_low: int = -5,
112
+ mix_snr_high: int = 5,
113
+ mix_min_overlap: float = 0.5,
114
+ **kwargs
115
+ ):
116
+ kwargs['return_info'] = True # We require the info for each song of the dataset.
117
+ super().__init__(*args, **kwargs)
118
+ self.info_fields_required = info_fields_required
119
+ self.external_metadata_source = external_metadata_source
120
+ self.aug_p = aug_p
121
+ self.mix_p = mix_p
122
+ if self.aug_p > 0:
123
+ assert self.mix_p > 0, "Expecting some mixing proportion mix_p if aug_p > 0"
124
+ assert self.channels == 1, "SoundDataset with audio mixing considers only monophonic audio"
125
+ self.mix_snr_low = mix_snr_low
126
+ self.mix_snr_high = mix_snr_high
127
+ self.mix_min_overlap = mix_min_overlap
128
+
129
+ def _get_info_path(self, path: tp.Union[str, Path]) -> Path:
130
+ """Get path of JSON with metadata (description, etc.).
131
+ If there exists a JSON with the same name as 'path.name', then it will be used.
132
+ Else, such JSON will be searched for in an external json source folder if it exists.
133
+ """
134
+ info_path = Path(path).with_suffix('.json')
135
+ if Path(info_path).exists():
136
+ return info_path
137
+ elif self.external_metadata_source and (Path(self.external_metadata_source) / info_path.name).exists():
138
+ return Path(self.external_metadata_source) / info_path.name
139
+ else:
140
+ raise Exception(f"Unable to find a metadata JSON for path: {path}")
141
+
142
+ def __getitem__(self, index):
143
+ wav, info = super().__getitem__(index)
144
+ info_data = info.to_dict()
145
+ info_path = self._get_info_path(info.meta.path)
146
+ if Path(info_path).exists():
147
+ with open(info_path, 'r') as json_file:
148
+ sound_data = json.load(json_file)
149
+ sound_data.update(info_data)
150
+ sound_info = SoundInfo.from_dict(sound_data, fields_required=self.info_fields_required)
151
+ # if there are multiple descriptions, sample one randomly
152
+ if isinstance(sound_info.description, list):
153
+ sound_info.description = random.choice(sound_info.description)
154
+ else:
155
+ sound_info = SoundInfo.from_dict(info_data, fields_required=False)
156
+
157
+ sound_info.self_wav = WavCondition(
158
+ wav=wav[None], length=torch.tensor([info.n_frames]),
159
+ sample_rate=[sound_info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
160
+
161
+ return wav, sound_info
162
+
163
+ def collater(self, samples):
164
+ # when training, audio mixing is performed in the collate function
165
+ wav, sound_info = super().collater(samples) # SoundDataset always returns infos
166
+ if self.aug_p > 0:
167
+ wav, sound_info = mix_samples(wav, sound_info, self.aug_p, self.mix_p,
168
+ snr_low=self.mix_snr_low, snr_high=self.mix_snr_high,
169
+ min_overlap=self.mix_min_overlap)
170
+ return wav, sound_info
171
+
172
+
173
+ def rms_f(x: torch.Tensor) -> torch.Tensor:
174
+ return (x ** 2).mean(1).pow(0.5)
175
+
176
+
177
+ def normalize(audio: torch.Tensor, target_level: int = -25) -> torch.Tensor:
178
+ """Normalize the signal to the target level."""
179
+ rms = rms_f(audio)
180
+ scalar = 10 ** (target_level / 20) / (rms + EPS)
181
+ audio = audio * scalar.unsqueeze(1)
182
+ return audio
183
+
184
+
185
+ def is_clipped(audio: torch.Tensor, clipping_threshold: float = 0.99) -> torch.Tensor:
186
+ return (abs(audio) > clipping_threshold).any(1)
187
+
188
+
189
+ def mix_pair(src: torch.Tensor, dst: torch.Tensor, min_overlap: float) -> torch.Tensor:
190
+ start = random.randint(0, int(src.shape[1] * (1 - min_overlap)))
191
+ remainder = src.shape[1] - start
192
+ if dst.shape[1] > remainder:
193
+ src[:, start:] = src[:, start:] + dst[:, :remainder]
194
+ else:
195
+ src[:, start:start+dst.shape[1]] = src[:, start:start+dst.shape[1]] + dst
196
+ return src
197
+
198
+
199
+ def snr_mixer(clean: torch.Tensor, noise: torch.Tensor, snr: int, min_overlap: float,
200
+ target_level: int = -25, clipping_threshold: float = 0.99) -> torch.Tensor:
201
+ """Function to mix clean speech and noise at various SNR levels.
202
+
203
+ Args:
204
+ clean (torch.Tensor): Clean audio source to mix, of shape [B, T].
205
+ noise (torch.Tensor): Noise audio source to mix, of shape [B, T].
206
+ snr (int): SNR level when mixing.
207
+ min_overlap (float): Minimum overlap between the two mixed sources.
208
+ target_level (int): Gain level in dB.
209
+ clipping_threshold (float): Threshold for clipping the audio.
210
+ Returns:
211
+ torch.Tensor: The mixed audio, of shape [B, T].
212
+ """
213
+ if clean.shape[1] > noise.shape[1]:
214
+ noise = torch.nn.functional.pad(noise, (0, clean.shape[1] - noise.shape[1]))
215
+ else:
216
+ noise = noise[:, :clean.shape[1]]
217
+
218
+ # normalizing to -25 dB FS
219
+ clean = clean / (clean.max(1)[0].abs().unsqueeze(1) + EPS)
220
+ clean = normalize(clean, target_level)
221
+ rmsclean = rms_f(clean)
222
+
223
+ noise = noise / (noise.max(1)[0].abs().unsqueeze(1) + EPS)
224
+ noise = normalize(noise, target_level)
225
+ rmsnoise = rms_f(noise)
226
+
227
+ # set the noise level for a given SNR
228
+ noisescalar = (rmsclean / (10 ** (snr / 20)) / (rmsnoise + EPS)).unsqueeze(1)
229
+ noisenewlevel = noise * noisescalar
230
+
231
+ # mix noise and clean speech
232
+ noisyspeech = mix_pair(clean, noisenewlevel, min_overlap)
233
+
234
+ # randomly select RMS value between -15 dBFS and -35 dBFS and normalize noisyspeech with that value
235
+ # there is a chance of clipping that might happen with very less probability, which is not a major issue.
236
+ noisy_rms_level = np.random.randint(TARGET_LEVEL_LOWER, TARGET_LEVEL_UPPER)
237
+ rmsnoisy = rms_f(noisyspeech)
238
+ scalarnoisy = (10 ** (noisy_rms_level / 20) / (rmsnoisy + EPS)).unsqueeze(1)
239
+ noisyspeech = noisyspeech * scalarnoisy
240
+ clean = clean * scalarnoisy
241
+ noisenewlevel = noisenewlevel * scalarnoisy
242
+
243
+ # final check to see if there are any amplitudes exceeding +/- 1. If so, normalize all the signals accordingly
244
+ clipped = is_clipped(noisyspeech)
245
+ if clipped.any():
246
+ noisyspeech_maxamplevel = noisyspeech[clipped].max(1)[0].abs().unsqueeze(1) / (clipping_threshold - EPS)
247
+ noisyspeech[clipped] = noisyspeech[clipped] / noisyspeech_maxamplevel
248
+
249
+ return noisyspeech
250
+
251
+
252
+ def snr_mix(src: torch.Tensor, dst: torch.Tensor, snr_low: int, snr_high: int, min_overlap: float):
253
+ if snr_low == snr_high:
254
+ snr = snr_low
255
+ else:
256
+ snr = np.random.randint(snr_low, snr_high)
257
+ mix = snr_mixer(src, dst, snr, min_overlap)
258
+ return mix
259
+
260
+
261
+ def mix_text(src_text: str, dst_text: str):
262
+ """Mix text from different sources by concatenating them."""
263
+ if src_text == dst_text:
264
+ return src_text
265
+ return src_text + " " + dst_text
266
+
267
+
268
+ def mix_samples(wavs: torch.Tensor, infos: tp.List[SoundInfo], aug_p: float, mix_p: float,
269
+ snr_low: int, snr_high: int, min_overlap: float):
270
+ """Mix samples within a batch, summing the waveforms and concatenating the text infos.
271
+
272
+ Args:
273
+ wavs (torch.Tensor): Audio tensors of shape [B, C, T].
274
+ infos (list[SoundInfo]): List of SoundInfo items corresponding to the audio.
275
+ aug_p (float): Augmentation probability.
276
+ mix_p (float): Proportion of items in the batch to mix (and merge) together.
277
+ snr_low (int): Lowerbound for sampling SNR.
278
+ snr_high (int): Upperbound for sampling SNR.
279
+ min_overlap (float): Minimum overlap between mixed samples.
280
+ Returns:
281
+ tuple[torch.Tensor, list[SoundInfo]]: A tuple containing the mixed wavs
282
+ and mixed SoundInfo for the given batch.
283
+ """
284
+ # no mixing to perform within the batch
285
+ if mix_p == 0:
286
+ return wavs, infos
287
+
288
+ if random.uniform(0, 1) < aug_p:
289
+ # perform all augmentations on waveforms as [B, T]
290
+ # randomly picking pairs of audio to mix
291
+ assert wavs.size(1) == 1, f"Mix samples requires monophonic audio but C={wavs.size(1)}"
292
+ wavs = wavs.mean(dim=1, keepdim=False)
293
+ B, T = wavs.shape
294
+ k = int(mix_p * B)
295
+ mixed_sources_idx = torch.randperm(B)[:k]
296
+ mixed_targets_idx = torch.randperm(B)[:k]
297
+ aug_wavs = snr_mix(
298
+ wavs[mixed_sources_idx],
299
+ wavs[mixed_targets_idx],
300
+ snr_low,
301
+ snr_high,
302
+ min_overlap,
303
+ )
304
+ # mixing textual descriptions in metadata
305
+ descriptions = [info.description for info in infos]
306
+ aug_infos = []
307
+ for i, j in zip(mixed_sources_idx, mixed_targets_idx):
308
+ text = mix_text(descriptions[i], descriptions[j])
309
+ m = replace(infos[i])
310
+ m.description = text
311
+ aug_infos.append(m)
312
+
313
+ # back to [B, C, T]
314
+ aug_wavs = aug_wavs.unsqueeze(1)
315
+ assert aug_wavs.shape[0] > 0, "Samples mixing returned empty batch."
316
+ assert aug_wavs.dim() == 3, f"Returned wav should be [B, C, T] but dim = {aug_wavs.dim()}"
317
+ assert aug_wavs.shape[0] == len(aug_infos), "Mismatch between number of wavs and infos in the batch"
318
+
319
+ return aug_wavs, aug_infos # [B, C, T]
320
+ else:
321
+ # randomly pick samples in the batch to match
322
+ # the batch size when performing audio mixing
323
+ B, C, T = wavs.shape
324
+ k = int(mix_p * B)
325
+ wav_idx = torch.randperm(B)[:k]
326
+ wavs = wavs[wav_idx]
327
+ infos = [infos[i] for i in wav_idx]
328
+ assert wavs.shape[0] == len(infos), "Mismatch between number of wavs and infos in the batch"
329
+
330
+ return wavs, infos # [B, C, T]
audiocraft/data/zip.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Utility for reading some info from inside a zip file.
7
+ """
8
+
9
+ import typing
10
+ import zipfile
11
+
12
+ from dataclasses import dataclass
13
+ from functools import lru_cache
14
+ from typing_extensions import Literal
15
+
16
+
17
+ DEFAULT_SIZE = 32
18
+ MODE = Literal['r', 'w', 'x', 'a']
19
+
20
+
21
+ @dataclass(order=True)
22
+ class PathInZip:
23
+ """Hold a path of file within a zip file.
24
+
25
+ Args:
26
+ path (str): The convention is <path_to_zip>:<relative_path_inside_zip>.
27
+ Let's assume there is a zip file /some/location/foo.zip
28
+ and inside of it is a json file located at /data/file1.json,
29
+ Then we expect path = "/some/location/foo.zip:/data/file1.json".
30
+ """
31
+
32
+ INFO_PATH_SEP = ':'
33
+ zip_path: str
34
+ file_path: str
35
+
36
+ def __init__(self, path: str) -> None:
37
+ split_path = path.split(self.INFO_PATH_SEP)
38
+ assert len(split_path) == 2
39
+ self.zip_path, self.file_path = split_path
40
+
41
+ @classmethod
42
+ def from_paths(cls, zip_path: str, file_path: str):
43
+ return cls(zip_path + cls.INFO_PATH_SEP + file_path)
44
+
45
+ def __str__(self) -> str:
46
+ return self.zip_path + self.INFO_PATH_SEP + self.file_path
47
+
48
+
49
+ def _open_zip(path: str, mode: MODE = 'r'):
50
+ return zipfile.ZipFile(path, mode)
51
+
52
+
53
+ _cached_open_zip = lru_cache(DEFAULT_SIZE)(_open_zip)
54
+
55
+
56
+ def set_zip_cache_size(max_size: int):
57
+ """Sets the maximal LRU caching for zip file opening.
58
+
59
+ Args:
60
+ max_size (int): the maximal LRU cache.
61
+ """
62
+ global _cached_open_zip
63
+ _cached_open_zip = lru_cache(max_size)(_open_zip)
64
+
65
+
66
+ def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
67
+ """Opens a file stored inside a zip and returns a file-like object.
68
+
69
+ Args:
70
+ path_in_zip (PathInZip): A PathInZip object representing the file to return a file-like object of.
71
+ mode (str): The mode in which to open the file with.
72
+ Returns:
73
+ A file-like object for PathInZip.
74
+ """
75
+ zf = _cached_open_zip(path_in_zip.zip_path)
76
+ return zf.open(path_in_zip.file_path)
audiocraft/environment.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Provides cluster and tools configuration across clusters (slurm, dora, utilities).
9
+ """
10
+
11
+ import logging
12
+ import os
13
+ from pathlib import Path
14
+ import re
15
+ import typing as tp
16
+
17
+ import omegaconf
18
+
19
+ from .utils.cluster import _guess_cluster_type
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class AudioCraftEnvironment:
26
+ """Environment configuration for teams and clusters.
27
+
28
+ AudioCraftEnvironment picks compute cluster settings (slurm, dora) from the current running environment
29
+ or declared variable and the loaded team configuration. Additionally, the AudioCraftEnvironment
30
+ provides pointers to a reference folder resolved automatically across clusters that is shared across team members,
31
+ allowing to share sigs or other files to run jobs. Finally, it provides dataset mappers to automatically
32
+ map dataset file paths to new locations across clusters, allowing to use the same manifest of files across cluters.
33
+
34
+ The cluster type is identified automatically and base configuration file is read from config/teams.yaml.
35
+ Use the following environment variables to specify the cluster, team or configuration:
36
+
37
+ AUDIOCRAFT_CLUSTER (optional): Cluster type to enforce. Useful if the cluster type
38
+ cannot be inferred automatically.
39
+ AUDIOCRAFT_CONFIG (optional): Path to yaml config holding the teams configuration.
40
+ If not set, configuration is read from config/teams.yaml.
41
+ AUDIOCRAFT_TEAM (optional): Name of the team. Recommended to set to your own team.
42
+ Cluster configuration are shared across teams to match compute allocation,
43
+ specify your cluster configuration in the configuration file under a key mapping
44
+ your team name.
45
+ """
46
+ _instance = None
47
+ DEFAULT_TEAM = "default"
48
+
49
+ def __init__(self) -> None:
50
+ """Loads configuration."""
51
+ self.team: str = os.getenv("AUDIOCRAFT_TEAM", self.DEFAULT_TEAM)
52
+ cluster_type = _guess_cluster_type()
53
+ cluster = os.getenv(
54
+ "AUDIOCRAFT_CLUSTER", cluster_type.value
55
+ )
56
+ logger.info("Detecting cluster type %s", cluster_type)
57
+
58
+ self.cluster: str = cluster
59
+
60
+ config_path = os.getenv(
61
+ "AUDIOCRAFT_CONFIG",
62
+ Path(__file__)
63
+ .parent.parent.joinpath("config/teams", self.team)
64
+ .with_suffix(".yaml"),
65
+ )
66
+ self.config = omegaconf.OmegaConf.load(config_path)
67
+ self._dataset_mappers = []
68
+ cluster_config = self._get_cluster_config()
69
+ if "dataset_mappers" in cluster_config:
70
+ for pattern, repl in cluster_config["dataset_mappers"].items():
71
+ regex = re.compile(pattern)
72
+ self._dataset_mappers.append((regex, repl))
73
+
74
+ def _get_cluster_config(self) -> omegaconf.DictConfig:
75
+ assert isinstance(self.config, omegaconf.DictConfig)
76
+ return self.config[self.cluster]
77
+
78
+ @classmethod
79
+ def instance(cls):
80
+ if cls._instance is None:
81
+ cls._instance = cls()
82
+ return cls._instance
83
+
84
+ @classmethod
85
+ def reset(cls):
86
+ """Clears the environment and forces a reload on next invocation."""
87
+ cls._instance = None
88
+
89
+ @classmethod
90
+ def get_team(cls) -> str:
91
+ """Gets the selected team as dictated by the AUDIOCRAFT_TEAM env var.
92
+ If not defined, defaults to "labs".
93
+ """
94
+ return cls.instance().team
95
+
96
+ @classmethod
97
+ def get_cluster(cls) -> str:
98
+ """Gets the detected cluster.
99
+ This value can be overridden by the AUDIOCRAFT_CLUSTER env var.
100
+ """
101
+ return cls.instance().cluster
102
+
103
+ @classmethod
104
+ def get_dora_dir(cls) -> Path:
105
+ """Gets the path to the dora directory for the current team and cluster.
106
+ Value is overridden by the AUDIOCRAFT_DORA_DIR env var.
107
+ """
108
+ cluster_config = cls.instance()._get_cluster_config()
109
+ dora_dir = os.getenv("AUDIOCRAFT_DORA_DIR", cluster_config["dora_dir"])
110
+ logger.warning(f"Dora directory: {dora_dir}")
111
+ return Path(dora_dir)
112
+
113
+ @classmethod
114
+ def get_reference_dir(cls) -> Path:
115
+ """Gets the path to the reference directory for the current team and cluster.
116
+ Value is overridden by the AUDIOCRAFT_REFERENCE_DIR env var.
117
+ """
118
+ cluster_config = cls.instance()._get_cluster_config()
119
+ return Path(os.getenv("AUDIOCRAFT_REFERENCE_DIR", cluster_config["reference_dir"]))
120
+
121
+ @classmethod
122
+ def get_slurm_exclude(cls) -> tp.Optional[str]:
123
+ """Get the list of nodes to exclude for that cluster."""
124
+ cluster_config = cls.instance()._get_cluster_config()
125
+ return cluster_config.get("slurm_exclude")
126
+
127
+ @classmethod
128
+ def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str]] = None) -> str:
129
+ """Gets the requested partitions for the current team and cluster as a comma-separated string.
130
+
131
+ Args:
132
+ partition_types (list[str], optional): partition types to retrieve. Values must be
133
+ from ['global', 'team']. If not provided, the global partition is returned.
134
+ """
135
+ if not partition_types:
136
+ partition_types = ["global"]
137
+
138
+ cluster_config = cls.instance()._get_cluster_config()
139
+ partitions = [
140
+ cluster_config["partitions"][partition_type]
141
+ for partition_type in partition_types
142
+ ]
143
+ return ",".join(partitions)
144
+
145
+ @classmethod
146
+ def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path:
147
+ """Converts reference placeholder in path with configured reference dir to resolve paths.
148
+
149
+ Args:
150
+ path (str or Path): Path to resolve.
151
+ Returns:
152
+ Path: Resolved path.
153
+ """
154
+ path = str(path)
155
+
156
+ if path.startswith("//reference"):
157
+ reference_dir = cls.get_reference_dir()
158
+ logger.warn(f"Reference directory: {reference_dir}")
159
+ assert (
160
+ reference_dir.exists() and reference_dir.is_dir()
161
+ ), f"Reference directory does not exist: {reference_dir}."
162
+ path = re.sub("^//reference", str(reference_dir), path)
163
+
164
+ return Path(path)
165
+
166
+ @classmethod
167
+ def apply_dataset_mappers(cls, path: str) -> str:
168
+ """Applies dataset mapping regex rules as defined in the configuration.
169
+ If no rules are defined, the path is returned as-is.
170
+ """
171
+ instance = cls.instance()
172
+
173
+ for pattern, repl in instance._dataset_mappers:
174
+ path = pattern.sub(repl, path)
175
+
176
+ return path
audiocraft/grids/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Dora Grids."""
audiocraft/grids/_base_explorers.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from abc import ABC, abstractmethod
8
+ import time
9
+ import typing as tp
10
+ from dora import Explorer
11
+ import treetable as tt
12
+
13
+
14
+ def get_sheep_ping(sheep) -> tp.Optional[str]:
15
+ """Return the amount of time since the Sheep made some update
16
+ to its log. Returns a str using the relevant time unit."""
17
+ ping = None
18
+ if sheep.log is not None and sheep.log.exists():
19
+ delta = time.time() - sheep.log.stat().st_mtime
20
+ if delta > 3600 * 24:
21
+ ping = f'{delta / (3600 * 24):.1f}d'
22
+ elif delta > 3600:
23
+ ping = f'{delta / (3600):.1f}h'
24
+ elif delta > 60:
25
+ ping = f'{delta / 60:.1f}m'
26
+ else:
27
+ ping = f'{delta:.1f}s'
28
+ return ping
29
+
30
+
31
+ class BaseExplorer(ABC, Explorer):
32
+ """Base explorer for AudioCraft grids.
33
+
34
+ All task specific solvers are expected to implement the `get_grid_metrics`
35
+ method to specify logic about metrics to display for a given task.
36
+
37
+ If additional stages are used, the child explorer must define how to handle
38
+ these new stages in the `process_history` and `process_sheep` methods.
39
+ """
40
+ def stages(self):
41
+ return ["train", "valid", "evaluate"]
42
+
43
+ def get_grid_meta(self):
44
+ """Returns the list of Meta information to display for each XP/job.
45
+ """
46
+ return [
47
+ tt.leaf("index", align=">"),
48
+ tt.leaf("name", wrap=140),
49
+ tt.leaf("state"),
50
+ tt.leaf("sig", align=">"),
51
+ tt.leaf("sid", align="<"),
52
+ ]
53
+
54
+ @abstractmethod
55
+ def get_grid_metrics(self):
56
+ """Return the metrics that should be displayed in the tracking table.
57
+ """
58
+ ...
59
+
60
+ def process_sheep(self, sheep, history):
61
+ train = {
62
+ "epoch": len(history),
63
+ }
64
+ parts = {"train": train}
65
+ for metrics in history:
66
+ for key, sub in metrics.items():
67
+ part = parts.get(key, {})
68
+ if 'duration' in sub:
69
+ # Convert to minutes for readability.
70
+ sub['duration'] = sub['duration'] / 60.
71
+ part.update(sub)
72
+ parts[key] = part
73
+ ping = get_sheep_ping(sheep)
74
+ if ping is not None:
75
+ for name in self.stages():
76
+ if name not in parts:
77
+ parts[name] = {}
78
+ # Add the ping to each part for convenience.
79
+ parts[name]['ping'] = ping
80
+ return parts
audiocraft/grids/audiogen/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """AudioGen grids."""
audiocraft/grids/audiogen/audiogen_base_16khz.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from ..musicgen._explorers import LMExplorer
8
+ from ...environment import AudioCraftEnvironment
9
+
10
+
11
+ @LMExplorer
12
+ def explorer(launcher):
13
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
14
+ launcher.slurm_(gpus=64, partition=partitions)
15
+ launcher.bind_(solver='audiogen/audiogen_base_16khz')
16
+ # replace this by the desired environmental sound dataset
17
+ launcher.bind_(dset='internal/sounds_16khz')
18
+
19
+ fsdp = {'autocast': False, 'fsdp.use': True}
20
+ medium = {'model/lm/model_scale': 'medium'}
21
+
22
+ launcher.bind_(fsdp)
23
+ launcher(medium)
audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Evaluation with objective metrics for the pretrained AudioGen models.
9
+ This grid takes signature from the training grid and runs evaluation-only stage.
10
+
11
+ When running the grid for the first time, please use:
12
+ REGEN=1 dora grid audiogen.audiogen_pretrained_16khz_eval
13
+ and re-use the REGEN=1 option when the grid is changed to force regenerating it.
14
+
15
+ Note that you need the proper metrics external libraries setup to use all
16
+ the objective metrics activated in this grid. Refer to the README for more information.
17
+ """
18
+
19
+ import os
20
+
21
+ from ..musicgen._explorers import GenerationEvalExplorer
22
+ from ...environment import AudioCraftEnvironment
23
+ from ... import train
24
+
25
+
26
+ def eval(launcher, batch_size: int = 32):
27
+ opts = {
28
+ 'dset': 'audio/audiocaps_16khz',
29
+ 'solver/audiogen/evaluation': 'objective_eval',
30
+ 'execute_only': 'evaluate',
31
+ '+dataset.evaluate.batch_size': batch_size,
32
+ '+metrics.fad.tf.batch_size': 32,
33
+ }
34
+ # binary for FAD computation: replace this path with your own path
35
+ metrics_opts = {
36
+ 'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research'
37
+ }
38
+ opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.}
39
+ opt2 = {'transformer_lm.two_step_cfg': True}
40
+
41
+ sub = launcher.bind(opts)
42
+ sub.bind_(metrics_opts)
43
+
44
+ # base objective metrics
45
+ sub(opt1, opt2)
46
+
47
+
48
+ @GenerationEvalExplorer
49
+ def explorer(launcher):
50
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
51
+ launcher.slurm_(gpus=4, partition=partitions)
52
+
53
+ if 'REGEN' not in os.environ:
54
+ folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1]
55
+ with launcher.job_array():
56
+ for sig in folder.iterdir():
57
+ if not sig.is_symlink():
58
+ continue
59
+ xp = train.main.get_xp_from_sig(sig.name)
60
+ launcher(xp.argv)
61
+ return
62
+
63
+ audiogen_base = launcher.bind(solver="audiogen/audiogen_base_16khz")
64
+ audiogen_base.bind_({'autocast': False, 'fsdp.use': True})
65
+
66
+ audiogen_base_medium = audiogen_base.bind({'continue_from': '//pretrained/facebook/audiogen-medium'})
67
+ audiogen_base_medium.bind_({'model/lm/model_scale': 'medium'})
68
+ eval(audiogen_base_medium, batch_size=128)
audiocraft/grids/compression/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """EnCodec grids."""
audiocraft/grids/compression/_explorers.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import treetable as tt
8
+
9
+ from .._base_explorers import BaseExplorer
10
+
11
+
12
+ class CompressionExplorer(BaseExplorer):
13
+ eval_metrics = ["sisnr", "visqol"]
14
+
15
+ def stages(self):
16
+ return ["train", "valid", "evaluate"]
17
+
18
+ def get_grid_meta(self):
19
+ """Returns the list of Meta information to display for each XP/job.
20
+ """
21
+ return [
22
+ tt.leaf("index", align=">"),
23
+ tt.leaf("name", wrap=140),
24
+ tt.leaf("state"),
25
+ tt.leaf("sig", align=">"),
26
+ ]
27
+
28
+ def get_grid_metrics(self):
29
+ """Return the metrics that should be displayed in the tracking table.
30
+ """
31
+ return [
32
+ tt.group(
33
+ "train",
34
+ [
35
+ tt.leaf("epoch"),
36
+ tt.leaf("bandwidth", ".2f"),
37
+ tt.leaf("adv", ".4f"),
38
+ tt.leaf("d_loss", ".4f"),
39
+ ],
40
+ align=">",
41
+ ),
42
+ tt.group(
43
+ "valid",
44
+ [
45
+ tt.leaf("bandwidth", ".2f"),
46
+ tt.leaf("adv", ".4f"),
47
+ tt.leaf("msspec", ".4f"),
48
+ tt.leaf("sisnr", ".2f"),
49
+ ],
50
+ align=">",
51
+ ),
52
+ tt.group(
53
+ "evaluate", [tt.leaf(name, ".3f") for name in self.eval_metrics], align=">"
54
+ ),
55
+ ]
audiocraft/grids/compression/debug.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Grid search file, simply list all the exp you want in `explorer`.
9
+ Any new exp added there will be scheduled.
10
+ You can cancel and experiment by commenting its line.
11
+
12
+ This grid is a minimal example for debugging compression task
13
+ and how to override parameters directly in a grid.
14
+ Learn more about dora grids: https://github.com/facebookresearch/dora
15
+ """
16
+
17
+ from ._explorers import CompressionExplorer
18
+ from ...environment import AudioCraftEnvironment
19
+
20
+
21
+ @CompressionExplorer
22
+ def explorer(launcher):
23
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
24
+ launcher.slurm_(gpus=2, partition=partitions)
25
+ launcher.bind_(solver='compression/debug')
26
+
27
+ with launcher.job_array():
28
+ # base debug task using config from solver=compression/debug
29
+ launcher()
30
+ # we can override parameters in the grid to launch additional xps
31
+ launcher({'rvq.bins': 2048, 'rvq.n_q': 4})
audiocraft/grids/compression/encodec_audiogen_16khz.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Grid search file, simply list all the exp you want in `explorer`.
9
+ Any new exp added there will be scheduled.
10
+ You can cancel and experiment by commenting its line.
11
+
12
+ This grid shows how to train the new AudioGen EnCodec model at 16 kHz.
13
+ """
14
+
15
+ from ._explorers import CompressionExplorer
16
+ from ...environment import AudioCraftEnvironment
17
+
18
+
19
+ @CompressionExplorer
20
+ def explorer(launcher):
21
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
22
+ launcher.slurm_(gpus=8, partition=partitions)
23
+ # use configuration for AudioGen's EnCodec model trained on monophonic audio sampled at 16 kHz
24
+ # AudioGen's EnCodec is trained with a total stride of 320 leading to a frame rate of 50 hz
25
+ launcher.bind_(solver='compression/encodec_audiogen_16khz')
26
+ # replace this by the desired sound dataset
27
+ launcher.bind_(dset='internal/sounds_16khz')
28
+ # launch xp
29
+ launcher()
audiocraft/grids/compression/encodec_base_24khz.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Grid search file, simply list all the exp you want in `explorer`.
9
+ Any new exp added there will be scheduled.
10
+ You can cancel and experiment by commenting its line.
11
+
12
+ This grid shows how to train a base causal EnCodec model at 24 kHz.
13
+ """
14
+
15
+ from ._explorers import CompressionExplorer
16
+ from ...environment import AudioCraftEnvironment
17
+
18
+
19
+ @CompressionExplorer
20
+ def explorer(launcher):
21
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
22
+ launcher.slurm_(gpus=8, partition=partitions)
23
+ # base causal EnCodec trained on monophonic audio sampled at 24 kHz
24
+ launcher.bind_(solver='compression/encodec_base_24khz')
25
+ # replace this by the desired dataset
26
+ launcher.bind_(dset='audio/example')
27
+ # launch xp
28
+ launcher()
audiocraft/grids/compression/encodec_musicgen_32khz.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Grid search file, simply list all the exp you want in `explorer`.
9
+ Any new exp added there will be scheduled.
10
+ You can cancel and experiment by commenting its line.
11
+
12
+ This grid shows how to train a MusicGen EnCodec model at 32 kHz.
13
+ """
14
+
15
+ from ._explorers import CompressionExplorer
16
+ from ...environment import AudioCraftEnvironment
17
+
18
+
19
+ @CompressionExplorer
20
+ def explorer(launcher):
21
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
22
+ launcher.slurm_(gpus=8, partition=partitions)
23
+ # use configuration for MusicGen's EnCodec model trained on monophonic audio sampled at 32 kHz
24
+ # MusicGen's EnCodec is trained with a total stride of 640 leading to a frame rate of 50 hz
25
+ launcher.bind_(solver='compression/encodec_musicgen_32khz')
26
+ # replace this by the desired music dataset
27
+ launcher.bind_(dset='internal/music_400k_32khz')
28
+ # launch xp
29
+ launcher()
30
+ launcher({
31
+ 'metrics.visqol.bin': '/data/home/jadecopet/local/usr/opt/visqol',
32
+ 'label': 'visqol',
33
+ 'evaluate.metrics.visqol': True
34
+ })
audiocraft/grids/diffusion/4_bands_base_32khz.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Training of the 4 diffusion models described in
9
+ "From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion"
10
+ (paper link).
11
+ """
12
+
13
+ from ._explorers import DiffusionExplorer
14
+
15
+
16
+ @DiffusionExplorer
17
+ def explorer(launcher):
18
+ launcher.slurm_(gpus=4, partition='learnfair')
19
+
20
+ launcher.bind_({'solver': 'diffusion/default',
21
+ 'dset': 'internal/music_10k_32khz'})
22
+
23
+ with launcher.job_array():
24
+ launcher({'filter.use': True, 'filter.idx_band': 0, "processor.use": False, 'processor.power_std': 0.4})
25
+ launcher({'filter.use': True, 'filter.idx_band': 1, "processor.use": False, 'processor.power_std': 0.4})
26
+ launcher({'filter.use': True, 'filter.idx_band': 2, "processor.use": True, 'processor.power_std': 0.4})
27
+ launcher({'filter.use': True, 'filter.idx_band': 3, "processor.use": True, 'processor.power_std': 0.75})
audiocraft/grids/diffusion/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Diffusion grids."""
audiocraft/grids/diffusion/_explorers.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import treetable as tt
8
+
9
+ from .._base_explorers import BaseExplorer
10
+
11
+
12
+ class DiffusionExplorer(BaseExplorer):
13
+ eval_metrics = ["sisnr", "visqol"]
14
+
15
+ def stages(self):
16
+ return ["train", "valid", "valid_ema", "evaluate", "evaluate_ema"]
17
+
18
+ def get_grid_meta(self):
19
+ """Returns the list of Meta information to display for each XP/job.
20
+ """
21
+ return [
22
+ tt.leaf("index", align=">"),
23
+ tt.leaf("name", wrap=140),
24
+ tt.leaf("state"),
25
+ tt.leaf("sig", align=">"),
26
+ ]
27
+
28
+ def get_grid_metrics(self):
29
+ """Return the metrics that should be displayed in the tracking table.
30
+ """
31
+ return [
32
+ tt.group(
33
+ "train",
34
+ [
35
+ tt.leaf("epoch"),
36
+ tt.leaf("loss", ".3%"),
37
+ ],
38
+ align=">",
39
+ ),
40
+ tt.group(
41
+ "valid",
42
+ [
43
+ tt.leaf("loss", ".3%"),
44
+ # tt.leaf("loss_0", ".3%"),
45
+ ],
46
+ align=">",
47
+ ),
48
+ tt.group(
49
+ "valid_ema",
50
+ [
51
+ tt.leaf("loss", ".3%"),
52
+ # tt.leaf("loss_0", ".3%"),
53
+ ],
54
+ align=">",
55
+ ),
56
+ tt.group(
57
+ "evaluate", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"),
58
+ tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"),
59
+ tt.leaf("rvm_3", ".4f"), ], align=">"
60
+ ),
61
+ tt.group(
62
+ "evaluate_ema", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"),
63
+ tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"),
64
+ tt.leaf("rvm_3", ".4f")], align=">"
65
+ ),
66
+ ]
audiocraft/grids/musicgen/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """MusicGen grids."""
audiocraft/grids/musicgen/_explorers.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ import treetable as tt
10
+
11
+ from .._base_explorers import BaseExplorer
12
+
13
+
14
+ class LMExplorer(BaseExplorer):
15
+ eval_metrics: tp.List[str] = []
16
+
17
+ def stages(self) -> tp.List[str]:
18
+ return ['train', 'valid']
19
+
20
+ def get_grid_metrics(self):
21
+ """Return the metrics that should be displayed in the tracking table."""
22
+ return [
23
+ tt.group(
24
+ 'train',
25
+ [
26
+ tt.leaf('epoch'),
27
+ tt.leaf('duration', '.1f'), # duration in minutes
28
+ tt.leaf('ping'),
29
+ tt.leaf('ce', '.4f'), # cross entropy
30
+ tt.leaf("ppl", '.3f'), # perplexity
31
+ ],
32
+ align='>',
33
+ ),
34
+ tt.group(
35
+ 'valid',
36
+ [
37
+ tt.leaf('ce', '.4f'),
38
+ tt.leaf('ppl', '.3f'),
39
+ tt.leaf('best_ppl', '.3f'),
40
+ ],
41
+ align='>',
42
+ ),
43
+ ]
44
+
45
+ def process_sheep(self, sheep, history):
46
+ parts = super().process_sheep(sheep, history)
47
+
48
+ track_by = {'ppl': 'lower'} # values should be in ['lower', 'higher']
49
+ best_metrics = {k: (1 if v == 'lower' else -1) * float('inf') for k, v in track_by.items()}
50
+
51
+ def comparator(mode, a, b):
52
+ return a < b if mode == 'lower' else a > b
53
+
54
+ for metrics in history:
55
+ for key, sub in metrics.items():
56
+ for metric in track_by:
57
+ # for the validation set, keep track of best metrics (ppl in this example)
58
+ # this is so we can conveniently compare metrics between runs in the grid
59
+ if key == 'valid' and metric in sub and comparator(
60
+ track_by[metric], sub[metric], best_metrics[metric]
61
+ ):
62
+ best_metrics[metric] = sub[metric]
63
+
64
+ if 'valid' in parts:
65
+ parts['valid'].update({f'best_{k}': v for k, v in best_metrics.items()})
66
+ return parts
67
+
68
+
69
+ class GenerationEvalExplorer(BaseExplorer):
70
+ eval_metrics: tp.List[str] = []
71
+
72
+ def stages(self) -> tp.List[str]:
73
+ return ['evaluate']
74
+
75
+ def get_grid_metrics(self):
76
+ """Return the metrics that should be displayed in the tracking table."""
77
+ return [
78
+ tt.group(
79
+ 'evaluate',
80
+ [
81
+ tt.leaf('epoch', '.3f'),
82
+ tt.leaf('duration', '.1f'),
83
+ tt.leaf('ping'),
84
+ tt.leaf('ce', '.4f'),
85
+ tt.leaf('ppl', '.3f'),
86
+ tt.leaf('fad', '.3f'),
87
+ tt.leaf('kld', '.3f'),
88
+ tt.leaf('text_consistency', '.3f'),
89
+ tt.leaf('chroma_cosine', '.3f'),
90
+ ],
91
+ align='>',
92
+ ),
93
+ ]
audiocraft/grids/musicgen/musicgen_base_32khz.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from ._explorers import LMExplorer
8
+ from ...environment import AudioCraftEnvironment
9
+
10
+
11
+ @LMExplorer
12
+ def explorer(launcher):
13
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
14
+ launcher.slurm_(gpus=32, partition=partitions)
15
+ launcher.bind_(solver='musicgen/musicgen_base_32khz')
16
+ # replace this by the desired music dataset
17
+ launcher.bind_(dset='internal/music_400k_32khz')
18
+
19
+ fsdp = {'autocast': False, 'fsdp.use': True}
20
+ medium = {'model/lm/model_scale': 'medium'}
21
+ large = {'model/lm/model_scale': 'large'}
22
+
23
+ cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
24
+ wd_low = {'conditioners.description.t5.word_dropout': 0.2}
25
+
26
+ adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
27
+
28
+ launcher.bind_(fsdp)
29
+
30
+ launcher.slurm_(gpus=32).bind_(label='32gpus')
31
+ with launcher.job_array():
32
+ sub = launcher.bind()
33
+ sub()
34
+
35
+ launcher.slurm_(gpus=64).bind_(label='64gpus')
36
+ with launcher.job_array():
37
+ sub = launcher.bind()
38
+ sub(medium, adam)
39
+
40
+ launcher.slurm_(gpus=96).bind_(label='96gpus')
41
+ with launcher.job_array():
42
+ sub = launcher.bind()
43
+ sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})