adefossez commited on
Commit
c16da55
·
1 Parent(s): 16a7142

new version

Browse files
Files changed (3) hide show
  1. CHANGELOG.md +2 -0
  2. app_batched.py +157 -67
  3. audiocraft/modules/transformer.py +11 -8
CHANGELOG.md CHANGED
@@ -13,6 +13,8 @@ Now repeating the conditioning periodically if it is too short.
13
 
14
  More options when launching Gradio app locally (thanks @ashleykleynhans).
15
 
 
 
16
  ## [0.0.1] - 2023-06-09
17
 
18
  Initial release, with model evaluation only.
 
13
 
14
  More options when launching Gradio app locally (thanks @ashleykleynhans).
15
 
16
+ Testing out PyTorch 2.0 memory efficient attention.
17
+
18
  ## [0.0.1] - 2023-06-09
19
 
20
  Initial release, with model evaluation only.
app_batched.py CHANGED
@@ -6,7 +6,12 @@ This source code is licensed under the license found in the
6
  LICENSE file in the root directory of this source tree.
7
  """
8
 
 
 
 
9
  from tempfile import NamedTemporaryFile
 
 
10
  import torch
11
  import gradio as gr
12
  from audiocraft.data.audio_utils import convert_audio
@@ -16,6 +21,29 @@ from audiocraft.models import MusicGen
16
 
17
  MODEL = None
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  def load_model():
21
  print("Loading model")
@@ -28,11 +56,13 @@ def predict(texts, melodies):
28
  MODEL = load_model()
29
 
30
  duration = 12
 
 
31
  MODEL.set_generation_params(duration=duration)
32
 
33
- print(texts, melodies)
 
34
  processed_melodies = []
35
-
36
  target_sr = 32000
37
  target_ac = 1
38
  for melody in melodies:
@@ -60,73 +90,133 @@ def predict(texts, melodies):
60
  audio_write(
61
  file.name, output, MODEL.sample_rate, strategy="loudness",
62
  loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
63
- waveform_video = gr.make_waveform(file.name)
64
- out_files.append(waveform_video)
65
- return [out_files]
66
-
67
-
68
- with gr.Blocks() as demo:
69
- gr.Markdown(
70
- """
71
- # MusicGen
72
-
73
- This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
74
- presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
75
- <br/>
76
- <a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
77
- <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
78
- for longer sequences, more control and no queue.</p>
79
- """
80
- )
81
- with gr.Row():
82
- with gr.Column():
83
- with gr.Row():
84
- text = gr.Text(label="Describe your music", lines=2, interactive=True)
85
- melody = gr.Audio(source="upload", type="numpy", label="Condition on a melody (optional)", interactive=True)
86
- with gr.Row():
87
- submit = gr.Button("Generate")
88
- with gr.Column():
89
- output = gr.Video(label="Generated Music")
90
- submit.click(predict, inputs=[text, melody], outputs=[output], batch=True, max_batch_size=12)
91
- gr.Examples(
92
- fn=predict,
93
- examples=[
94
- [
95
- "An 80s driving pop song with heavy drums and synth pads in the background",
96
- "./assets/bach.mp3",
97
- ],
98
- [
99
- "A cheerful country song with acoustic guitars",
100
- "./assets/bolero_ravel.mp3",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  ],
102
- [
103
- "90s rock song with electric guitar and heavy drums",
104
- None,
105
- ],
106
- [
107
- "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
108
- "./assets/bach.mp3",
109
- ],
110
- [
111
- "lofi slow bpm electro chill with organic samples",
112
- None,
113
- ],
114
- ],
115
- inputs=[text, melody],
116
- outputs=[output]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  )
118
- gr.Markdown("""
119
- ### More details
120
 
121
- The model will generate 12 seconds of audio based on the description you provided.
122
- You can optionaly provide a reference audio from which a broad melody will be extracted.
123
- The model will then try to follow both the description and melody provided.
124
- All samples are generated with the `melody` model.
125
-
126
- You can also use your own GPU or a Google Colab by following the instructions on our repo.
127
 
128
- See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
129
- for more details.
130
- """)
 
 
 
 
 
131
 
132
- demo.queue(max_size=15).launch()
 
6
  LICENSE file in the root directory of this source tree.
7
  """
