sherpa-onnx-reverb-diarization-v1 / speaker-diarization-onnx.py
csukuangfj's picture
add models
5ec554b
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
"""
Please refer to
https://github.com/k2-fsa/sherpa-onnx/blob/master/.github/workflows/speaker-diarization.yaml
for usages.
"""
import argparse
from datetime import timedelta
from pathlib import Path
from typing import List
import librosa
import numpy as np
import onnxruntime as ort
import sherpa_onnx
import soundfile as sf
from numpy.lib.stride_tricks import as_strided
class Segment:
def __init__(
self,
start,
end,
speaker,
):
assert start < end
self.start = start
self.end = end
self.speaker = speaker
def merge(self, other, gap=0.5):
assert self.speaker == other.speaker, (self.speaker, other.speaker)
if self.end < other.start and self.end + gap >= other.start:
return Segment(start=self.start, end=other.end, speaker=self.speaker)
elif other.end < self.start and other.end + gap >= self.start:
return Segment(start=other.start, end=self.end, speaker=self.speaker)
else:
return None
@property
def duration(self):
return self.end - self.start
def __str__(self):
s = f"{timedelta(seconds=self.start)}"[:-3]
s += " --> "
s += f"{timedelta(seconds=self.end)}"[:-3]
s += f" speaker_{self.speaker:02d}"
return s
def merge_segment_list(in_out: List[Segment], min_duration_off: float):
changed = True
while changed:
changed = False
for i in range(len(in_out)):
if i + 1 >= len(in_out):
continue
new_segment = in_out[i].merge(in_out[i + 1], gap=min_duration_off)
if new_segment is None:
continue
del in_out[i + 1]
in_out[i] = new_segment
changed = True
break
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--seg-model",
type=str,
required=True,
help="Path to model.onnx for segmentation",
)
parser.add_argument(
"--speaker-embedding-model",
type=str,
required=True,
help="Path to model.onnx for speaker embedding extractor",
)
parser.add_argument("--wav", type=str, required=True, help="Path to test.wav")
return parser.parse_args()
class OnnxSegmentationModel:
def __init__(self, filename):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.model = ort.InferenceSession(
filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
meta = self.model.get_modelmeta().custom_metadata_map
print(meta)
self.window_size = int(meta["window_size"])
self.sample_rate = int(meta["sample_rate"])
self.window_shift = int(0.1 * self.window_size)
self.receptive_field_size = int(meta["receptive_field_size"])
self.receptive_field_shift = int(meta["receptive_field_shift"])
self.num_speakers = int(meta["num_speakers"])
self.powerset_max_classes = int(meta["powerset_max_classes"])
self.num_classes = int(meta["num_classes"])
def __call__(self, x):
"""
Args:
x: (N, num_samples)
Returns:
A tensor of shape (N, num_frames, num_classes)
"""
x = np.expand_dims(x, axis=1)
(y,) = self.model.run(
[self.model.get_outputs()[0].name], {self.model.get_inputs()[0].name: x}
)
return y
def load_wav(filename, expected_sample_rate) -> np.ndarray:
audio, sample_rate = sf.read(filename, dtype="float32", always_2d=True)
audio = audio[:, 0] # only use the first channel
if sample_rate != expected_sample_rate:
audio = librosa.resample(
audio,
orig_sr=sample_rate,
target_sr=expected_sample_rate,
)
return audio
def get_powerset_mapping(num_classes, num_speakers, powerset_max_classes):
mapping = np.zeros((num_classes, num_speakers))
k = 1
for i in range(1, powerset_max_classes + 1):
if i == 1:
for j in range(0, num_speakers):
mapping[k, j] = 1
k += 1
elif i == 2:
for j in range(0, num_speakers):
for m in range(j + 1, num_speakers):
mapping[k, j] = 1
mapping[k, m] = 1
k += 1
elif i == 3:
raise RuntimeError("Unsupported")
return mapping
def to_multi_label(y, mapping):
"""
Args:
y: (num_chunks, num_frames, num_classes)
Returns:
A tensor of shape (num_chunks, num_frames, num_speakers)
"""
y = np.argmax(y, axis=-1)
labels = mapping[y.reshape(-1)].reshape(y.shape[0], y.shape[1], -1)
return labels
# speaker count per frame
def speaker_count(labels, seg_m):
"""
Args:
labels: (num_chunks, num_frames, num_speakers)
seg_m: Segmentation model
Returns:
A integer array of shape (num_total_frames,)
"""
labels = labels.sum(axis=-1)
# Now labels: (num_chunks, num_frames)
num_frames = (
int(
(seg_m.window_size + (labels.shape[0] - 1) * seg_m.window_shift)
/ seg_m.receptive_field_shift
)
+ 1
)
ans = np.zeros((num_frames,))
count = np.zeros((num_frames,))
for i in range(labels.shape[0]):
this_chunk = labels[i]
start = int(i * seg_m.window_shift / seg_m.receptive_field_shift + 0.5)
end = start + this_chunk.shape[0]
ans[start:end] += this_chunk
count[start:end] += 1
ans /= np.maximum(count, 1e-12)
return (ans + 0.5).astype(np.int8)
def load_speaker_embedding_model(filename):
config = sherpa_onnx.SpeakerEmbeddingExtractorConfig(
model=filename,
num_threads=1,
debug=0,
)
if not config.validate():
raise ValueError(f"Invalid config. {config}")
extractor = sherpa_onnx.SpeakerEmbeddingExtractor(config)
return extractor
def get_embeddings(embedding_filename, audio, labels, seg_m, exclude_overlap):
"""
Args:
embedding_filename: Path to the speaker embedding extractor model
audio: (num_samples,)
labels: (num_chunks, num_frames, num_speakers)
seg_m: segmentation model
Returns:
Return (num_chunks, num_speakers, embedding_dim)
"""
if exclude_overlap:
labels = labels * (labels.sum(axis=-1, keepdims=True) < 2)
extractor = load_speaker_embedding_model(embedding_filename)
buffer = np.empty(seg_m.window_size)
num_chunks, num_frames, num_speakers = labels.shape
ans_chunk_speaker_pair = []
ans_embeddings = []
for i in range(num_chunks):
labels_T = labels[i].T
# t: (num_speakers, num_frames)
sample_offset = i * seg_m.window_shift
for j in range(num_speakers):
frames = labels_T[j]
if frames.sum() < 10:
# skip segment less than 20 frames, i.e., about 0.2 seconds
continue
start = None
start_samples = 0
idx = 0
for k in range(num_frames):
if frames[k] != 0:
if start is None:
start = k
elif start is not None:
start_samples = (
int(start / num_frames * seg_m.window_size) + sample_offset
)
end_samples = (
int(k / num_frames * seg_m.window_size) + sample_offset
)
num_samples = end_samples - start_samples
buffer[idx : idx + num_samples] = audio[start_samples:end_samples]
idx += num_samples
start = None
if start is not None:
start_samples = (
int(start / num_frames * seg_m.window_size) + sample_offset
)
end_samples = int(k / num_frames * seg_m.window_size) + sample_offset
num_samples = end_samples - start_samples
buffer[idx : idx + num_samples] = audio[start_samples:end_samples]
idx += num_samples
stream = extractor.create_stream()
stream.accept_waveform(sample_rate=seg_m.sample_rate, waveform=buffer[:idx])
stream.input_finished()
assert extractor.is_ready(stream)
embedding = extractor.compute(stream)
embedding = np.array(embedding)
ans_chunk_speaker_pair.append([i, j])
ans_embeddings.append(embedding)
assert len(ans_chunk_speaker_pair) == len(ans_embeddings), (
len(ans_chunk_speaker_pair),
len(ans_embeddings),
)
return ans_chunk_speaker_pair, np.array(ans_embeddings)
def main():
args = get_args()
assert Path(args.seg_model).is_file(), args.seg_model
assert Path(args.wav).is_file(), args.wav
seg_m = OnnxSegmentationModel(args.seg_model)
audio = load_wav(args.wav, seg_m.sample_rate)
# audio: (num_samples,)
num = (audio.shape[0] - seg_m.window_size) // seg_m.window_shift + 1
samples = as_strided(
audio,
shape=(num, seg_m.window_size),
strides=(seg_m.window_shift * audio.strides[0], audio.strides[0]),
)
# or use torch.Tensor.unfold
# samples = torch.from_numpy(audio).unfold(0, seg_m.window_size, seg_m.window_shift).numpy()
if (
audio.shape[0] < seg_m.window_size
or (audio.shape[0] - seg_m.window_size) % seg_m.window_shift > 0
):
has_last_chunk = True
else:
has_last_chunk = False
num_chunks = samples.shape[0]
batch_size = 32
output = []
for i in range(0, num_chunks, batch_size):
start = i
end = i + batch_size
# it's perfectly ok to use end > num_chunks
y = seg_m(samples[start:end])
output.append(y)
if has_last_chunk:
last_chunk = audio[num_chunks * seg_m.window_shift :] # noqa
pad_size = seg_m.window_size - last_chunk.shape[0]
last_chunk = np.pad(last_chunk, (0, pad_size))
last_chunk = np.expand_dims(last_chunk, axis=0)
y = seg_m(last_chunk)
output.append(y)
y = np.vstack(output)
# y: (num_chunks, num_frames, num_classes)
mapping = get_powerset_mapping(
num_classes=seg_m.num_classes,
num_speakers=seg_m.num_speakers,
powerset_max_classes=seg_m.powerset_max_classes,
)
labels = to_multi_label(y, mapping=mapping)
# labels: (num_chunks, num_frames, num_speakers)
inactive = (labels.sum(axis=1) == 0).astype(np.int8)
# inactive: (num_chunks, num_speakers)
speakers_per_frame = speaker_count(labels=labels, seg_m=seg_m)
# speakers_per_frame: (num_frames, speakers_per_frame)
if speakers_per_frame.max() == 0:
print("No speakers found in the audio file!")
return
# if users specify only 1 speaker for clustering, then return the
# result directly
# Now, get embeddings
chunk_speaker_pair, embeddings = get_embeddings(
args.speaker_embedding_model,
audio=audio,
labels=labels,
seg_m=seg_m,
# exclude_overlap=True,
exclude_overlap=False,
)
# chunk_speaker_pair: a list of (chunk_idx, speaker_idx)
# embeddings: (batch_size, embedding_dim)
# Please change num_clusters or threshold by yourself.
clustering_config = sherpa_onnx.FastClusteringConfig(num_clusters=2)
# clustering_config = sherpa_onnx.FastClusteringConfig(threshold=0.8)
clustering = sherpa_onnx.FastClustering(clustering_config)
cluster_labels = clustering(embeddings)
chunk_speaker_to_cluster = dict()
for (chunk_idx, speaker_idx), cluster_idx in zip(
chunk_speaker_pair, cluster_labels
):
if inactive[chunk_idx, speaker_idx] == 1:
print("skip ", chunk_idx, speaker_idx)
continue
chunk_speaker_to_cluster[(chunk_idx, speaker_idx)] = cluster_idx
num_speakers = max(cluster_labels) + 1
relabels = np.zeros((labels.shape[0], labels.shape[1], num_speakers))
for i in range(labels.shape[0]):
for j in range(labels.shape[1]):
for k in range(labels.shape[2]):
if (i, k) not in chunk_speaker_to_cluster:
continue
t = chunk_speaker_to_cluster[(i, k)]
if labels[i, j, k] == 1:
relabels[i, j, t] = 1
num_frames = (
int(
(seg_m.window_size + (relabels.shape[0] - 1) * seg_m.window_shift)
/ seg_m.receptive_field_shift
)
+ 1
)
count = np.zeros((num_frames, relabels.shape[-1]))
for i in range(relabels.shape[0]):
this_chunk = relabels[i]
start = int(i * seg_m.window_shift / seg_m.receptive_field_shift + 0.5)
end = start + this_chunk.shape[0]
count[start:end] += this_chunk
if has_last_chunk:
stop_frame = int(audio.shape[0] / seg_m.receptive_field_shift)
count = count[:stop_frame]
sorted_count = np.argsort(-count, axis=-1)
final = np.zeros((count.shape[0], count.shape[1]))
for i, (c, sc) in enumerate(zip(speakers_per_frame, sorted_count)):
for k in range(c):
final[i, sc[k]] = 1
min_duration_off = 0.5
min_duration_on = 0.3
onset = 0.5
offset = 0.5
# final: (num_frames, num_speakers)
final = final.T
for kk in range(final.shape[0]):
segment_list = []
frames = final[kk]
is_active = frames[0] > onset
start = None
if is_active:
start = 0
scale = seg_m.receptive_field_shift / seg_m.sample_rate
scale_offset = seg_m.receptive_field_size / seg_m.sample_rate * 0.5
for i in range(1, len(frames)):
if is_active:
if frames[i] < offset:
segment = Segment(
start=start * scale + scale_offset,
end=i * scale + scale_offset,
speaker=kk,
)
segment_list.append(segment)
is_active = False
else:
if frames[i] > onset:
start = i
is_active = True
if is_active:
segment = Segment(
start=start * scale + scale_offset,
end=(len(frames) - 1) * scale + scale_offset,
speaker=kk,
)
segment_list.append(segment)
if len(segment_list) > 1:
merge_segment_list(segment_list, min_duration_off=min_duration_off)
for s in segment_list:
if s.duration < min_duration_on:
continue
print(s)
if __name__ == "__main__":
main()