csukuangfj
commited on
Commit
•
5ec554b
1
Parent(s):
1b239e0
add models
Browse files- LICENSE +48 -0
- README.md +12 -0
- export-onnx.py +130 -0
- model.int8.onnx +3 -0
- model.onnx +3 -0
- run.sh +46 -0
- show-onnx.py +43 -0
- speaker-diarization-onnx.py +488 -0
- speaker-diarization-torch.py +86 -0
- vad-onnx.py +244 -0
- vad-torch.py +38 -0
LICENSE
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Rev Model Non-Production License
|
2 |
+
|
3 |
+
1. Scope and acceptance
|
4 |
+
1.1. Scope of the Agreement. This Agreement applies to any use, modification, or Distribution of any Rev Model by You, regardless of the source from which You obtained a copy of such Rev Model.
|
5 |
+
1.2. Acceptance. By accessing, using, modifying, Distributing a Rev Model, or by creating, using or distributing a Derivative of the Rev Model, You agree to be bound by this Agreement.
|
6 |
+
1.3. Acceptance on Behalf of a Third-Party. If You accept this Agreement on behalf of Your employer or another person or entity, You represent and warrant that You have the authority to act and accept this Agreement on their behalf. In such a case, the word “You” in this Agreement will refer to Your employer or such other person or entity.
|
7 |
+
|
8 |
+
2. License
|
9 |
+
2.1. Grant of Rights. Subject to Section 3 (Limitations) below, Rev hereby grants You a non-exclusive, royalty-free, worldwide, non-sublicensable, non-transferable, limited license to (a) use, copy, modify, and Distribute the Rev Model and any Derivatives, subject to the conditions provided in Section 2.2 below; and (b) use the Rev Model and any Derivatives to generate Output.
|
10 |
+
2.2. Distribution of Rev Model. Subject to Section 3 below, You may Distribute copies of the Rev Model under the following conditions: (a) You must make available a copy of this Agreement to third-party recipients of the Rev Model, it being specified that any rights to use the Rev Models shall be directly granted by Rev to said third-party recipients pursuant to the Rev Model Non-Production License agreement executed between these parties; and (b) You must retain in all copies of the Rev Models the following attribution notice within a “Notice” text file distributed as part of such copies: “Licensed by Rev under the Rev Model Non-Production License”.
|
11 |
+
2.3. Distribution of Derivatives. Subject to Section 3 below, You may Distribute any Derivatives made by or for You under additional or different terms and conditions, provided that: (a) in any event, Your use and modification of Rev Model and/or Derivatives shall remain governed by the terms and conditions of this Agreement; (b) You include in any such Derivatives made by or for You prominent notices stating that You modified the applicable Rev Model; (c) You will impose the same terms and conditions with respect to the Rev Model, Derivatives and Output that apply to You under this Agreement; and (d) any terms and conditions You impose on any third-party recipients relating to Derivatives shall neither limit such third-party recipients’ use of the Rev Model in accordance with this Rev Model Non-Production License.
|
12 |
+
|
13 |
+
3. Limitations
|
14 |
+
3.1. Misrepresentation. You must not misrepresent or imply, through any means, that the Derivatives made by or for You and/or any modified version of the Rev Model that You Distribute under your name and responsibility is an official product of Rev or has been endorsed, approved or validated by Rev, unless You are authorized by Us to do so in writing.
|
15 |
+
3.2. Usage Limitation. You shall only use the Rev Models, Derivatives (whether or not created by Rev) and Outputs for testing, research, Personal, or evaluation purposes in Non-Production Environments. Subject to the foregoing, You shall not supply the Rev Models, Derivatives, or Outputs in the course of a commercial activity, whether in return for payment or free of charge, in any medium or form, including but not limited to through a hosted or managed service (e.g. SaaS, cloud instances, etc.), or behind a software layer.
|
16 |
+
3.3. Usage Not Permitted Under this Agreement. If You want to use a Rev Model or a Derivative for any purpose that is not expressly authorized under this Agreement, You must request a license from Rev, which Rev may grant to You in Rev’s sole discretion. Please contact Rev at the following e-mail address if You want to discuss such a license: licensing@rev.com.
|
17 |
+
|
18 |
+
4. Intellectual Property
|
19 |
+
4.1. Trademarks. No trademark licenses are granted under this Agreement, and in connection with the Rev Models, You may not use any name or mark owned by or associated with Rev or any of its affiliates, except (a) as required for reasonable and customary use in describing and Distributing the Rev Models and Derivatives made by or for Rev and (b) for attribution purposes as required by this Agreement.
|
20 |
+
4.2. Outputs. We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs that You generate and their subsequent uses in accordance with this Agreement. Notwithstanding the foregoing, You agree that Your use of the Outputs is subject to the Usage Limitations set forth in Section 3.2 above.
|
21 |
+
4.3. Derivatives. By entering into this Agreement, You accept that any Derivatives that You may create or that may be created for You shall be subject to the restrictions set out in Section 3 of this Agreement.
|
22 |
+
|
23 |
+
5. Liability
|
24 |
+
5.1. Limitation of Liability. In no event, unless required by applicable law (such as the deliberate and grossly negligent acts of Rev) or agreed to in writing, shall Rev be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this Agreement or out of the use or inability to use the Rev Models and Derivatives (including but not limited to damages for loss of data, loss of goodwill, loss of expected profit or savings, work stoppage, computer failure or malfunction, or any damage caused by malware or security breaches), even if Rev has been advised of the possibility of such damages.
|
25 |
+
5.2. Indemnification. You agree to indemnify and hold harmless Rev from and against any claims, damages, or losses arising out of or related to Your use or Distribution of the Rev Models, Derivatives and Output.
|
26 |
+
|
27 |
+
6. Warranty Disclaimer. Unless required by applicable law or agreed to in writing, Rev provides the Rev Models and Derivatives on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. Rev does not represent nor warrant that the Rev Models, Derivatives or Output will be error-free, meet Your or any third party’s requirements, be secure or will allow You or any third party to achieve any kind of result or generate any kind of content. You are solely responsible for determining the appropriateness of using or Distributing the Rev Models, Derivatives and Output and assume any risks associated with Your exercise of rights under this Agreement.
|
28 |
+
|
29 |
+
7. Termination
|
30 |
+
7.1. Term. This Agreement is effective as of the date of your acceptance of this Agreement or access to the applicable Rev Models or Derivatives and will continue until terminated in accordance with the following terms.
|
31 |
+
7.2. Termination. Rev may terminate this Agreement at any time if You are in breach of this Agreement. Upon termination of this Agreement, You must cease to use all Rev Models and Derivatives and shall permanently delete any copy thereof. Sections 3, 4, 5, 6, 7 and 8 shall survive the termination of this Agreement.
|
32 |
+
7.3. Litigation. If You initiate any legal action or proceedings against Us or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the Rev Model or a Derivative, or any part thereof, infringe upon intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement will immediately terminate as of the date such legal action or claim is filed or initiated.
|
33 |
+
|
34 |
+
8. General provisions
|
35 |
+
8.1. Governing Law. This Agreement will be governed by the laws of the State of Texas, without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
|
36 |
+
8.2. Forum. The United States District Court for the Western District of Texas and any appellate court therefrom shall have exclusive jurisdiction of any dispute arising out of this Agreement.
|
37 |
+
8.3. Severability. If any provision of this Agreement is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
|
38 |
+
|
39 |
+
9. Definitions
|
40 |
+
“Agreement” means this Rev Model Non-Production License agreement governing the access, use, and Distribution of the Rev Models, Derivatives and Output.
|
41 |
+
“Derivative” means any (a) modified version of the Rev Model (including but not limited to any customized or fine-tuned version thereof), (b) work based on the Rev Model, or (c) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered as Derivatives under this Agreement.
|
42 |
+
“Distribution”, “Distributing”, “Distribute” or “Distributed” means providing or making available, by any means, a copy of the Rev Models and/or the Derivatives as the case may be, subject to Section 3 of this Agreement.
|
43 |
+
“Rev”, “We” or “Us” means Rev.com, Inc.
|
44 |
+
“Rev Model” means the AI model(s), and its elements which include algorithms, software, instructed checkpoints, parameters, source code (inference code, evaluation code and, if applicable, fine-tuning code) and any other elements associated thereto made available by Rev under this Agreement, including, if any, the technical documentation, manuals and instructions for the use and operation thereof.
|
45 |
+
“Non-Production Environment” means any setting, use case, or application of the Rev Models or Derivatives that expressly excludes live, real-world conditions, commercial operations, revenue-generating activities, or direct interactions with or impacts on end users (such as, for instance, Your employees or customers). Non-Production Environment may include, but is not limited to, any setting, use case, or application for research, development, testing, quality assurance, training, internal evaluation (other than any internal usage by employees in the context of the employer’s commercial business activities), and demonstration purposes by You.
|
46 |
+
“Outputs” mean any content generated by the operation of the Rev Models or the Derivatives from a prompt (i.e., audio files) provided by users. For the avoidance of doubt, Outputs do not include any components of a Rev Models, such as any fine-tuned versions of the Rev Models, the weights, or parameters.
|
47 |
+
“Personal” means any use of a Rev Model or a Derivative that is (a) solely for personal, non-profit and non-commercial purposes and (b) not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities. For illustration purposes, Personal use of a Rev Model or a Derivative does not include any usage by individuals employed in companies in the context of their daily tasks, any activity that is intended to generate revenue, or that is performed on behalf of a commercial entity.
|
48 |
+
“You” means the individual or entity entering into this Agreement with Rev.
|
README.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Introduction
|
2 |
+
|
3 |
+
Models in this file are converted from
|
4 |
+
https://huggingface.co/Revai/reverb-diarization-v1/tree/main
|
5 |
+
|
6 |
+
Note that it is accessible under a non-commercial license.
|
7 |
+
|
8 |
+
Please see ./LICENSE for details.
|
9 |
+
|
10 |
+
See also
|
11 |
+
https://www.rev.com/blog/speech-to-text-technology/introducing-reverb-open-source-asr-diarization
|
12 |
+
|
export-onnx.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
from typing import Any, Dict
|
4 |
+
|
5 |
+
import onnx
|
6 |
+
import torch
|
7 |
+
from onnxruntime.quantization import QuantType, quantize_dynamic
|
8 |
+
from pyannote.audio import Model
|
9 |
+
from pyannote.audio.core.task import Problem, Resolution
|
10 |
+
|
11 |
+
|
12 |
+
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
|
13 |
+
"""Add meta data to an ONNX model. It is changed in-place.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
filename:
|
17 |
+
Filename of the ONNX model to be changed.
|
18 |
+
meta_data:
|
19 |
+
Key-value pairs.
|
20 |
+
"""
|
21 |
+
model = onnx.load(filename)
|
22 |
+
|
23 |
+
while len(model.metadata_props):
|
24 |
+
model.metadata_props.pop()
|
25 |
+
|
26 |
+
for key, value in meta_data.items():
|
27 |
+
meta = model.metadata_props.add()
|
28 |
+
meta.key = key
|
29 |
+
meta.value = str(value)
|
30 |
+
|
31 |
+
onnx.save(model, filename)
|
32 |
+
|
33 |
+
|
34 |
+
@torch.no_grad()
|
35 |
+
def main():
|
36 |
+
# You can download ./pytorch_model.bin from
|
37 |
+
# https://hf-mirror.com/csukuangfj/pyannote-models/tree/main/segmentation-3.0
|
38 |
+
# or from
|
39 |
+
# https://huggingface.co/Revai/reverb-diarization-v1/tree/main
|
40 |
+
pt_filename = "./pytorch_model.bin"
|
41 |
+
model = Model.from_pretrained(pt_filename)
|
42 |
+
model.eval()
|
43 |
+
assert model.dimension == 7, model.dimension
|
44 |
+
print(model.specifications)
|
45 |
+
|
46 |
+
assert (
|
47 |
+
model.specifications.problem == Problem.MONO_LABEL_CLASSIFICATION
|
48 |
+
), model.specifications.problem
|
49 |
+
|
50 |
+
assert (
|
51 |
+
model.specifications.resolution == Resolution.FRAME
|
52 |
+
), model.specifications.resolution
|
53 |
+
|
54 |
+
assert model.specifications.duration == 10.0, model.specifications.duration
|
55 |
+
|
56 |
+
assert model.audio.sample_rate == 16000, model.audio.sample_rate
|
57 |
+
|
58 |
+
# (batch, num_channels, num_samples)
|
59 |
+
assert list(model.example_input_array.shape) == [
|
60 |
+
1,
|
61 |
+
1,
|
62 |
+
16000 * 10,
|
63 |
+
], model.example_input_array.shape
|
64 |
+
|
65 |
+
example_output = model(model.example_input_array)
|
66 |
+
|
67 |
+
# (batch, num_frames, num_classes)
|
68 |
+
assert list(example_output.shape) == [1, 589, 7], example_output.shape
|
69 |
+
|
70 |
+
assert model.receptive_field.step == 0.016875, model.receptive_field.step
|
71 |
+
assert model.receptive_field.duration == 0.0619375, model.receptive_field.duration
|
72 |
+
assert model.receptive_field.step * 16000 == 270, model.receptive_field.step * 16000
|
73 |
+
assert model.receptive_field.duration * 16000 == 991, (
|
74 |
+
model.receptive_field.duration * 16000
|
75 |
+
)
|
76 |
+
|
77 |
+
opset_version = 13
|
78 |
+
|
79 |
+
filename = "model.onnx"
|
80 |
+
torch.onnx.export(
|
81 |
+
model,
|
82 |
+
model.example_input_array,
|
83 |
+
filename,
|
84 |
+
opset_version=opset_version,
|
85 |
+
input_names=["x"],
|
86 |
+
output_names=["y"],
|
87 |
+
dynamic_axes={
|
88 |
+
"x": {0: "N", 2: "T"},
|
89 |
+
"y": {0: "N", 1: "T"},
|
90 |
+
},
|
91 |
+
)
|
92 |
+
|
93 |
+
sample_rate = model.audio.sample_rate
|
94 |
+
|
95 |
+
window_size = int(model.specifications.duration) * 16000
|
96 |
+
receptive_field_size = int(model.receptive_field.duration * 16000)
|
97 |
+
receptive_field_shift = int(model.receptive_field.step * 16000)
|
98 |
+
|
99 |
+
meta_data = {
|
100 |
+
"num_speakers": len(model.specifications.classes),
|
101 |
+
"powerset_max_classes": model.specifications.powerset_max_classes,
|
102 |
+
"num_classes": model.dimension,
|
103 |
+
"sample_rate": sample_rate,
|
104 |
+
"window_size": window_size,
|
105 |
+
"receptive_field_size": receptive_field_size,
|
106 |
+
"receptive_field_shift": receptive_field_shift,
|
107 |
+
"model_type": "pyannote-segmentation-3.0",
|
108 |
+
"version": "1",
|
109 |
+
"model_author": "pyannote",
|
110 |
+
"maintainer": "k2-fsa",
|
111 |
+
"url_1": "https://huggingface.co/pyannote/segmentation-3.0",
|
112 |
+
"url_2": "https://huggingface.co/csukuangfj/pyannote-models/tree/main/segmentation-3.0",
|
113 |
+
"license": "https://huggingface.co/pyannote/segmentation-3.0/blob/main/LICENSE",
|
114 |
+
}
|
115 |
+
add_meta_data(filename=filename, meta_data=meta_data)
|
116 |
+
|
117 |
+
print("Generate int8 quantization models")
|
118 |
+
|
119 |
+
filename_int8 = "model.int8.onnx"
|
120 |
+
quantize_dynamic(
|
121 |
+
model_input=filename,
|
122 |
+
model_output=filename_int8,
|
123 |
+
weight_type=QuantType.QUInt8,
|
124 |
+
)
|
125 |
+
|
126 |
+
print(f"Saved to {filename} and {filename_int8}")
|
127 |
+
|
128 |
+
|
129 |
+
if __name__ == "__main__":
|
130 |
+
main()
|
model.int8.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c0e83b2ca69b379eea37b80e2b739b1f6e43f3964c95aaac0bdeb5e2e225ec6e
|
3 |
+
size 2415974
|
model.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d4d3926911a214a1e56df418f8d967f6dee3e139348a0738b6ef982fb3108fd4
|
3 |
+
size 9512165
|
run.sh
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
3 |
+
|
4 |
+
set -ex
|
5 |
+
function install_pyannote() {
|
6 |
+
pip install pyannote.audio onnx onnxruntime
|
7 |
+
}
|
8 |
+
|
9 |
+
function download_test_files() {
|
10 |
+
curl -SL -O https://huggingface.co/Revai/reverb-diarization-v1/resolve/main/pytorch_model.bin
|
11 |
+
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/lei-jun-test.wav
|
12 |
+
}
|
13 |
+
|
14 |
+
install_pyannote
|
15 |
+
download_test_files
|
16 |
+
|
17 |
+
./export-onnx.py
|
18 |
+
./preprocess.sh
|
19 |
+
|
20 |
+
echo "----------torch----------"
|
21 |
+
./vad-torch.py
|
22 |
+
|
23 |
+
echo "----------onnx model.onnx----------"
|
24 |
+
./vad-onnx.py --model ./model.onnx --wav ./lei-jun-test.wav
|
25 |
+
|
26 |
+
echo "----------onnx model.int8.onnx----------"
|
27 |
+
./vad-onnx.py --model ./model.int8.onnx --wav ./lei-jun-test.wav
|
28 |
+
|
29 |
+
curl -SL -O https://huggingface.co/Revai/reverb-diarization-v1/resolve/main/LICENSE
|
30 |
+
|
31 |
+
cat >README.md << EOF
|
32 |
+
# Introduction
|
33 |
+
|
34 |
+
Models in this file are converted from
|
35 |
+
https://huggingface.co/Revai/reverb-diarization-v1/tree/main
|
36 |
+
|
37 |
+
Note that it is accessible under a non-commercial license.
|
38 |
+
|
39 |
+
Please see ./LICENSE for details.
|
40 |
+
|
41 |
+
See also
|
42 |
+
https://www.rev.com/blog/speech-to-text-technology/introducing-reverb-open-source-asr-diarization
|
43 |
+
|
44 |
+
EOF
|
45 |
+
|
46 |
+
|
show-onnx.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
3 |
+
|
4 |
+
import onnxruntime
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
|
8 |
+
def get_args():
|
9 |
+
parser = argparse.ArgumentParser(
|
10 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
11 |
+
)
|
12 |
+
|
13 |
+
parser.add_argument(
|
14 |
+
"--filename",
|
15 |
+
type=str,
|
16 |
+
required=True,
|
17 |
+
help="Path to model.onnx",
|
18 |
+
)
|
19 |
+
|
20 |
+
return parser.parse_args()
|
21 |
+
|
22 |
+
|
23 |
+
def show(filename):
|
24 |
+
session_opts = onnxruntime.SessionOptions()
|
25 |
+
session_opts.log_severity_level = 3
|
26 |
+
sess = onnxruntime.InferenceSession(filename, session_opts)
|
27 |
+
for i in sess.get_inputs():
|
28 |
+
print(i)
|
29 |
+
|
30 |
+
print("-----")
|
31 |
+
|
32 |
+
for i in sess.get_outputs():
|
33 |
+
print(i)
|
34 |
+
|
35 |
+
|
36 |
+
def main():
|
37 |
+
args = get_args()
|
38 |
+
print(f"========={args.filename}==========")
|
39 |
+
show(args.filename)
|
40 |
+
|
41 |
+
|
42 |
+
if __name__ == "__main__":
|
43 |
+
main()
|
speaker-diarization-onnx.py
ADDED
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
3 |
+
|
4 |
+
"""
|
5 |
+
Please refer to
|
6 |
+
https://github.com/k2-fsa/sherpa-onnx/blob/master/.github/workflows/speaker-diarization.yaml
|
7 |
+
for usages.
|
8 |
+
"""
|
9 |
+
|
10 |
+
import argparse
|
11 |
+
from datetime import timedelta
|
12 |
+
from pathlib import Path
|
13 |
+
from typing import List
|
14 |
+
|
15 |
+
import librosa
|
16 |
+
import numpy as np
|
17 |
+
import onnxruntime as ort
|
18 |
+
import sherpa_onnx
|
19 |
+
import soundfile as sf
|
20 |
+
from numpy.lib.stride_tricks import as_strided
|
21 |
+
|
22 |
+
|
23 |
+
class Segment:
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
start,
|
27 |
+
end,
|
28 |
+
speaker,
|
29 |
+
):
|
30 |
+
assert start < end
|
31 |
+
self.start = start
|
32 |
+
self.end = end
|
33 |
+
self.speaker = speaker
|
34 |
+
|
35 |
+
def merge(self, other, gap=0.5):
|
36 |
+
assert self.speaker == other.speaker, (self.speaker, other.speaker)
|
37 |
+
if self.end < other.start and self.end + gap >= other.start:
|
38 |
+
return Segment(start=self.start, end=other.end, speaker=self.speaker)
|
39 |
+
elif other.end < self.start and other.end + gap >= self.start:
|
40 |
+
return Segment(start=other.start, end=self.end, speaker=self.speaker)
|
41 |
+
else:
|
42 |
+
return None
|
43 |
+
|
44 |
+
@property
|
45 |
+
def duration(self):
|
46 |
+
return self.end - self.start
|
47 |
+
|
48 |
+
def __str__(self):
|
49 |
+
s = f"{timedelta(seconds=self.start)}"[:-3]
|
50 |
+
s += " --> "
|
51 |
+
s += f"{timedelta(seconds=self.end)}"[:-3]
|
52 |
+
s += f" speaker_{self.speaker:02d}"
|
53 |
+
return s
|
54 |
+
|
55 |
+
|
56 |
+
def merge_segment_list(in_out: List[Segment], min_duration_off: float):
|
57 |
+
changed = True
|
58 |
+
while changed:
|
59 |
+
changed = False
|
60 |
+
for i in range(len(in_out)):
|
61 |
+
if i + 1 >= len(in_out):
|
62 |
+
continue
|
63 |
+
|
64 |
+
new_segment = in_out[i].merge(in_out[i + 1], gap=min_duration_off)
|
65 |
+
if new_segment is None:
|
66 |
+
continue
|
67 |
+
del in_out[i + 1]
|
68 |
+
in_out[i] = new_segment
|
69 |
+
changed = True
|
70 |
+
break
|
71 |
+
|
72 |
+
|
73 |
+
def get_args():
|
74 |
+
parser = argparse.ArgumentParser()
|
75 |
+
parser.add_argument(
|
76 |
+
"--seg-model",
|
77 |
+
type=str,
|
78 |
+
required=True,
|
79 |
+
help="Path to model.onnx for segmentation",
|
80 |
+
)
|
81 |
+
parser.add_argument(
|
82 |
+
"--speaker-embedding-model",
|
83 |
+
type=str,
|
84 |
+
required=True,
|
85 |
+
help="Path to model.onnx for speaker embedding extractor",
|
86 |
+
)
|
87 |
+
parser.add_argument("--wav", type=str, required=True, help="Path to test.wav")
|
88 |
+
|
89 |
+
return parser.parse_args()
|
90 |
+
|
91 |
+
|
92 |
+
class OnnxSegmentationModel:
|
93 |
+
def __init__(self, filename):
|
94 |
+
session_opts = ort.SessionOptions()
|
95 |
+
session_opts.inter_op_num_threads = 1
|
96 |
+
session_opts.intra_op_num_threads = 1
|
97 |
+
|
98 |
+
self.session_opts = session_opts
|
99 |
+
|
100 |
+
self.model = ort.InferenceSession(
|
101 |
+
filename,
|
102 |
+
sess_options=self.session_opts,
|
103 |
+
providers=["CPUExecutionProvider"],
|
104 |
+
)
|
105 |
+
|
106 |
+
meta = self.model.get_modelmeta().custom_metadata_map
|
107 |
+
print(meta)
|
108 |
+
|
109 |
+
self.window_size = int(meta["window_size"])
|
110 |
+
self.sample_rate = int(meta["sample_rate"])
|
111 |
+
self.window_shift = int(0.1 * self.window_size)
|
112 |
+
self.receptive_field_size = int(meta["receptive_field_size"])
|
113 |
+
self.receptive_field_shift = int(meta["receptive_field_shift"])
|
114 |
+
self.num_speakers = int(meta["num_speakers"])
|
115 |
+
self.powerset_max_classes = int(meta["powerset_max_classes"])
|
116 |
+
self.num_classes = int(meta["num_classes"])
|
117 |
+
|
118 |
+
def __call__(self, x):
|
119 |
+
"""
|
120 |
+
Args:
|
121 |
+
x: (N, num_samples)
|
122 |
+
Returns:
|
123 |
+
A tensor of shape (N, num_frames, num_classes)
|
124 |
+
"""
|
125 |
+
x = np.expand_dims(x, axis=1)
|
126 |
+
|
127 |
+
(y,) = self.model.run(
|
128 |
+
[self.model.get_outputs()[0].name], {self.model.get_inputs()[0].name: x}
|
129 |
+
)
|
130 |
+
|
131 |
+
return y
|
132 |
+
|
133 |
+
|
134 |
+
def load_wav(filename, expected_sample_rate) -> np.ndarray:
|
135 |
+
audio, sample_rate = sf.read(filename, dtype="float32", always_2d=True)
|
136 |
+
audio = audio[:, 0] # only use the first channel
|
137 |
+
if sample_rate != expected_sample_rate:
|
138 |
+
audio = librosa.resample(
|
139 |
+
audio,
|
140 |
+
orig_sr=sample_rate,
|
141 |
+
target_sr=expected_sample_rate,
|
142 |
+
)
|
143 |
+
return audio
|
144 |
+
|
145 |
+
|
146 |
+
def get_powerset_mapping(num_classes, num_speakers, powerset_max_classes):
|
147 |
+
mapping = np.zeros((num_classes, num_speakers))
|
148 |
+
|
149 |
+
k = 1
|
150 |
+
for i in range(1, powerset_max_classes + 1):
|
151 |
+
if i == 1:
|
152 |
+
for j in range(0, num_speakers):
|
153 |
+
mapping[k, j] = 1
|
154 |
+
k += 1
|
155 |
+
elif i == 2:
|
156 |
+
for j in range(0, num_speakers):
|
157 |
+
for m in range(j + 1, num_speakers):
|
158 |
+
mapping[k, j] = 1
|
159 |
+
mapping[k, m] = 1
|
160 |
+
k += 1
|
161 |
+
elif i == 3:
|
162 |
+
raise RuntimeError("Unsupported")
|
163 |
+
|
164 |
+
return mapping
|
165 |
+
|
166 |
+
|
167 |
+
def to_multi_label(y, mapping):
|
168 |
+
"""
|
169 |
+
Args:
|
170 |
+
y: (num_chunks, num_frames, num_classes)
|
171 |
+
Returns:
|
172 |
+
A tensor of shape (num_chunks, num_frames, num_speakers)
|
173 |
+
"""
|
174 |
+
y = np.argmax(y, axis=-1)
|
175 |
+
labels = mapping[y.reshape(-1)].reshape(y.shape[0], y.shape[1], -1)
|
176 |
+
return labels
|
177 |
+
|
178 |
+
|
179 |
+
# speaker count per frame
|
180 |
+
def speaker_count(labels, seg_m):
|
181 |
+
"""
|
182 |
+
Args:
|
183 |
+
labels: (num_chunks, num_frames, num_speakers)
|
184 |
+
seg_m: Segmentation model
|
185 |
+
Returns:
|
186 |
+
A integer array of shape (num_total_frames,)
|
187 |
+
"""
|
188 |
+
labels = labels.sum(axis=-1)
|
189 |
+
# Now labels: (num_chunks, num_frames)
|
190 |
+
|
191 |
+
num_frames = (
|
192 |
+
int(
|
193 |
+
(seg_m.window_size + (labels.shape[0] - 1) * seg_m.window_shift)
|
194 |
+
/ seg_m.receptive_field_shift
|
195 |
+
)
|
196 |
+
+ 1
|
197 |
+
)
|
198 |
+
ans = np.zeros((num_frames,))
|
199 |
+
count = np.zeros((num_frames,))
|
200 |
+
|
201 |
+
for i in range(labels.shape[0]):
|
202 |
+
this_chunk = labels[i]
|
203 |
+
start = int(i * seg_m.window_shift / seg_m.receptive_field_shift + 0.5)
|
204 |
+
end = start + this_chunk.shape[0]
|
205 |
+
ans[start:end] += this_chunk
|
206 |
+
count[start:end] += 1
|
207 |
+
|
208 |
+
ans /= np.maximum(count, 1e-12)
|
209 |
+
|
210 |
+
return (ans + 0.5).astype(np.int8)
|
211 |
+
|
212 |
+
|
213 |
+
def load_speaker_embedding_model(filename):
|
214 |
+
config = sherpa_onnx.SpeakerEmbeddingExtractorConfig(
|
215 |
+
model=filename,
|
216 |
+
num_threads=1,
|
217 |
+
debug=0,
|
218 |
+
)
|
219 |
+
if not config.validate():
|
220 |
+
raise ValueError(f"Invalid config. {config}")
|
221 |
+
extractor = sherpa_onnx.SpeakerEmbeddingExtractor(config)
|
222 |
+
return extractor
|
223 |
+
|
224 |
+
|
225 |
+
def get_embeddings(embedding_filename, audio, labels, seg_m, exclude_overlap):
|
226 |
+
"""
|
227 |
+
Args:
|
228 |
+
embedding_filename: Path to the speaker embedding extractor model
|
229 |
+
audio: (num_samples,)
|
230 |
+
labels: (num_chunks, num_frames, num_speakers)
|
231 |
+
seg_m: segmentation model
|
232 |
+
Returns:
|
233 |
+
Return (num_chunks, num_speakers, embedding_dim)
|
234 |
+
"""
|
235 |
+
if exclude_overlap:
|
236 |
+
labels = labels * (labels.sum(axis=-1, keepdims=True) < 2)
|
237 |
+
|
238 |
+
extractor = load_speaker_embedding_model(embedding_filename)
|
239 |
+
buffer = np.empty(seg_m.window_size)
|
240 |
+
num_chunks, num_frames, num_speakers = labels.shape
|
241 |
+
|
242 |
+
ans_chunk_speaker_pair = []
|
243 |
+
ans_embeddings = []
|
244 |
+
|
245 |
+
for i in range(num_chunks):
|
246 |
+
labels_T = labels[i].T
|
247 |
+
# t: (num_speakers, num_frames)
|
248 |
+
|
249 |
+
sample_offset = i * seg_m.window_shift
|
250 |
+
|
251 |
+
for j in range(num_speakers):
|
252 |
+
frames = labels_T[j]
|
253 |
+
if frames.sum() < 10:
|
254 |
+
# skip segment less than 20 frames, i.e., about 0.2 seconds
|
255 |
+
continue
|
256 |
+
|
257 |
+
start = None
|
258 |
+
start_samples = 0
|
259 |
+
idx = 0
|
260 |
+
for k in range(num_frames):
|
261 |
+
if frames[k] != 0:
|
262 |
+
if start is None:
|
263 |
+
start = k
|
264 |
+
elif start is not None:
|
265 |
+
start_samples = (
|
266 |
+
int(start / num_frames * seg_m.window_size) + sample_offset
|
267 |
+
)
|
268 |
+
end_samples = (
|
269 |
+
int(k / num_frames * seg_m.window_size) + sample_offset
|
270 |
+
)
|
271 |
+
num_samples = end_samples - start_samples
|
272 |
+
buffer[idx : idx + num_samples] = audio[start_samples:end_samples]
|
273 |
+
idx += num_samples
|
274 |
+
|
275 |
+
start = None
|
276 |
+
if start is not None:
|
277 |
+
start_samples = (
|
278 |
+
int(start / num_frames * seg_m.window_size) + sample_offset
|
279 |
+
)
|
280 |
+
end_samples = int(k / num_frames * seg_m.window_size) + sample_offset
|
281 |
+
num_samples = end_samples - start_samples
|
282 |
+
buffer[idx : idx + num_samples] = audio[start_samples:end_samples]
|
283 |
+
idx += num_samples
|
284 |
+
|
285 |
+
stream = extractor.create_stream()
|
286 |
+
stream.accept_waveform(sample_rate=seg_m.sample_rate, waveform=buffer[:idx])
|
287 |
+
stream.input_finished()
|
288 |
+
|
289 |
+
assert extractor.is_ready(stream)
|
290 |
+
embedding = extractor.compute(stream)
|
291 |
+
embedding = np.array(embedding)
|
292 |
+
|
293 |
+
ans_chunk_speaker_pair.append([i, j])
|
294 |
+
ans_embeddings.append(embedding)
|
295 |
+
|
296 |
+
assert len(ans_chunk_speaker_pair) == len(ans_embeddings), (
|
297 |
+
len(ans_chunk_speaker_pair),
|
298 |
+
len(ans_embeddings),
|
299 |
+
)
|
300 |
+
return ans_chunk_speaker_pair, np.array(ans_embeddings)
|
301 |
+
|
302 |
+
|
303 |
+
def main():
|
304 |
+
args = get_args()
|
305 |
+
assert Path(args.seg_model).is_file(), args.seg_model
|
306 |
+
assert Path(args.wav).is_file(), args.wav
|
307 |
+
|
308 |
+
seg_m = OnnxSegmentationModel(args.seg_model)
|
309 |
+
audio = load_wav(args.wav, seg_m.sample_rate)
|
310 |
+
# audio: (num_samples,)
|
311 |
+
|
312 |
+
num = (audio.shape[0] - seg_m.window_size) // seg_m.window_shift + 1
|
313 |
+
|
314 |
+
samples = as_strided(
|
315 |
+
audio,
|
316 |
+
shape=(num, seg_m.window_size),
|
317 |
+
strides=(seg_m.window_shift * audio.strides[0], audio.strides[0]),
|
318 |
+
)
|
319 |
+
|
320 |
+
# or use torch.Tensor.unfold
|
321 |
+
# samples = torch.from_numpy(audio).unfold(0, seg_m.window_size, seg_m.window_shift).numpy()
|
322 |
+
|
323 |
+
if (
|
324 |
+
audio.shape[0] < seg_m.window_size
|
325 |
+
or (audio.shape[0] - seg_m.window_size) % seg_m.window_shift > 0
|
326 |
+
):
|
327 |
+
has_last_chunk = True
|
328 |
+
else:
|
329 |
+
has_last_chunk = False
|
330 |
+
|
331 |
+
num_chunks = samples.shape[0]
|
332 |
+
batch_size = 32
|
333 |
+
output = []
|
334 |
+
for i in range(0, num_chunks, batch_size):
|
335 |
+
start = i
|
336 |
+
end = i + batch_size
|
337 |
+
# it's perfectly ok to use end > num_chunks
|
338 |
+
y = seg_m(samples[start:end])
|
339 |
+
output.append(y)
|
340 |
+
|
341 |
+
if has_last_chunk:
|
342 |
+
last_chunk = audio[num_chunks * seg_m.window_shift :] # noqa
|
343 |
+
pad_size = seg_m.window_size - last_chunk.shape[0]
|
344 |
+
last_chunk = np.pad(last_chunk, (0, pad_size))
|
345 |
+
last_chunk = np.expand_dims(last_chunk, axis=0)
|
346 |
+
y = seg_m(last_chunk)
|
347 |
+
output.append(y)
|
348 |
+
|
349 |
+
y = np.vstack(output)
|
350 |
+
# y: (num_chunks, num_frames, num_classes)
|
351 |
+
|
352 |
+
mapping = get_powerset_mapping(
|
353 |
+
num_classes=seg_m.num_classes,
|
354 |
+
num_speakers=seg_m.num_speakers,
|
355 |
+
powerset_max_classes=seg_m.powerset_max_classes,
|
356 |
+
)
|
357 |
+
labels = to_multi_label(y, mapping=mapping)
|
358 |
+
# labels: (num_chunks, num_frames, num_speakers)
|
359 |
+
|
360 |
+
inactive = (labels.sum(axis=1) == 0).astype(np.int8)
|
361 |
+
# inactive: (num_chunks, num_speakers)
|
362 |
+
|
363 |
+
speakers_per_frame = speaker_count(labels=labels, seg_m=seg_m)
|
364 |
+
# speakers_per_frame: (num_frames, speakers_per_frame)
|
365 |
+
|
366 |
+
if speakers_per_frame.max() == 0:
|
367 |
+
print("No speakers found in the audio file!")
|
368 |
+
return
|
369 |
+
|
370 |
+
# if users specify only 1 speaker for clustering, then return the
|
371 |
+
# result directly
|
372 |
+
|
373 |
+
# Now, get embeddings
|
374 |
+
chunk_speaker_pair, embeddings = get_embeddings(
|
375 |
+
args.speaker_embedding_model,
|
376 |
+
audio=audio,
|
377 |
+
labels=labels,
|
378 |
+
seg_m=seg_m,
|
379 |
+
# exclude_overlap=True,
|
380 |
+
exclude_overlap=False,
|
381 |
+
)
|
382 |
+
# chunk_speaker_pair: a list of (chunk_idx, speaker_idx)
|
383 |
+
# embeddings: (batch_size, embedding_dim)
|
384 |
+
|
385 |
+
# Please change num_clusters or threshold by yourself.
|
386 |
+
clustering_config = sherpa_onnx.FastClusteringConfig(num_clusters=2)
|
387 |
+
# clustering_config = sherpa_onnx.FastClusteringConfig(threshold=0.8)
|
388 |
+
clustering = sherpa_onnx.FastClustering(clustering_config)
|
389 |
+
cluster_labels = clustering(embeddings)
|
390 |
+
|
391 |
+
chunk_speaker_to_cluster = dict()
|
392 |
+
for (chunk_idx, speaker_idx), cluster_idx in zip(
|
393 |
+
chunk_speaker_pair, cluster_labels
|
394 |
+
):
|
395 |
+
if inactive[chunk_idx, speaker_idx] == 1:
|
396 |
+
print("skip ", chunk_idx, speaker_idx)
|
397 |
+
continue
|
398 |
+
chunk_speaker_to_cluster[(chunk_idx, speaker_idx)] = cluster_idx
|
399 |
+
|
400 |
+
num_speakers = max(cluster_labels) + 1
|
401 |
+
relabels = np.zeros((labels.shape[0], labels.shape[1], num_speakers))
|
402 |
+
for i in range(labels.shape[0]):
|
403 |
+
for j in range(labels.shape[1]):
|
404 |
+
for k in range(labels.shape[2]):
|
405 |
+
if (i, k) not in chunk_speaker_to_cluster:
|
406 |
+
continue
|
407 |
+
t = chunk_speaker_to_cluster[(i, k)]
|
408 |
+
|
409 |
+
if labels[i, j, k] == 1:
|
410 |
+
relabels[i, j, t] = 1
|
411 |
+
|
412 |
+
num_frames = (
|
413 |
+
int(
|
414 |
+
(seg_m.window_size + (relabels.shape[0] - 1) * seg_m.window_shift)
|
415 |
+
/ seg_m.receptive_field_shift
|
416 |
+
)
|
417 |
+
+ 1
|
418 |
+
)
|
419 |
+
|
420 |
+
count = np.zeros((num_frames, relabels.shape[-1]))
|
421 |
+
for i in range(relabels.shape[0]):
|
422 |
+
this_chunk = relabels[i]
|
423 |
+
start = int(i * seg_m.window_shift / seg_m.receptive_field_shift + 0.5)
|
424 |
+
end = start + this_chunk.shape[0]
|
425 |
+
count[start:end] += this_chunk
|
426 |
+
|
427 |
+
if has_last_chunk:
|
428 |
+
stop_frame = int(audio.shape[0] / seg_m.receptive_field_shift)
|
429 |
+
count = count[:stop_frame]
|
430 |
+
|
431 |
+
sorted_count = np.argsort(-count, axis=-1)
|
432 |
+
final = np.zeros((count.shape[0], count.shape[1]))
|
433 |
+
|
434 |
+
for i, (c, sc) in enumerate(zip(speakers_per_frame, sorted_count)):
|
435 |
+
for k in range(c):
|
436 |
+
final[i, sc[k]] = 1
|
437 |
+
|
438 |
+
min_duration_off = 0.5
|
439 |
+
min_duration_on = 0.3
|
440 |
+
onset = 0.5
|
441 |
+
offset = 0.5
|
442 |
+
# final: (num_frames, num_speakers)
|
443 |
+
|
444 |
+
final = final.T
|
445 |
+
for kk in range(final.shape[0]):
|
446 |
+
segment_list = []
|
447 |
+
frames = final[kk]
|
448 |
+
|
449 |
+
is_active = frames[0] > onset
|
450 |
+
|
451 |
+
start = None
|
452 |
+
if is_active:
|
453 |
+
start = 0
|
454 |
+
scale = seg_m.receptive_field_shift / seg_m.sample_rate
|
455 |
+
scale_offset = seg_m.receptive_field_size / seg_m.sample_rate * 0.5
|
456 |
+
for i in range(1, len(frames)):
|
457 |
+
if is_active:
|
458 |
+
if frames[i] < offset:
|
459 |
+
segment = Segment(
|
460 |
+
start=start * scale + scale_offset,
|
461 |
+
end=i * scale + scale_offset,
|
462 |
+
speaker=kk,
|
463 |
+
)
|
464 |
+
segment_list.append(segment)
|
465 |
+
is_active = False
|
466 |
+
else:
|
467 |
+
if frames[i] > onset:
|
468 |
+
start = i
|
469 |
+
is_active = True
|
470 |
+
|
471 |
+
if is_active:
|
472 |
+
segment = Segment(
|
473 |
+
start=start * scale + scale_offset,
|
474 |
+
end=(len(frames) - 1) * scale + scale_offset,
|
475 |
+
speaker=kk,
|
476 |
+
)
|
477 |
+
segment_list.append(segment)
|
478 |
+
|
479 |
+
if len(segment_list) > 1:
|
480 |
+
merge_segment_list(segment_list, min_duration_off=min_duration_off)
|
481 |
+
for s in segment_list:
|
482 |
+
if s.duration < min_duration_on:
|
483 |
+
continue
|
484 |
+
print(s)
|
485 |
+
|
486 |
+
|
487 |
+
if __name__ == "__main__":
|
488 |
+
main()
|
speaker-diarization-torch.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
"""
|
4 |
+
Please refer to
|
5 |
+
https://github.com/k2-fsa/sherpa-onnx/blob/master/.github/workflows/speaker-diarization.yaml
|
6 |
+
for usages.
|
7 |
+
"""
|
8 |
+
|
9 |
+
"""
|
10 |
+
1. Go to https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM/tree/main
|
11 |
+
wget https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM/resolve/main/speaker-embedding.onnx
|
12 |
+
|
13 |
+
2. Change line 166 of pyannote/audio/pipelines/speaker_diarization.py
|
14 |
+
|
15 |
+
```
|
16 |
+
# self._embedding = PretrainedSpeakerEmbedding(
|
17 |
+
# self.embedding, use_auth_token=use_auth_token
|
18 |
+
# )
|
19 |
+
self._embedding = embedding
|
20 |
+
```
|
21 |
+
"""
|
22 |
+
|
23 |
+
import argparse
|
24 |
+
from pathlib import Path
|
25 |
+
|
26 |
+
import torch
|
27 |
+
from pyannote.audio import Model
|
28 |
+
from pyannote.audio.pipelines import SpeakerDiarization as SpeakerDiarizationPipeline
|
29 |
+
from pyannote.audio.pipelines.speaker_verification import (
|
30 |
+
ONNXWeSpeakerPretrainedSpeakerEmbedding,
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
def get_args():
|
35 |
+
parser = argparse.ArgumentParser()
|
36 |
+
parser.add_argument("--wav", type=str, required=True, help="Path to test.wav")
|
37 |
+
|
38 |
+
return parser.parse_args()
|
39 |
+
|
40 |
+
|
41 |
+
def build_pipeline():
|
42 |
+
embedding_filename = "./speaker-embedding.onnx"
|
43 |
+
if Path(embedding_filename).is_file():
|
44 |
+
# You need to modify line 166
|
45 |
+
# of pyannote/audio/pipelines/speaker_diarization.py
|
46 |
+
# Please see the comments at the start of this script for details
|
47 |
+
embedding = ONNXWeSpeakerPretrainedSpeakerEmbedding(embedding_filename)
|
48 |
+
else:
|
49 |
+
embedding = "hbredin/wespeaker-voxceleb-resnet34-LM"
|
50 |
+
|
51 |
+
pt_filename = "./pytorch_model.bin"
|
52 |
+
segmentation = Model.from_pretrained(pt_filename)
|
53 |
+
segmentation.eval()
|
54 |
+
|
55 |
+
pipeline = SpeakerDiarizationPipeline(
|
56 |
+
segmentation=segmentation,
|
57 |
+
embedding=embedding,
|
58 |
+
embedding_exclude_overlap=True,
|
59 |
+
)
|
60 |
+
|
61 |
+
params = {
|
62 |
+
"clustering": {
|
63 |
+
"method": "centroid",
|
64 |
+
"min_cluster_size": 12,
|
65 |
+
"threshold": 0.7045654963945799,
|
66 |
+
},
|
67 |
+
"segmentation": {"min_duration_off": 0.5},
|
68 |
+
}
|
69 |
+
|
70 |
+
pipeline.instantiate(params)
|
71 |
+
return pipeline
|
72 |
+
|
73 |
+
|
74 |
+
@torch.no_grad()
|
75 |
+
def main():
|
76 |
+
args = get_args()
|
77 |
+
assert Path(args.wav).is_file(), args.wav
|
78 |
+
pipeline = build_pipeline()
|
79 |
+
print(pipeline)
|
80 |
+
t = pipeline(args.wav)
|
81 |
+
print(type(t))
|
82 |
+
print(t)
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == "__main__":
|
86 |
+
main()
|
vad-onnx.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
"""
|
4 |
+
./export-onnx.py
|
5 |
+
./preprocess.sh
|
6 |
+
|
7 |
+
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/lei-jun-test.wav
|
8 |
+
./vad-onnx.py --model ./model.onnx --wav ./lei-jun-test.wav
|
9 |
+
"""
|
10 |
+
|
11 |
+
import argparse
|
12 |
+
from pathlib import Path
|
13 |
+
|
14 |
+
import librosa
|
15 |
+
import numpy as np
|
16 |
+
import onnxruntime as ort
|
17 |
+
import soundfile as sf
|
18 |
+
from numpy.lib.stride_tricks import as_strided
|
19 |
+
|
20 |
+
|
21 |
+
def get_args():
|
22 |
+
parser = argparse.ArgumentParser()
|
23 |
+
parser.add_argument("--model", type=str, required=True, help="Path to model.onnx")
|
24 |
+
parser.add_argument("--wav", type=str, required=True, help="Path to test.wav")
|
25 |
+
|
26 |
+
return parser.parse_args()
|
27 |
+
|
28 |
+
|
29 |
+
class OnnxModel:
|
30 |
+
def __init__(self, filename):
|
31 |
+
session_opts = ort.SessionOptions()
|
32 |
+
session_opts.inter_op_num_threads = 1
|
33 |
+
session_opts.intra_op_num_threads = 1
|
34 |
+
|
35 |
+
self.session_opts = session_opts
|
36 |
+
|
37 |
+
self.model = ort.InferenceSession(
|
38 |
+
filename,
|
39 |
+
sess_options=self.session_opts,
|
40 |
+
providers=["CPUExecutionProvider"],
|
41 |
+
)
|
42 |
+
|
43 |
+
meta = self.model.get_modelmeta().custom_metadata_map
|
44 |
+
print(meta)
|
45 |
+
|
46 |
+
self.window_size = int(meta["window_size"])
|
47 |
+
self.sample_rate = int(meta["sample_rate"])
|
48 |
+
self.window_shift = int(0.1 * self.window_size)
|
49 |
+
self.receptive_field_size = int(meta["receptive_field_size"])
|
50 |
+
self.receptive_field_shift = int(meta["receptive_field_shift"])
|
51 |
+
self.num_speakers = int(meta["num_speakers"])
|
52 |
+
self.powerset_max_classes = int(meta["powerset_max_classes"])
|
53 |
+
self.num_classes = int(meta["num_classes"])
|
54 |
+
|
55 |
+
def __call__(self, x):
|
56 |
+
"""
|
57 |
+
Args:
|
58 |
+
x: (N, num_samples)
|
59 |
+
Returns:
|
60 |
+
A tensor of shape (N, num_frames, num_classes)
|
61 |
+
"""
|
62 |
+
x = np.expand_dims(x, axis=1)
|
63 |
+
|
64 |
+
(y,) = self.model.run(
|
65 |
+
[self.model.get_outputs()[0].name], {self.model.get_inputs()[0].name: x}
|
66 |
+
)
|
67 |
+
|
68 |
+
return y
|
69 |
+
|
70 |
+
|
71 |
+
def load_wav(filename, expected_sample_rate) -> np.ndarray:
|
72 |
+
audio, sample_rate = sf.read(filename, dtype="float32", always_2d=True)
|
73 |
+
audio = audio[:, 0] # only use the first channel
|
74 |
+
if sample_rate != expected_sample_rate:
|
75 |
+
audio = librosa.resample(
|
76 |
+
audio,
|
77 |
+
orig_sr=sample_rate,
|
78 |
+
target_sr=expected_sample_rate,
|
79 |
+
)
|
80 |
+
return audio
|
81 |
+
|
82 |
+
|
83 |
+
def get_powerset_mapping(num_classes, num_speakers, powerset_max_classes):
|
84 |
+
mapping = np.zeros((num_classes, num_speakers))
|
85 |
+
|
86 |
+
k = 1
|
87 |
+
for i in range(1, powerset_max_classes + 1):
|
88 |
+
if i == 1:
|
89 |
+
for j in range(0, num_speakers):
|
90 |
+
mapping[k, j] = 1
|
91 |
+
k += 1
|
92 |
+
elif i == 2:
|
93 |
+
for j in range(0, num_speakers):
|
94 |
+
for m in range(j + 1, num_speakers):
|
95 |
+
mapping[k, j] = 1
|
96 |
+
mapping[k, m] = 1
|
97 |
+
k += 1
|
98 |
+
elif i == 3:
|
99 |
+
raise RuntimeError("Unsupported")
|
100 |
+
|
101 |
+
return mapping
|
102 |
+
|
103 |
+
|
104 |
+
def to_multi_label(y, mapping):
|
105 |
+
"""
|
106 |
+
Args:
|
107 |
+
y: (num_chunks, num_frames, num_classes)
|
108 |
+
Returns:
|
109 |
+
A tensor of shape (num_chunks, num_frames, num_speakers)
|
110 |
+
"""
|
111 |
+
y = np.argmax(y, axis=-1)
|
112 |
+
labels = mapping[y.reshape(-1)].reshape(y.shape[0], y.shape[1], -1)
|
113 |
+
return labels
|
114 |
+
|
115 |
+
|
116 |
+
def main():
|
117 |
+
args = get_args()
|
118 |
+
assert Path(args.model).is_file(), args.model
|
119 |
+
assert Path(args.wav).is_file(), args.wav
|
120 |
+
|
121 |
+
m = OnnxModel(args.model)
|
122 |
+
audio = load_wav(args.wav, m.sample_rate)
|
123 |
+
# audio: (num_samples,)
|
124 |
+
print("audio", audio.shape, audio.min(), audio.max(), audio.sum())
|
125 |
+
|
126 |
+
num = (audio.shape[0] - m.window_size) // m.window_shift + 1
|
127 |
+
|
128 |
+
samples = as_strided(
|
129 |
+
audio,
|
130 |
+
shape=(num, m.window_size),
|
131 |
+
strides=(m.window_shift * audio.strides[0], audio.strides[0]),
|
132 |
+
)
|
133 |
+
|
134 |
+
# or use torch.Tensor.unfold
|
135 |
+
# samples = torch.from_numpy(audio).unfold(0, m.window_size, m.window_shift).numpy()
|
136 |
+
|
137 |
+
print(
|
138 |
+
"samples",
|
139 |
+
samples.shape,
|
140 |
+
samples.mean(),
|
141 |
+
samples.sum(),
|
142 |
+
samples[:3, :3].sum(axis=-1),
|
143 |
+
)
|
144 |
+
|
145 |
+
if (
|
146 |
+
audio.shape[0] < m.window_size
|
147 |
+
or (audio.shape[0] - m.window_size) % m.window_shift > 0
|
148 |
+
):
|
149 |
+
has_last_chunk = True
|
150 |
+
else:
|
151 |
+
has_last_chunk = False
|
152 |
+
|
153 |
+
num_chunks = samples.shape[0]
|
154 |
+
batch_size = 32
|
155 |
+
output = []
|
156 |
+
for i in range(0, num_chunks, batch_size):
|
157 |
+
start = i
|
158 |
+
end = i + batch_size
|
159 |
+
# it's perfectly ok to use end > num_chunks
|
160 |
+
y = m(samples[start:end])
|
161 |
+
output.append(y)
|
162 |
+
|
163 |
+
if has_last_chunk:
|
164 |
+
last_chunk = audio[num_chunks * m.window_shift :] # noqa
|
165 |
+
pad_size = m.window_size - last_chunk.shape[0]
|
166 |
+
last_chunk = np.pad(last_chunk, (0, pad_size))
|
167 |
+
last_chunk = np.expand_dims(last_chunk, axis=0)
|
168 |
+
y = m(last_chunk)
|
169 |
+
output.append(y)
|
170 |
+
|
171 |
+
y = np.vstack(output)
|
172 |
+
# y: (num_chunks, num_frames, num_classes)
|
173 |
+
|
174 |
+
mapping = get_powerset_mapping(
|
175 |
+
num_classes=m.num_classes,
|
176 |
+
num_speakers=m.num_speakers,
|
177 |
+
powerset_max_classes=m.powerset_max_classes,
|
178 |
+
)
|
179 |
+
labels = to_multi_label(y, mapping=mapping)
|
180 |
+
# labels: (num_chunks, num_frames, num_speakers)
|
181 |
+
|
182 |
+
# binary classification
|
183 |
+
labels = np.max(labels, axis=-1)
|
184 |
+
# labels: (num_chunk, num_frames)
|
185 |
+
|
186 |
+
num_frames = (
|
187 |
+
int(
|
188 |
+
(m.window_size + (labels.shape[0] - 1) * m.window_shift)
|
189 |
+
/ m.receptive_field_shift
|
190 |
+
)
|
191 |
+
+ 1
|
192 |
+
)
|
193 |
+
|
194 |
+
count = np.zeros((num_frames,))
|
195 |
+
classification = np.zeros((num_frames,))
|
196 |
+
weight = np.hamming(labels.shape[1])
|
197 |
+
|
198 |
+
for i in range(labels.shape[0]):
|
199 |
+
this_chunk = labels[i]
|
200 |
+
start = int(i * m.window_shift / m.receptive_field_shift + 0.5)
|
201 |
+
end = start + this_chunk.shape[0]
|
202 |
+
|
203 |
+
classification[start:end] += this_chunk * weight
|
204 |
+
count[start:end] += weight
|
205 |
+
|
206 |
+
classification /= np.maximum(count, 1e-12)
|
207 |
+
|
208 |
+
if has_last_chunk:
|
209 |
+
stop_frame = int(audio.shape[0] / m.receptive_field_shift)
|
210 |
+
classification = classification[:stop_frame]
|
211 |
+
|
212 |
+
classification = classification.tolist()
|
213 |
+
|
214 |
+
onset = 0.5
|
215 |
+
offset = 0.5
|
216 |
+
|
217 |
+
is_active = classification[0] > onset
|
218 |
+
start = None
|
219 |
+
if is_active:
|
220 |
+
start = 0
|
221 |
+
|
222 |
+
scale = m.receptive_field_shift / m.sample_rate
|
223 |
+
scale_offset = m.receptive_field_size / m.sample_rate * 0.5
|
224 |
+
|
225 |
+
for i in range(len(classification)):
|
226 |
+
if is_active:
|
227 |
+
if classification[i] < offset:
|
228 |
+
print(
|
229 |
+
f"{start*scale + scale_offset:.3f} -- {i*scale + scale_offset:.3f}"
|
230 |
+
)
|
231 |
+
is_active = False
|
232 |
+
else:
|
233 |
+
if classification[i] > onset:
|
234 |
+
start = i
|
235 |
+
is_active = True
|
236 |
+
|
237 |
+
if is_active:
|
238 |
+
print(
|
239 |
+
f"{start*scale + scale_offset:.3f} -- {(len(classification)-1)*scale + scale_offset:.3f}"
|
240 |
+
)
|
241 |
+
|
242 |
+
|
243 |
+
if __name__ == "__main__":
|
244 |
+
main()
|
vad-torch.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from pyannote.audio import Model
|
5 |
+
from pyannote.audio.pipelines import (
|
6 |
+
VoiceActivityDetection as VoiceActivityDetectionPipeline,
|
7 |
+
)
|
8 |
+
|
9 |
+
|
10 |
+
@torch.no_grad()
|
11 |
+
def main():
|
12 |
+
# Please download it from
|
13 |
+
# https://huggingface.co/csukuangfj/pyannote-models/tree/main/segmentation-3.0
|
14 |
+
pt_filename = "./pytorch_model.bin"
|
15 |
+
model = Model.from_pretrained(pt_filename)
|
16 |
+
model.eval()
|
17 |
+
|
18 |
+
pipeline = VoiceActivityDetectionPipeline(segmentation=model)
|
19 |
+
|
20 |
+
# https://huggingface.co/pyannote/voice-activity-detection/blob/main/config.yaml
|
21 |
+
# https://github.com/pyannote/pyannote-audio/issues/1215
|
22 |
+
initial_params = {
|
23 |
+
"min_duration_on": 0.0,
|
24 |
+
"min_duration_off": 0.0,
|
25 |
+
}
|
26 |
+
pipeline.onset = 0.5
|
27 |
+
pipeline.offset = 0.5
|
28 |
+
|
29 |
+
pipeline.instantiate(initial_params)
|
30 |
+
|
31 |
+
# wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/lei-jun-test.wav
|
32 |
+
t = pipeline("./lei-jun-test.wav")
|
33 |
+
print(type(t))
|
34 |
+
print(t)
|
35 |
+
|
36 |
+
|
37 |
+
if __name__ == "__main__":
|
38 |
+
main()
|