fffiloni commited on
Commit
9321dd9
1 Parent(s): 41af220

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +342 -19
app.py CHANGED
@@ -1,28 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- import torchaudio
 
3
  from audiocraft.models import MAGNeT
4
- from audiocraft. data. audio import audio_write
5
 
6
- model = MAGNeT.get_pretrained('facebook/magnet-small-10secs')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
 
9
- def infer(description):
10
- descriptions = ['disco beat', 'energetic EDM']
 
 
11
 
12
- wav = model.generate(descriptions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- for idx, one_wav in enumerate(wav):
15
- print(idx)
16
- audio_write(f'{idx}',
17
- one_wav.cpu(),
18
- model.sample_rate,
19
- strategy="loudness",
20
- loudness_compressor=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- return "done"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- gr.Interface(
25
- fn = infer,
26
- inputs = gr.Textbox(value="gogo"),
27
- outputs = gr.Textbox()
28
- ).launch()
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under thmage license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ from concurrent.futures import ProcessPoolExecutor
9
+ import logging
10
+ import os
11
+ from pathlib import Path
12
+ import subprocess as sp
13
+ import sys
14
+ from tempfile import NamedTemporaryFile
15
+ import time
16
+ import typing as tp
17
+ import warnings
18
+
19
  import gradio as gr
20
+
21
+ from audiocraft.data.audio import audio_write
22
  from audiocraft.models import MAGNeT
 
23
 
24
+
25
+ MODEL = None # Last used model
26
+ SPACE_ID = os.environ.get('SPACE_ID', '')
27
+ MAX_BATCH_SIZE = 12
28
+ N_REPEATS = 2
29
+ INTERRUPTING = False
30
+ MBD = None
31
+ # We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
32
+ _old_call = sp.call
33
+
34
+ PROD_STRIDE_1 = "prod-stride1 (new!)"
35
+
36
+
37
+ def _call_nostderr(*args, **kwargs):
38
+ # Avoid ffmpeg vomiting on the logs.
39
+ kwargs['stderr'] = sp.DEVNULL
40
+ kwargs['stdout'] = sp.DEVNULL
41
+ _old_call(*args, **kwargs)
42
+
43
+
44
+ sp.call = _call_nostderr
45
+ # Preallocating the pool of processes.
46
+ pool = ProcessPoolExecutor(4)
47
+ pool.__enter__()
48
+
49
+
50
+ def interrupt():
51
+ global INTERRUPTING
52
+ INTERRUPTING = True
53
+
54
+
55
+ class FileCleaner:
56
+ def __init__(self, file_lifetime: float = 3600):
57
+ self.file_lifetime = file_lifetime
58
+ self.files = []
59
+
60
+ def add(self, path: tp.Union[str, Path]):
61
+ self._cleanup()
62
+ self.files.append((time.time(), Path(path)))
63
+
64
+ def _cleanup(self):
65
+ now = time.time()
66
+ for time_added, path in list(self.files):
67
+ if now - time_added > self.file_lifetime:
68
+ if path.exists():
69
+ path.unlink()
70
+ self.files.pop(0)
71
+ else:
72
+ break
73
+
74
+
75
+ file_cleaner = FileCleaner()
76
+
77
+
78
+ def make_waveform(*args, **kwargs):
79
+ # Further remove some warnings.
80
+ be = time.time()
81
+ with warnings.catch_warnings():
82
+ warnings.simplefilter('ignore')
83
+ out = gr.make_waveform(*args, **kwargs)
84
+ print("Make a video took", time.time() - be)
85
+ return out
86
+
87
+
88
+ def load_model(version='facebook/magnet-small-10secs'):
89
+ global MODEL
90
+ print("Loading model", version)
91
+ if MODEL is None or MODEL.name != version:
92
+ MODEL = None # in case loading would crash
93
+ MODEL = MAGNeT.get_pretrained(version)
94
 
95
 
96
+ def _do_predictions(texts, progress=False, gradio_progress=None, **gen_kwargs):
97
+ MODEL.set_generation_params(**gen_kwargs)
98
+ print("new batch", len(texts), texts)
99
+ be = time.time()
100
 
101
+ try:
102
+ outputs = MODEL.generate(texts, progress=progress, return_tokens=False)
103
+ except RuntimeError as e:
104
+ raise gr.Error("Error while generating " + e.args[0])
105
+ outputs = outputs.detach().cpu().float()
106
+ pending_videos = []
107
+ out_wavs = []
108
+ for i, output in enumerate(outputs):
109
+ with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
110
+ audio_write(
111
+ file.name, output, MODEL.sample_rate, strategy="loudness",
112
+ loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
113
+ if i == 0:
114
+ pending_videos.append(pool.submit(make_waveform, file.name))
115
+ out_wavs.append(file.name)
116
+ file_cleaner.add(file.name)
117
+ out_videos = [pending_video.result() for pending_video in pending_videos]
118
+ for video in out_videos:
119
+ file_cleaner.add(video)
120
+ print("batch finished", len(texts), time.time() - be)
121
+ print("Tempfiles currently stored: ", len(file_cleaner.files))
122
+ return out_videos, out_wavs
123
+
124
+
125
+ def predict_batched(texts, melodies):
126
+ max_text_length = 512
127
+ texts = [text[:max_text_length] for text in texts]
128
+ load_model('facebook/magnet-small-10secs')
129
+ res = _do_predictions(texts, melodies)
130
+ return res
131
+
132
+
133
+ def predict_full(model, model_path, text, temperature, topp,
134
+ max_cfg_coef, min_cfg_coef,
135
+ decoding_steps1, decoding_steps2, decoding_steps3, decoding_steps4,
136
+ span_score,
137
+ progress=gr.Progress()):
138
+ global INTERRUPTING
139
+ INTERRUPTING = False
140
+ progress(0, desc="Loading model...")
141
+ model_path = model_path.strip()
142
+ if model_path:
143
+ if not Path(model_path).exists():
144
+ raise gr.Error(f"Model path {model_path} doesn't exist.")
145
+ if not Path(model_path).is_dir():
146
+ raise gr.Error(f"Model path {model_path} must be a folder containing "
147
+ "state_dict.bin and compression_state_dict_.bin.")
148
+ model = model_path
149
+ if temperature < 0:
150
+ raise gr.Error("Temperature must be >= 0.")
151
+
152
+ load_model(model)
153
+
154
+ max_generated = 0
155
+
156
+ def _progress(generated, to_generate):
157
+ nonlocal max_generated
158
+ max_generated = max(generated, max_generated)
159
+ progress((min(max_generated, to_generate), to_generate))
160
+ if INTERRUPTING:
161
+ raise gr.Error("Interrupted.")
162
+ MODEL.set_custom_progress_callback(_progress)
163
 
164
+ videos, wavs = _do_predictions(
165
+ [text] * N_REPEATS, progress=True,
166
+ temperature=temperature, top_p=topp,
167
+ max_cfg_coef=max_cfg_coef, min_cfg_coef=min_cfg_coef,
168
+ decoding_steps=[decoding_steps1, decoding_steps2, decoding_steps3, decoding_steps4],
169
+ span_arrangement='stride1' if (span_score == PROD_STRIDE_1) else 'nonoverlap',
170
+ gradio_progress=progress)
171
+
172
+ outputs_ = [videos[0]] + [wav for wav in wavs]
173
+ return tuple(outputs_)
174
+
175
+ def ui_full(launch_kwargs):
176
+ with gr.Blocks() as interface:
177
+ gr.Markdown(
178
+ """
179
+ # MAGNeT
180
+ This is your private demo for [MAGNeT](https://github.com/facebookresearch/audiocraft),
181
+ A fast text-to-music model, consists of a single, non-autoregressive transformer.
182
+ presented at: ["Masked Audio Generation using a Single Non-Autoregressive Transformer"] (https://huggingface.co/papers/2401.04577)
183
+ """
184
+ )
185
+ with gr.Row():
186
+ with gr.Column():
187
+ with gr.Row():
188
+ text = gr.Text(label="Input Text", value="80s electronic track with melodic synthesizers, catchy beat and groovy bass", interactive=True)
189
+ with gr.Row():
190
+ submit = gr.Button("Submit")
191
+ # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
192
+ _ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
193
+ with gr.Row():
194
+ model = gr.Radio(['facebook/magnet-small-10secs', 'facebook/magnet-medium-10secs',
195
+ 'facebook/magnet-small-30secs', 'facebook/magnet-medium-30secs',
196
+ 'facebook/audio-magnet-small', 'facebook/audio-magnet-medium'],
197
+ label="Model", value='facebook/magnet-small-10secs', interactive=True)
198
+ model_path = gr.Text(label="Model Path (custom models)")
199
+ with gr.Row():
200
+ span_score = gr.Radio(["max-nonoverlap", PROD_STRIDE_1],
201
+ label="Span Scoring", value=PROD_STRIDE_1, interactive=True)
202
+ with gr.Row():
203
+ decoding_steps1 = gr.Number(label="Decoding Steps (stage 1)", value=20, interactive=True)
204
+ decoding_steps2 = gr.Number(label="Decoding Steps (stage 2)", value=10, interactive=True)
205
+ decoding_steps3 = gr.Number(label="Decoding Steps (stage 3)", value=10, interactive=True)
206
+ decoding_steps4 = gr.Number(label="Decoding Steps (stage 4)", value=10, interactive=True)
207
+ with gr.Row():
208
+ temperature = gr.Number(label="Temperature", value=3.0, step=0.25, minimum=0, interactive=True)
209
+ topp = gr.Number(label="Top-p", value=0.9, step=0.1, minimum=0, maximum=1, interactive=True)
210
+ max_cfg_coef = gr.Number(label="Max CFG coefficient", value=10.0, minimum=0, interactive=True)
211
+ min_cfg_coef = gr.Number(label="Min CFG coefficient", value=1.0, minimum=0, interactive=True)
212
+ with gr.Column():
213
+ output = gr.Video(label="Generated Audio - variation 1")
214
+ audio_outputs = [gr.Audio(label=f"Generated Audio - variation {i+1}", type='filepath') for i in range(N_REPEATS)]
215
+ submit.click(fn=predict_full,
216
+ inputs=[model, model_path, text,
217
+ temperature, topp,
218
+ max_cfg_coef, min_cfg_coef,
219
+ decoding_steps1, decoding_steps2, decoding_steps3, decoding_steps4,
220
+ span_score],
221
+ outputs=[output] + [o for o in audio_outputs])
222
+ gr.Examples(
223
+ fn=predict_full,
224
+ examples=[
225
+ [
226
+ "80s electronic track with melodic synthesizers, catchy beat and groovy bass",
227
+ 'facebook/magnet-small-10secs',
228
+ 20, 3.0, 0.9, 10.0,
229
+ ],
230
+ [
231
+ "80s electronic track with melodic synthesizers, catchy beat and groovy bass. 170 bpm",
232
+ 'facebook/magnet-small-10secs',
233
+ 20, 3.0, 0.9, 10.0,
234
+ ],
235
+ [
236
+ "Earthy tones, environmentally conscious, ukulele-infused, harmonic, breezy, easygoing, organic instrumentation, gentle grooves",
237
+ 'facebook/magnet-medium-10secs',
238
+ 20, 3.0, 0.9, 10.0,
239
+ ],
240
+ [ "Funky groove with electric piano playing blue chords rhythmically",
241
+ 'facebook/magnet-medium-10secs',
242
+ 20, 3.0, 0.9, 10.0,
243
+ ],
244
+ [
245
+ "Rock with saturated guitars, a heavy bass line and crazy drum break and fills.",
246
+ 'facebook/magnet-small-30secs',
247
+ 60, 3.0, 0.9, 10.0,
248
+ ],
249
+ [ "A grand orchestral arrangement with thunderous percussion, epic brass fanfares, and soaring strings, creating a cinematic atmosphere fit for a heroic battle",
250
+ 'facebook/magnet-medium-30secs',
251
+ 60, 3.0, 0.9, 10.0,
252
+ ],
253
+ [ "Seagulls squawking as ocean waves crash while wind blows heavily into a microphone.",
254
+ 'facebook/audio-magnet-small',
255
+ 20, 3.5, 0.8, 20.0,
256
+ ],
257
+ [ "A toilet flushing as music is playing and a man is singing in the distance.",
258
+ 'facebook/audio-magnet-medium',
259
+ 20, 3.5, 0.8, 20.0,
260
+ ],
261
+ ],
262
+
263
+ inputs=[text, model, decoding_steps1, temperature, topp, max_cfg_coef],
264
+ outputs=[output]
265
+ )
266
+
267
+ gr.Markdown(
268
+ """
269
+ ### More details
270
+
271
+ #### Music Generation
272
+ "magnet" models will generate a short music extract based on the textual description you provided.
273
+ These models can generate either 10 seconds or 30 seconds of music.
274
+ These models were trained with descriptions from a stock music catalog. Descriptions that will work best
275
+ should include some level of details on the instruments present, along with some intended use case
276
+ (e.g. adding "perfect for a commercial" can somehow help).
277
+
278
+ We present 4 model variants:
279
+ 1. facebook/magnet-small-10secs - a 300M non-autoregressive transformer capable of generating 10-second music conditioned
280
+ on text.
281
+ 2. facebook/magnet-medium-10secs - 1.5B parameters, 10 seconds audio.
282
+ 3. facebook/magnet-small-30secs - 300M parameters, 30 seconds audio.
283
+ 4. facebook/magnet-medium-30secs - 1.5B parameters, 30 seconds audio.
284
 
285
+ #### Sound-Effect Generation
286
+ "audio-magnet" models will generate a 10-second sound effect based on the description you provide.
287
+
288
+ These models were trained on the following data sources: a subset of AudioSet (Gemmeke et al., 2017),
289
+ [BBC sound effects](https://sound-effects.bbcrewind.co.uk/), AudioCaps (Kim et al., 2019),
290
+ Clotho v2 (Drossos et al., 2020), VGG-Sound (Chen et al., 2020), FSD50K (Fonseca et al., 2021),
291
+ [Free To Use Sounds](https://www.freetousesounds.com/all-in-one-bundle/), [Sonniss Game Effects](https://sonniss.com/gameaudiogdc),
292
+ [WeSoundEffects](https://wesoundeffects.com/we-sound-effects-bundle-2020/),
293
+ [Paramount Motion - Odeon Cinematic Sound Effects](https://www.paramountmotion.com/odeon-sound-effects).
294
+
295
+ We present 2 model variants:
296
+ 1. facebook/audio-magnet-small - 10 second sound effect generation, 300M parameters.
297
+ 2. facebook/audio-magnet-medium - 10 second sound effect generation, 1.5B parameters.
298
+
299
+ See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft/blob/main/docs/MAGNET.md)
300
+ for more details.
301
+ """
302
+ )
303
+
304
+ interface.queue(max_size=10).launch(**launch_kwargs)
305
+
306
+
307
+ if __name__ == "__main__":
308
+ parser = argparse.ArgumentParser()
309
+ parser.add_argument(
310
+ '--listen',
311
+ type=str,
312
+ default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1',
313
+ help='IP to listen on for connections to Gradio',
314
+ )
315
+ parser.add_argument(
316
+ '--username', type=str, default='', help='Username for authentication'
317
+ )
318
+ parser.add_argument(
319
+ '--password', type=str, default='', help='Password for authentication'
320
+ )
321
+ parser.add_argument(
322
+ '--server_port',
323
+ type=int,
324
+ default=0,
325
+ help='Port to run the server listener on',
326
+ )
327
+ parser.add_argument(
328
+ '--inbrowser', action='store_true', help='Open in browser'
329
+ )
330
+ parser.add_argument(
331
+ '--share', action='store_true', help='Share the gradio UI'
332
+ )
333
+
334
+ args = parser.parse_args()
335
+
336
+ launch_kwargs = {}
337
+ launch_kwargs['server_name'] = args.listen
338
+
339
+ if args.username and args.password:
340
+ launch_kwargs['auth'] = (args.username, args.password)
341
+ if args.server_port:
342
+ launch_kwargs['server_port'] = args.server_port
343
+ if args.inbrowser:
344
+ launch_kwargs['inbrowser'] = args.inbrowser
345
+ if args.share:
346
+ launch_kwargs['share'] = args.share
347
+
348
+ logging.basicConfig(level=logging.INFO, stream=sys.stderr)
349
 
350
+ # Show the interface
351
+ ui_full(launch_kwargs)