8
 
9
+ import argparse
10
+ from concurrent.futures import ProcessPoolExecutor
11
+ import subprocess as sp
12
  from tempfile import NamedTemporaryFile
13
+ import time
14
+ import warnings
15
  import torch
16
  import gradio as gr
17
  from audiocraft.data.audio_utils import convert_audio
 
21
 
22
  MODEL = None
23
 
24
+ _old_call = sp.call
25
+
26
+
27
+ def _call_nostderr(*args, **kwargs):
28
+ # Avoid ffmpeg vomitting on the logs.
29
+ kwargs['stderr'] = sp.DEVNULL
30
+ kwargs['stdout'] = sp.DEVNULL
31
+ _old_call(*args, **kwargs)
32
+
33
+
34
+ sp.call = _call_nostderr
35
+ pool = ProcessPoolExecutor(3)
36
+ pool.__enter__()
37
+
38
+
39
+ def make_waveform(*args, **kwargs):
40
+ be = time.time()
41
+ with warnings.catch_warnings():
42
+ warnings.simplefilter('ignore')
43
+ out = gr.make_waveform(*args, **kwargs)
44
+ print("Make a video took", time.time() - be)
45
+ return out
46
+
47
 
48
  def load_model():
49
  print("Loading model")
 
56
  MODEL = load_model()
57
 
58
  duration = 12
59
+ max_text_length = 512
60
+ texts = [text[:max_text_length] for text in texts]
61
  MODEL.set_generation_params(duration=duration)
62
 
63
+ print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
64
+ be = time.time()
65
  processed_melodies = []
 
66
  target_sr = 32000
67
  target_ac = 1
68
  for melody in melodies:
 
