Spaces:
Running
Running
Update to fix Collab launch
Browse files
audiocraft/models/musicgen.py
CHANGED
@@ -412,6 +412,38 @@ class MusicGen:
|
|
412 |
gen_audio = self.compression_model.decode(gen_tokens, None)
|
413 |
return gen_audio
|
414 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
415 |
def to(self, device: str):
|
416 |
self.compression_model.to(device)
|
417 |
self.lm.to(device)
|
|
|
412 |
gen_audio = self.compression_model.decode(gen_tokens, None)
|
413 |
return gen_audio
|
414 |
|
415 |
+
#def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
|
416 |
+
# prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
|
417 |
+
# """Generate discrete audio tokens given audio prompt and/or conditions.
|
418 |
+
|
419 |
+
# Args:
|
420 |
+
# attributes (tp.List[ConditioningAttributes]): Conditions used for generation (text/melody).
|
421 |
+
# prompt_tokens (tp.Optional[torch.Tensor]): Audio prompt used for continuation.
|
422 |
+
# progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
423 |
+
# Returns:
|
424 |
+
# torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
|
425 |
+
# """
|
426 |
+
# def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
427 |
+
# print(f'{generated_tokens: 6d} / {tokens_to_generate: 6d}', end='\r')
|
428 |
+
|
429 |
+
# if prompt_tokens is not None:
|
430 |
+
# assert self.generation_params['max_gen_len'] > prompt_tokens.shape[-1], \
|
431 |
+
# "Prompt is longer than audio to generate"
|
432 |
+
|
433 |
+
# callback = None
|
434 |
+
# if progress:
|
435 |
+
# callback = _progress_callback
|
436 |
+
|
437 |
+
# # generate by sampling from LM
|
438 |
+
# with self.autocast:
|
439 |
+
# gen_tokens = self.lm.generate(prompt_tokens, attributes, callback=callback, **self.generation_params)
|
440 |
+
|
441 |
+
# # generate audio
|
442 |
+
# assert gen_tokens.dim() == 3
|
443 |
+
# with torch.no_grad():
|
444 |
+
# gen_audio = self.compression_model.decode(gen_tokens, None)
|
445 |
+
# return gen_audio
|
446 |
+
|
447 |
def to(self, device: str):
|
448 |
self.compression_model.to(device)
|
449 |
self.lm.to(device)
|