|
|
|
|
|
|
|
""" |
|
Please refer to |
|
https://github.com/k2-fsa/sherpa-onnx/blob/master/.github/workflows/speaker-diarization.yaml |
|
for usages. |
|
""" |
|
|
|
import argparse |
|
from datetime import timedelta |
|
from pathlib import Path |
|
from typing import List |
|
|
|
import librosa |
|
import numpy as np |
|
import onnxruntime as ort |
|
import sherpa_onnx |
|
import soundfile as sf |
|
from numpy.lib.stride_tricks import as_strided |
|
|
|
|
|
class Segment: |
|
def __init__( |
|
self, |
|
start, |
|
end, |
|
speaker, |
|
): |
|
assert start < end |
|
self.start = start |
|
self.end = end |
|
self.speaker = speaker |
|
|
|
def merge(self, other, gap=0.5): |
|
assert self.speaker == other.speaker, (self.speaker, other.speaker) |
|
if self.end < other.start and self.end + gap >= other.start: |
|
return Segment(start=self.start, end=other.end, speaker=self.speaker) |
|
elif other.end < self.start and other.end + gap >= self.start: |
|
return Segment(start=other.start, end=self.end, speaker=self.speaker) |
|
else: |
|
return None |
|
|
|
@property |
|
def duration(self): |
|
return self.end - self.start |
|
|
|
def __str__(self): |
|
s = f"{timedelta(seconds=self.start)}"[:-3] |
|
s += " --> " |
|
s += f"{timedelta(seconds=self.end)}"[:-3] |
|
s += f" speaker_{self.speaker:02d}" |
|
return s |
|
|
|
|
|
def merge_segment_list(in_out: List[Segment], min_duration_off: float): |
|
changed = True |
|
while changed: |
|
changed = False |
|
for i in range(len(in_out)): |
|
if i + 1 >= len(in_out): |
|
continue |
|
|
|
new_segment = in_out[i].merge(in_out[i + 1], gap=min_duration_off) |
|
if new_segment is None: |
|
continue |
|
del in_out[i + 1] |
|
in_out[i] = new_segment |
|
changed = True |
|
break |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--seg-model", |
|
type=str, |
|
required=True, |
|
help="Path to model.onnx for segmentation", |
|
) |
|
parser.add_argument( |
|
"--speaker-embedding-model", |
|
type=str, |
|
required=True, |
|
help="Path to model.onnx for speaker embedding extractor", |
|
) |
|
parser.add_argument("--wav", type=str, required=True, help="Path to test.wav") |
|
|
|
return parser.parse_args() |
|
|
|
|
|
class OnnxSegmentationModel: |
|
def __init__(self, filename): |
|
session_opts = ort.SessionOptions() |
|
session_opts.inter_op_num_threads = 1 |
|
session_opts.intra_op_num_threads = 1 |
|
|
|
self.session_opts = session_opts |
|
|
|
self.model = ort.InferenceSession( |
|
filename, |
|
sess_options=self.session_opts, |
|
providers=["CPUExecutionProvider"], |
|
) |
|
|
|
meta = self.model.get_modelmeta().custom_metadata_map |
|
print(meta) |
|
|
|
self.window_size = int(meta["window_size"]) |
|
self.sample_rate = int(meta["sample_rate"]) |
|
self.window_shift = int(0.1 * self.window_size) |
|
self.receptive_field_size = int(meta["receptive_field_size"]) |
|
self.receptive_field_shift = int(meta["receptive_field_shift"]) |
|
self.num_speakers = int(meta["num_speakers"]) |
|
self.powerset_max_classes = int(meta["powerset_max_classes"]) |
|
self.num_classes = int(meta["num_classes"]) |
|
|
|
def __call__(self, x): |
|
""" |
|
Args: |
|
x: (N, num_samples) |
|
Returns: |
|
A tensor of shape (N, num_frames, num_classes) |
|
""" |
|
x = np.expand_dims(x, axis=1) |
|
|
|
(y,) = self.model.run( |
|
[self.model.get_outputs()[0].name], {self.model.get_inputs()[0].name: x} |
|
) |
|
|
|
return y |
|
|
|
|
|
def load_wav(filename, expected_sample_rate) -> np.ndarray: |
|
audio, sample_rate = sf.read(filename, dtype="float32", always_2d=True) |
|
audio = audio[:, 0] |
|
if sample_rate != expected_sample_rate: |
|
audio = librosa.resample( |
|
audio, |
|
orig_sr=sample_rate, |
|
target_sr=expected_sample_rate, |
|
) |
|
return audio |
|
|
|
|
|
def get_powerset_mapping(num_classes, num_speakers, powerset_max_classes): |
|
mapping = np.zeros((num_classes, num_speakers)) |
|
|
|
k = 1 |
|
for i in range(1, powerset_max_classes + 1): |
|
if i == 1: |
|
for j in range(0, num_speakers): |
|
mapping[k, j] = 1 |
|
k += 1 |
|
elif i == 2: |
|
for j in range(0, num_speakers): |
|
for m in range(j + 1, num_speakers): |
|
mapping[k, j] = 1 |
|
mapping[k, m] = 1 |
|
k += 1 |
|
elif i == 3: |
|
raise RuntimeError("Unsupported") |
|
|
|
return mapping |
|
|
|
|
|
def to_multi_label(y, mapping): |
|
""" |
|
Args: |
|
y: (num_chunks, num_frames, num_classes) |
|
Returns: |
|
A tensor of shape (num_chunks, num_frames, num_speakers) |
|
""" |
|
y = np.argmax(y, axis=-1) |
|
labels = mapping[y.reshape(-1)].reshape(y.shape[0], y.shape[1], -1) |
|
return labels |
|
|
|
|
|
|
|
def speaker_count(labels, seg_m): |
|
""" |
|
Args: |
|
labels: (num_chunks, num_frames, num_speakers) |
|
seg_m: Segmentation model |
|
Returns: |
|
A integer array of shape (num_total_frames,) |
|
""" |
|
labels = labels.sum(axis=-1) |
|
|
|
|
|
num_frames = ( |
|
int( |
|
(seg_m.window_size + (labels.shape[0] - 1) * seg_m.window_shift) |
|
/ seg_m.receptive_field_shift |
|
) |
|
+ 1 |
|
) |
|
ans = np.zeros((num_frames,)) |
|
count = np.zeros((num_frames,)) |
|
|
|
for i in range(labels.shape[0]): |
|
this_chunk = labels[i] |
|
start = int(i * seg_m.window_shift / seg_m.receptive_field_shift + 0.5) |
|
end = start + this_chunk.shape[0] |
|
ans[start:end] += this_chunk |
|
count[start:end] += 1 |
|
|
|
ans /= np.maximum(count, 1e-12) |
|
|
|
return (ans + 0.5).astype(np.int8) |
|
|
|
|
|
def load_speaker_embedding_model(filename): |
|
config = sherpa_onnx.SpeakerEmbeddingExtractorConfig( |
|
model=filename, |
|
num_threads=1, |
|
debug=0, |
|
) |
|
if not config.validate(): |
|
raise ValueError(f"Invalid config. {config}") |
|
extractor = sherpa_onnx.SpeakerEmbeddingExtractor(config) |
|
return extractor |
|
|
|
|
|
def get_embeddings(embedding_filename, audio, labels, seg_m, exclude_overlap): |
|
""" |
|
Args: |
|
embedding_filename: Path to the speaker embedding extractor model |
|
audio: (num_samples,) |
|
labels: (num_chunks, num_frames, num_speakers) |
|
seg_m: segmentation model |
|
Returns: |
|
Return (num_chunks, num_speakers, embedding_dim) |
|
""" |
|
if exclude_overlap: |
|
labels = labels * (labels.sum(axis=-1, keepdims=True) < 2) |
|
|
|
extractor = load_speaker_embedding_model(embedding_filename) |
|
buffer = np.empty(seg_m.window_size) |
|
num_chunks, num_frames, num_speakers = labels.shape |
|
|
|
ans_chunk_speaker_pair = [] |
|
ans_embeddings = [] |
|
|
|
for i in range(num_chunks): |
|
labels_T = labels[i].T |
|
|
|
|
|
sample_offset = i * seg_m.window_shift |
|
|
|
for j in range(num_speakers): |
|
frames = labels_T[j] |
|
if frames.sum() < 10: |
|
|
|
continue |
|
|
|
start = None |
|
start_samples = 0 |
|
idx = 0 |
|
for k in range(num_frames): |
|
if frames[k] != 0: |
|
if start is None: |
|
start = k |
|
elif start is not None: |
|
start_samples = ( |
|
int(start / num_frames * seg_m.window_size) + sample_offset |
|
) |
|
end_samples = ( |
|
int(k / num_frames * seg_m.window_size) + sample_offset |
|
) |
|
num_samples = end_samples - start_samples |
|
buffer[idx : idx + num_samples] = audio[start_samples:end_samples] |
|
idx += num_samples |
|
|
|
start = None |
|
if start is not None: |
|
start_samples = ( |
|
int(start / num_frames * seg_m.window_size) + sample_offset |
|
) |
|
end_samples = int(k / num_frames * seg_m.window_size) + sample_offset |
|
num_samples = end_samples - start_samples |
|
buffer[idx : idx + num_samples] = audio[start_samples:end_samples] |
|
idx += num_samples |
|
|
|
stream = extractor.create_stream() |
|
stream.accept_waveform(sample_rate=seg_m.sample_rate, waveform=buffer[:idx]) |
|
stream.input_finished() |
|
|
|
assert extractor.is_ready(stream) |
|
embedding = extractor.compute(stream) |
|
embedding = np.array(embedding) |
|
|
|
ans_chunk_speaker_pair.append([i, j]) |
|
ans_embeddings.append(embedding) |
|
|
|
assert len(ans_chunk_speaker_pair) == len(ans_embeddings), ( |
|
len(ans_chunk_speaker_pair), |
|
len(ans_embeddings), |
|
) |
|
return ans_chunk_speaker_pair, np.array(ans_embeddings) |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
assert Path(args.seg_model).is_file(), args.seg_model |
|
assert Path(args.wav).is_file(), args.wav |
|
|
|
seg_m = OnnxSegmentationModel(args.seg_model) |
|
audio = load_wav(args.wav, seg_m.sample_rate) |
|
|
|
|
|
num = (audio.shape[0] - seg_m.window_size) // seg_m.window_shift + 1 |
|
|
|
samples = as_strided( |
|
audio, |
|
shape=(num, seg_m.window_size), |
|
strides=(seg_m.window_shift * audio.strides[0], audio.strides[0]), |
|
) |
|
|
|
|
|
|
|
|
|
if ( |
|
audio.shape[0] < seg_m.window_size |
|
or (audio.shape[0] - seg_m.window_size) % seg_m.window_shift > 0 |
|
): |
|
has_last_chunk = True |
|
else: |
|
has_last_chunk = False |
|
|
|
num_chunks = samples.shape[0] |
|
batch_size = 32 |
|
output = [] |
|
for i in range(0, num_chunks, batch_size): |
|
start = i |
|
end = i + batch_size |
|
|
|
y = seg_m(samples[start:end]) |
|
output.append(y) |
|
|
|
if has_last_chunk: |
|
last_chunk = audio[num_chunks * seg_m.window_shift :] |
|
pad_size = seg_m.window_size - last_chunk.shape[0] |
|
last_chunk = np.pad(last_chunk, (0, pad_size)) |
|
last_chunk = np.expand_dims(last_chunk, axis=0) |
|
y = seg_m(last_chunk) |
|
output.append(y) |
|
|
|
y = np.vstack(output) |
|
|
|
|
|
mapping = get_powerset_mapping( |
|
num_classes=seg_m.num_classes, |
|
num_speakers=seg_m.num_speakers, |
|
powerset_max_classes=seg_m.powerset_max_classes, |
|
) |
|
labels = to_multi_label(y, mapping=mapping) |
|
|
|
|
|
inactive = (labels.sum(axis=1) == 0).astype(np.int8) |
|
|
|
|
|
speakers_per_frame = speaker_count(labels=labels, seg_m=seg_m) |
|
|
|
|
|
if speakers_per_frame.max() == 0: |
|
print("No speakers found in the audio file!") |
|
return |
|
|
|
|
|
|
|
|
|
|
|
chunk_speaker_pair, embeddings = get_embeddings( |
|
args.speaker_embedding_model, |
|
audio=audio, |
|
labels=labels, |
|
seg_m=seg_m, |
|
|
|
exclude_overlap=False, |
|
) |
|
|
|
|
|
|
|
|
|
clustering_config = sherpa_onnx.FastClusteringConfig(num_clusters=2) |
|
|
|
clustering = sherpa_onnx.FastClustering(clustering_config) |
|
cluster_labels = clustering(embeddings) |
|
|
|
chunk_speaker_to_cluster = dict() |
|
for (chunk_idx, speaker_idx), cluster_idx in zip( |
|
chunk_speaker_pair, cluster_labels |
|
): |
|
if inactive[chunk_idx, speaker_idx] == 1: |
|
print("skip ", chunk_idx, speaker_idx) |
|
continue |
|
chunk_speaker_to_cluster[(chunk_idx, speaker_idx)] = cluster_idx |
|
|
|
num_speakers = max(cluster_labels) + 1 |
|
relabels = np.zeros((labels.shape[0], labels.shape[1], num_speakers)) |
|
for i in range(labels.shape[0]): |
|
for j in range(labels.shape[1]): |
|
for k in range(labels.shape[2]): |
|
if (i, k) not in chunk_speaker_to_cluster: |
|
continue |
|
t = chunk_speaker_to_cluster[(i, k)] |
|
|
|
if labels[i, j, k] == 1: |
|
relabels[i, j, t] = 1 |
|
|
|
num_frames = ( |
|
int( |
|
(seg_m.window_size + (relabels.shape[0] - 1) * seg_m.window_shift) |
|
/ seg_m.receptive_field_shift |
|
) |
|
+ 1 |
|
) |
|
|
|
count = np.zeros((num_frames, relabels.shape[-1])) |
|
for i in range(relabels.shape[0]): |
|
this_chunk = relabels[i] |
|
start = int(i * seg_m.window_shift / seg_m.receptive_field_shift + 0.5) |
|
end = start + this_chunk.shape[0] |
|
count[start:end] += this_chunk |
|
|
|
if has_last_chunk: |
|
stop_frame = int(audio.shape[0] / seg_m.receptive_field_shift) |
|
count = count[:stop_frame] |
|
|
|
sorted_count = np.argsort(-count, axis=-1) |
|
final = np.zeros((count.shape[0], count.shape[1])) |
|
|
|
for i, (c, sc) in enumerate(zip(speakers_per_frame, sorted_count)): |
|
for k in range(c): |
|
final[i, sc[k]] = 1 |
|
|
|
min_duration_off = 0.5 |
|
min_duration_on = 0.3 |
|
onset = 0.5 |
|
offset = 0.5 |
|
|
|
|
|
final = final.T |
|
for kk in range(final.shape[0]): |
|
segment_list = [] |
|
frames = final[kk] |
|
|
|
is_active = frames[0] > onset |
|
|
|
start = None |
|
if is_active: |
|
start = 0 |
|
scale = seg_m.receptive_field_shift / seg_m.sample_rate |
|
scale_offset = seg_m.receptive_field_size / seg_m.sample_rate * 0.5 |
|
for i in range(1, len(frames)): |
|
if is_active: |
|
if frames[i] < offset: |
|
segment = Segment( |
|
start=start * scale + scale_offset, |
|
end=i * scale + scale_offset, |
|
speaker=kk, |
|
) |
|
segment_list.append(segment) |
|
is_active = False |
|
else: |
|
if frames[i] > onset: |
|
start = i |
|
is_active = True |
|
|
|
if is_active: |
|
segment = Segment( |
|
start=start * scale + scale_offset, |
|
end=(len(frames) - 1) * scale + scale_offset, |
|
speaker=kk, |
|
) |
|
segment_list.append(segment) |
|
|
|
if len(segment_list) > 1: |
|
merge_segment_list(segment_list, min_duration_off=min_duration_off) |
|
for s in segment_list: |
|
if s.duration < min_duration_on: |
|
continue |
|
print(s) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|