|
|
|
""" |
|
Copyright (C) 2021-2022 Intel Corporation |
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); |
|
you may not use this file except in compliance with the License. |
|
You may obtain a copy of the License at |
|
|
|
http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
Unless required by applicable law or agreed to in writing, software |
|
distributed under the License is distributed on an "AS IS" BASIS, |
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
See the License for the specific language governing permissions and |
|
limitations under the License. |
|
""" |
|
|
|
from argparse import ArgumentParser, SUPPRESS |
|
from itertools import groupby |
|
import json |
|
import logging as log |
|
from pathlib import Path |
|
from time import perf_counter |
|
import sys |
|
|
|
import numpy as np |
|
import wave |
|
|
|
from openvino.inference_engine import IECore |
|
|
|
ie = IECore() |
|
|
|
|
|
|
|
class Wav2Vec: |
|
alphabet = [ |
|
"<pad>", "<s>", "</s>", "<unk>", "|", |
|
"e", "t", "a", "o", "n", "i", "h", "s", "r", "d", "l", "u", |
|
"m", "w", "c", "f", "g", "y", "p", "b", "v", "k", "'", "x", "j", "q", "z"] |
|
words_delimiter = '|' |
|
pad_token = '<pad>' |
|
def __init__(self): |
|
self.nnet = ie.read_network("/home/intel/Documents/ASR/wav2vec2-base-ft-keyword-spotting-int8/ov_model.xml", "/home/intel/Documents/ASR/wav2vec2-base-ft-keyword-spotting-int8/ov_model.bin") |
|
|
|
@staticmethod |
|
def preprocess(sound): |
|
return (sound - np.mean(sound)) / (np.std(sound) + 1e-15) |
|
|
|
def infer(self, audio): |
|
exec_net = ie.load_network(self.nnet, "CPU") |
|
outss = exec_net.infer({"input_values": audio}) |
|
|
|
return outss |
|
|
|
def decode(self, logits): |
|
token_ids = np.squeeze(np.argmax(logits, -1)) |
|
tokens = [self.decoding_vocab[idx] for idx in token_ids] |
|
tokens = [token_group[0] for token_group in groupby(tokens)] |
|
tokens = [t for t in tokens if t != self.pad_token] |
|
res_string = ''.join([t if t != self.words_delimiter else ' ' for t in tokens]).strip() |
|
res_string = ' '.join(res_string.split(' ')) |
|
res_string = res_string.lower() |
|
return res_string |
|
|
|
def reshape(self, audio): |
|
self.nnet.reshape({next(iter(self.nnet.input_info)): audio.shape}) |
|
|
|
def main(): |
|
model = Wav2Vec() |
|
start_time = perf_counter() |
|
with wave.open("/home/intel/Documents/ASR/applications.ai.conversational-ai.asr-grpc-security/client_sample_examples/python/audio_data_samples/how_are_you_doing.wav", 'rb') as wave_read: |
|
channel_num, sample_width, sampling_rate, pcm_length, compression_type, _ = wave_read.getparams() |
|
assert sample_width == 2, "Only 16-bit WAV PCM supported" |
|
assert compression_type == 'NONE', "Only linear PCM WAV files supported" |
|
assert channel_num == 1, "Only mono WAV PCM supported" |
|
assert sampling_rate == 16000, "Only 16 KHz audio supported" |
|
audio = np.frombuffer(wave_read.readframes(pcm_length * channel_num), dtype=np.int16).reshape((1, pcm_length)) |
|
audio = audio.astype(float) / np.iinfo(np.int16).max |
|
|
|
normalized_audio = model.preprocess(audio) |
|
model.reshape(normalized_audio) |
|
character_probs = model.infer(normalized_audio) |
|
print(type(character_probs)) |
|
print(character_probs.keys()) |
|
transcription = model.decode(character_probs["3761"]) |
|
total_latency = (perf_counter() - start_time) * 1e3 |
|
|
|
|
|
print(transcription) |
|
print(total_latency) |
|
|
|
if __name__ == '__main__': |
|
sys.exit(main() or 0) |
|
|