Spaces:
Runtime error
Runtime error
import torch | |
import time | |
import typing as tp | |
from audiocraft.models import MusicGen | |
from audiocraft.modules.conditioners import ConditioningAttributes | |
class HijackedMusicGen(MusicGen): | |
def __init__(self, socketio=None, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.socketio = socketio | |
self._progress_callback = self._timed_progress_callback if socketio is not None else None | |
self._last_update_time = time.time() | |
def _timed_progress_callback(self, generated_tokens: int, tokens_to_generate: int): | |
current_time = time.time() | |
if current_time - self._last_update_time >= 0.1: # 0.1 seconds have passed | |
self.socketio.emit('progress', {'generated_tokens': generated_tokens, 'tokens_to_generate': tokens_to_generate}) | |
self._last_update_time = current_time | |
def get_pretrained(socketio, name: str = 'melody', device='cuda'): | |
music_gen = MusicGen.get_pretrained(name, device) | |
return HijackedMusicGen(socketio, music_gen.name, music_gen.compression_model, music_gen.lm) | |
def progress_callback(self): | |
raise Exception("Progress callback is write-only") | |
def progress_callback(self, callback): | |
self._progress_callback = callback |