File size: 1,335 Bytes
2c7e285
 
 
 
 
 
 
 
 
 
 
ff2924b
 
2c7e285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be41572
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from typing import Any, Dict

import torch
from diffusers import AudioLDM2Pipeline, DPMSolverMultistepScheduler


class EndpointHandler:
    def __init__(self, path=""):
        # load model and processor from path
        self.pipeline = AudioLDM2Pipeline.from_pretrained(
            "cvssp/audioldm2-music", torch_dtype=torch.float16
        )
        self.pipeline.to("cuda")
        self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
            self.pipeline.scheduler.config
        )
        self.pipeline.enable_model_cpu_offload()

    def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
        """
        Args:
            data (:dict:):
                The payload with the text prompt and generation parameters.
        """
        # process input
        song_description = data.pop("inputs", data)
        duration = data.get("duration", 30)
        negative_prompt = data.get("negative_prompt", "Low quality, average quality.")

        audio = self.pipeline(
            song_description,
            negative_prompt=negative_prompt,
            num_waveforms_per_prompt=4,
            audio_length_in_s=duration,
            num_inference_steps=20,
        ).audios[0]

        # postprocess the prediction
        prediction = audio.tolist()

        return {"generated_audio": prediction}