from typing import Dict, List, Any from scipy.io import wavfile from transformers import AutoProcessor, MusicgenForConditionalGeneration import torch import io import base64 def create_params(params, fr): # default out = { "do_sample": True, "guidance_scale": 3, "max_new_tokens": 256 } has_tokens = False if params is None: return out if 'duration' in params: out['max_new_tokens'] = params['duration'] * fr has_tokens = True for k, p in params.items(): if k in out: if has_tokens and k == 'max_new_tokens': continue out[k] = p return out class EndpointHandler: def __init__(self, path="pbotsaris/musicgen-small"): self.processor = AutoProcessor.from_pretrained(path) self.model = MusicgenForConditionalGeneration.from_pretrained(path) self.model.to('cuda:0') #type: ignore def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Args: data (:dict:): The payload with the text prompt and generation parameters. """ inputs = data.pop("inputs", data) params = data.pop("parameters", None) inputs = self.processor( text=[inputs], padding=True, return_tensors="pt" ) params = create_params(params, self.model.config.audio_encoder.frame_rate) #type: ignore outputs = self.model.generate(**inputs.to('cuda:0'), **params) #type: ignore pred = outputs[0, 0].cpu().numpy() sr = self.model.config.audio_encoder.sampling_rate #type: ignore wav_buffer = io.BytesIO() wavfile.write(wav_buffer, rate=sr, data=pred) wav_data = wav_buffer.getvalue() base64_encoded_wav = base64.b64encode(wav_data).decode('utf-8') return [{"audio": base64_encoded_wav, "sr": sr}] if __name__ == "__main__": handler = EndpointHandler()