from typing import Dict, List, Any from transformers import pipeline import soundfile as sf import torch import logging import base64 logger = logging.getLogger(__name__) class EndpointHandler(): def __init__(self, path=""): # load the optimized model # create inference pipeline self.pipeline = pipeline("text-to-audio", "facebook/musicgen-stereo-large", device="cuda", torch_dtype=torch.float16) def generate_audio(self, text: str): # Here you can implement your audio generation logic # For demonstration purposes, let's use your existing code logger.info("Generating audio for text: %s", text) try: music = self.pipeline(text, forward_params={"max_new_tokens": 256}) return music["audio"][0].T, music["sampling_rate"] except Exception as e: logger.error("Error generating audio for text: %s", text, exc_info=True) raise e def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: input = data.pop("inputs", data) # parameters = data.pop("parameters",data) audio_data, sampling_rate = self.generate_audio(input) # Create JSON response response = { "audio_data": audio_data.tolist(), "sampling_rate": sampling_rate } return response