File size: 1,335 Bytes
2c7e285 ff2924b 2c7e285 be41572 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
from typing import Any, Dict
import torch
from diffusers import AudioLDM2Pipeline, DPMSolverMultistepScheduler
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.pipeline = AudioLDM2Pipeline.from_pretrained(
"cvssp/audioldm2-music", torch_dtype=torch.float16
)
self.pipeline.to("cuda")
self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipeline.scheduler.config
)
self.pipeline.enable_model_cpu_offload()
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
"""
Args:
data (:dict:):
The payload with the text prompt and generation parameters.
"""
# process input
song_description = data.pop("inputs", data)
duration = data.get("duration", 30)
negative_prompt = data.get("negative_prompt", "Low quality, average quality.")
audio = self.pipeline(
song_description,
negative_prompt=negative_prompt,
num_waveforms_per_prompt=4,
audio_length_in_s=duration,
num_inference_steps=20,
).audios[0]
# postprocess the prediction
prediction = audio.tolist()
return {"generated_audio": prediction}
|