import base64 import json import os from io import StringIO from typing import Dict, Any from transformers import pipeline class EndpointHandler: def __init__(self, asr_model_path: str = "./whisper-large-v2"): # Create an ASR pipeline using the model located in the specified directory self.asr_pipeline = pipeline( "automatic-speech-recognition", model = asr_model_path, ) def __call__(self, data: Dict[str, Any]) -> str: json_data = json.loads(data) if "audio_data" not in json_data.keys(): raise Exception("Request must contain a top-level key named 'audio_data'") # Get the audio data from the input audio_data = json_data["audio_data"] language = json_data["language"] # Decode the binary audio data if it's provided as a base64 string if isinstance(audio_data, str): audio_data = base64.b64decode(audio_data) # Process the audio data with the ASR pipeline transcription = self.asr_pipeline( audio_data, return_timestamps=False, chunk_length_s=30, batch_size=8, max_length=10000, max_new_tokens=10000, generate_kwargs={"task": "transcribe", "language": "<|language|>"} ) # Convert the transcription to JSON result = StringIO() json.dump(transcription, result) return result.getvalue() def init(): global asr_pipeline # Set the path to the directory where the model is stored model_path = os.getenv("AZUREML_MODEL_DIR", "./whisper-large-v2") # Create an ASR pipeline using the model located in the specified directory asr_pipeline = pipeline( "automatic-speech-recognition", model = model_path, ) def run(raw_data): json_data = json.loads(raw_data) if "audio_data" not in json_data.keys(): raise Exception("Request must contain a top level key named 'audio_data'") # Get the audio data from the input audio_data = json_data["audio_data"] # Decode the binary audio data if it's provided as a base64 string if isinstance(audio_data, str): import base64 audio_data = base64.b64decode(audio_data) # Process the audio data with the ASR pipeline transcription = asr_pipeline( audio_data, return_timestamps = False, chunk_length_s = 30, batch_size = 8, max_new_tokens = 1000, generate_kwargs = {"task": "transcribe", "language": "<|de|>"} ) # Convert the transcription to JSON result = StringIO() json.dump(transcription, result) return result.getvalue()