Surn commited on
Commit
de8ae12
·
1 Parent(s): 595ae94

Fix CPU/GPU issue

Browse files
Files changed (2) hide show
  1. app.py +154 -94
  2. audiocraft/models/musicgen.py +5 -0
app.py CHANGED
@@ -22,11 +22,8 @@ import random
22
  MODEL = None
23
  MODELS = None
24
  IS_SHARED_SPACE = "musicgen/MusicGen" in os.environ.get('SPACE_ID', '')
25
- IS_SHARED_SPACE = "musicgen/MusicGen" in os.environ.get('SPACE_ID', '')
26
- INTERRUPTED = False
27
  INTERRUPTED = False
28
  UNLOAD_MODEL = False
29
- UNLOAD_MODEL = False
30
  MOVE_TO_CPU = False
31
 
32
  def interrupt():
@@ -44,16 +41,36 @@ def make_waveform(*args, **kwargs):
44
  return out
45
 
46
  def load_model(version):
 
47
  print("Loading model", version)
48
- return MusicGen.get_pretrained(version)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
 
51
  def predict(model, text, melody, duration, dimension, topk, topp, temperature, cfg_coef, background, title, include_settings, settings_font, settings_font_color, seed, overlap=1):
52
- global MODEL
53
  output_segments = None
54
  topk = int(topk)
55
  if MODEL is None or MODEL.name != model:
56
  MODEL = load_model(model)
 
 
 
57
 
58
  output = None
59
  segment_duration = duration