90
  audio_write(
91
  file.name, output, MODEL.sample_rate, strategy="loudness",
92
  loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
93
+ out_files.append(pool.submit(make_waveform, file.name))
94
+ res = [[out_file.result() for out_file in out_files]]
95
+ print("batch finished", len(texts), time.time() - be)
96
+ return res
97
+
98
+
99
+ def ui(**kwargs):
100
+ with gr.Blocks() as demo:
101
+ gr.Markdown(
102
+ """
103
+ # MusicGen
104
+
105
+ This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
106
+ presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
107
+ <br/>
108
+ <a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
109
+ <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
110
+ for longer sequences, more control and no queue.</p>
111
+ """
112
+ )
113
+ with gr.Row():
114
+ with gr.Column():
115
+ with gr.Row():
116
+ text = gr.Text(label="Describe your music", lines=2, interactive=True)
117
+ melody = gr.Audio(source="upload", type="numpy", label="Condition on a melody (optional)", interactive=True)
118
+ with gr.Row():
119
+ submit = gr.Button("Generate")
120
+ with gr.Column():
121
+ output = gr.Video(label="Generated Music")
122
+ submit.click(predict, inputs=[text, melody], outputs=[output], batch=True, max_batch_size=8)
123
+ gr.Examples(
124
+ fn=predict,
125
+ examples=[
126
+ [
127
+ "An 80s driving pop song with heavy drums and synth pads in the background",
128
+ "./assets/bach.mp3",
129
+ ],
130
+ [
131
+ "A cheerful country song with acoustic guitars",
132
+ "./assets/bolero_ravel.mp3",
133
+ ],
134
+ [
135
+ "90s rock song with electric guitar and heavy drums",
136
+ None,
137
+ ],
138
+ [
139
+ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
140
+ "./assets/bach.mp3",
141
+ ],
142
+ [
143
+ "lofi slow bpm electro chill with organic samples",
144
+ None,
145
+ ],
146
  ],
147
+ inputs=[text, melody],
148
+ outputs=[output]
149
+ )
150
+ gr.Markdown("""
151
+ ### More details
152
+
153
+ The model will generate 12 seconds of audio based on the description you provided.
154
+ You can optionaly provide a reference audio from which a broad melody will be extracted.
155
+ The model will then try to follow both the description and melody provided.
156
+ All samples are generated with the `melody` model.
157
+
158
+ You can also use your own GPU or a Google Colab by following the instructions on our repo.
159
+
160
+ See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
161
+ for more details.
162
+ """)
163
+
164
+ # Show the interface
165
+ launch_kwargs = {}
166
+ username = kwargs.get('username')
167
+ password = kwargs.get('password')
168
+ server_port = kwargs.get('server_port', 0)
169
+ inbrowser = kwargs.get('inbrowser', False)
170
+ share = kwargs.get('share', False)
171
+ server_name = kwargs.get('listen')
172
+
173
+ launch_kwargs['server_name'] = server_name
174
+
175
+ if username and password:
176
+ launch_kwargs['auth'] = (username, password)
177
+ if server_port > 0:
178
+ launch_kwargs['server_port'] = server_port
179
+ if inbrowser:
180
+ launch_kwargs['inbrowser'] = inbrowser
181
+ if share:
182
+ launch_kwargs['share'] = share
183
+ demo.queue(max_size=60).launch(**launch_kwargs)
184
+
185
+ if __name__ == "__main__":
186
+ parser = argparse.ArgumentParser()
187
+ parser.add_argument(
188
+ '--listen',
189
+ type=str,
190
+ default='127.0.0.1',
191
+ help='IP to listen on for connections to Gradio',
192
+ )
193
+ parser.add_argument(
194
+ '--username', type=str, default='', help='Username for authentication'
195
+ )
196
+ parser.add_argument(
197
+ '--password', type=str, default='', help='Password for authentication'
198
+ )
199
+ parser.add_argument(
200
+ '--server_port',
201
+ type=int,
202
+ default=0,
203
+ help='Port to run the server listener on',
204
+ )
205
+ parser.add_argument(
206
+ '--inbrowser', action='store_true', help='Open in browser'
207
+ )
208
+ parser.add_argument(
209
+ '--share', action='store_true', help='Share the gradio UI'
210
  )
 
 
211
 
212
+ args = parser.parse_args()
 
 
 
 
 
213
 
214
+ ui(
215
+ username=args.username,
216
+ password=args.password,
217
+ inbrowser=args.inbrowser,
218
+ server_port=args.server_port,
219
+ share=args.share,
220
+ listen=args.listen
221
+ )
222
 
 
audiocraft/modules/transformer.py CHANGED
@@ -247,20 +247,20 @@ class StreamingMultiheadAttention(StreamingModule):
247
  # Complete the key/value pair using the streaming state.
248
  if self._streaming_state:
249
  pk = self._streaming_state['past_keys']
250
- nk = torch.cat([pk, k], dim=1)
251
  if v is k:
252
  nv = nk
253
  else:
254
  pv = self._streaming_state['past_values']
255
- nv = torch.cat([pv, v], dim=1)
256
  else:
257
  nk = k
258
  nv = v
259
 
260
- assert nk.shape[1] == nv.shape[1]
261
  offset = 0
262
  if self.past_context is not None:
263
- offset = max(0, nk.shape[1] - self.past_context)
264
  if self._is_streaming:
265
  self._streaming_state['past_keys'] = nk[:, offset:]
266
  if v is not k:
@@ -271,6 +271,7 @@ class StreamingMultiheadAttention(StreamingModule):
271
  self._streaming_state['offset'] = torch.tensor(0)
272
  return nk, nv
273
 
 
274
  def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
275
  # Apply rope embeddings to query and key tensors.
276
  assert self.rope is not None
@@ -325,7 +326,7 @@ class StreamingMultiheadAttention(StreamingModule):
325
  q = self.q_layer_norm(q)
326
  k = self.k_layer_norm(k)
327
  # q, k, v = [rearrange(x, "b t (h d) -> (b h) t d", h=self.num_heads) for x in [q, k, v]]
328
- q, k, v = [rearrange(x, "b t (h d) -> b t h d", h=self.num_heads) for x in [q, k, v]]
329
  else:
330
  if not _is_profiled():
331
  # profiling breaks that propertysomehow.
@@ -333,7 +334,7 @@ class StreamingMultiheadAttention(StreamingModule):
333
  assert value is key, "specialized implementation"
334
  projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
335
  if self.kv_repeat == 1:
336
- packed = rearrange(projected, "b t (p h d) -> b t p h d", p=3, h=self.num_heads)
337
  q, k, v = ops.unbind(packed, dim=2)
338
  else:
339
  embed_dim = self.embed_dim
@@ -355,6 +356,7 @@ class StreamingMultiheadAttention(StreamingModule):
355
  k = self.k_layer_norm(k)
356
  q, k = [rearrange(x, "b t (h d) -> b t h d", h=self.num_heads) for x in [q, k]]
357
  if self.rope:
 
358
  q, k = self._apply_rope(q, k)
359
  k, v = self._complete_kv(k, v)
360
  if self.kv_repeat > 1:
@@ -364,7 +366,8 @@ class StreamingMultiheadAttention(StreamingModule):
364
  q, k, v = [x.float() for x in [q, k, v]]
365
  if self.memory_efficient:
366
  p = self.dropout if self.training else 0
367
- x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p)
 
