hbs2 commited on
Commit
9361d1d
·
verified ·
1 Parent(s): fd87e9f

Upload 9 files

Browse files
Files changed (9) hide show
  1. __init__.py +0 -0
  2. audio.py +119 -0
  3. feature_extractor.py +170 -0
  4. silero_vad.onnx +3 -0
  5. tokenizer.py +278 -0
  6. transcribe.py +1272 -0
  7. utils.py +157 -0
  8. vad.py +291 -0
  9. version.py +3 -0
__init__.py ADDED
File without changes
audio.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """We use the PyAV library to decode the audio: https://github.com/PyAV-Org/PyAV
2
+
3
+ The advantage of PyAV is that it bundles the FFmpeg libraries so there is no additional
4
+ system dependencies. FFmpeg does not need to be installed on the system.
5
+
6
+ However, the API is quite low-level so we need to manipulate audio frames directly.
7
+ """
8
+
9
+ import gc
10
+ import io
11
+ import itertools
12
+
13
+ from typing import BinaryIO, Union
14
+
15
+ import av
16
+ import numpy as np
17
+
18
+
19
+ def decode_audio(
20
+ input_file: Union[str, BinaryIO],
21
+ sampling_rate: int = 16000,
22
+ split_stereo: bool = False,
23
+ ):
24
+ """Decodes the audio.
25
+
26
+ Args:
27
+ input_file: Path to the input file or a file-like object.
28
+ sampling_rate: Resample the audio to this sample rate.
29
+ split_stereo: Return separate left and right channels.
30
+
31
+ Returns:
32
+ A float32 Numpy array.
33
+
34
+ If `split_stereo` is enabled, the function returns a 2-tuple with the
35
+ separated left and right channels.
36
+ """
37
+ resampler = av.audio.resampler.AudioResampler(
38
+ format="s16",
39
+ layout="mono" if not split_stereo else "stereo",
40
+ rate=sampling_rate,
41
+ )
42
+
43
+ raw_buffer = io.BytesIO()
44
+ dtype = None
45
+
46
+ with av.open(input_file, mode="r", metadata_errors="ignore") as container:
47
+ frames = container.decode(audio=0)
48
+ frames = _ignore_invalid_frames(frames)
49
+ frames = _group_frames(frames, 500000)
50
+ frames = _resample_frames(frames, resampler)
51
+
52
+ for frame in frames:
53
+ array = frame.to_ndarray()
54
+ dtype = array.dtype
55
+ raw_buffer.write(array)
56
+
57
+ # It appears that some objects related to the resampler are not freed
58
+ # unless the garbage collector is manually run.
59
+ del resampler
60
+ gc.collect()
61
+
62
+ audio = np.frombuffer(raw_buffer.getbuffer(), dtype=dtype)
63
+
64
+ # Convert s16 back to f32.
65
+ audio = audio.astype(np.float32) / 32768.0
66
+
67
+ if split_stereo:
68
+ left_channel = audio[0::2]
69
+ right_channel = audio[1::2]
70
+ return left_channel, right_channel
71
+
72
+ return audio
73
+
74
+
75
+ def _ignore_invalid_frames(frames):
76
+ iterator = iter(frames)
77
+
78
+ while True:
79
+ try:
80
+ yield next(iterator)
81
+ except StopIteration:
82
+ break
83
+ except av.error.InvalidDataError:
84
+ continue
85
+
86
+
87
+ def _group_frames(frames, num_samples=None):
88
+ fifo = av.audio.fifo.AudioFifo()
89
+
90
+ for frame in frames:
91
+ frame.pts = None # Ignore timestamp check.
92
+ fifo.write(frame)
93
+
94
+ if num_samples is not None and fifo.samples >= num_samples:
95
+ yield fifo.read()
96
+
97
+ if fifo.samples > 0:
98
+ yield fifo.read()
99
+
100
+
101
+ def _resample_frames(frames, resampler):
102
+ # Add None to flush the resampler.
103
+ for frame in itertools.chain(frames, [None]):
104
+ yield from resampler.resample(frame)
105
+
106
+
107
+ def pad_or_trim(array, length: int, *, axis: int = -1):
108
+ """
109
+ Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
110
+ """
111
+ if array.shape[axis] > length:
112
+ array = array.take(indices=range(length), axis=axis)
113
+
114
+ if array.shape[axis] < length:
115
+ pad_widths = [(0, 0)] * array.ndim
116
+ pad_widths[axis] = (0, length - array.shape[axis])
117
+ array = np.pad(array, pad_widths)
118
+
119
+ return array
feature_extractor.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ # Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py # noqa: E501
5
+ class FeatureExtractor:
6
+ def __init__(
7
+ self,
8
+ feature_size=80,
9
+ sampling_rate=16000,
10
+ hop_length=160,
11
+ chunk_length=30,
12
+ n_fft=400,
13
+ ):
14
+ self.n_fft = n_fft
15
+ self.hop_length = hop_length
16
+ self.chunk_length = chunk_length
17
+ self.n_samples = chunk_length * sampling_rate
18
+ self.nb_max_frames = self.n_samples // hop_length
19
+ self.time_per_frame = hop_length / sampling_rate
20
+ self.sampling_rate = sampling_rate
21
+ self.mel_filters = self.get_mel_filters(
22
+ sampling_rate, n_fft, n_mels=feature_size
23
+ )
24
+
25
+ def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32):
26
+ # Initialize the weights
27
+ n_mels = int(n_mels)
28
+ weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
29
+
30
+ # Center freqs of each FFT bin
31
+ fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
32
+
33
+ # 'Center freqs' of mel bands - uniformly spaced between limits
34
+ min_mel = 0.0
35
+ max_mel = 45.245640471924965
36
+
37
+ mels = np.linspace(min_mel, max_mel, n_mels + 2)
38
+
39
+ mels = np.asanyarray(mels)
40
+
41
+ # Fill in the linear scale
42
+ f_min = 0.0
43
+ f_sp = 200.0 / 3
44
+ freqs = f_min + f_sp * mels
45
+
46
+ # And now the nonlinear scale
47
+ min_log_hz = 1000.0 # beginning of log region (Hz)
48
+ min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
49
+ logstep = np.log(6.4) / 27.0 # step size for log region
50
+
51
+ # If we have vector data, vectorize
52
+ log_t = mels >= min_log_mel
53
+ freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
54
+
55
+ mel_f = freqs
56
+
57
+ fdiff = np.diff(mel_f)
58
+ ramps = np.subtract.outer(mel_f, fftfreqs)
59
+
60
+ for i in range(n_mels):
61
+ # lower and upper slopes for all bins
62
+ lower = -ramps[i] / fdiff[i]
63
+ upper = ramps[i + 2] / fdiff[i + 1]
64
+
65
+ # .. then intersect them with each other and zero
66
+ weights[i] = np.maximum(0, np.minimum(lower, upper))
67
+
68
+ # Slaney-style mel is scaled to be approx constant energy per channel
69
+ enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
70
+ weights *= enorm[:, np.newaxis]
71
+
72
+ return weights
73
+
74
+ def fram_wave(self, waveform, center=True):
75
+ """
76
+ Transform a raw waveform into a list of smaller waveforms.
77
+ The window length defines how much of the signal is
78
+ contain in each frame (smalle waveform), while the hope length defines the step
79
+ between the beginning of each new frame.
80
+ Centering is done by reflecting the waveform which is first centered around
81
+ `frame_idx * hop_length`.
82
+ """
83
+ frames = []
84
+ for i in range(0, waveform.shape[0] + 1, self.hop_length):
85
+ half_window = (self.n_fft - 1) // 2 + 1
86
+ if center:
87
+ start = i - half_window if i > half_window else 0
88
+ end = (
89
+ i + half_window
90
+ if i < waveform.shape[0] - half_window
91
+ else waveform.shape[0]
92
+ )
93
+
94
+ frame = waveform[start:end]
95
+
96
+ if start == 0:
97
+ padd_width = (-i + half_window, 0)
98
+ frame = np.pad(frame, pad_width=padd_width, mode="reflect")
99
+
100
+ elif end == waveform.shape[0]:
101
+ padd_width = (0, (i - waveform.shape[0] + half_window))
102
+ frame = np.pad(frame, pad_width=padd_width, mode="reflect")
103
+
104
+ else:
105
+ frame = waveform[i : i + self.n_fft]
106
+ frame_width = frame.shape[0]
107
+ if frame_width < waveform.shape[0]:
108
+ frame = np.lib.pad(
109
+ frame,
110
+ pad_width=(0, self.n_fft - frame_width),
111
+ mode="constant",
112
+ constant_values=0,
113
+ )
114
+
115
+ frames.append(frame)
116
+ return np.stack(frames, 0)
117
+
118
+ def stft(self, frames, window):
119
+ """
120
+ Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal.
121
+ Should give the same results as `torch.stft`.
122
+ """
123
+ frame_size = frames.shape[1]
124
+ fft_size = self.n_fft
125
+
126
+ if fft_size is None:
127
+ fft_size = frame_size
128
+
129
+ if fft_size < frame_size:
130
+ raise ValueError("FFT size must greater or equal the frame size")
131
+ # number of FFT bins to store
132
+ num_fft_bins = (fft_size >> 1) + 1
133
+
134
+ data = np.empty((len(frames), num_fft_bins), dtype=np.complex64)
135
+ fft_signal = np.zeros(fft_size)
136
+
137
+ for f, frame in enumerate(frames):
138
+ if window is not None:
139
+ np.multiply(frame, window, out=fft_signal[:frame_size])
140
+ else:
141
+ fft_signal[:frame_size] = frame
142
+ data[f] = np.fft.fft(fft_signal, axis=0)[:num_fft_bins]
143
+ return data.T
144
+
145
+ def __call__(self, waveform, padding=True, chunk_length=None):
146
+ """
147
+ Compute the log-Mel spectrogram of the provided audio, gives similar results
148
+ whisper's original torch implementation with 1e-5 tolerance.
149
+ """
150
+ if chunk_length is not None:
151
+ self.n_samples = chunk_length * self.sampling_rate
152
+ self.nb_max_frames = self.n_samples // self.hop_length
153
+
154
+ if padding:
155
+ waveform = np.pad(waveform, [(0, self.n_samples)])
156
+
157
+ window = np.hanning(self.n_fft + 1)[:-1]
158
+
159
+ frames = self.fram_wave(waveform)
160
+ stft = self.stft(frames, window=window)
161
+ magnitudes = np.abs(stft[:, :-1]) ** 2
162
+
163
+ filters = self.mel_filters
164
+ mel_spec = filters @ magnitudes
165
+
166
+ log_spec = np.log10(np.clip(mel_spec, a_min=1e-10, a_max=None))
167
+ log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
168
+ log_spec = (log_spec + 4.0) / 4.0
169
+
170
+ return log_spec
silero_vad.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:591f853590d11ddde2f2a54f9e7ccecb2533a8af7716330e8adfa6f3849787a9
3
+ size 1807524
tokenizer.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import string
2
+
3
+ from functools import cached_property
4
+ from typing import List, Optional, Tuple
5
+
6
+ import tokenizers
7
+
8
+
9
+ class Tokenizer:
10
+ """Simple wrapper around a tokenizers.Tokenizer."""
11
+
12
+ def __init__(
13
+ self,
14
+ tokenizer: tokenizers.Tokenizer,
15
+ multilingual: bool,
16
+ task: Optional[str] = None,
17
+ language: Optional[str] = None,
18
+ ):
19
+ self.tokenizer = tokenizer
20
+
21
+ if multilingual:
22
+ if task not in _TASKS:
23
+ raise ValueError(
24
+ "'%s' is not a valid task (accepted tasks: %s)"
25
+ % (task, ", ".join(_TASKS))
26
+ )
27
+
28
+ if language not in _LANGUAGE_CODES:
29
+ raise ValueError(
30
+ "'%s' is not a valid language code (accepted language codes: %s)"
31
+ % (language, ", ".join(_LANGUAGE_CODES))
32
+ )
33
+
34
+ self.task = self.tokenizer.token_to_id("<|%s|>" % task)
35
+ self.language = self.tokenizer.token_to_id("<|%s|>" % language)
36
+ self.language_code = language
37
+ else:
38
+ self.task = None
39
+ self.language = None
40
+ self.language_code = "en"
41
+
42
+ @cached_property
43
+ def transcribe(self) -> int:
44
+ return self.tokenizer.token_to_id("<|transcribe|>")
45
+
46
+ @cached_property
47
+ def translate(self) -> int:
48
+ return self.tokenizer.token_to_id("<|translate|>")
49
+
50
+ @cached_property
51
+ def sot(self) -> int:
52
+ return self.tokenizer.token_to_id("<|startoftranscript|>")
53
+
54
+ @cached_property
55
+ def sot_lm(self) -> int:
56
+ return self.tokenizer.token_to_id("<|startoflm|>")
57
+
58
+ @cached_property
59
+ def sot_prev(self) -> int:
60
+ return self.tokenizer.token_to_id("<|startofprev|>")
61
+
62
+ @cached_property
63
+ def eot(self) -> int:
64
+ return self.tokenizer.token_to_id("<|endoftext|>")
65
+
66
+ @cached_property
67
+ def no_timestamps(self) -> int:
68
+ return self.tokenizer.token_to_id("<|notimestamps|>")
69
+
70
+ @property
71
+ def timestamp_begin(self) -> int:
72
+ return self.no_timestamps + 1
73
+
74
+ @property
75
+ def sot_sequence(self) -> List[int]:
76
+ sequence = [self.sot]
77
+
78
+ if self.language is not None:
79
+ sequence.append(self.language)
80
+
81
+ if self.task is not None:
82
+ sequence.append(self.task)
83
+
84
+ return sequence
85
+
86
+ def encode(self, text: str) -> List[int]:
87
+ return self.tokenizer.encode(text, add_special_tokens=False).ids
88
+
89
+ def decode(self, tokens: List[int]) -> str:
90
+ text_tokens = [token for token in tokens if token < self.eot]
91
+ return self.tokenizer.decode(text_tokens)
92
+
93
+ def decode_with_timestamps(self, tokens: List[int]) -> str:
94
+ outputs = [[]]
95
+
96
+ for token in tokens:
97
+ if token >= self.timestamp_begin:
98
+ timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
99
+ outputs.append(timestamp)
100
+ outputs.append([])
101
+ else:
102
+ outputs[-1].append(token)
103
+
104
+ return "".join(
105
+ [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
106
+ )
107
+
108
+ def split_to_word_tokens(
109
+ self, tokens: List[int]
110
+ ) -> Tuple[List[str], List[List[int]]]:
111
+ if self.language_code in {"zh", "ja", "th", "lo", "my", "yue"}:
112
+ # These languages don't typically use spaces, so it is difficult to split words
113
+ # without morpheme analysis. Here, we instead split words at any
114
+ # position where the tokens are decoded as valid unicode points
115
+ return self.split_tokens_on_unicode(tokens)
116
+
117
+ return self.split_tokens_on_spaces(tokens)
118
+
119
+ def split_tokens_on_unicode(
120
+ self, tokens: List[int]
121
+ ) -> Tuple[List[str], List[List[int]]]:
122
+ decoded_full = self.decode_with_timestamps(tokens)
123
+ replacement_char = "\ufffd"
124
+
125
+ words = []
126
+ word_tokens = []
127
+ current_tokens = []
128
+ unicode_offset = 0
129
+
130
+ for token in tokens:
131
+ current_tokens.append(token)
132
+ decoded = self.decode_with_timestamps(current_tokens)
133
+
134
+ try:
135
+ replacement_char_index = decoded.index(replacement_char)
136
+ replacement_char_index += unicode_offset
137
+ except ValueError:
138
+ replacement_char_index = None
139
+
140
+ if replacement_char_index is None or (
141
+ replacement_char_index < len(decoded_full)
142
+ and decoded_full[replacement_char_index] == replacement_char
143
+ ):
144
+ words.append(decoded)
145
+ word_tokens.append(current_tokens)
146
+ current_tokens = []
147
+ unicode_offset += len(decoded)
148
+
149
+ return words, word_tokens
150
+
151
+ def split_tokens_on_spaces(
152
+ self, tokens: List[int]
153
+ ) -> Tuple[List[str], List[List[int]]]:
154
+ subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
155
+ words = []
156
+ word_tokens = []
157
+
158
+ for subword, subword_tokens in zip(subwords, subword_tokens_list):
159
+ special = subword_tokens[0] >= self.eot
160
+ with_space = subword.startswith(" ")
161
+ punctuation = subword.strip() in string.punctuation
162
+ if special or with_space or punctuation or len(words) == 0:
163
+ words.append(subword)
164
+ word_tokens.append(subword_tokens)
165
+ else:
166
+ words[-1] = words[-1] + subword
167
+ word_tokens[-1].extend(subword_tokens)
168
+
169
+ return words, word_tokens
170
+
171
+
172
+ _TASKS = (
173
+ "transcribe",
174
+ "translate",
175
+ )
176
+
177
+ _LANGUAGE_CODES = (
178
+ "af",
179
+ "am",
180
+ "ar",
181
+ "as",
182
+ "az",
183
+ "ba",
184
+ "be",
185
+ "bg",
186
+ "bn",
187
+ "bo",
188
+ "br",
189
+ "bs",
190
+ "ca",
191
+ "cs",
192
+ "cy",
193
+ "da",
194
+ "de",
195
+ "el",
196
+ "en",
197
+ "es",
198
+ "et",
199
+ "eu",
200
+ "fa",
201
+ "fi",
202
+ "fo",
203
+ "fr",
204
+ "gl",
205
+ "gu",
206
+ "ha",
207
+ "haw",
208
+ "he",
209
+ "hi",
210
+ "hr",
211
+ "ht",
212
+ "hu",
213
+ "hy",
214
+ "id",
215
+ "is",
216
+ "it",
217
+ "ja",
218
+ "jw",
219
+ "ka",
220
+ "kk",
221
+ "km",
222
+ "kn",
223
+ "ko",
224
+ "la",
225
+ "lb",
226
+ "ln",
227
+ "lo",
228
+ "lt",
229
+ "lv",
230
+ "mg",
231
+ "mi",
232
+ "mk",
233
+ "ml",
234
+ "mn",
235
+ "mr",
236
+ "ms",
237
+ "mt",
238
+ "my",
239
+ "ne",
240
+ "nl",
241
+ "nn",
242
+ "no",
243
+ "oc",
244
+ "pa",
245
+ "pl",
246
+ "ps",
247
+ "pt",
248
+ "ro",
249
+ "ru",
250
+ "sa",
251
+ "sd",
252
+ "si",
253
+ "sk",
254
+ "sl",
255
+ "sn",
256
+ "so",
257
+ "sq",
258
+ "sr",
259
+ "su",
260
+ "sv",
261
+ "sw",
262
+ "ta",
263
+ "te",
264
+ "tg",
265
+ "th",
266
+ "tk",
267
+ "tl",
268
+ "tr",
269
+ "tt",
270
+ "uk",
271
+ "ur",
272
+ "uz",
273
+ "vi",
274
+ "yi",
275
+ "yo",
276
+ "zh",
277
+ "yue",
278
+ )
transcribe.py ADDED
@@ -0,0 +1,1272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import json
3
+ import logging
4
+ import os
5
+ import zlib
6
+
7
+ from inspect import signature
8
+ from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union
9
+
10
+ import ctranslate2
11
+ import numpy as np
12
+ import tokenizers
13
+
14
+ from faster_whisper.audio import decode_audio, pad_or_trim
15
+ from faster_whisper.feature_extractor import FeatureExtractor
16
+ from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer
17
+ from faster_whisper.utils import download_model, format_timestamp, get_end, get_logger
18
+ from faster_whisper.vad import (
19
+ SpeechTimestampsMap,
20
+ VadOptions,
21
+ collect_chunks,
22
+ get_speech_timestamps,
23
+ )
24
+
25
+
26
+ class Word(NamedTuple):
27
+ start: float
28
+ end: float
29
+ word: str
30
+ probability: float
31
+
32
+
33
+ class Segment(NamedTuple):
34
+ id: int
35
+ seek: int
36
+ start: float
37
+ end: float
38
+ text: str
39
+ tokens: List[int]
40
+ temperature: float
41
+ avg_logprob: float
42
+ compression_ratio: float
43
+ no_speech_prob: float
44
+ words: Optional[List[Word]]
45
+
46
+
47
+ class TranscriptionOptions(NamedTuple):
48
+ beam_size: int
49
+ best_of: int
50
+ patience: float
51
+ length_penalty: float
52
+ repetition_penalty: float
53
+ no_repeat_ngram_size: int
54
+ log_prob_threshold: Optional[float]
55
+ no_speech_threshold: Optional[float]
56
+ compression_ratio_threshold: Optional[float]
57
+ condition_on_previous_text: bool
58
+ prompt_reset_on_temperature: float
59
+ temperatures: List[float]
60
+ initial_prompt: Optional[Union[str, Iterable[int]]]
61
+ prefix: Optional[str]
62
+ suppress_blank: bool
63
+ suppress_tokens: Optional[List[int]]
64
+ without_timestamps: bool
65
+ max_initial_timestamp: float
66
+ word_timestamps: bool
67
+ prepend_punctuations: str
68
+ append_punctuations: str
69
+ max_new_tokens: Optional[int]
70
+ clip_timestamps: Union[str, List[float]]
71
+ hallucination_silence_threshold: Optional[float]
72
+ hotwords: Optional[str]
73
+
74
+
75
+ class TranscriptionInfo(NamedTuple):
76
+ language: str
77
+ language_probability: float
78
+ duration: float
79
+ duration_after_vad: float
80
+ all_language_probs: Optional[List[Tuple[str, float]]]
81
+ transcription_options: TranscriptionOptions
82
+ vad_options: VadOptions
83
+
84
+
85
+ class WhisperModel:
86
+ def __init__(
87
+ self,
88
+ model_size_or_path: str,
89
+ device: str = "auto",
90
+ device_index: Union[int, List[int]] = 0,
91
+ compute_type: str = "default",
92
+ cpu_threads: int = 0,
93
+ num_workers: int = 1,
94
+ download_root: Optional[str] = None,
95
+ local_files_only: bool = False,
96
+ files: dict = None,
97
+ **model_kwargs,
98
+ ):
99
+ """Initializes the Whisper model.
100
+
101
+ Args:
102
+ model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en,
103
+ small, small.en, distil-small.en, medium, medium.en, distil-medium.en, large-v1,
104
+ large-v2, large-v3, large, distil-large-v2 or distil-large-v3), a path to a
105
+ converted model directory, or a CTranslate2-converted Whisper model ID from the HF Hub.
106
+ When a size or a model ID is configured, the converted model is downloaded
107
+ from the Hugging Face Hub.
108
+ device: Device to use for computation ("cpu", "cuda", "auto").
109
+ device_index: Device ID to use.
110
+ The model can also be loaded on multiple GPUs by passing a list of IDs
111
+ (e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can run in parallel
112
+ when transcribe() is called from multiple Python threads (see also num_workers).
113
+ compute_type: Type to use for computation.
114
+ See https://opennmt.net/CTranslate2/quantization.html.
115
+ cpu_threads: Number of threads to use when running on CPU (4 by default).
116
+ A non zero value overrides the OMP_NUM_THREADS environment variable.
117
+ num_workers: When transcribe() is called from multiple Python threads,
118
+ having multiple workers enables true parallelism when running the model
119
+ (concurrent calls to self.model.generate() will run in parallel).
120
+ This can improve the global throughput at the cost of increased memory usage.
121
+ download_root: Directory where the models should be saved. If not set, the models
122
+ are saved in the standard Hugging Face cache directory.
123
+ local_files_only: If True, avoid downloading the file and return the path to the
124
+ local cached file if it exists.
125
+ files: Load model files from the memory. This argument is a dictionary mapping file names
126
+ to file contents as file-like or bytes objects. If this is set, model_path acts as an
127
+ identifier for this model.
128
+ """
129
+ self.logger = get_logger()
130
+
131
+ tokenizer_bytes, preprocessor_bytes = None, None
132
+ if files:
133
+ model_path = model_size_or_path
134
+ tokenizer_bytes = files.pop("tokenizer.json", None)
135
+ preprocessor_bytes = files.pop("preprocessor_config.json", None)
136
+ elif os.path.isdir(model_size_or_path):
137
+ model_path = model_size_or_path
138
+ else:
139
+ model_path = download_model(
140
+ model_size_or_path,
141
+ local_files_only=local_files_only,
142
+ cache_dir=download_root,
143
+ )
144
+
145
+ self.model = ctranslate2.models.Whisper(
146
+ model_path,
147
+ device=device,
148
+ device_index=device_index,
149
+ compute_type=compute_type,
150
+ intra_threads=cpu_threads,
151
+ inter_threads=num_workers,
152
+ files=files,
153
+ **model_kwargs,
154
+ )
155
+
156
+ tokenizer_file = os.path.join(model_path, "tokenizer.json")
157
+ if tokenizer_bytes:
158
+ self.hf_tokenizer = tokenizers.Tokenizer.from_buffer(tokenizer_bytes)
159
+ elif os.path.isfile(tokenizer_file):
160
+ self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
161
+ else:
162
+ self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained(
163
+ "openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
164
+ )
165
+ self.feat_kwargs = self._get_feature_kwargs(model_path, preprocessor_bytes)
166
+ self.feature_extractor = FeatureExtractor(**self.feat_kwargs)
167
+ self.num_samples_per_token = self.feature_extractor.hop_length * 2
168
+ self.frames_per_second = (
169
+ self.feature_extractor.sampling_rate // self.feature_extractor.hop_length
170
+ )
171
+ self.tokens_per_second = (
172
+ self.feature_extractor.sampling_rate // self.num_samples_per_token
173
+ )
174
+ self.input_stride = 2
175
+ self.time_precision = 0.02
176
+ self.max_length = 448
177
+
178
+ @property
179
+ def supported_languages(self) -> List[str]:
180
+ """The languages supported by the model."""
181
+ return list(_LANGUAGE_CODES) if self.model.is_multilingual else ["en"]
182
+
183
+ def _get_feature_kwargs(self, model_path, preprocessor_bytes=None) -> dict:
184
+ config = {}
185
+ try:
186
+ config_path = os.path.join(model_path, "preprocessor_config.json")
187
+ if preprocessor_bytes:
188
+ config = json.loads(preprocessor_bytes)
189
+ elif os.path.isfile(config_path):
190
+ with open(config_path, "r", encoding="utf-8") as file:
191
+ config = json.load(file)
192
+ else:
193
+ return config
194
+ valid_keys = signature(FeatureExtractor.__init__).parameters.keys()
195
+ return {k: v for k, v in config.items() if k in valid_keys}
196
+ except json.JSONDecodeError as e:
197
+ self.logger.warning("Could not load preprocessor config: %s", e)
198
+
199
+ return config
200
+
201
+ def transcribe(
202
+ self,
203
+ audio: Union[str, BinaryIO, np.ndarray],
204
+ language: Optional[str] = None,
205
+ task: str = "transcribe",
206
+ beam_size: int = 5,
207
+ best_of: int = 5,
208
+ patience: float = 1,
209
+ length_penalty: float = 1,
210
+ repetition_penalty: float = 1,
211
+ no_repeat_ngram_size: int = 0,
212
+ temperature: Union[float, List[float], Tuple[float, ...]] = [
213
+ 0.0,
214
+ 0.2,
215
+ 0.4,
216
+ 0.6,
217
+ 0.8,
218
+ 1.0,
219
+ ],
220
+ compression_ratio_threshold: Optional[float] = 2.4,
221
+ log_prob_threshold: Optional[float] = -1.0,
222
+ no_speech_threshold: Optional[float] = 0.6,
223
+ condition_on_previous_text: bool = True,
224
+ prompt_reset_on_temperature: float = 0.5,
225
+ initial_prompt: Optional[Union[str, Iterable[int]]] = None,
226
+ prefix: Optional[str] = None,
227
+ suppress_blank: bool = True,
228
+ suppress_tokens: Optional[List[int]] = [-1],
229
+ without_timestamps: bool = False,
230
+ max_initial_timestamp: float = 1.0,
231
+ word_timestamps: bool = False,
232
+ prepend_punctuations: str = "\"'“¿([{-",
233
+ append_punctuations: str = "\"'.。,,!!??::”)]}、",
234
+ vad_filter: bool = False,
235
+ vad_parameters: Optional[Union[dict, VadOptions]] = None,
236
+ max_new_tokens: Optional[int] = None,
237
+ chunk_length: Optional[int] = None,
238
+ clip_timestamps: Union[str, List[float]] = "0",
239
+ hallucination_silence_threshold: Optional[float] = None,
240
+ hotwords: Optional[str] = None,
241
+ language_detection_threshold: Optional[float] = None,
242
+ language_detection_segments: int = 1,
243
+ ) -> Tuple[Iterable[Segment], TranscriptionInfo]:
244
+ """Transcribes an input file.
245
+
246
+ Arguments:
247
+ audio: Path to the input file (or a file-like object), or the audio waveform.
248
+ language: The language spoken in the audio. It should be a language code such
249
+ as "en" or "fr". If not set, the language will be detected in the first 30 seconds
250
+ of audio.
251
+ task: Task to execute (transcribe or translate).
252
+ beam_size: Beam size to use for decoding.
253
+ best_of: Number of candidates when sampling with non-zero temperature.
254
+ patience: Beam search patience factor.
255
+ length_penalty: Exponential length penalty constant.
256
+ repetition_penalty: Penalty applied to the score of previously generated tokens
257
+ (set > 1 to penalize).
258
+ no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable).
259
+ temperature: Temperature for sampling. It can be a tuple of temperatures,
260
+ which will be successively used upon failures according to either
261
+ `compression_ratio_threshold` or `log_prob_threshold`.
262
+ compression_ratio_threshold: If the gzip compression ratio is above this value,
263
+ treat as failed.
264
+ log_prob_threshold: If the average log probability over sampled tokens is
265
+ below this value, treat as failed.
266
+ no_speech_threshold: If the no_speech probability is higher than this value AND
267
+ the average log probability over sampled tokens is below `log_prob_threshold`,
268
+ consider the segment as silent.
269
+ condition_on_previous_text: If True, the previous output of the model is provided
270
+ as a prompt for the next window; disabling may make the text inconsistent across
271
+ windows, but the model becomes less prone to getting stuck in a failure loop,
272
+ such as repetition looping or timestamps going out of sync.
273
+ prompt_reset_on_temperature: Resets prompt if temperature is above this value.
274
+ Arg has effect only if condition_on_previous_text is True.
275
+ initial_prompt: Optional text string or iterable of token ids to provide as a
276
+ prompt for the first window.
277
+ prefix: Optional text to provide as a prefix for the first window.
278
+ suppress_blank: Suppress blank outputs at the beginning of the sampling.
279
+ suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
280
+ of symbols as defined in the model config.json file.
281
+ without_timestamps: Only sample text tokens.
282
+ max_initial_timestamp: The initial timestamp cannot be later than this.
283
+ word_timestamps: Extract word-level timestamps using the cross-attention pattern
284
+ and dynamic time warping, and include the timestamps for each word in each segment.
285
+ prepend_punctuations: If word_timestamps is True, merge these punctuation symbols
286
+ with the next word
287
+ append_punctuations: If word_timestamps is True, merge these punctuation symbols
288
+ with the previous word
289
+ vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
290
+ without speech. This step is using the Silero VAD model
291
+ https://github.com/snakers4/silero-vad.
292
+ vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
293
+ parameters and default values in the class `VadOptions`).
294
+ max_new_tokens: Maximum number of new tokens to generate per-chunk. If not set,
295
+ the maximum will be set by the default max_length.
296
+ chunk_length: The length of audio segments. If it is not None, it will overwrite the
297
+ default chunk_length of the FeatureExtractor.
298
+ clip_timestamps:
299
+ Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to
300
+ process. The last end timestamp defaults to the end of the file.
301
+ vad_filter will be ignored if clip_timestamps is used.
302
+ hallucination_silence_threshold:
303
+ When word_timestamps is True, skip silent periods longer than this threshold
304
+ (in seconds) when a possible hallucination is detected
305
+ hotwords:
306
+ Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.
307
+ language_detection_threshold: If the maximum probability of the language tokens is higher
308
+ than this value, the language is detected.
309
+ language_detection_segments: Number of segments to consider for the language detection.
310
+ Returns:
311
+ A tuple with:
312
+
313
+ - a generator over transcribed segments
314
+ - an instance of TranscriptionInfo
315
+ """
316
+ sampling_rate = self.feature_extractor.sampling_rate
317
+
318
+ if not isinstance(audio, np.ndarray):
319
+ audio = decode_audio(audio, sampling_rate=sampling_rate)
320
+
321
+ duration = audio.shape[0] / sampling_rate
322
+ duration_after_vad = duration
323
+
324
+ self.logger.info(
325
+ "Processing audio with duration %s", format_timestamp(duration)
326
+ )
327
+
328
+ if vad_filter and clip_timestamps == "0":
329
+ if vad_parameters is None:
330
+ vad_parameters = VadOptions()
331
+ elif isinstance(vad_parameters, dict):
332
+ vad_parameters = VadOptions(**vad_parameters)
333
+ speech_chunks = get_speech_timestamps(audio, vad_parameters)
334
+ audio = collect_chunks(audio, speech_chunks)
335
+ duration_after_vad = audio.shape[0] / sampling_rate
336
+
337
+ self.logger.info(
338
+ "VAD filter removed %s of audio",
339
+ format_timestamp(duration - duration_after_vad),
340
+ )
341
+
342
+ if self.logger.isEnabledFor(logging.DEBUG):
343
+ self.logger.debug(
344
+ "VAD filter kept the following audio segments: %s",
345
+ ", ".join(
346
+ "[%s -> %s]"
347
+ % (
348
+ format_timestamp(chunk["start"] / sampling_rate),
349
+ format_timestamp(chunk["end"] / sampling_rate),
350
+ )
351
+ for chunk in speech_chunks
352
+ ),
353
+ )
354
+
355
+ else:
356
+ speech_chunks = None
357
+
358
+ features = self.feature_extractor(audio, chunk_length=chunk_length)
359
+
360
+ encoder_output = None
361
+ all_language_probs = None
362
+
363
+ if language is None:
364
+ if not self.model.is_multilingual:
365
+ language = "en"
366
+ language_probability = 1
367
+ else:
368
+ if (
369
+ language_detection_segments is None
370
+ or language_detection_segments < 1
371
+ ):
372
+ language_detection_segments = 1
373
+ seek = 0
374
+ detected_language_info = {}
375
+ content_frames = (
376
+ features.shape[-1] - self.feature_extractor.nb_max_frames
377
+ )
378
+ while (
379
+ seek <= content_frames
380
+ and seek
381
+ < self.feature_extractor.nb_max_frames * language_detection_segments
382
+ ):
383
+ segment = features[
384
+ :, seek : seek + self.feature_extractor.nb_max_frames
385
+ ]
386
+ encoder_output = self.encode(segment)
387
+ # results is a list of tuple[str, float] with language names and
388
+ # probabilities.
389
+ results = self.model.detect_language(encoder_output)[0]
390
+ # Parse language names to strip out markers
391
+ all_language_probs = [
392
+ (token[2:-2], prob) for (token, prob) in results
393
+ ]
394
+ # Get top language token and probability
395
+ language, language_probability = all_language_probs[0]
396
+ if (
397
+ language_detection_threshold is None
398
+ or language_probability > language_detection_threshold
399
+ ):
400
+ break
401
+ detected_language_info.setdefault(language, []).append(
402
+ language_probability
403
+ )
404
+ seek += segment.shape[-1]
405
+ else:
406
+ # If no language detected for all segments, the majority vote of the highest
407
+ # projected languages for all segments is used to determine the language.
408
+ language = max(
409
+ detected_language_info,
410
+ key=lambda lang: len(detected_language_info[lang]),
411
+ )
412
+ language_probability = max(detected_language_info[language])
413
+
414
+ self.logger.info(
415
+ "Detected language '%s' with probability %.2f",
416
+ language,
417
+ language_probability,
418
+ )
419
+ else:
420
+ if not self.model.is_multilingual and language != "en":
421
+ self.logger.warning(
422
+ "The current model is English-only but the language parameter is set to '%s'; "
423
+ "using 'en' instead." % language
424
+ )
425
+ language = "en"
426
+
427
+ language_probability = 1
428
+
429
+ tokenizer = Tokenizer(
430
+ self.hf_tokenizer,
431
+ self.model.is_multilingual,
432
+ task=task,
433
+ language=language,
434
+ )
435
+
436
+ options = TranscriptionOptions(
437
+ beam_size=beam_size,
438
+ best_of=best_of,
439
+ patience=patience,
440
+ length_penalty=length_penalty,
441
+ repetition_penalty=repetition_penalty,
442
+ no_repeat_ngram_size=no_repeat_ngram_size,
443
+ log_prob_threshold=log_prob_threshold,
444
+ no_speech_threshold=no_speech_threshold,
445
+ compression_ratio_threshold=compression_ratio_threshold,
446
+ condition_on_previous_text=condition_on_previous_text,
447
+ prompt_reset_on_temperature=prompt_reset_on_temperature,
448
+ temperatures=(
449
+ temperature if isinstance(temperature, (list, tuple)) else [temperature]
450
+ ),
451
+ initial_prompt=initial_prompt,
452
+ prefix=prefix,
453
+ suppress_blank=suppress_blank,
454
+ suppress_tokens=get_suppressed_tokens(tokenizer, suppress_tokens),
455
+ without_timestamps=without_timestamps,
456
+ max_initial_timestamp=max_initial_timestamp,
457
+ word_timestamps=word_timestamps,
458
+ prepend_punctuations=prepend_punctuations,
459
+ append_punctuations=append_punctuations,
460
+ max_new_tokens=max_new_tokens,
461
+ clip_timestamps=clip_timestamps,
462
+ hallucination_silence_threshold=hallucination_silence_threshold,
463
+ hotwords=hotwords,
464
+ )
465
+
466
+ segments = self.generate_segments(features, tokenizer, options, encoder_output)
467
+
468
+ if speech_chunks:
469
+ segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate)
470
+
471
+ info = TranscriptionInfo(
472
+ language=language,
473
+ language_probability=language_probability,
474
+ duration=duration,
475
+ duration_after_vad=duration_after_vad,
476
+ transcription_options=options,
477
+ vad_options=vad_parameters,
478
+ all_language_probs=all_language_probs,
479
+ )
480
+
481
+ return segments, info
482
+
483
+ def generate_segments(
484
+ self,
485
+ features: np.ndarray,
486
+ tokenizer: Tokenizer,
487
+ options: TranscriptionOptions,
488
+ encoder_output: Optional[ctranslate2.StorageView] = None,
489
+ ) -> Iterable[Segment]:
490
+ content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
491
+ content_duration = float(content_frames * self.feature_extractor.time_per_frame)
492
+
493
+ if isinstance(options.clip_timestamps, str):
494
+ options = options._replace(
495
+ clip_timestamps=[
496
+ float(ts)
497
+ for ts in (
498
+ options.clip_timestamps.split(",")
499
+ if options.clip_timestamps
500
+ else []
501
+ )
502
+ ]
503
+ )
504
+ seek_points: List[int] = [
505
+ round(ts * self.frames_per_second) for ts in options.clip_timestamps
506
+ ]
507
+ if len(seek_points) == 0:
508
+ seek_points.append(0)
509
+ if len(seek_points) % 2 == 1:
510
+ seek_points.append(content_frames)
511
+ seek_clips: List[Tuple[int, int]] = list(
512
+ zip(seek_points[::2], seek_points[1::2])
513
+ )
514
+
515
+ punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
516
+
517
+ idx = 0
518
+ clip_idx = 0
519
+ seek = seek_clips[clip_idx][0]
520
+ all_tokens = []
521
+ prompt_reset_since = 0
522
+
523
+ if options.initial_prompt is not None:
524
+ if isinstance(options.initial_prompt, str):
525
+ initial_prompt = " " + options.initial_prompt.strip()
526
+ initial_prompt_tokens = tokenizer.encode(initial_prompt)
527
+ all_tokens.extend(initial_prompt_tokens)
528
+ else:
529
+ all_tokens.extend(options.initial_prompt)
530
+
531
+ last_speech_timestamp = 0.0
532
+ # NOTE: This loop is obscurely flattened to make the diff readable.
533
+ # A later commit should turn this into a simpler nested loop.
534
+ # for seek_clip_start, seek_clip_end in seek_clips:
535
+ # while seek < seek_clip_end
536
+ while clip_idx < len(seek_clips):
537
+ seek_clip_start, seek_clip_end = seek_clips[clip_idx]
538
+ if seek_clip_end > content_frames:
539
+ seek_clip_end = content_frames
540
+ if seek < seek_clip_start:
541
+ seek = seek_clip_start
542
+ if seek >= seek_clip_end:
543
+ clip_idx += 1
544
+ if clip_idx < len(seek_clips):
545
+ seek = seek_clips[clip_idx][0]
546
+ continue
547
+ time_offset = seek * self.feature_extractor.time_per_frame
548
+ window_end_time = float(
549
+ (seek + self.feature_extractor.nb_max_frames)
550
+ * self.feature_extractor.time_per_frame
551
+ )
552
+ segment_size = min(
553
+ self.feature_extractor.nb_max_frames,
554
+ content_frames - seek,
555
+ seek_clip_end - seek,
556
+ )
557
+ segment = features[:, seek : seek + segment_size]
558
+ segment_duration = segment_size * self.feature_extractor.time_per_frame
559
+ segment = pad_or_trim(segment, self.feature_extractor.nb_max_frames)
560
+
561
+ if self.logger.isEnabledFor(logging.DEBUG):
562
+ self.logger.debug(
563
+ "Processing segment at %s", format_timestamp(time_offset)
564
+ )
565
+
566
+ previous_tokens = all_tokens[prompt_reset_since:]
567
+ prompt = self.get_prompt(
568
+ tokenizer,
569
+ previous_tokens,
570
+ without_timestamps=options.without_timestamps,
571
+ prefix=options.prefix if seek == 0 else None,
572
+ hotwords=options.hotwords,
573
+ )
574
+
575
+ if seek > 0 or encoder_output is None:
576
+ encoder_output = self.encode(segment)
577
+
578
+ (
579
+ result,
580
+ avg_logprob,
581
+ temperature,
582
+ compression_ratio,
583
+ ) = self.generate_with_fallback(encoder_output, prompt, tokenizer, options)
584
+
585
+ if options.no_speech_threshold is not None:
586
+ # no voice activity check
587
+ should_skip = result.no_speech_prob > options.no_speech_threshold
588
+
589
+ if (
590
+ options.log_prob_threshold is not None
591
+ and avg_logprob > options.log_prob_threshold
592
+ ):
593
+ # don't skip if the logprob is high enough, despite the no_speech_prob
594
+ should_skip = False
595
+
596
+ if should_skip:
597
+ self.logger.debug(
598
+ "No speech threshold is met (%f > %f)",
599
+ result.no_speech_prob,
600
+ options.no_speech_threshold,
601
+ )
602
+
603
+ # fast-forward to the next segment boundary
604
+ seek += segment_size
605
+ continue
606
+
607
+ tokens = result.sequences_ids[0]
608
+
609
+ previous_seek = seek
610
+ current_segments = []
611
+
612
+ # anomalous words are very long/short/improbable
613
+ def word_anomaly_score(word: dict) -> float:
614
+ probability = word.get("probability", 0.0)
615
+ duration = word["end"] - word["start"]
616
+ score = 0.0
617
+ if probability < 0.15:
618
+ score += 1.0
619
+ if duration < 0.133:
620
+ score += (0.133 - duration) * 15
621
+ if duration > 2.0:
622
+ score += duration - 2.0
623
+ return score
624
+
625
+ def is_segment_anomaly(segment: Optional[dict]) -> bool:
626
+ if segment is None or not segment["words"]:
627
+ return False
628
+ words = [w for w in segment["words"] if w["word"] not in punctuation]
629
+ words = words[:8]
630
+ score = sum(word_anomaly_score(w) for w in words)
631
+ return score >= 3 or score + 0.01 >= len(words)
632
+
633
+ def next_words_segment(segments: List[dict]) -> Optional[dict]:
634
+ return next((s for s in segments if s["words"]), None)
635
+
636
+ single_timestamp_ending = (
637
+ len(tokens) >= 2
638
+ and tokens[-2] < tokenizer.timestamp_begin <= tokens[-1]
639
+ )
640
+
641
+ consecutive_timestamps = [
642
+ i
643
+ for i in range(len(tokens))
644
+ if i > 0
645
+ and tokens[i] >= tokenizer.timestamp_begin
646
+ and tokens[i - 1] >= tokenizer.timestamp_begin
647
+ ]
648
+
649
+ if len(consecutive_timestamps) > 0:
650
+ slices = list(consecutive_timestamps)
651
+ if single_timestamp_ending:
652
+ slices.append(len(tokens))
653
+
654
+ last_slice = 0
655
+ for current_slice in slices:
656
+ sliced_tokens = tokens[last_slice:current_slice]
657
+ start_timestamp_position = (
658
+ sliced_tokens[0] - tokenizer.timestamp_begin
659
+ )
660
+ end_timestamp_position = (
661
+ sliced_tokens[-1] - tokenizer.timestamp_begin
662
+ )
663
+ start_time = (
664
+ time_offset + start_timestamp_position * self.time_precision
665
+ )
666
+ end_time = (
667
+ time_offset + end_timestamp_position * self.time_precision
668
+ )
669
+
670
+ current_segments.append(
671
+ dict(
672
+ seek=seek,
673
+ start=start_time,
674
+ end=end_time,
675
+ tokens=sliced_tokens,
676
+ )
677
+ )
678
+ last_slice = current_slice
679
+
680
+ if single_timestamp_ending:
681
+ # single timestamp at the end means no speech after the last timestamp.
682
+ seek += segment_size
683
+ else:
684
+ # otherwise, ignore the unfinished segment and seek to the last timestamp
685
+ last_timestamp_position = (
686
+ tokens[last_slice - 1] - tokenizer.timestamp_begin
687
+ )
688
+ seek += last_timestamp_position * self.input_stride
689
+
690
+ else:
691
+ duration = segment_duration
692
+ timestamps = [
693
+ token for token in tokens if token >= tokenizer.timestamp_begin
694
+ ]
695
+ if len(timestamps) > 0 and timestamps[-1] != tokenizer.timestamp_begin:
696
+ last_timestamp_position = timestamps[-1] - tokenizer.timestamp_begin
697
+ duration = last_timestamp_position * self.time_precision
698
+
699
+ current_segments.append(
700
+ dict(
701
+ seek=seek,
702
+ start=time_offset,
703
+ end=time_offset + duration,
704
+ tokens=tokens,
705
+ )
706
+ )
707
+
708
+ seek += segment_size
709
+
710
+ if options.word_timestamps:
711
+ self.add_word_timestamps(
712
+ current_segments,
713
+ tokenizer,
714
+ encoder_output,
715
+ segment_size,
716
+ options.prepend_punctuations,
717
+ options.append_punctuations,
718
+ last_speech_timestamp=last_speech_timestamp,
719
+ )
720
+
721
+ if not single_timestamp_ending:
722
+ last_word_end = get_end(current_segments)
723
+ if last_word_end is not None and last_word_end > time_offset:
724
+ seek = round(last_word_end * self.frames_per_second)
725
+
726
+ # skip silence before possible hallucinations
727
+ if options.hallucination_silence_threshold is not None:
728
+ threshold = options.hallucination_silence_threshold
729
+
730
+ # if first segment might be a hallucination, skip leading silence
731
+ first_segment = next_words_segment(current_segments)
732
+ if first_segment is not None and is_segment_anomaly(first_segment):
733
+ gap = first_segment["start"] - time_offset
734
+ if gap > threshold:
735
+ seek = previous_seek + round(gap * self.frames_per_second)
736
+ continue
737
+
738
+ # skip silence before any possible hallucination that is surrounded
739
+ # by silence or more hallucinations
740
+ hal_last_end = last_speech_timestamp
741
+ for si in range(len(current_segments)):
742
+ segment = current_segments[si]
743
+ if not segment["words"]:
744
+ continue
745
+ if is_segment_anomaly(segment):
746
+ next_segment = next_words_segment(
747
+ current_segments[si + 1 :]
748
+ )
749
+ if next_segment is not None:
750
+ hal_next_start = next_segment["words"][0]["start"]
751
+ else:
752
+ hal_next_start = time_offset + segment_duration
753
+ silence_before = (
754
+ segment["start"] - hal_last_end > threshold
755
+ or segment["start"] < threshold
756
+ or segment["start"] - time_offset < 2.0
757
+ )
758
+ silence_after = (
759
+ hal_next_start - segment["end"] > threshold
760
+ or is_segment_anomaly(next_segment)
761
+ or window_end_time - segment["end"] < 2.0
762
+ )
763
+ if silence_before and silence_after:
764
+ seek = round(
765
+ max(time_offset + 1, segment["start"])
766
+ * self.frames_per_second
767
+ )
768
+ if content_duration - segment["end"] < threshold:
769
+ seek = content_frames
770
+ current_segments[si:] = []
771
+ break
772
+ hal_last_end = segment["end"]
773
+
774
+ last_word_end = get_end(current_segments)
775
+ if last_word_end is not None:
776
+ last_speech_timestamp = last_word_end
777
+
778
+ for segment in current_segments:
779
+ tokens = segment["tokens"]
780
+ text = tokenizer.decode(tokens)
781
+
782
+ if segment["start"] == segment["end"] or not text.strip():
783
+ continue
784
+
785
+ all_tokens.extend(tokens)
786
+ idx += 1
787
+
788
+ yield Segment(
789
+ id=idx,
790
+ seek=seek,
791
+ start=segment["start"],
792
+ end=segment["end"],
793
+ text=text,
794
+ tokens=tokens,
795
+ temperature=temperature,
796
+ avg_logprob=avg_logprob,
797
+ compression_ratio=compression_ratio,
798
+ no_speech_prob=result.no_speech_prob,
799
+ words=(
800
+ [Word(**word) for word in segment["words"]]
801
+ if options.word_timestamps
802
+ else None
803
+ ),
804
+ )
805
+
806
+ if (
807
+ not options.condition_on_previous_text
808
+ or temperature > options.prompt_reset_on_temperature
809
+ ):
810
+ if options.condition_on_previous_text:
811
+ self.logger.debug(
812
+ "Reset prompt. prompt_reset_on_temperature threshold is met %f > %f",
813
+ temperature,
814
+ options.prompt_reset_on_temperature,
815
+ )
816
+
817
+ prompt_reset_since = len(all_tokens)
818
+
819
+ def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
820
+ # When the model is running on multiple GPUs, the encoder output should be moved
821
+ # to the CPU since we don't know which GPU will handle the next job.
822
+ to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
823
+
824
+ features = np.expand_dims(features, 0)
825
+ features = get_ctranslate2_storage(features)
826
+
827
+ return self.model.encode(features, to_cpu=to_cpu)
828
+
829
+ def generate_with_fallback(
830
+ self,
831
+ encoder_output: ctranslate2.StorageView,
832
+ prompt: List[int],
833
+ tokenizer: Tokenizer,
834
+ options: TranscriptionOptions,
835
+ ) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]:
836
+ decode_result = None
837
+ all_results = []
838
+ below_cr_threshold_results = []
839
+
840
+ max_initial_timestamp_index = int(
841
+ round(options.max_initial_timestamp / self.time_precision)
842
+ )
843
+ if options.max_new_tokens is not None:
844
+ max_length = len(prompt) + options.max_new_tokens
845
+ else:
846
+ max_length = self.max_length
847
+
848
+ if max_length > self.max_length:
849
+ raise ValueError(
850
+ f"The length of the prompt is {len(prompt)}, and the `max_new_tokens` "
851
+ f"{max_length - len(prompt)}. Thus, the combined length of the prompt "
852
+ f"and `max_new_tokens` is: {max_length}. This exceeds the "
853
+ f"`max_length` of the Whisper model: {self.max_length}. "
854
+ "You should either reduce the length of your prompt, or "
855
+ "reduce the value of `max_new_tokens`, "
856
+ f"so that their combined length is less that {self.max_length}."
857
+ )
858
+
859
+ for temperature in options.temperatures:
860
+ if temperature > 0:
861
+ kwargs = {
862
+ "beam_size": 1,
863
+ "num_hypotheses": options.best_of,
864
+ "sampling_topk": 0,
865
+ "sampling_temperature": temperature,
866
+ }
867
+ else:
868
+ kwargs = {
869
+ "beam_size": options.beam_size,
870
+ "patience": options.patience,
871
+ }
872
+
873
+ result = self.model.generate(
874
+ encoder_output,
875
+ [prompt],
876
+ length_penalty=options.length_penalty,
877
+ repetition_penalty=options.repetition_penalty,
878
+ no_repeat_ngram_size=options.no_repeat_ngram_size,
879
+ max_length=max_length,
880
+ return_scores=True,
881
+ return_no_speech_prob=True,
882
+ suppress_blank=options.suppress_blank,
883
+ suppress_tokens=options.suppress_tokens,
884
+ max_initial_timestamp_index=max_initial_timestamp_index,
885
+ **kwargs,
886
+ )[0]
887
+
888
+ tokens = result.sequences_ids[0]
889
+
890
+ # Recover the average log prob from the returned score.
891
+ seq_len = len(tokens)
892
+ cum_logprob = result.scores[0] * (seq_len**options.length_penalty)
893
+ avg_logprob = cum_logprob / (seq_len + 1)
894
+
895
+ text = tokenizer.decode(tokens).strip()
896
+ compression_ratio = get_compression_ratio(text)
897
+
898
+ decode_result = (
899
+ result,
900
+ avg_logprob,
901
+ temperature,
902
+ compression_ratio,
903
+ )
904
+ all_results.append(decode_result)
905
+
906
+ needs_fallback = False
907
+
908
+ if options.compression_ratio_threshold is not None:
909
+ if compression_ratio > options.compression_ratio_threshold:
910
+ needs_fallback = True # too repetitive
911
+
912
+ self.logger.debug(
913
+ "Compression ratio threshold is not met with temperature %.1f (%f > %f)",
914
+ temperature,
915
+ compression_ratio,
916
+ options.compression_ratio_threshold,
917
+ )
918
+ else:
919
+ below_cr_threshold_results.append(decode_result)
920
+
921
+ if (
922
+ options.log_prob_threshold is not None
923
+ and avg_logprob < options.log_prob_threshold
924
+ ):
925
+ needs_fallback = True # average log probability is too low
926
+
927
+ self.logger.debug(
928
+ "Log probability threshold is not met with temperature %.1f (%f < %f)",
929
+ temperature,
930
+ avg_logprob,
931
+ options.log_prob_threshold,
932
+ )
933
+
934
+ if (
935
+ options.no_speech_threshold is not None
936
+ and result.no_speech_prob > options.no_speech_threshold
937
+ and options.log_prob_threshold is not None
938
+ and avg_logprob < options.log_prob_threshold
939
+ ):
940
+ needs_fallback = False # silence
941
+
942
+ if not needs_fallback:
943
+ break
944
+ else:
945
+ # all failed, select the result with the highest average log probability
946
+ decode_result = max(
947
+ below_cr_threshold_results or all_results, key=lambda x: x[1]
948
+ )
949
+ # to pass final temperature for prompt_reset_on_temperature
950
+ decode_result = (
951
+ decode_result[0],
952
+ decode_result[1],
953
+ temperature,
954
+ decode_result[3],
955
+ )
956
+
957
+ return decode_result
958
+
959
+ def get_prompt(
960
+ self,
961
+ tokenizer: Tokenizer,
962
+ previous_tokens: List[int],
963
+ without_timestamps: bool = False,
964
+ prefix: Optional[str] = None,
965
+ hotwords: Optional[str] = None,
966
+ ) -> List[int]:
967
+ prompt = []
968
+
969
+ if previous_tokens or (hotwords and not prefix):
970
+ prompt.append(tokenizer.sot_prev)
971
+ if hotwords and not prefix:
972
+ hotwords_tokens = tokenizer.encode(" " + hotwords.strip())
973
+ if len(hotwords_tokens) >= self.max_length // 2:
974
+ hotwords_tokens = hotwords_tokens[: self.max_length // 2 - 1]
975
+ prompt.extend(hotwords_tokens)
976
+ if previous_tokens:
977
+ prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
978
+
979
+ prompt.extend(tokenizer.sot_sequence)
980
+
981
+ if without_timestamps:
982
+ prompt.append(tokenizer.no_timestamps)
983
+
984
+ if prefix:
985
+ prefix_tokens = tokenizer.encode(" " + prefix.strip())
986
+ if len(prefix_tokens) >= self.max_length // 2:
987
+ prefix_tokens = prefix_tokens[: self.max_length // 2 - 1]
988
+ if not without_timestamps:
989
+ prompt.append(tokenizer.timestamp_begin)
990
+ prompt.extend(prefix_tokens)
991
+
992
+ return prompt
993
+
994
+ def add_word_timestamps(
995
+ self,
996
+ segments: List[dict],
997
+ tokenizer: Tokenizer,
998
+ encoder_output: ctranslate2.StorageView,
999
+ num_frames: int,
1000
+ prepend_punctuations: str,
1001
+ append_punctuations: str,
1002
+ last_speech_timestamp: float,
1003
+ ) -> None:
1004
+ if len(segments) == 0:
1005
+ return
1006
+
1007
+ text_tokens_per_segment = [
1008
+ [token for token in segment["tokens"] if token < tokenizer.eot]
1009
+ for segment in segments
1010
+ ]
1011
+
1012
+ text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
1013
+ alignment = self.find_alignment(
1014
+ tokenizer, text_tokens, encoder_output, num_frames
1015
+ )
1016
+ word_durations = np.array([word["end"] - word["start"] for word in alignment])
1017
+ word_durations = word_durations[word_durations.nonzero()]
1018
+ median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
1019
+ median_duration = min(0.7, float(median_duration))
1020
+ max_duration = median_duration * 2
1021
+
1022
+ # hack: truncate long words at sentence boundaries.
1023
+ # a better segmentation algorithm based on VAD should be able to replace this.
1024
+ if len(word_durations) > 0:
1025
+ sentence_end_marks = ".。!!??"
1026
+ # ensure words at sentence boundaries
1027
+ # are not longer than twice the median word duration.
1028
+ for i in range(1, len(alignment)):
1029
+ if alignment[i]["end"] - alignment[i]["start"] > max_duration:
1030
+ if alignment[i]["word"] in sentence_end_marks:
1031
+ alignment[i]["end"] = alignment[i]["start"] + max_duration
1032
+ elif alignment[i - 1]["word"] in sentence_end_marks:
1033
+ alignment[i]["start"] = alignment[i]["end"] - max_duration
1034
+
1035
+ merge_punctuations(alignment, prepend_punctuations, append_punctuations)
1036
+
1037
+ time_offset = (
1038
+ segments[0]["seek"]
1039
+ * self.feature_extractor.hop_length
1040
+ / self.feature_extractor.sampling_rate
1041
+ )
1042
+
1043
+ word_index = 0
1044
+
1045
+ for segment, text_tokens in zip(segments, text_tokens_per_segment):
1046
+ saved_tokens = 0
1047
+ words = []
1048
+
1049
+ while word_index < len(alignment) and saved_tokens < len(text_tokens):
1050
+ timing = alignment[word_index]
1051
+
1052
+ if timing["word"]:
1053
+ words.append(
1054
+ dict(
1055
+ word=timing["word"],
1056
+ start=round(time_offset + timing["start"], 2),
1057
+ end=round(time_offset + timing["end"], 2),
1058
+ probability=timing["probability"],
1059
+ )
1060
+ )
1061
+
1062
+ saved_tokens += len(timing["tokens"])
1063
+ word_index += 1
1064
+
1065
+ # hack: truncate long words at segment boundaries.
1066
+ # a better segmentation algorithm based on VAD should be able to replace this.
1067
+ if len(words) > 0:
1068
+ # ensure the first and second word after a pause is not longer than
1069
+ # twice the median word duration.
1070
+ if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
1071
+ words[0]["end"] - words[0]["start"] > max_duration
1072
+ or (
1073
+ len(words) > 1
1074
+ and words[1]["end"] - words[0]["start"] > max_duration * 2
1075
+ )
1076
+ ):
1077
+ if (
1078
+ len(words) > 1
1079
+ and words[1]["end"] - words[1]["start"] > max_duration
1080
+ ):
1081
+ boundary = max(
1082
+ words[1]["end"] / 2, words[1]["end"] - max_duration
1083
+ )
1084
+ words[0]["end"] = words[1]["start"] = boundary
1085
+ words[0]["start"] = max(0, words[0]["end"] - max_duration)
1086
+
1087
+ # prefer the segment-level start timestamp if the first word is too long.
1088
+ if (
1089
+ segment["start"] < words[0]["end"]
1090
+ and segment["start"] - 0.5 > words[0]["start"]
1091
+ ):
1092
+ words[0]["start"] = max(
1093
+ 0, min(words[0]["end"] - median_duration, segment["start"])
1094
+ )
1095
+ else:
1096
+ segment["start"] = words[0]["start"]
1097
+
1098
+ # prefer the segment-level end timestamp if the last word is too long.
1099
+ if (
1100
+ segment["end"] > words[-1]["start"]
1101
+ and segment["end"] + 0.5 < words[-1]["end"]
1102
+ ):
1103
+ words[-1]["end"] = max(
1104
+ words[-1]["start"] + median_duration, segment["end"]
1105
+ )
1106
+ else:
1107
+ segment["end"] = words[-1]["end"]
1108
+
1109
+ last_speech_timestamp = segment["end"]
1110
+
1111
+ segment["words"] = words
1112
+
1113
+ def find_alignment(
1114
+ self,
1115
+ tokenizer: Tokenizer,
1116
+ text_tokens: List[int],
1117
+ encoder_output: ctranslate2.StorageView,
1118
+ num_frames: int,
1119
+ median_filter_width: int = 7,
1120
+ ) -> List[dict]:
1121
+ if len(text_tokens) == 0:
1122
+ return []
1123
+
1124
+ result = self.model.align(
1125
+ encoder_output,
1126
+ tokenizer.sot_sequence,
1127
+ [text_tokens],
1128
+ num_frames,
1129
+ median_filter_width=median_filter_width,
1130
+ )[0]
1131
+
1132
+ text_token_probs = result.text_token_probs
1133
+
1134
+ alignments = result.alignments
1135
+ text_indices = np.array([pair[0] for pair in alignments])
1136
+ time_indices = np.array([pair[1] for pair in alignments])
1137
+
1138
+ words, word_tokens = tokenizer.split_to_word_tokens(
1139
+ text_tokens + [tokenizer.eot]
1140
+ )
1141
+ if len(word_tokens) <= 1:
1142
+ # return on eot only
1143
+ # >>> np.pad([], (1, 0))
1144
+ # array([0.])
1145
+ # This results in crashes when we lookup jump_times with float, like
1146
+ # IndexError: arrays used as indices must be of integer (or boolean) type
1147
+ return []
1148
+ word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
1149
+ if len(word_boundaries) <= 1:
1150
+ return []
1151
+
1152
+ jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
1153
+ jump_times = time_indices[jumps] / self.tokens_per_second
1154
+ start_times = jump_times[word_boundaries[:-1]]
1155
+ end_times = jump_times[word_boundaries[1:]]
1156
+ word_probabilities = [
1157
+ np.mean(text_token_probs[i:j])
1158
+ for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
1159
+ ]
1160
+
1161
+ return [
1162
+ dict(
1163
+ word=word, tokens=tokens, start=start, end=end, probability=probability
1164
+ )
1165
+ for word, tokens, start, end, probability in zip(
1166
+ words, word_tokens, start_times, end_times, word_probabilities
1167
+ )
1168
+ ]
1169
+
1170
+
1171
+ def restore_speech_timestamps(
1172
+ segments: Iterable[Segment],
1173
+ speech_chunks: List[dict],
1174
+ sampling_rate: int,
1175
+ ) -> Iterable[Segment]:
1176
+ ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
1177
+
1178
+ for segment in segments:
1179
+ if segment.words:
1180
+ words = []
1181
+ for word in segment.words:
1182
+ # Ensure the word start and end times are resolved to the same chunk.
1183
+ middle = (word.start + word.end) / 2
1184
+ chunk_index = ts_map.get_chunk_index(middle)
1185
+ word = word._replace(
1186
+ start=ts_map.get_original_time(word.start, chunk_index),
1187
+ end=ts_map.get_original_time(word.end, chunk_index),
1188
+ )
1189
+ words.append(word)
1190
+
1191
+ segment = segment._replace(
1192
+ start=words[0].start,
1193
+ end=words[-1].end,
1194
+ words=words,
1195
+ )
1196
+
1197
+ else:
1198
+ segment = segment._replace(
1199
+ start=ts_map.get_original_time(segment.start),
1200
+ end=ts_map.get_original_time(segment.end),
1201
+ )
1202
+
1203
+ yield segment
1204
+
1205
+
1206
+ def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView:
1207
+ segment = np.ascontiguousarray(segment)
1208
+ segment = ctranslate2.StorageView.from_array(segment)
1209
+ return segment
1210
+
1211
+
1212
+ def get_compression_ratio(text: str) -> float:
1213
+ text_bytes = text.encode("utf-8")
1214
+ return len(text_bytes) / len(zlib.compress(text_bytes))
1215
+
1216
+
1217
+ def get_suppressed_tokens(
1218
+ tokenizer: Tokenizer,
1219
+ suppress_tokens: Optional[List[int]],
1220
+ ) -> Optional[List[int]]:
1221
+ if not suppress_tokens or -1 in suppress_tokens:
1222
+ return suppress_tokens
1223
+
1224
+ suppress_tokens = list(suppress_tokens)
1225
+
1226
+ # Ensure the following special tokens are suppressed when the user does
1227
+ # not use the default set (-1).
1228
+ suppress_tokens.extend(
1229
+ [
1230
+ tokenizer.transcribe,
1231
+ tokenizer.translate,
1232
+ tokenizer.sot,
1233
+ tokenizer.sot_prev,
1234
+ tokenizer.sot_lm,
1235
+ ]
1236
+ )
1237
+
1238
+ return sorted(set(suppress_tokens))
1239
+
1240
+
1241
+ def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> None:
1242
+ # merge prepended punctuations
1243
+ i = len(alignment) - 2
1244
+ j = len(alignment) - 1
1245
+ while i >= 0:
1246
+ previous = alignment[i]
1247
+ following = alignment[j]
1248
+ if previous["word"].startswith(" ") and previous["word"].strip() in prepended:
1249
+ # prepend it to the following word
1250
+ following["word"] = previous["word"] + following["word"]
1251
+ following["tokens"] = previous["tokens"] + following["tokens"]
1252
+ previous["word"] = ""
1253
+ previous["tokens"] = []
1254
+ else:
1255
+ j = i
1256
+ i -= 1
1257
+
1258
+ # merge appended punctuations
1259
+ i = 0
1260
+ j = 1
1261
+ while j < len(alignment):
1262
+ previous = alignment[i]
1263
+ following = alignment[j]
1264
+ if not previous["word"].endswith(" ") and following["word"] in appended:
1265
+ # append it to the previous word
1266
+ previous["word"] = previous["word"] + following["word"]
1267
+ previous["tokens"] = previous["tokens"] + following["tokens"]
1268
+ following["word"] = ""
1269
+ following["tokens"] = []
1270
+ else:
1271
+ i = j
1272
+ j += 1
utils.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import re
4
+
5
+ from typing import List, Optional
6
+
7
+ import huggingface_hub
8
+ import requests
9
+
10
+ from tqdm.auto import tqdm
11
+
12
+ _MODELS = {
13
+ "tiny.en": "Systran/faster-whisper-tiny.en",
14
+ "tiny": "Systran/faster-whisper-tiny",
15
+ "base.en": "Systran/faster-whisper-base.en",
16
+ "base": "Systran/faster-whisper-base",
17
+ "small.en": "Systran/faster-whisper-small.en",
18
+ "small": "Systran/faster-whisper-small",
19
+ "medium.en": "Systran/faster-whisper-medium.en",
20
+ "medium": "Systran/faster-whisper-medium",
21
+ "large-v1": "Systran/faster-whisper-large-v1",
22
+ "large-v2": "Systran/faster-whisper-large-v2",
23
+ "large-v3": "Systran/faster-whisper-large-v3",
24
+ "large": "Systran/faster-whisper-large-v3",
25
+ "distil-large-v2": "Systran/faster-distil-whisper-large-v2",
26
+ "distil-medium.en": "Systran/faster-distil-whisper-medium.en",
27
+ "distil-small.en": "Systran/faster-distil-whisper-small.en",
28
+ "distil-large-v3": "Systran/faster-distil-whisper-large-v3",
29
+ }
30
+
31
+
32
+ def available_models() -> List[str]:
33
+ """Returns the names of available models."""
34
+ return list(_MODELS.keys())
35
+
36
+
37
+ def get_assets_path():
38
+ """Returns the path to the assets directory."""
39
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
40
+
41
+
42
+ def get_logger():
43
+ """Returns the module logger."""
44
+ return logging.getLogger("faster_whisper")
45
+
46
+
47
+ def download_model(
48
+ size_or_id: str,
49
+ output_dir: Optional[str] = None,
50
+ local_files_only: bool = False,
51
+ cache_dir: Optional[str] = None,
52
+ ):
53
+ """Downloads a CTranslate2 Whisper model from the Hugging Face Hub.
54
+
55
+ Args:
56
+ size_or_id: Size of the model to download from https://huggingface.co/Systran
57
+ (tiny, tiny.en, base, base.en, small, small.en, distil-small.en, medium, medium.en,
58
+ distil-medium.en, large-v1, large-v2, large-v3, large, distil-large-v2,
59
+ distil-large-v3), or a CTranslate2-converted model ID from the Hugging Face Hub
60
+ (e.g. Systran/faster-whisper-large-v3).
61
+ output_dir: Directory where the model should be saved. If not set, the model is saved in
62
+ the cache directory.
63
+ local_files_only: If True, avoid downloading the file and return the path to the local
64
+ cached file if it exists.
65
+ cache_dir: Path to the folder where cached files are stored.
66
+
67
+ Returns:
68
+ The path to the downloaded model.
69
+
70
+ Raises:
71
+ ValueError: if the model size is invalid.
72
+ """
73
+ if re.match(r".*/.*", size_or_id):
74
+ repo_id = size_or_id
75
+ else:
76
+ repo_id = _MODELS.get(size_or_id)
77
+ if repo_id is None:
78
+ raise ValueError(
79
+ "Invalid model size '%s', expected one of: %s"
80
+ % (size_or_id, ", ".join(_MODELS.keys()))
81
+ )
82
+
83
+ allow_patterns = [
84
+ "config.json",
85
+ "preprocessor_config.json",
86
+ "model.bin",
87
+ "tokenizer.json",
88
+ "vocabulary.*",
89
+ ]
90
+
91
+ kwargs = {
92
+ "local_files_only": local_files_only,
93
+ "allow_patterns": allow_patterns,
94
+ "tqdm_class": disabled_tqdm,
95
+ }
96
+
97
+ if output_dir is not None:
98
+ kwargs["local_dir"] = output_dir
99
+ kwargs["local_dir_use_symlinks"] = False
100
+
101
+ if cache_dir is not None:
102
+ kwargs["cache_dir"] = cache_dir
103
+
104
+ try:
105
+ return huggingface_hub.snapshot_download(repo_id, **kwargs)
106
+ except (
107
+ huggingface_hub.utils.HfHubHTTPError,
108
+ requests.exceptions.ConnectionError,
109
+ ) as exception:
110
+ logger = get_logger()
111
+ logger.warning(
112
+ "An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s",
113
+ repo_id,
114
+ exception,
115
+ )
116
+ logger.warning(
117
+ "Trying to load the model directly from the local cache, if it exists."
118
+ )
119
+
120
+ kwargs["local_files_only"] = True
121
+ return huggingface_hub.snapshot_download(repo_id, **kwargs)
122
+
123
+
124
+ def format_timestamp(
125
+ seconds: float,
126
+ always_include_hours: bool = False,
127
+ decimal_marker: str = ".",
128
+ ) -> str:
129
+ assert seconds >= 0, "non-negative timestamp expected"
130
+ milliseconds = round(seconds * 1000.0)
131
+
132
+ hours = milliseconds // 3_600_000
133
+ milliseconds -= hours * 3_600_000
134
+
135
+ minutes = milliseconds // 60_000
136
+ milliseconds -= minutes * 60_000
137
+
138
+ seconds = milliseconds // 1_000
139
+ milliseconds -= seconds * 1_000
140
+
141
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
142
+ return (
143
+ f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
144
+ )
145
+
146
+
147
+ class disabled_tqdm(tqdm):
148
+ def __init__(self, *args, **kwargs):
149
+ kwargs["disable"] = True
150
+ super().__init__(*args, **kwargs)
151
+
152
+
153
+ def get_end(segments: List[dict]) -> Optional[float]:
154
+ return next(
155
+ (w["end"] for s in reversed(segments) for w in reversed(s["words"])),
156
+ segments[-1]["end"] if segments else None,
157
+ )
vad.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bisect
2
+ import functools
3
+ import os
4
+ import warnings
5
+
6
+ from typing import List, NamedTuple, Optional
7
+
8
+ import numpy as np
9
+
10
+ from faster_whisper.utils import get_assets_path
11
+
12
+
13
+ # The code below is adapted from https://github.com/snakers4/silero-vad.
14
+ class VadOptions(NamedTuple):
15
+ """VAD options.
16
+
17
+ Attributes:
18
+ threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
19
+ probabilities ABOVE this value are considered as SPEECH. It is better to tune this
20
+ parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
21
+ min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out.
22
+ max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer
23
+ than max_speech_duration_s will be split at the timestamp of the last silence that
24
+ lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
25
+ split aggressively just before max_speech_duration_s.
26
+ min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
27
+ before separating it
28
+ window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model.
29
+ WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
30
+ Values other than these may affect model performance!!
31
+ speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
32
+ """
33
+
34
+ threshold: float = 0.5
35
+ min_speech_duration_ms: int = 250
36
+ max_speech_duration_s: float = float("inf")
37
+ min_silence_duration_ms: int = 2000
38
+ window_size_samples: int = 1024
39
+ speech_pad_ms: int = 400
40
+
41
+
42
+ def get_speech_timestamps(
43
+ audio: np.ndarray,
44
+ vad_options: Optional[VadOptions] = None,
45
+ **kwargs,
46
+ ) -> List[dict]:
47
+ """This method is used for splitting long audios into speech chunks using silero VAD.
48
+
49
+ Args:
50
+ audio: One dimensional float array.
51
+ vad_options: Options for VAD processing.
52
+ kwargs: VAD options passed as keyword arguments for backward compatibility.
53
+
54
+ Returns:
55
+ List of dicts containing begin and end samples of each speech chunk.
56
+ """
57
+ if vad_options is None:
58
+ vad_options = VadOptions(**kwargs)
59
+
60
+ threshold = vad_options.threshold
61
+ min_speech_duration_ms = vad_options.min_speech_duration_ms
62
+ max_speech_duration_s = vad_options.max_speech_duration_s
63
+ min_silence_duration_ms = vad_options.min_silence_duration_ms
64
+ window_size_samples = vad_options.window_size_samples
65
+ speech_pad_ms = vad_options.speech_pad_ms
66
+
67
+ if window_size_samples not in [512, 1024, 1536]:
68
+ warnings.warn(
69
+ "Unusual window_size_samples! Supported window_size_samples:\n"
70
+ " - [512, 1024, 1536] for 16000 sampling_rate"
71
+ )
72
+
73
+ sampling_rate = 16000
74
+ min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
75
+ speech_pad_samples = sampling_rate * speech_pad_ms / 1000
76
+ max_speech_samples = (
77
+ sampling_rate * max_speech_duration_s
78
+ - window_size_samples
79
+ - 2 * speech_pad_samples
80
+ )
81
+ min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
82
+ min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
83
+
84
+ audio_length_samples = len(audio)
85
+
86
+ model = get_vad_model()
87
+ state = model.get_initial_state(batch_size=1)
88
+
89
+ speech_probs = []
90
+ for current_start_sample in range(0, audio_length_samples, window_size_samples):
91
+ chunk = audio[current_start_sample : current_start_sample + window_size_samples]
92
+ if len(chunk) < window_size_samples:
93
+ chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
94
+ speech_prob, state = model(chunk, state, sampling_rate)
95
+ speech_probs.append(speech_prob)
96
+
97
+ triggered = False
98
+ speeches = []
99
+ current_speech = {}
100
+ neg_threshold = threshold - 0.15
101
+
102
+ # to save potential segment end (and tolerate some silence)
103
+ temp_end = 0
104
+ # to save potential segment limits in case of maximum segment size reached
105
+ prev_end = next_start = 0
106
+
107
+ for i, speech_prob in enumerate(speech_probs):
108
+ if (speech_prob >= threshold) and temp_end:
109
+ temp_end = 0
110
+ if next_start < prev_end:
111
+ next_start = window_size_samples * i
112
+
113
+ if (speech_prob >= threshold) and not triggered:
114
+ triggered = True
115
+ current_speech["start"] = window_size_samples * i
116
+ continue
117
+
118
+ if (
119
+ triggered
120
+ and (window_size_samples * i) - current_speech["start"] > max_speech_samples
121
+ ):
122
+ if prev_end:
123
+ current_speech["end"] = prev_end
124
+ speeches.append(current_speech)
125
+ current_speech = {}
126
+ # previously reached silence (< neg_thres) and is still not speech (< thres)
127
+ if next_start < prev_end:
128
+ triggered = False
129
+ else:
130
+ current_speech["start"] = next_start
131
+ prev_end = next_start = temp_end = 0
132
+ else:
133
+ current_speech["end"] = window_size_samples * i
134
+ speeches.append(current_speech)
135
+ current_speech = {}
136
+ prev_end = next_start = temp_end = 0
137
+ triggered = False
138
+ continue
139
+
140
+ if (speech_prob < neg_threshold) and triggered:
141
+ if not temp_end:
142
+ temp_end = window_size_samples * i
143
+ # condition to avoid cutting in very short silence
144
+ if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech:
145
+ prev_end = temp_end
146
+ if (window_size_samples * i) - temp_end < min_silence_samples:
147
+ continue
148
+ else:
149
+ current_speech["end"] = temp_end
150
+ if (
151
+ current_speech["end"] - current_speech["start"]
152
+ ) > min_speech_samples:
153
+ speeches.append(current_speech)
154
+ current_speech = {}
155
+ prev_end = next_start = temp_end = 0
156
+ triggered = False
157
+ continue
158
+
159
+ if (
160
+ current_speech
161
+ and (audio_length_samples - current_speech["start"]) > min_speech_samples
162
+ ):
163
+ current_speech["end"] = audio_length_samples
164
+ speeches.append(current_speech)
165
+
166
+ for i, speech in enumerate(speeches):
167
+ if i == 0:
168
+ speech["start"] = int(max(0, speech["start"] - speech_pad_samples))
169
+ if i != len(speeches) - 1:
170
+ silence_duration = speeches[i + 1]["start"] - speech["end"]
171
+ if silence_duration < 2 * speech_pad_samples:
172
+ speech["end"] += int(silence_duration // 2)
173
+ speeches[i + 1]["start"] = int(
174
+ max(0, speeches[i + 1]["start"] - silence_duration // 2)
175
+ )
176
+ else:
177
+ speech["end"] = int(
178
+ min(audio_length_samples, speech["end"] + speech_pad_samples)
179
+ )
180
+ speeches[i + 1]["start"] = int(
181
+ max(0, speeches[i + 1]["start"] - speech_pad_samples)
182
+ )
183
+ else:
184
+ speech["end"] = int(
185
+ min(audio_length_samples, speech["end"] + speech_pad_samples)
186
+ )
187
+
188
+ return speeches
189
+
190
+
191
+ def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
192
+ """Collects and concatenates audio chunks."""
193
+ if not chunks:
194
+ return np.array([], dtype=np.float32)
195
+
196
+ return np.concatenate([audio[chunk["start"] : chunk["end"]] for chunk in chunks])
197
+
198
+
199
+ class SpeechTimestampsMap:
200
+ """Helper class to restore original speech timestamps."""
201
+
202
+ def __init__(self, chunks: List[dict], sampling_rate: int, time_precision: int = 2):
203
+ self.sampling_rate = sampling_rate
204
+ self.time_precision = time_precision
205
+ self.chunk_end_sample = []
206
+ self.total_silence_before = []
207
+
208
+ previous_end = 0
209
+ silent_samples = 0
210
+
211
+ for chunk in chunks:
212
+ silent_samples += chunk["start"] - previous_end
213
+ previous_end = chunk["end"]
214
+
215
+ self.chunk_end_sample.append(chunk["end"] - silent_samples)
216
+ self.total_silence_before.append(silent_samples / sampling_rate)
217
+
218
+ def get_original_time(
219
+ self,
220
+ time: float,
221
+ chunk_index: Optional[int] = None,
222
+ ) -> float:
223
+ if chunk_index is None:
224
+ chunk_index = self.get_chunk_index(time)
225
+
226
+ total_silence_before = self.total_silence_before[chunk_index]
227
+ return round(total_silence_before + time, self.time_precision)
228
+
229
+ def get_chunk_index(self, time: float) -> int:
230
+ sample = int(time * self.sampling_rate)
231
+ return min(
232
+ bisect.bisect(self.chunk_end_sample, sample),
233
+ len(self.chunk_end_sample) - 1,
234
+ )
235
+
236
+
237
+ @functools.lru_cache
238
+ def get_vad_model():
239
+ """Returns the VAD model instance."""
240
+ path = os.path.join(get_assets_path(), "silero_vad.onnx")
241
+ return SileroVADModel(path)
242
+
243
+
244
+ class SileroVADModel:
245
+ def __init__(self, path):
246
+ try:
247
+ import onnxruntime
248
+ except ImportError as e:
249
+ raise RuntimeError(
250
+ "Applying the VAD filter requires the onnxruntime package"
251
+ ) from e
252
+
253
+ opts = onnxruntime.SessionOptions()
254
+ opts.inter_op_num_threads = 1
255
+ opts.intra_op_num_threads = 1
256
+ opts.log_severity_level = 4
257
+
258
+ self.session = onnxruntime.InferenceSession(
259
+ path,
260
+ providers=["CPUExecutionProvider"],
261
+ sess_options=opts,
262
+ )
263
+
264
+ def get_initial_state(self, batch_size: int):
265
+ h = np.zeros((2, batch_size, 64), dtype=np.float32)
266
+ c = np.zeros((2, batch_size, 64), dtype=np.float32)
267
+ return h, c
268
+
269
+ def __call__(self, x, state, sr: int):
270
+ if len(x.shape) == 1:
271
+ x = np.expand_dims(x, 0)
272
+ if len(x.shape) > 2:
273
+ raise ValueError(
274
+ f"Too many dimensions for input audio chunk {len(x.shape)}"
275
+ )
276
+ if sr / x.shape[1] > 31.25:
277
+ raise ValueError("Input audio chunk is too short")
278
+
279
+ h, c = state
280
+
281
+ ort_inputs = {
282
+ "input": x,
283
+ "h": h,
284
+ "c": c,
285
+ "sr": np.array(sr, dtype="int64"),
286
+ }
287
+
288
+ out, h, c = self.session.run(None, ort_inputs)
289
+ state = (h, c)
290
+
291
+ return out, state
version.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """Version information."""
2
+
3
+ __version__ = "1.0.2"