keess's picture
- add custom endpoint handler
e5983bf
from typing import Dict, List, Any
import torch as torch
from transformers import pipeline, WhisperProcessor, WhisperForConditionalGeneration
import gradio as gr
import subprocess
import numpy as np
import time
import pandas as pd
from datasets import Audio, Dataset
class EndpointHandler():
# model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
# model='silero_vad', force_reload=False, onnx=True)
# (get_speech_timestamps,
# _, read_audio,
# *_) = utils
def __init__(self, path=""):
device = 0 if torch.cuda.is_available() else "cpu"
# self.pipe = pipeline(
# task="automatic-speech-recognition",
# model="openai/whisper-large",
# # chunk_length_s=30,
# device=device,
# )
self.processor = WhisperProcessor.from_pretrained("openai/whisper-large")
self.model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
self.model.config.forced_decoder_ids = self.processor.get_decoder_prompt_ids(language="nl", task="transcribe")
# self.pipe.model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="nl", task="transcribe")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
date (:obj: `str`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
#print request
print("request")
print(data)
print(data["inputs"])
# audio_data = read(io.BytesIO(data))
# get inputs, inputs in request body is possible equal to wav or mp3 file
inputs = data.pop("inputs", data)
print("here comes text")
print(inputs)
data = [inputs]
ds = pd.DataFrame(data, columns=['audio'])
ds = Dataset.from_pandas(ds)
# load dummy dataset and read soundfiles
ds = ds.cast_column("audio", Audio(sampling_rate=16_000))
input_speech = next(iter(ds))["audio"]["array"]
input_features = self.processor(input_speech, return_tensors="pt").input_features
predicted_ids = self.model.generate(input_features, forced_decoder_ids=self.model.config.forced_decoder_ids)
transcription = self.processor.batch_decode(predicted_ids)
print("this is the description")
print(transcription)
# print(self.pipe(inputs))
# text = self.pipe(inputs)["text"]
# text = self.transcribe(inputs)
# print(text)
return transcription