csukuangfj commited on
Commit
16d0f41
·
1 Parent(s): e6c05b9

first commit

Browse files
README.md CHANGED
@@ -1,3 +1,7 @@
1
  ---
2
  license: apache-2.0
3
  ---
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  ---
4
+
5
+ # Introduction
6
+
7
+ See https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-onnx/file/view/master/quickstart.md
add-model-metadata.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) 2023 Xiaomi Corporation
4
+ # Author: Fangjun Kuang
5
+
6
+ from pathlib import Path
7
+ from typing import Dict
8
+
9
+ import numpy as np
10
+ import onnx
11
+ import yaml
12
+
13
+
14
+ def load_cmvn():
15
+ neg_mean = None
16
+ inv_stddev = None
17
+
18
+ with open("am.mvn") as f:
19
+ for line in f:
20
+ if not line.startswith("<LearnRateCoef>"):
21
+ continue
22
+ t = line.split()[3:-1]
23
+
24
+ if neg_mean is None:
25
+ neg_mean = ",".join(t)
26
+ else:
27
+ inv_stddev = ",".join(t)
28
+
29
+ return neg_mean, inv_stddev
30
+
31
+
32
+ def load_lfr_params(config):
33
+ with open("config.yaml") as f:
34
+ for line in f:
35
+ if "lfr_m" in line:
36
+ lfr_m = int(line.split()[-1])
37
+ elif "lfr_n" in line:
38
+ lfr_n = int(line.split()[-1])
39
+ break
40
+ lfr_window_size = config["frontend_conf"]["lfr_m"]
41
+ lfr_window_shift = config["frontend_conf"]["lfr_n"]
42
+
43
+ return lfr_window_size, lfr_window_shift
44
+
45
+
46
+ def add_meta_data(filename: str, meta_data: Dict[str, str]):
47
+ """Add meta data to an ONNX model. It is changed in-place.
48
+
49
+ Args:
50
+ filename:
51
+ Filename of the ONNX model to be changed.
52
+ meta_data:
53
+ Key-value pairs.
54
+ """
55
+ model = onnx.load(filename)
56
+ for key, value in meta_data.items():
57
+ meta = model.metadata_props.add()
58
+ meta.key = key
59
+ meta.value = value
60
+
61
+ onnx.save(model, filename)
62
+ print(f"Updated {filename}")
63
+
64
+
65
+ def main():
66
+ if Path(".done").is_file():
67
+ print("already added model metadata - skipping")
68
+ return
69
+ with open("config.yaml", "r") as stream:
70
+ config = yaml.safe_load(stream)
71
+
72
+ lfr_window_size, lfr_window_shift = load_lfr_params(config)
73
+ neg_mean, inv_stddev = load_cmvn()
74
+ vocab_size = len(config["token_list"])
75
+
76
+ meta_data = {
77
+ "lfr_window_size": str(lfr_window_size),
78
+ "lfr_window_shift": str(lfr_window_shift),
79
+ "neg_mean": neg_mean,
80
+ "inv_stddev": inv_stddev,
81
+ "model_type": "paraformer",
82
+ "version": "1",
83
+ "model_author": "damo",
84
+ "vocab_size": str(vocab_size),
85
+ "comment": "speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
86
+ }
87
+ add_meta_data("model.int8.onnx", meta_data)
88
+
89
+ Path(".done").touch()
90
+
91
+
92
+ if __name__ == "__main__":
93
+ main()
convert-tokens.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import sys
3
+ from pathlib import Path
4
+ from typing import Dict
5
+ import yaml
6
+
7
+
8
+ def write_tokens(tokens: Dict[int, str]):
9
+ with open("tokens.txt", "w", encoding="utf-8") as f:
10
+ for idx, s in enumerate(tokens):
11
+ f.write(f"{s} {idx}\n")
12
+
13
+
14
+ def main():
15
+ if Path("./tokens.txt").is_file():
16
+ print("./tokens.txt already exists - skipping")
17
+ return
18
+
19
+ with open("config.yaml", "r") as stream:
20
+ config = yaml.safe_load(stream)
21
+
22
+ tokens = config["token_list"]
23
+ write_tokens(tokens)
24
+
25
+
26
+ if __name__ == "__main__":
27
+ main()
download-model.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Please first run
4
+ # pip install modelscope
5
+
6
+ # See https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-onnx/file/view/master/quickstart.md
7
+
8
+
9
+ from modelscope.hub.file_download import model_file_download
10
+
11
+ files = [
12
+ "model_quant.onnx",
13
+ "am.mvn",
14
+ "config.yaml",
15
+ "configuration.json",
16
+ ]
17
+ for f in files:
18
+ model_dir = model_file_download(
19
+ model_id="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-onnx",
20
+ file_path=f,
21
+ revision="v1.2.4",
22
+ )
23
+ print(model_dir)
24
+
25
+ # /Users/fangjun/.cache/modelscope/hub/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-onnx/model_quant.onnx
26
+ #
27
+ # mv model_quant.onnx model.int8.onnx
test-paraformer-onnx.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) 2023 Xiaomi Corporation
4
+ # Author: Fangjun Kuang
5
+
6
+ import kaldi_native_fbank as knf
7
+ import librosa
8
+ import numpy as np
9
+ import onnxruntime
10
+
11
+
12
+ """
13
+ ---------inputs----------
14
+ speech ['batch_size', 'feats_length', 560] tensor(float)
15
+ speech_lengths ['batch_size'] tensor(int32)
16
+ ---------outputs----------
17
+ logits ['batch_size', 'logits_length', 8404] tensor(float)
18
+ token_num ['Casttoken_num_dim_0'] tensor(int32)
19
+ us_alphas ['batch_size', 'alphas_length'] tensor(float)
20
+ us_cif_peak ['batch_size', 'alphas_length'] tensor(float)
21
+ """
22
+
23
+
24
+ def show_model_info():
25
+ session_opts = onnxruntime.SessionOptions()
26
+ session_opts.log_severity_level = 3 # error level
27
+ sess = onnxruntime.InferenceSession("model.int8.onnx", session_opts)
28
+ print("---------inputs----------")
29
+ for n in sess.get_inputs():
30
+ print(n.name, n.shape, n.type)
31
+
32
+ print("---------outputs----------")
33
+ for n in sess.get_outputs():
34
+ print(n.name, n.shape, n.type)
35
+
36
+ import sys
37
+
38
+ sys.exit(0)
39
+
40
+
41
+ def load_cmvn():
42
+ neg_mean = None
43
+ inv_std = None
44
+
45
+ with open("am.mvn") as f:
46
+ for line in f:
47
+ if not line.startswith("<LearnRateCoef>"):
48
+ continue
49
+ t = line.split()[3:-1]
50
+ t = list(map(lambda x: float(x), t))
51
+
52
+ if neg_mean is None:
53
+ neg_mean = np.array(t, dtype=np.float32)
54
+ else:
55
+ inv_std = np.array(t, dtype=np.float32)
56
+
57
+ return neg_mean, inv_std
58
+
59
+
60
+ def compute_feat(filename):
61
+ sample_rate = 16000
62
+ samples, _ = librosa.load(filename, sr=sample_rate)
63
+ opts = knf.FbankOptions()
64
+ opts.frame_opts.dither = 0
65
+ opts.frame_opts.snip_edges = False
66
+ opts.frame_opts.samp_freq = sample_rate
67
+ opts.mel_opts.num_bins = 80
68
+
69
+ online_fbank = knf.OnlineFbank(opts)
70
+ online_fbank.accept_waveform(sample_rate, (samples * 32768).tolist())
71
+ online_fbank.input_finished()
72
+
73
+ features = np.stack(
74
+ [online_fbank.get_frame(i) for i in range(online_fbank.num_frames_ready)]
75
+ )
76
+ assert features.data.contiguous is True
77
+ assert features.dtype == np.float32, features.dtype
78
+ print("features sum", features.sum(), features.size)
79
+
80
+ window_size = 7 # lfr_m
81
+ window_shift = 6 # lfr_n
82
+
83
+ T = (features.shape[0] - window_size) // window_shift + 1
84
+ features = np.lib.stride_tricks.as_strided(
85
+ features,
86
+ shape=(T, features.shape[1] * window_size),
87
+ strides=((window_shift * features.shape[1]) * 4, 4),
88
+ )
89
+ neg_mean, inv_std = load_cmvn()
90
+ features = (features + neg_mean) * inv_std
91
+ return features
92
+
93
+
94
+ # tokens.txt in paraformer has only one column
95
+ # while it has two columns ins sherpa-onnx.
96
+ # This function can handle tokens.txt from both paraformer and sherpa-onnx
97
+ def load_tokens():
98
+ ans = dict()
99
+ i = 0
100
+ with open("tokens.txt", encoding="utf-8") as f:
101
+ for line in f:
102
+ ans[i] = line.strip().split()[0]
103
+ i += 1
104
+ return ans
105
+
106
+
107
+ def main():
108
+ # show_model_info()
109
+ features = compute_feat("1.wav")
110
+ features = np.expand_dims(features, axis=0)
111
+ print(np.sum(features), features.size, features.shape)
112
+ features_length = np.array([features.shape[1]], dtype=np.int32)
113
+
114
+ features2 = compute_feat("2.wav")
115
+ print(np.sum(features2), features2.size, features2.shape)
116
+ features2 = np.expand_dims(features2, axis=0)
117
+ features2_length = np.array([features2.shape[1]], dtype=np.int32)
118
+ print(features.shape, features2.shape)
119
+
120
+ pad = np.ones((1, 10, 560), dtype=np.float32) * -23.0258
121
+ features3 = np.concatenate([features2, pad], axis=1)
122
+
123
+ features4 = np.concatenate([features, features3], axis=0)
124
+ features4_length = np.array([features.shape[1], features2.shape[1]], dtype=np.int32)
125
+ print(features4.shape, features4_length)
126
+
127
+ session_opts = onnxruntime.SessionOptions()
128
+ session_opts.log_severity_level = 3 # error level
129
+ sess = onnxruntime.InferenceSession("model.int8.onnx", session_opts)
130
+
131
+ inputs = {
132
+ "speech": features4,
133
+ "speech_lengths": features4_length,
134
+ }
135
+ output_names = ["logits", "token_num", "us_alphas", "us_cif_peak"]
136
+
137
+ try:
138
+ outputs = sess.run(output_names, input_feed=inputs)
139
+ except ONNXRuntimeError:
140
+ print("Input wav is silence or noise")
141
+ return
142
+
143
+ print("0", outputs[0].shape)
144
+ print("1", outputs[1].shape)
145
+ print("2", outputs[2].shape)
146
+ print("3", outputs[3].shape)
147
+ log_probs = outputs[0][0]
148
+ log_probs1 = outputs[0][1]
149
+ y = log_probs.argmax(axis=-1)[: outputs[1][0]]
150
+ y1 = log_probs1.argmax(axis=-1)[: outputs[1][1]]
151
+ print(outputs[1])
152
+ print(y)
153
+ print(y1)
154
+
155
+ tokens = load_tokens()
156
+ text = "".join([tokens[i] for i in y if i not in (0, 2)])
157
+ print(text)
158
+
159
+ text1 = "".join([tokens[i] for i in y1 if i not in (0, 2)])
160
+ print(text1)
161
+
162
+ token_num = outputs[1]
163
+
164
+ print([i for i in outputs[-1][0] if i > (1 - 1e-4)])
165
+ print(len([i for i in outputs[-1][0] if i > (1 - 1e-4)]))
166
+ print(token_num[0])
167
+
168
+ print([i for i in outputs[-1][1] if i > (1 - 1e-4)])
169
+ print(len([i for i in outputs[-1][1] if i > (1 - 1e-4)]))
170
+ print(token_num[1])
171
+
172
+
173
+ if __name__ == "__main__":
174
+ main()