Spaces:
Running
Running
from typing import Dict | |
import numpy as np | |
import pyctcdecode | |
import torch | |
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM | |
class PreTrainedPipeline: | |
def __init__(self, model_path: str, language_model_fp: str): | |
self.language_model_fp = language_model_fp | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model = Wav2Vec2ForCTC.from_pretrained(model_path) | |
self.model.to(self.device) | |
processor = Wav2Vec2Processor.from_pretrained(model_path) | |
self.sampling_rate = processor.feature_extractor.sampling_rate | |
vocab = processor.tokenizer.get_vocab() | |
sorted_vocab_dict = [ | |
(char, ix) for char, ix in sorted(vocab.items(), key=lambda item: item[1]) | |
] | |
self.decoder = pyctcdecode.build_ctcdecoder( | |
labels=[x[0] for x in sorted_vocab_dict], | |
kenlm_model_path=self.language_model_fp, | |
) | |
self.processor_with_lm = Wav2Vec2ProcessorWithLM( | |
feature_extractor=processor.feature_extractor, | |
tokenizer=processor.tokenizer, | |
decoder=self.decoder, | |
) | |
def __call__(self, inputs: np.array) -> Dict[str, str]: | |
""" | |
Args: | |
inputs (:obj:`np.array`): | |
The raw waveform of audio received. By default at 16KHz. | |
Return: | |
A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing | |
the detected text from the input audio. | |
""" | |
input_values = self.processor_with_lm( | |
inputs, return_tensors="pt", sampling_rate=self.sampling_rate | |
)["input_values"] | |
input_values = input_values.to(self.device) | |
with torch.no_grad(): | |
# input_values should be a 2D tensor by now. 1st dim represents audio channels. | |
model_outs = self.model(input_values) | |
logits = model_outs.logits.cpu().detach().numpy() | |
text_predicted = self.processor_with_lm.batch_decode(logits)["text"] | |
return {"text": text_predicted} | |