@@ -139,98 +156,141 @@ def predict(model, text, melody, duration, dimension, topk, topp, temperature, c
139
  file.name, output, MODEL.sample_rate, strategy="loudness",
140
  loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
141
  waveform_video = make_waveform(file.name,bg_image=background, bar_count=40)
 
 
 
 
 
 
142
  return waveform_video, seed
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
- css="""
146
- #col-container {max-width: 910px; margin-left: auto; margin-right: auto;}
147
- a {text-decoration-line: underline; font-weight: 600;}
148
- """
149
- with gr.Blocks(title="UnlimitedMusicGen", css=css) as demo:
150
- gr.Markdown(
151
- """
152
- # UnlimitedMusicGen
153
- This is your private demo for [UnlimitedMusicGen](https://github.com/Oncorporation/audiocraft), a simple and controllable model for music generation
154
- presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
155
- """
156
- )
157
- if IS_SHARED_SPACE:
158
- gr.Markdown("""
159
- This Space doesn't work in this shared UI ⚠
160
-
161
- <a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
162
- <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
163
- to use it privately, or use the <a href="https://huggingface.co/spaces/facebook/MusicGen">public demo</a>
164
- """)
165
- with gr.Row():
166
- with gr.Column():
167
- with gr.Row():
168
- text = gr.Text(label="Input Text", interactive=True, value="4/4 100bpm 320kbps 48khz, Industrial/Electronic Soundtrack, Dark, Intense, Sci-Fi")
169
- melody = gr.Audio(source="upload", type="numpy", label="Melody Condition (optional)", interactive=True)
170
- with gr.Row():
171
- submit = gr.Button("Submit")
172
- # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
173
- _ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
174
- with gr.Row():
175
- background= gr.Image(value="./assets/background.png", source="upload", label="Background", shape=(768,512), type="filepath", interactive=True)
176
- include_settings = gr.Checkbox(label="Add Settings to background", value=True, interactive=True)
177
- with gr.Row():
178
- title = gr.Textbox(label="Title", value="UnlimitedMusicGen", interactive=True)
179
- settings_font = gr.Text(label="Settings Font", value="arial.ttf", interactive=True)
180
- settings_font_color = gr.ColorPicker(label="Settings Font Color", value="#ffffff", interactive=True)
181
- with gr.Row():
182
- model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
183
- with gr.Row():
184
- duration = gr.Slider(minimum=1, maximum=1000, value=10, label="Duration", interactive=True)
185
- overlap = gr.Slider(minimum=1, maximum=29, value=5, step=1, label="Overlap", interactive=True)
186
- dimension = gr.Slider(minimum=-2, maximum=2, value=2, step=1, label="Dimension", info="determines which direction to add new segements of audio. (1 = stack tracks, 2 = lengthen, -2..0 = ?)", interactive=True)
187
- with gr.Row():
188
- topk = gr.Number(label="Top-k", value=250, interactive=True)
189
- topp = gr.Number(label="Top-p", value=0, interactive=True)
190
- temperature = gr.Number(label="Randomness Temperature", value=1.0, precision=2, interactive=True)
191
- cfg_coef = gr.Number(label="Classifier Free Guidance", value=5.0, precision=2, interactive=True)
192
- with gr.Row():
193
- seed = gr.Number(label="Seed", value=-1, precision=0, interactive=True)
194
- gr.Button('\U0001f3b2\ufe0f').style(full_width=False).click(fn=lambda: -1, outputs=[seed], queue=False)
195
- reuse_seed = gr.Button('\u267b\ufe0f').style(full_width=False)
196
- with gr.Column() as c:
197
- output = gr.Video(label="Generated Music")
198
- seed_used = gr.Number(label='Seed used', value=-1, interactive=False)
199
-
200
- reuse_seed.click(fn=lambda x: x, inputs=[seed_used], outputs=[seed], queue=False)
201
- submit.click(predict, inputs=[model, text, melody, duration, dimension, topk, topp, temperature, cfg_coef, background, title, include_settings, settings_font, settings_font_color, seed, overlap], outputs=[output, seed_used])
202
- gr.Examples(
203
- fn=predict,
204
- examples=[
205
- [
206
- "An 80s driving pop song with heavy drums and synth pads in the background",
207
- "./assets/bach.mp3",
208
- "melody"
209
- ],
210
- [
211
- "A cheerful country song with acoustic guitars",
212
- "./assets/bolero_ravel.mp3",
213
- "melody"
214
- ],
215
- [
216
- "90s rock song with electric guitar and heavy drums",
217
- None,
218
- "medium"
219
- ],
220
- [
221
- "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
222
- "./assets/bach.mp3",
223
- "melody"
224
- ],
225
- [
226
- "lofi slow bpm electro chill with organic samples",
227
- None,
228
- "medium",
229
  ],
230
- ],
231
- inputs=[text, melody, model],
232
- outputs=[output]
233
- )
 
 
 
 
 
 
 
 
 
 
 
 
234
 
 
 
 
 
 
 
 
 
 
 
235
 
236
- demo.queue(max_size=32).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  MODEL = None
23
  MODELS = None
24
  IS_SHARED_SPACE = "musicgen/MusicGen" in os.environ.get('SPACE_ID', '')
 
 
25
  INTERRUPTED = False
26
  UNLOAD_MODEL = False
 
27
  MOVE_TO_CPU = False
28
 
29
  def interrupt():
 
41
  return out
42
 
43
  def load_model(version):
44
+ global MODEL, MODELS, UNLOAD_MODEL
45
  print("Loading model", version)
46
+ if MODELS is None:
47
+ return MusicGen.get_pretrained(version)
48
+ else:
49
+ t1 = time.monotonic()
50
+ if MODEL is not None:
51
+ MODEL.to('cpu') # move to cache
52
+ print("Previous model moved to CPU in %.2fs" % (time.monotonic() - t1))
53
+ t1 = time.monotonic()
54
+ if MODELS.get(version) is None:
55
+ print("Loading model %s from disk" % version)
56
+ result = MusicGen.get_pretrained(version)
57
+ MODELS[version] = result
58
+ print("Model loaded in %.2fs" % (time.monotonic() - t1))
59
+ return result
60
+ result = MODELS[version].to('cuda')
61
+ print("Cached model loaded in %.2fs" % (time.monotonic() - t1))
62
+ return result
63
 
64
 
65
  def predict(model, text, melody, duration, dimension, topk, topp, temperature, cfg_coef, background, title, include_settings, settings_font, settings_font_color, seed, overlap=1):
66
+ global MODEL, INTERRUPTED
67
  output_segments = None
68
  topk = int(topk)
69
  if MODEL is None or MODEL.name != model:
70
  MODEL = load_model(model)
71
+ else:
72
+ if MOVE_TO_CPU:
73
+ MODEL.to('cuda')
74
 
75
  output = None
76
  segment_duration = duration
 
156
  file.name, output, MODEL.sample_rate, strategy="loudness",
157
  loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
158
  waveform_video = make_waveform(file.name,bg_image=background, bar_count=40)
159
+ if MOVE_TO_CPU:
160
+ MODEL.to('cpu')
161
+ if UNLOAD_MODEL:
162
+ MODEL = None
163
+ torch.cuda.empty_cache()
164
+ torch.cuda.ipc_collect()
165
  return waveform_video, seed
166
 
167
+ def ui(**kwargs):
168
+ css="""
169
+ #col-container {max-width: 910px; margin-left: auto; margin-right: auto;}
170
+ a {text-decoration-line: underline; font-weight: 600;}
171
+ """
172
+ with gr.Blocks(title="UnlimitedMusicGen", css=css) as demo:
173
+ gr.Markdown(
174
+ """
175
+ # UnlimitedMusicGen
176
+ This is your private demo for [UnlimitedMusicGen](https://github.com/Oncorporation/audiocraft), a simple and controllable model for music generation
177
+ presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
178
+ """
179
+ )
180
+ if IS_SHARED_SPACE:
181
+ gr.Markdown("""
182
+ ⚠ This Space doesn't work in this shared UI ⚠
183
 
184
+ <a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
185
+ <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
186
+ to use it privately, or use the <a href="https://huggingface.co/spaces/facebook/MusicGen">public demo</a>
187
+ """)
188
+ with gr.Row():
189
+ with gr.Column():
190
+ with gr.Row():
191
+ text = gr.Text(label="Input Text", interactive=True, value="4/4 100bpm 320kbps 48khz, Industrial/Electronic Soundtrack, Dark, Intense, Sci-Fi")
192
+ melody = gr.Audio(source="upload", type="numpy", label="Melody Condition (optional)", interactive=True)
193
+ with gr.Row():
194
+ submit = gr.Button("Submit")
195
+ # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
196
+ _ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
197
+ with gr.Row():
198
+ background= gr.Image(value="./assets/background.png", source="upload", label="Background", shape=(768,512), type="filepath", interactive=True)
199
+ include_settings = gr.Checkbox(label="Add Settings to background", value=True, interactive=True)
200
+ with gr.Row():
201
+ title = gr.Textbox(label="Title", value="UnlimitedMusicGen", interactive=True)
202
+ settings_font = gr.Text(label="Settings Font", value="arial.ttf", interactive=True)
203
+ settings_font_color = gr.ColorPicker(label="Settings Font Color", value="#ffffff", interactive=True)
204
+ with gr.Row():
205
+ model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
206
+ with gr.Row():
207
+ duration = gr.Slider(minimum=1, maximum=1000, value=10, label="Duration", interactive=True)
208
+ overlap = gr.Slider(minimum=1, maximum=29, value=5, step=1, label="Overlap", interactive=True)
209
+ dimension = gr.Slider(minimum=-2, maximum=2, value=2, step=1, label="Dimension", info="determines which direction to add new segements of audio. (1 = stack tracks, 2 = lengthen, -2..0 = ?)", interactive=True)
210
+ with gr.Row():
211
+ topk = gr.Number(label="Top-k", value=250, interactive=True)
212
+ topp = gr.Number(label="Top-p", value=0, interactive=True)
213
+ temperature = gr.Number(label="Randomness Temperature", value=1.0, precision=2, interactive=True)
214
+ cfg_coef = gr.Number(label="Classifier Free Guidance", value=5.0, precision=2, interactive=True)
215
+ with gr.Row():
216
+ seed = gr.Number(label="Seed", value=-1, precision=0, interactive=True)
217
+ gr.Button('\U0001f3b2\ufe0f').style(full_width=False).click(fn=lambda: -1, outputs=[seed], queue=False)
218
+ reuse_seed = gr.Button('\u267b\ufe0f').style(full_width=False)
219
+ with gr.Column() as c:
220
+ output = gr.Video(label="Generated Music")
221
+ seed_used = gr.Number(label='Seed used', value=-1, interactive=False)
222
+
223
+ reuse_seed.click(fn=lambda x: x, inputs=[seed_used], outputs=[seed], queue=False)
224
+ submit.click(predict, inputs=[model, text, melody, duration, dimension, topk, topp, temperature, cfg_coef, background, title, include_settings, settings_font, settings_font_color, seed, overlap], outputs=[output, seed_used])
225
+ gr.Examples(
226
+ fn=predict,
227
+ examples=[
228
+ [
229
+ "An 80s driving pop song with heavy drums and synth pads in the background",
230
+ "./assets/bach.mp3",
231
+ "melody"
232
+ ],
233
+ [
234
+ "A cheerful country song with acoustic guitars",
235
+ "./assets/bolero_ravel.mp3",
236
+ "melody"
237
+ ],
238
+ [
239
+ "90s rock song with electric guitar and heavy drums",
240
+ None,
241
+ "medium"
242
+ ],
243
+ [
244
+ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
245
+ "./assets/bach.mp3",
246
+ "melody"
247
+ ],
248
+ [
249
+ "lofi slow bpm electro chill with organic samples",
250
+ None,
251
+ "medium",
252
+ ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  ],
254
+ inputs=[text, melody, model],
255
+ outputs=[output]
256
+ )
257
+
258
+ # Show the interface
259
+ launch_kwargs = {}
260
+ share = kwargs.get('share', False)
261
+ if share:
262
+ launch_kwargs['share'] = share
263
+
264
+
265
+
266
+ demo.queue(max_size=15).launch(**launch_kwargs )
267
+
268
+ if __name__ == "__main__":
269
+ parser = argparse.ArgumentParser()
270
 
271
+ parser.add_argument(
272
+ '--share', action='store_true', help='Share the gradio UI'
273
+ )
274
+ parser.add_argument(
275
+ '--unload_model', action='store_true', help='Unload the model after every generation to save GPU memory'
276
+ )
277
+
278
+ parser.add_argument(
279
+ '--unload_to_cpu', action='store_true', help='Move the model to main RAM after every generation to save GPU memory but reload faster than after full unload (see above)'
280
+ )
281
 
282
+ parser.add_argument(
283
+ '--cache', action='store_true', help='Cache models in RAM to quickly switch between them'
284
+ )
285
+
286
+ args = parser.parse_args()
287
+ UNLOAD_MODEL = args.unload_model
288
+ MOVE_TO_CPU = args.unload_to_cpu
289
+ if args.cache:
290
+ MODELS = {}
291
+
292
+ ui(
293
+ unload_to_cpu = MOVE_TO_CPU,
294
+ share=args.share
295
+
296
+ )
audiocraft/models/musicgen.py CHANGED
@@ -284,3 +284,8 @@ class MusicGen:
284
  with torch.no_grad():
285
  gen_audio = self.compression_model.decode(gen_tokens, None)
286
  return gen_audio
 
 
 
 
 
 
284
  with torch.no_grad():
285
  gen_audio = self.compression_model.decode(gen_tokens, None)
286
  return gen_audio
287
+
288
+ def to(self, device: str):
289
+ self.compression_model.to(device)
290
+ self.lm.to(device)
291
+ return self