368
  else:
369
  # We include the dot product as float32, for consistency
370
  # with the other implementations that include that step
@@ -385,7 +388,7 @@ class StreamingMultiheadAttention(StreamingModule):
385
  w = F.dropout(w, self.dropout, training=self.training).to(v)
386
  x = torch.einsum("bhqk,bkhc->bqhc", w, v)
387
  x = x.to(dtype)
388
- x = rearrange(x, "b t h d -> b t (h d)", h=self.num_heads)
389
  x = self.out_proj(x)
390
  else:
391
  key, value = self._complete_kv(key, value)
 
247
  # Complete the key/value pair using the streaming state.
248
  if self._streaming_state:
249
  pk = self._streaming_state['past_keys']
250
+ nk = torch.cat([pk, k], dim=2)
251
  if v is k:
252
  nv = nk
253
  else:
254
  pv = self._streaming_state['past_values']
255
+ nv = torch.cat([pv, v], dim=2)
256
  else:
257
  nk = k
258
  nv = v
259
 
260
+ assert nk.shape[2] == nv.shape[2]
261
  offset = 0
262
  if self.past_context is not None:
263
+ offset = max(0, nk.shape[2] - self.past_context)
264
  if self._is_streaming:
265
  self._streaming_state['past_keys'] = nk[:, offset:]
266
  if v is not k:
 
271
  self._streaming_state['offset'] = torch.tensor(0)
272
  return nk, nv
273
 
274
+
275
  def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
276
  # Apply rope embeddings to query and key tensors.
277
  assert self.rope is not None
 
326
  q = self.q_layer_norm(q)
327
  k = self.k_layer_norm(k)
328
  # q, k, v = [rearrange(x, "b t (h d) -> (b h) t d", h=self.num_heads) for x in [q, k, v]]
329
+ q, k, v = [rearrange(x, "b t (h d) -> b h t d", h=self.num_heads) for x in [q, k, v]]
330
  else:
331
  if not _is_profiled():
332
  # profiling breaks that propertysomehow.
 
334
  assert value is key, "specialized implementation"
335
  projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
336
  if self.kv_repeat == 1:
337
+ packed = rearrange(projected, "b t (p h d) -> b h p t d", p=3, h=self.num_heads)
338
  q, k, v = ops.unbind(packed, dim=2)
339
  else:
340
  embed_dim = self.embed_dim
 
356
  k = self.k_layer_norm(k)
357
  q, k = [rearrange(x, "b t (h d) -> b t h d", h=self.num_heads) for x in [q, k]]
358
  if self.rope:
359
+ assert False, "Not supported for now"
360
  q, k = self._apply_rope(q, k)
361
  k, v = self._complete_kv(k, v)
362
  if self.kv_repeat > 1:
 
366
  q, k, v = [x.float() for x in [q, k, v]]
367
  if self.memory_efficient:
368
  p = self.dropout if self.training else 0
369
+ x = torch.nn.functional.scaled_dot_product_attention(
370
+ q, k, v, is_causal=attn_mask is not None, dropout_p=p)
371
  else:
372
  # We include the dot product as float32, for consistency
373
  # with the other implementations that include that step
 
388
  w = F.dropout(w, self.dropout, training=self.training).to(v)
389
  x = torch.einsum("bhqk,bkhc->bqhc", w, v)
390
  x = x.to(dtype)
391
+ x = rearrange(x, "b h t d -> b t (h d)", h=self.num_heads)
392
  x = self.out_proj(x)
393
  else:
394
  key, value = self._complete_kv(key, value)