Spaces:
Runtime error
Runtime error
import os | |
import json | |
import numpy as np | |
import torch | |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
class TranscriberModel: | |
def __init__(self, lang: str = 'en'): | |
with open(os.path.join(os.path.dirname(__file__), 'config.json')) as f: | |
config = json.load(f) | |
self.processor = Wav2Vec2Processor.from_pretrained(config['language_model_names'][lang]) | |
self.model = Wav2Vec2ForCTC.from_pretrained(config['language_model_names'][lang]) | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def forward(self, speech_array: np.array, sampling_rate: int = 16000) -> str: | |
model_input = self.processor(speech_array, sampling_rate=sampling_rate, return_tensors="pt", padding=True) | |
with torch.no_grad(): | |
logits = self.model(model_input.input_values, attention_mask=model_input.attention_mask).logits | |
predicted_ids = torch.argmax(logits, dim=-1) | |
return self.processor.batch_decode(predicted_ids) | |