csukuangfj commited on
Commit
9e8a5d8
·
1 Parent(s): 4777950

add app and model

Browse files
Files changed (4) hide show
  1. app.py +294 -0
  2. examples.py +52 -0
  3. model.py +121 -0
  4. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ #
3
+ # Copyright 2022-2024 Xiaomi Corp. (authors: Fangjun Kuang)
4
+ #
5
+ # See LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ # References:
20
+ # https://gradio.app/docs/#dropdown
21
+
22
+ import logging
23
+ import os
24
+ import tempfile
25
+ import time
26
+ import urllib.request
27
+ from datetime import datetime
28
+
29
+ from examples import examples
30
+ import gradio as gr
31
+ import soundfile as sf
32
+
33
+ from model import decode, get_pretrained_model, whisper_models
34
+
35
+
36
+ def convert_to_wav(in_filename: str) -> str:
37
+ """Convert the input audio file to a wave file"""
38
+ out_filename = in_filename + ".wav"
39
+ logging.info(f"Converting '{in_filename}' to '{out_filename}'")
40
+
41
+ _ = os.system(
42
+ f"ffmpeg -hide_banner -i '{in_filename}' -ar 16000 -ac 1 '{out_filename}' -y"
43
+ )
44
+
45
+ return out_filename
46
+
47
+
48
+ def build_html_output(s: str, style: str = "result_item_success"):
49
+ return f"""
50
+ <div class='result'>
51
+ <div class='result_item {style}'>
52
+ {s}
53
+ </div>
54
+ </div>
55
+ """
56
+
57
+
58
+ def process_url(
59
+ repo_id: str,
60
+ url: str,
61
+ ):
62
+ logging.info(f"Processing URL: {url}")
63
+ with tempfile.NamedTemporaryFile() as f:
64
+ try:
65
+ urllib.request.urlretrieve(url, f.name)
66
+
67
+ return process(
68
+ in_filename=f.name,
69
+ repo_id=repo_id,
70
+ )
71
+ except Exception as e:
72
+ logging.info(str(e))
73
+ return "", build_html_output(str(e), "result_item_error")
74
+
75
+
76
+ def process_uploaded_file(
77
+ repo_id: str,
78
+ in_filename: str,
79
+ ):
80
+ if in_filename is None or in_filename == "":
81
+ return "", build_html_output(
82
+ "Please first upload a file and then click "
83
+ 'the button "submit for recognition"',
84
+ "result_item_error",
85
+ )
86
+
87
+ logging.info(f"Processing uploaded file: {in_filename}")
88
+ try:
89
+ return process(
90
+ in_filename=in_filename,
91
+ repo_id=repo_id,
92
+ )
93
+ except Exception as e:
94
+ logging.info(str(e))
95
+ return "", build_html_output(str(e), "result_item_error")
96
+
97
+
98
+ def process_microphone(
99
+ repo_id: str,
100
+ in_filename: str,
101
+ ):
102
+ if in_filename is None or in_filename == "":
103
+ return "", build_html_output(
104
+ "Please first click 'Record from microphone', speak, "
105
+ "click 'Stop recording', and then "
106
+ "click the button 'submit for recognition'",
107
+ "result_item_error",
108
+ )
109
+
110
+ logging.info(f"Processing microphone: {in_filename}")
111
+ try:
112
+ return process(
113
+ in_filename=in_filename,
114
+ repo_id=repo_id,
115
+ )
116
+ except Exception as e:
117
+ logging.info(str(e))
118
+ return "", build_html_output(str(e), "result_item_error")
119
+
120
+
121
+ def process(
122
+ repo_id: str,
123
+ in_filename: str,
124
+ ):
125
+ logging.info(f"repo_id: {repo_id}")
126
+ logging.info(f"in_filename: {in_filename}")
127
+
128
+ filename = convert_to_wav(in_filename)
129
+
130
+ now = datetime.now()
131
+ date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
132
+ logging.info(f"Started at {date_time}")
133
+
134
+ start = time.time()
135
+
136
+ slid = get_pretrained_model(repo_id)
137
+
138
+ lang = decode(slid, filename)
139
+
140
+ date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
141
+ end = time.time()
142
+
143
+ info = sf.info(filename)
144
+ duration = info.duration
145
+
146
+ elapsed = end - start
147
+ rtf = elapsed / duration
148
+
149
+ logging.info(f"Finished at {date_time} s. Elapsed: {elapsed: .3f} s")
150
+
151
+ info = f"""
152
+ Wave duration : {duration: .3f} s <br/>
153
+ Processing time: {elapsed: .3f} s <br/>
154
+ RTF: {elapsed: .3f}/{duration: .3f} = {rtf:.3f} <br/>
155
+ """
156
+ if rtf > 1:
157
+ info += (
158
+ "<br/>We are loading the model for the first run. "
159
+ "Please run again to measure the real RTF.<br/>"
160
+ )
161
+
162
+ logging.info(info)
163
+ logging.info(f"\nrepo_id: {repo_id}\nDetected language: {lang}")
164
+
165
+ return text, build_html_output(info)
166
+
167
+
168
+ title = "# Spoken Language Identification: [Next-gen Kaldi](https://github.com/k2-fsa) + [Whisper](https://github.com/openai/whisper/)"
169
+ description = """
170
+ This space shows how to do spoken language identification with [Next-gen Kaldi](https://github.com/k2-fsa)
171
+ using [Whisper](https://github.com/openai/whisper/) multilingual models.
172
+
173
+ It is running on a machine with 2 vCPUs with 16 GB RAM within a docker container provided by Hugging Face.
174
+
175
+ See more information by visiting the following links:
176
+
177
+ - <https://github.com/k2-fsa/sherpa-onnx>
178
+
179
+ If you want to deploy it locally, please see
180
+ <https://k2-fsa.github.io/sherpa/onnx>
181
+ """
182
+
183
+ # css style is copied from
184
+ # https://huggingface.co/spaces/alphacep/asr/blob/main/app.py#L113
185
+ css = """
186
+ .result {display:flex;flex-direction:column}
187
+ .result_item {padding:15px;margin-bottom:8px;border-radius:15px;width:100%}
188
+ .result_item_success {background-color:mediumaquamarine;color:white;align-self:start}
189
+ .result_item_error {background-color:#ff7070;color:white;align-self:start}
190
+ """
191
+
192
+
193
+ demo = gr.Blocks(css=css)
194
+
195
+
196
+ with demo:
197
+ gr.Markdown(title)
198
+ model_choices = list(whisper_models.keys())
199
+
200
+ model_dropdown = gr.Dropdown(
201
+ choices=model_choices,
202
+ label="Select a model",
203
+ value=model_choices[0],
204
+ )
205
+
206
+ with gr.Tabs():
207
+ with gr.TabItem("Upload from disk"):
208
+ uploaded_file = gr.Audio(
209
+ sources=["upload"], # Choose between "microphone", "upload"
210
+ type="filepath",
211
+ label="Upload from disk",
212
+ )
213
+ upload_button = gr.Button("Submit for recognition")
214
+ uploaded_output = gr.Textbox(label="Recognized speech from uploaded file")
215
+ uploaded_html_info = gr.HTML(label="Info")
216
+
217
+ gr.Examples(
218
+ examples=examples,
219
+ inputs=[
220
+ model_dropdown,
221
+ uploaded_file,
222
+ ],
223
+ outputs=[uploaded_output, uploaded_html_info],
224
+ fn=process_uploaded_file,
225
+ )
226
+
227
+ with gr.TabItem("Record from microphone"):
228
+ microphone = gr.Audio(
229
+ sources=["microphone"], # Choose between "microphone", "upload"
230
+ type="filepath",
231
+ label="Record from microphone",
232
+ )
233
+
234
+ record_button = gr.Button("Submit for recognition")
235
+ recorded_output = gr.Textbox(label="Recognized speech from recordings")
236
+ recorded_html_info = gr.HTML(label="Info")
237
+
238
+ gr.Examples(
239
+ examples=examples,
240
+ inputs=[
241
+ model_dropdown,
242
+ microphone,
243
+ ],
244
+ outputs=[recorded_output, recorded_html_info],
245
+ fn=process_microphone,
246
+ )
247
+
248
+ with gr.TabItem("From URL"):
249
+ url_textbox = gr.Textbox(
250
+ max_lines=1,
251
+ placeholder="URL to an audio file",
252
+ label="URL",
253
+ interactive=True,
254
+ )
255
+
256
+ url_button = gr.Button("Submit for recognition")
257
+ url_output = gr.Textbox(label="Recognized speech from URL")
258
+ url_html_info = gr.HTML(label="Info")
259
+
260
+ upload_button.click(
261
+ process_uploaded_file,
262
+ inputs=[
263
+ model_dropdown,
264
+ uploaded_file,
265
+ ],
266
+ outputs=[uploaded_output, uploaded_html_info],
267
+ )
268
+
269
+ record_button.click(
270
+ process_microphone,
271
+ inputs=[
272
+ model_dropdown,
273
+ microphone,
274
+ ],
275
+ outputs=[recorded_output, recorded_html_info],
276
+ )
277
+
278
+ url_button.click(
279
+ process_url,
280
+ inputs=[
281
+ model_dropdown,
282
+ url_textbox,
283
+ ],
284
+ outputs=[url_output, url_html_info],
285
+ )
286
+
287
+ gr.Markdown(description)
288
+
289
+ if __name__ == "__main__":
290
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
291
+
292
+ logging.basicConfig(format=formatter, level=logging.INFO)
293
+
294
+ demo.launch()
examples.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ #
3
+ # Copyright 2022-2024 Xiaomi Corp. (authors: Fangjun Kuang)
4
+ #
5
+ # See LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ wavs = [
20
+ "ar-arabic.wav"
21
+ "bg-bulgarian.wav"
22
+ "cs-czech.wav"
23
+ "da-danish.wav"
24
+ "de-german.wav"
25
+ "el-greek.wav"
26
+ "en-english.wav"
27
+ "es-spanish.wav"
28
+ "fa-persian.wav"
29
+ "fi-finnish.wav"
30
+ "fr-french.wav"
31
+ "hi-hindi.wav"
32
+ "hr-croatian.wav"
33
+ "id-indonesian.wav"
34
+ "it-italian.wav"
35
+ "ja-japanese.wav"
36
+ "ko-korean.wav"
37
+ "nl-dutch.wav"
38
+ "no-norwegian.wav"
39
+ "po-polish.wav"
40
+ "pt-portuguese.wav"
41
+ "ro-romanian.wav"
42
+ "ru-russian.wav"
43
+ "sk-slovak.wav"
44
+ "sv-swedish.wav"
45
+ "ta-tamil.wav"
46
+ "tl-tagalog.wav"
47
+ "tr-turkish.wav"
48
+ "uk-ukrainian.wav"
49
+ "zh-chinese.wav"
50
+ ]
51
+
52
+ examples = [["tiny", f"./test_wavs/{w}"] for w in wavs]
model.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-2024 Xiaomi Corp. (authors: Fangjun Kuang)
2
+ #
3
+ # See LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import wave
18
+ from functools import lru_cache
19
+ from typing import Tuple
20
+
21
+ import numpy as np
22
+ import sherpa_onnx
23
+ from huggingface_hub import hf_hub_download
24
+ from iso639 import Lang
25
+
26
+ sample_rate = 16000
27
+
28
+
29
+ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
30
+ """
31
+ Args:
32
+ wave_filename:
33
+ Path to a wave file. It should be single channel and each sample should
34
+ be 16-bit. Its sample rate does not need to be 16kHz.
35
+ Returns:
36
+ Return a tuple containing:
37
+ - A 1-D array of dtype np.float32 containing the samples, which are
38
+ normalized to the range [-1, 1].
39
+ - sample rate of the wave file
40
+ """
41
+
42
+ with wave.open(wave_filename) as f:
43
+ assert f.getnchannels() == 1, f.getnchannels()
44
+ assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
45
+ num_samples = f.getnframes()
46
+ samples = f.readframes(num_samples)
47
+ samples_int16 = np.frombuffer(samples, dtype=np.int16)
48
+ samples_float32 = samples_int16.astype(np.float32)
49
+
50
+ samples_float32 = samples_float32 / 32768
51
+ return samples_float32, f.getframerate()
52
+
53
+
54
+ def decode(
55
+ slid: sherpa_onnx.SpokenLanguageIdentification,
56
+ filename: str,
57
+ ) -> str:
58
+ s = recognizer.create_stream()
59
+ samples, sample_rate = read_wave(filename)
60
+ s.accept_waveform(sample_rate, samples)
61
+ lang = slid.compute(s)
62
+ if lang == "":
63
+ return "Unknown"
64
+
65
+ try:
66
+ return Lang(lang).name
67
+ except:
68
+ return lang
69
+
70
+
71
+ def _get_nn_model_filename(
72
+ repo_id: str,
73
+ filename: str,
74
+ subfolder: str = ".",
75
+ ) -> str:
76
+ nn_model_filename = hf_hub_download(
77
+ repo_id=repo_id,
78
+ filename=filename,
79
+ subfolder=subfolder,
80
+ )
81
+ return nn_model_filename
82
+
83
+
84
+ @lru_cache(maxsize=8)
85
+ def get_pretrained_model(name: str) -> sherpa_onnx.SpokenLanguageIdentification:
86
+ assert name in (
87
+ "tiny",
88
+ "base",
89
+ "small",
90
+ "medium",
91
+ ), name
92
+ full_repo_id = "csukuangfj/sherpa-onnx-whisper-" + name
93
+ encoder = _get_nn_model_filename(
94
+ repo_id=full_repo_id,
95
+ filename=f"{name}-encoder.int8.onnx",
96
+ )
97
+
98
+ decoder = _get_nn_model_filename(
99
+ repo_id=full_repo_id,
100
+ filename=f"{name}-decoder.int8.onnx",
101
+ )
102
+
103
+ config = sherpa_onnx.SpokenLanguageIdentificationConfig(
104
+ whisper=sherpa_onnx.SpokenLanguageIdentificationWhisperConfig(
105
+ encoder=encoder,
106
+ decoder=decoder,
107
+ ),
108
+ num_threads=2,
109
+ debug=1,
110
+ provider="cpu",
111
+ )
112
+
113
+ return sherpa_onnx.SpokenLanguageIdentification(config)
114
+
115
+
116
+ whisper_models = {
117
+ "tiny": get_pretrained_model,
118
+ "base": get_pretrained_model,
119
+ "small": get_pretrained_model,
120
+ "medium": get_pretrained_model,
121
+ }
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ soundfile
2
+ numpy
3
+
4
+ huggingface_hub
5
+ sherpa-onnx>=1.9.12
6
+ iso639-lang