|
from typing import Dict, List, Any |
|
from scipy.io import wavfile |
|
from transformers import AutoProcessor, MusicgenForConditionalGeneration |
|
import torch |
|
import io |
|
|
|
def create_params(params, fr): |
|
|
|
out = { "do_sample": True, |
|
"guidance_scale": 3, |
|
"max_new_tokens": 256 |
|
} |
|
|
|
has_tokens = False |
|
|
|
if params is None: |
|
return out |
|
|
|
if 'duration' in params: |
|
out['max_new_tokens'] = params['duration'] * fr |
|
has_tokens = True |
|
|
|
for k, p in params.items(): |
|
if k in out: |
|
if has_tokens and k == 'max_new_tokens': |
|
continue |
|
|
|
out[k] = p |
|
|
|
return out |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path="pbotsaris/musicgen-small"): |
|
|
|
self.processor = AutoProcessor.from_pretrained(path) |
|
self.model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16) |
|
self.model.to('cuda') |
|
|
|
def __call__(self, data: Dict[str, Any]) -> bytes: |
|
""" |
|
Args: |
|
data (:dict:): |
|
The payload with the text prompt and generation parameters. |
|
|
|
Returns: wav file in bytes |
|
""" |
|
|
|
inputs = data.pop("inputs", data) |
|
params = data.pop("parameters", None) |
|
|
|
inputs = self.processor( |
|
text=[inputs], |
|
padding=True, |
|
return_tensors="pt" |
|
).to('cuda') |
|
|
|
params = create_params(params, self.model.config.audio_encoder.frame_rate) |
|
|
|
with torch.cuda.amp.autocast(): |
|
outputs = self.model.generate(**inputs, **params) |
|
|
|
pred = outputs[0].cpu().numpy().tolist() |
|
sr = 32000 |
|
|
|
try: |
|
sr = self.model.config.audio_encoder.sampling_rate |
|
|
|
except: |
|
sr = 32000 |
|
|
|
|
|
wav_buffer = io.BytesIO() |
|
wavfile.write(wav_buffer, sr, pred) |
|
|
|
|
|
wav_data = wav_buffer.getvalue() |
|
|
|
return wav_data |
|
|
|
|
|
if __name__ == "__main__": |
|
handler = EndpointHandler() |
|
|