from typing import Dict, List, Any from transformers import pipeline import soundfile as sf import torch import logging import io logger = logging.getLogger(__name__) class EndpointHandler(): def __init__(self, path=""): # load the optimized model # create inference pipeline self.pipeline = pipeline("text-to-audio", "facebook/musicgen-stereo-large", device="mps", torch_dtype=torch.float16) def generate_audio(self, text: str): # Here you can implement your audio generation logic # For demonstration purposes, let's use your existing code logger.info("Generating audio for text: %s", text) try: music = self.pipeline(text, forward_params={"max_new_tokens": 256}) return music["audio"][0].T, music["sampling_rate"] except Exception as e: logger.error("Error generating audio for text: %s", text, exc_info=True) raise e def __call__(self, data: Any) -> List[List[Dict[str, float]]]: """ Args: data (:obj:): includes the input data and the parameters for the inference. Return: A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing : - "label": A string representing what the label/class is. There can be multiple labels. - "score": A score between 0 and 1 describing how confident the model is for this label/class. """ input = data.pop("input", data) audio_data, sampling_rate = self.generate_audio(input) with io.BytesIO() as buffer: sf.write(buffer, audio_data, sampling_rate, format="WAV") buffer.seek(0) return buffer.getvalue()