csukuangfj commited on
Commit
f0a085b
·
1 Parent(s): 6781708

small fixes

Browse files
Files changed (3) hide show
  1. app.py +17 -2
  2. decode.py +117 -0
  3. model.py +34 -0
app.py CHANGED
@@ -21,8 +21,12 @@
21
 
22
 
23
  import logging
 
 
24
  import gradio as gr
25
- from model import language_to_models
 
 
26
 
27
  title = "# Next-gen Kaldi: Generate subtitles for videos"
28
 
@@ -70,6 +74,11 @@ def build_html_output(s: str, style: str = "result_item_success"):
70
  """
71
 
72
 
 
 
 
 
 
73
  def process_uploaded_file(
74
  language: str,
75
  repo_id: str,
@@ -84,7 +93,12 @@ def process_uploaded_file(
84
 
85
  logging.info(f"Processing uploaded file: {in_filename}")
86
 
87
- return "Done", build_html_output("ok", "result_item_success")
 
 
 
 
 
88
 
89
 
90
  demo = gr.Blocks(css=css)
@@ -118,6 +132,7 @@ with demo:
118
  source="upload",
119
  interactive=True,
120
  label="Upload from disk",
 
121
  )
122
  upload_button = gr.Button("Submit for recognition")
123
  uploaded_output = gr.Textbox(label="Recognized speech from uploaded file")
 
21
 
22
 
23
  import logging
24
+ import os
25
+
26
  import gradio as gr
27
+
28
+ from decode import decode
29
+ from model import get_pretrained_model, get_vad, language_to_models
30
 
31
  title = "# Next-gen Kaldi: Generate subtitles for videos"
32
 
 
74
  """
75
 
76
 
77
+ def show_file_info(in_filename: str):
78
+ logging.info(f"Input file: {in_filename}")
79
+ _ = os.system(f"ffprob -hide_banner -i '{in_filename}'")
80
+
81
+
82
  def process_uploaded_file(
83
  language: str,
84
  repo_id: str,
 
93
 
94
  logging.info(f"Processing uploaded file: {in_filename}")
95
 
96
+ recognizer = get_pretrained_model(repo_id)
97
+ vad = get_vad()
98
+
99
+ result = decode(recognizer, vad, in_filename)
100
+
101
+ return result, build_html_output("ok", "result_item_success")
102
 
103
 
104
  demo = gr.Blocks(css=css)
 
132
  source="upload",
133
  interactive=True,
134
  label="Upload from disk",
135
+ show_share_button=True,
136
  )
137
  upload_button = gr.Button("Submit for recognition")
138
  uploaded_output = gr.Textbox(label="Recognized speech from uploaded file")
decode.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang)
2
+ #
3
+ # See LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import subprocess
18
+ from dataclasses import dataclass
19
+ from datetime import timedelta
20
+ import logging
21
+
22
+ import numpy as np
23
+ import sherpa_onnx
24
+
25
+ from model import sample_rate
26
+
27
+
28
+ @dataclass
29
+ class Segment:
30
+ start: float
31
+ duration: float
32
+ text: str = ""
33
+
34
+ @property
35
+ def end(self):
36
+ return self.start + self.duration
37
+
38
+ def __str__(self):
39
+ s = f"{timedelta(seconds=self.start)}"[:-3]
40
+ s += " --> "
41
+ s += f"{timedelta(seconds=self.end)}"[:-3]
42
+ s = s.replace(".", ",")
43
+ s += "\n"
44
+ s += self.text
45
+ return s
46
+
47
+
48
+ def decode(
49
+ recognizer: sherpa_onnx.OfflineRecognizer,
50
+ vad: sherpa_onnx.VoiceActivityDetector,
51
+ filename: str,
52
+ ) -> str:
53
+ ffmpeg_cmd = [
54
+ "ffmpeg",
55
+ "-i",
56
+ filename,
57
+ "-f",
58
+ "s16le",
59
+ "-acodec",
60
+ "pcm_s16le",
61
+ "-ac",
62
+ "1",
63
+ "-ar",
64
+ str(sample_rate),
65
+ "-",
66
+ ]
67
+
68
+ process = subprocess.Popen(
69
+ ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL
70
+ )
71
+
72
+ frames_per_read = int(sample_rate * 100) # 100 second
73
+
74
+ window_size = 512
75
+
76
+ buffer = []
77
+
78
+ segment_list = []
79
+
80
+ logging.info("Started!")
81
+
82
+ while True:
83
+ # *2 because int16_t has two bytes
84
+ data = process.stdout.read(frames_per_read * 2)
85
+ if not data:
86
+ break
87
+
88
+ samples = np.frombuffer(data, dtype=np.int16)
89
+ samples = samples.astype(np.float32) / 32768
90
+
91
+ buffer = np.concatenate([buffer, samples])
92
+ while len(buffer) > window_size:
93
+ vad.accept_waveform(buffer[:window_size])
94
+ buffer = buffer[window_size:]
95
+
96
+ streams = []
97
+ segments = []
98
+ while not vad.empty():
99
+ segment = Segment(
100
+ start=vad.front.start / sample_rate,
101
+ duration=len(vad.front.samples) / sample_rate,
102
+ )
103
+ segments.append(segment)
104
+
105
+ stream = recognizer.create_stream()
106
+ stream.accept_waveform(sample_rate, vad.front.samples)
107
+
108
+ streams.append(stream)
109
+
110
+ vad.pop()
111
+
112
+ recognizer.decode_streams(streams)
113
+ for seg, stream in zip(segments, streams):
114
+ seg.text = stream.result.text
115
+ segment_list.append(seg)
116
+
117
+ return "\n\n".join(f"{i}\n{seg} " for i, seg in enumerate(segment_list, 1))
model.py CHANGED
@@ -165,6 +165,40 @@ def _get_russian_pre_trained_model(repo_id: str) -> sherpa_onnx.OfflineRecognize
165
  return recognizer
166
 
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  english_models = {
169
  "whisper-tiny.en": _get_whisper_model,
170
  "whisper-base.en": _get_whisper_model,
 
165
  return recognizer
166
 
167
 
168
+ @lru_cache(maxsize=2)
169
+ def get_vad() -> sherpa_onnx.VoiceActivityDetector:
170
+ vad_model = _get_nn_model_filename(
171
+ repo_id="csukuangfj/vad",
172
+ filename="silero_vad.onnx",
173
+ subfolder=".",
174
+ )
175
+
176
+ config = sherpa_onnx.VadModelConfig()
177
+ config.silero_vad.model = vad_model
178
+ config.silero_vad.min_silence_duration = 0.15
179
+ config.silero_vad.min_speech_duration = 0.25
180
+ config.sample_rate = sample_rate
181
+
182
+ vad = sherpa_onnx.VoiceActivityDetector(
183
+ config,
184
+ buffer_size_in_seconds=180,
185
+ )
186
+
187
+ return vad
188
+
189
+
190
+ @lru_cache(maxsize=10)
191
+ def get_pretrained_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer:
192
+ if repo_id in english_models:
193
+ return english_models[repo_id](repo_id)
194
+ elif repo_id in chinese_english_mixed_models:
195
+ return chinese_english_mixed_models[repo_id](repo_id)
196
+ elif repo_id in russian_models:
197
+ return russian_models[repo_id](repo_id)
198
+ else:
199
+ raise ValueError(f"Unsupported repo_id: {repo_id}")
200
+
201
+
202
  english_models = {
203
  "whisper-tiny.en": _get_whisper_model,
204
  "whisper-base.en": _get_whisper_model,