|
from typing import Dict |
|
import numpy as np |
|
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC |
|
import torch |
|
|
|
|
|
|
|
|
|
g1 = ['p_alveolar','n_alveolar'] |
|
g2 = ['p_palatal','n_palatal'] |
|
g3 = ['p_dental','n_dental'] |
|
g4 = ['p_glottal','n_glottal'] |
|
g5 = ['p_labial','n_labial'] |
|
g6 = ['p_velar','n_velar'] |
|
g7 = ['p_anterior','n_anterior'] |
|
g8 = ['p_posterior','n_posterior'] |
|
g9 = ['p_retroflex','n_retroflex'] |
|
g10 = ['p_mid','n_mid'] |
|
g11 = ['p_high_v','n_high_v'] |
|
g12 = ['p_low','n_low'] |
|
g13 = ['p_front','n_front'] |
|
g14 = ['p_back','n_back'] |
|
g15 = ['p_central','n_central'] |
|
g16 = ['p_consonant','n_consonant'] |
|
g17 = ['p_sonorant','n_sonorant'] |
|
g18 = ['p_long','n_long'] |
|
g19 = ['p_short','n_short'] |
|
g20 = ['p_vowel','n_vowel'] |
|
g21 = ['p_semivowel','n_semivowel'] |
|
g22 = ['p_fricative','n_fricative'] |
|
g23 = ['p_nasal','n_nasal'] |
|
g24 = ['p_stop','n_stop'] |
|
g25 = ['p_approximant','n_approximant'] |
|
g26 = ['p_affricate','n_affricate'] |
|
g27 = ['p_liquid','n_liquid'] |
|
g28 = ['p_continuant','n_continuant'] |
|
g29 = ['p_monophthong','n_monophthong'] |
|
g30 = ['p_diphthong','n_diphthong'] |
|
g31 = ['p_round','n_round'] |
|
g32 = ['p_voiced','n_voiced'] |
|
g33 = ['p_bilabial','n_bilabial'] |
|
g34 = ['p_coronal','n_coronal'] |
|
g35 = ['p_dorsal','n_dorsal'] |
|
groups = [g1,g2,g3,g4,g5,g6,g7,g8,g9,g10,g11,g12,g13,g14,g15,g16,g17,g18,g19,g20,g21,g22,g23,g24,g25,g26,g27,g28,g29,g30,g31,g32,g33,g34,g35] |
|
|
|
class PreTrainedPipeline(): |
|
def __init__(self, path=""): |
|
|
|
|
|
|
|
|
|
self.sampling_rate = 16000 |
|
|
|
|
|
self.processor = Wav2Vec2Processor.from_pretrained(path) |
|
self.model = Wav2Vec2ForCTC.from_pretrained(path) |
|
self.group_ids = [sorted(self.processor.tokenizer.convert_tokens_to_ids(group)) for group in groups] |
|
self.group_ids = [dict([(x[0]+1,x[1]) for x in list(enumerate(g))]) for g in self.group_ids] |
|
|
|
|
|
def __call__(self, inputs: np.array)-> Dict[str, str]: |
|
""" |
|
Args: |
|
inputs (:obj:`np.array`): |
|
The raw waveform of audio received. By default at 16KHz. |
|
Return: |
|
A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing |
|
the detected text from the input audio. |
|
""" |
|
|
|
input_values = self.processor(audio=inputs, sampling_rate=self.sampling_rate, return_tensors="pt").input_values |
|
assert 1==0, "Call" |
|
if torch.cuda.is_available(): |
|
self.model.to("cuda") |
|
input_values = input_values.to("cuda") |
|
|
|
with torch.no_grad(): |
|
logits = self.model(input_values).logits |
|
|
|
mask = torch.zeros(logits.size()[2], dtype = torch.bool) |
|
mask[0] = True |
|
mask[list(self.group_ids[31].values())] = True |
|
logits_g = logits[:,:,mask] |
|
pred_ids = torch.argmax(logits_g,dim=-1) |
|
pred_ids = pred_ids.cpu().apply_(lambda x: self.group_ids[31].get(x,x)) |
|
pred = self.processor.batch_decode(pred_ids,spaces_between_special_tokens=True)[0] |
|
pred = pred.replace('p_','+').replace('n_', '-') |
|
return { |
|
"text":pred |
|
} |
|
|
|
|