#!/usr/bin/env python3 """ ./export-onnx.py ./preprocess.sh wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/lei-jun-test.wav ./vad-onnx.py --model ./model.onnx --wav ./lei-jun-test.wav """ import argparse from pathlib import Path import librosa import numpy as np import onnxruntime as ort import soundfile as sf from numpy.lib.stride_tricks import as_strided def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, required=True, help="Path to model.onnx") parser.add_argument("--wav", type=str, required=True, help="Path to test.wav") return parser.parse_args() class OnnxModel: def __init__(self, filename): session_opts = ort.SessionOptions() session_opts.inter_op_num_threads = 1 session_opts.intra_op_num_threads = 1 self.session_opts = session_opts self.model = ort.InferenceSession( filename, sess_options=self.session_opts, providers=["CPUExecutionProvider"], ) meta = self.model.get_modelmeta().custom_metadata_map print(meta) self.window_size = int(meta["window_size"]) self.sample_rate = int(meta["sample_rate"]) self.window_shift = int(0.1 * self.window_size) self.receptive_field_size = int(meta["receptive_field_size"]) self.receptive_field_shift = int(meta["receptive_field_shift"]) self.num_speakers = int(meta["num_speakers"]) self.powerset_max_classes = int(meta["powerset_max_classes"]) self.num_classes = int(meta["num_classes"]) def __call__(self, x): """ Args: x: (N, num_samples) Returns: A tensor of shape (N, num_frames, num_classes) """ x = np.expand_dims(x, axis=1) (y,) = self.model.run( [self.model.get_outputs()[0].name], {self.model.get_inputs()[0].name: x} ) return y def load_wav(filename, expected_sample_rate) -> np.ndarray: audio, sample_rate = sf.read(filename, dtype="float32", always_2d=True) audio = audio[:, 0] # only use the first channel if sample_rate != expected_sample_rate: audio = librosa.resample( audio, orig_sr=sample_rate, target_sr=expected_sample_rate, ) return audio def get_powerset_mapping(num_classes, num_speakers, powerset_max_classes): mapping = np.zeros((num_classes, num_speakers)) k = 1 for i in range(1, powerset_max_classes + 1): if i == 1: for j in range(0, num_speakers): mapping[k, j] = 1 k += 1 elif i == 2: for j in range(0, num_speakers): for m in range(j + 1, num_speakers): mapping[k, j] = 1 mapping[k, m] = 1 k += 1 elif i == 3: raise RuntimeError("Unsupported") return mapping def to_multi_label(y, mapping): """ Args: y: (num_chunks, num_frames, num_classes) Returns: A tensor of shape (num_chunks, num_frames, num_speakers) """ y = np.argmax(y, axis=-1) labels = mapping[y.reshape(-1)].reshape(y.shape[0], y.shape[1], -1) return labels def main(): args = get_args() assert Path(args.model).is_file(), args.model assert Path(args.wav).is_file(), args.wav m = OnnxModel(args.model) audio = load_wav(args.wav, m.sample_rate) # audio: (num_samples,) print("audio", audio.shape, audio.min(), audio.max(), audio.sum()) num = (audio.shape[0] - m.window_size) // m.window_shift + 1 samples = as_strided( audio, shape=(num, m.window_size), strides=(m.window_shift * audio.strides[0], audio.strides[0]), ) # or use torch.Tensor.unfold # samples = torch.from_numpy(audio).unfold(0, m.window_size, m.window_shift).numpy() print( "samples", samples.shape, samples.mean(), samples.sum(), samples[:3, :3].sum(axis=-1), ) if ( audio.shape[0] < m.window_size or (audio.shape[0] - m.window_size) % m.window_shift > 0 ): has_last_chunk = True else: has_last_chunk = False num_chunks = samples.shape[0] batch_size = 32 output = [] for i in range(0, num_chunks, batch_size): start = i end = i + batch_size # it's perfectly ok to use end > num_chunks y = m(samples[start:end]) output.append(y) if has_last_chunk: last_chunk = audio[num_chunks * m.window_shift :] # noqa pad_size = m.window_size - last_chunk.shape[0] last_chunk = np.pad(last_chunk, (0, pad_size)) last_chunk = np.expand_dims(last_chunk, axis=0) y = m(last_chunk) output.append(y) y = np.vstack(output) # y: (num_chunks, num_frames, num_classes) mapping = get_powerset_mapping( num_classes=m.num_classes, num_speakers=m.num_speakers, powerset_max_classes=m.powerset_max_classes, ) labels = to_multi_label(y, mapping=mapping) # labels: (num_chunks, num_frames, num_speakers) # binary classification labels = np.max(labels, axis=-1) # labels: (num_chunk, num_frames) num_frames = ( int( (m.window_size + (labels.shape[0] - 1) * m.window_shift) / m.receptive_field_shift ) + 1 ) count = np.zeros((num_frames,)) classification = np.zeros((num_frames,)) weight = np.hamming(labels.shape[1]) for i in range(labels.shape[0]): this_chunk = labels[i] start = int(i * m.window_shift / m.receptive_field_shift + 0.5) end = start + this_chunk.shape[0] classification[start:end] += this_chunk * weight count[start:end] += weight classification /= np.maximum(count, 1e-12) if has_last_chunk: stop_frame = int(audio.shape[0] / m.receptive_field_shift) classification = classification[:stop_frame] classification = classification.tolist() onset = 0.5 offset = 0.5 is_active = classification[0] > onset start = None if is_active: start = 0 scale = m.receptive_field_shift / m.sample_rate scale_offset = m.receptive_field_size / m.sample_rate * 0.5 for i in range(len(classification)): if is_active: if classification[i] < offset: print( f"{start*scale + scale_offset:.3f} -- {i*scale + scale_offset:.3f}" ) is_active = False else: if classification[i] > onset: start = i is_active = True if is_active: print( f"{start*scale + scale_offset:.3f} -- {(len(classification)-1)*scale + scale_offset:.3f}" ) if __name__ == "__main__": main()