Blaxzter's picture
Upload handler.py
a8eab90
raw history blame
No virus
2.71 kB
import base64
import json
import os
from io import StringIO
from typing import Dict, Any
from transformers import pipeline
class EndpointHandler:
def __init__(self, asr_model_path: str = "./whisper-large-v2"):
# Create an ASR pipeline using the model located in the specified directory
self.asr_pipeline = pipeline(
"automatic-speech-recognition",
model = asr_model_path,
)
def __call__(self, data: Dict[str, Any]) -> str:
json_data = json.loads(data)
if "audio_data" not in json_data.keys():
raise Exception("Request must contain a top-level key named 'audio_data'")
# Get the audio data from the input
audio_data = json_data["audio_data"]
language = json_data["language"]
# Decode the binary audio data if it's provided as a base64 string
if isinstance(audio_data, str):
audio_data = base64.b64decode(audio_data)
# Process the audio data with the ASR pipeline
transcription = self.asr_pipeline(
audio_data,
return_timestamps=False,
chunk_length_s=30,
batch_size=8,
max_length=10000,
max_new_tokens=10000,
generate_kwargs={"task": "transcribe", "language": "<|language|>"}
)
# Convert the transcription to JSON
result = StringIO()
json.dump(transcription, result)
return result.getvalue()
def init():
global asr_pipeline
# Set the path to the directory where the model is stored
model_path = os.getenv("AZUREML_MODEL_DIR", "./whisper-large-v2")
# Create an ASR pipeline using the model located in the specified directory
asr_pipeline = pipeline(
"automatic-speech-recognition",
model = model_path,
)
def run(raw_data):
json_data = json.loads(raw_data)
if "audio_data" not in json_data.keys():
raise Exception("Request must contain a top level key named 'audio_data'")
# Get the audio data from the input
audio_data = json_data["audio_data"]
# Decode the binary audio data if it's provided as a base64 string
if isinstance(audio_data, str):
import base64
audio_data = base64.b64decode(audio_data)
# Process the audio data with the ASR pipeline
transcription = asr_pipeline(
audio_data,
return_timestamps = False,
chunk_length_s = 30,
batch_size = 8,
max_new_tokens = 1000,
generate_kwargs = {"task": "transcribe", "language": "<|de|>"}
)
# Convert the transcription to JSON
result = StringIO()
json.dump(transcription, result)
return result.getvalue()