Spaces:
Paused
Paused
Commit
·
6c3262d
1
Parent(s):
06d32c0
trt whisper live
Browse files- requirements.txt +4 -0
- run_faster_whisper_server.py +5 -0
- run_trt_server.py +5 -0
- whisper_live/__init__.py +0 -0
- whisper_live/__version__.py +1 -0
- whisper_live/client.py +528 -0
- whisper_live/server.py +498 -0
- whisper_live/transcriber.py +1023 -0
- whisper_live/trt_server.py +496 -0
- whisper_live/trt_transcriber.py +347 -0
- whisper_live/vad.py +114 -0
- whisper_live/whisper_utils.py +365 -0
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PyAudio
|
2 |
+
faster-whisper==0.9.0
|
3 |
+
websockets
|
4 |
+
onnxruntime==1.16.0
|
run_faster_whisper_server.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from whisper_live.server import TranscriptionServer
|
2 |
+
|
3 |
+
if __name__ == "__main__":
|
4 |
+
server = TranscriptionServer()
|
5 |
+
server.run("0.0.0.0", 6006)
|
run_trt_server.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from whisper_live.trt_server import TranscriptionServer
|
2 |
+
|
3 |
+
if __name__ == "__main__":
|
4 |
+
server = TranscriptionServer()
|
5 |
+
server.run("0.0.0.0", 6006)
|
whisper_live/__init__.py
ADDED
File without changes
|
whisper_live/__version__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__version__="0.0.9"
|
whisper_live/client.py
ADDED
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import wave
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import scipy
|
6 |
+
import ffmpeg
|
7 |
+
import pyaudio
|
8 |
+
import threading
|
9 |
+
import textwrap
|
10 |
+
import json
|
11 |
+
import websocket
|
12 |
+
import uuid
|
13 |
+
import time
|
14 |
+
|
15 |
+
|
16 |
+
def resample(file: str, sr: int = 16000):
|
17 |
+
"""
|
18 |
+
# https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/audio.py#L22
|
19 |
+
Open an audio file and read as mono waveform, resampling as necessary,
|
20 |
+
save the resampled audio
|
21 |
+
|
22 |
+
Args:
|
23 |
+
file (str): The audio file to open
|
24 |
+
sr (int): The sample rate to resample the audio if necessary
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
resampled_file (str): The resampled audio file
|
28 |
+
"""
|
29 |
+
try:
|
30 |
+
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
|
31 |
+
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
|
32 |
+
out, _ = (
|
33 |
+
ffmpeg.input(file, threads=0)
|
34 |
+
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
|
35 |
+
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
|
36 |
+
)
|
37 |
+
except ffmpeg.Error as e:
|
38 |
+
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
39 |
+
np_buffer = np.frombuffer(out, dtype=np.int16)
|
40 |
+
|
41 |
+
resampled_file = f"{file.split('.')[0]}_resampled.wav"
|
42 |
+
scipy.io.wavfile.write(resampled_file, sr, np_buffer.astype(np.int16))
|
43 |
+
return resampled_file
|
44 |
+
|
45 |
+
|
46 |
+
class Client:
|
47 |
+
"""
|
48 |
+
Handles audio recording, streaming, and communication with a server using WebSocket.
|
49 |
+
"""
|
50 |
+
INSTANCES = {}
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self, host=None, port=None, is_multilingual=False, lang=None, translate=False
|
54 |
+
):
|
55 |
+
"""
|
56 |
+
Initializes a Client instance for audio recording and streaming to a server.
|
57 |
+
|
58 |
+
If host and port are not provided, the WebSocket connection will not be established.
|
59 |
+
When translate is True, the task will be set to "translate" instead of "transcribe".
|
60 |
+
he audio recording starts immediately upon initialization.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
host (str): The hostname or IP address of the server.
|
64 |
+
port (int): The port number for the WebSocket server.
|
65 |
+
is_multilingual (bool, optional): Specifies if multilingual transcription is enabled. Default is False.
|
66 |
+
lang (str, optional): The selected language for transcription when multilingual is disabled. Default is None.
|
67 |
+
translate (bool, optional): Specifies if the task is translation. Default is False.
|
68 |
+
"""
|
69 |
+
self.chunk = 1024
|
70 |
+
self.format = pyaudio.paInt16
|
71 |
+
self.channels = 1
|
72 |
+
self.rate = 16000
|
73 |
+
self.record_seconds = 60000
|
74 |
+
self.recording = False
|
75 |
+
self.multilingual = False
|
76 |
+
self.language = None
|
77 |
+
self.task = "transcribe"
|
78 |
+
self.uid = str(uuid.uuid4())
|
79 |
+
self.waiting = False
|
80 |
+
self.last_response_recieved = None
|
81 |
+
self.disconnect_if_no_response_for = 15
|
82 |
+
self.multilingual = is_multilingual
|
83 |
+
self.language = lang if is_multilingual else "en"
|
84 |
+
if translate:
|
85 |
+
self.task = "translate"
|
86 |
+
|
87 |
+
self.timestamp_offset = 0.0
|
88 |
+
self.audio_bytes = None
|
89 |
+
self.p = pyaudio.PyAudio()
|
90 |
+
self.stream = self.p.open(
|
91 |
+
format=self.format,
|
92 |
+
channels=self.channels,
|
93 |
+
rate=self.rate,
|
94 |
+
input=True,
|
95 |
+
frames_per_buffer=self.chunk,
|
96 |
+
)
|
97 |
+
|
98 |
+
if host is not None and port is not None:
|
99 |
+
socket_url = f"ws://{host}:{port}"
|
100 |
+
self.client_socket = websocket.WebSocketApp(
|
101 |
+
socket_url,
|
102 |
+
on_open=lambda ws: self.on_open(ws),
|
103 |
+
on_message=lambda ws, message: self.on_message(ws, message),
|
104 |
+
on_error=lambda ws, error: self.on_error(ws, error),
|
105 |
+
on_close=lambda ws, close_status_code, close_msg: self.on_close(
|
106 |
+
ws, close_status_code, close_msg
|
107 |
+
),
|
108 |
+
)
|
109 |
+
else:
|
110 |
+
print("[ERROR]: No host or port specified.")
|
111 |
+
return
|
112 |
+
|
113 |
+
Client.INSTANCES[self.uid] = self
|
114 |
+
|
115 |
+
# start websocket client in a thread
|
116 |
+
self.ws_thread = threading.Thread(target=self.client_socket.run_forever)
|
117 |
+
self.ws_thread.setDaemon(True)
|
118 |
+
self.ws_thread.start()
|
119 |
+
|
120 |
+
self.frames = b""
|
121 |
+
print("[INFO]: * recording")
|
122 |
+
|
123 |
+
def on_message(self, ws, message):
|
124 |
+
"""
|
125 |
+
Callback function called when a message is received from the server.
|
126 |
+
|
127 |
+
It updates various attributes of the client based on the received message, including
|
128 |
+
recording status, language detection, and server messages. If a disconnect message
|
129 |
+
is received, it sets the recording status to False.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
ws (websocket.WebSocketApp): The WebSocket client instance.
|
133 |
+
message (str): The received message from the server.
|
134 |
+
|
135 |
+
"""
|
136 |
+
self.last_response_recieved = time.time()
|
137 |
+
message = json.loads(message)
|
138 |
+
|
139 |
+
if self.uid != message.get("uid"):
|
140 |
+
print("[ERROR]: invalid client uid")
|
141 |
+
return
|
142 |
+
|
143 |
+
if "status" in message.keys() and message["status"] == "WAIT":
|
144 |
+
self.waiting = True
|
145 |
+
print(
|
146 |
+
f"[INFO]:Server is full. Estimated wait time {round(message['message'])} minutes."
|
147 |
+
)
|
148 |
+
|
149 |
+
if "message" in message.keys() and message["message"] == "DISCONNECT":
|
150 |
+
print("[INFO]: Server overtime disconnected.")
|
151 |
+
self.recording = False
|
152 |
+
|
153 |
+
if "message" in message.keys() and message["message"] == "SERVER_READY":
|
154 |
+
self.recording = True
|
155 |
+
return
|
156 |
+
|
157 |
+
if "language" in message.keys():
|
158 |
+
self.language = message.get("language")
|
159 |
+
lang_prob = message.get("language_prob")
|
160 |
+
print(
|
161 |
+
f"[INFO]: Server detected language {self.language} with probability {lang_prob}"
|
162 |
+
)
|
163 |
+
return
|
164 |
+
|
165 |
+
if "segments" not in message.keys():
|
166 |
+
return
|
167 |
+
|
168 |
+
message = message["segments"]
|
169 |
+
text = []
|
170 |
+
if len(message):
|
171 |
+
for seg in message:
|
172 |
+
if text and text[-1] == seg["text"]:
|
173 |
+
# already got it
|
174 |
+
continue
|
175 |
+
text.append(seg["text"])
|
176 |
+
# keep only last 3
|
177 |
+
if len(text) > 3:
|
178 |
+
text = text[-3:]
|
179 |
+
wrapper = textwrap.TextWrapper(width=60)
|
180 |
+
word_list = wrapper.wrap(text="".join(text))
|
181 |
+
# Print each line.
|
182 |
+
if os.name == "nt":
|
183 |
+
os.system("cls")
|
184 |
+
else:
|
185 |
+
os.system("clear")
|
186 |
+
for element in word_list:
|
187 |
+
print(element)
|
188 |
+
|
189 |
+
def on_error(self, ws, error):
|
190 |
+
print(error)
|
191 |
+
|
192 |
+
def on_close(self, ws, close_status_code, close_msg):
|
193 |
+
print(f"[INFO]: Websocket connection closed: {close_status_code}: {close_msg}")
|
194 |
+
|
195 |
+
def on_open(self, ws):
|
196 |
+
"""
|
197 |
+
Callback function called when the WebSocket connection is successfully opened.
|
198 |
+
|
199 |
+
Sends an initial configuration message to the server, including client UID, multilingual mode,
|
200 |
+
language selection, and task type.
|
201 |
+
|
202 |
+
Args:
|
203 |
+
ws (websocket.WebSocketApp): The WebSocket client instance.
|
204 |
+
|
205 |
+
"""
|
206 |
+
print(self.multilingual, self.language, self.task)
|
207 |
+
|
208 |
+
print("[INFO]: Opened connection")
|
209 |
+
ws.send(
|
210 |
+
json.dumps(
|
211 |
+
{
|
212 |
+
"uid": self.uid,
|
213 |
+
"multilingual": self.multilingual,
|
214 |
+
"language": self.language,
|
215 |
+
"task": self.task,
|
216 |
+
}
|
217 |
+
)
|
218 |
+
)
|
219 |
+
|
220 |
+
@staticmethod
|
221 |
+
def bytes_to_float_array(audio_bytes):
|
222 |
+
"""
|
223 |
+
Convert audio data from bytes to a NumPy float array.
|
224 |
+
|
225 |
+
It assumes that the audio data is in 16-bit PCM format. The audio data is normalized to
|
226 |
+
have values between -1 and 1.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
audio_bytes (bytes): Audio data in bytes.
|
230 |
+
|
231 |
+
Returns:
|
232 |
+
np.ndarray: A NumPy array containing the audio data as float values normalized between -1 and 1.
|
233 |
+
"""
|
234 |
+
raw_data = np.frombuffer(buffer=audio_bytes, dtype=np.int16)
|
235 |
+
return raw_data.astype(np.float32) / 32768.0
|
236 |
+
|
237 |
+
def send_packet_to_server(self, message):
|
238 |
+
"""
|
239 |
+
Send an audio packet to the server using WebSocket.
|
240 |
+
|
241 |
+
Args:
|
242 |
+
message (bytes): The audio data packet in bytes to be sent to the server.
|
243 |
+
|
244 |
+
"""
|
245 |
+
try:
|
246 |
+
self.client_socket.send(message, websocket.ABNF.OPCODE_BINARY)
|
247 |
+
except Exception as e:
|
248 |
+
print(e)
|
249 |
+
|
250 |
+
def play_file(self, filename):
|
251 |
+
"""
|
252 |
+
Play an audio file and send it to the server for processing.
|
253 |
+
|
254 |
+
Reads an audio file, plays it through the audio output, and simultaneously sends
|
255 |
+
the audio data to the server for processing. It uses PyAudio to create an audio
|
256 |
+
stream for playback. The audio data is read from the file in chunks, converted to
|
257 |
+
floating-point format, and sent to the server using WebSocket communication.
|
258 |
+
This method is typically used when you want to process pre-recorded audio and send it
|
259 |
+
to the server in real-time.
|
260 |
+
|
261 |
+
Args:
|
262 |
+
filename (str): The path to the audio file to be played and sent to the server.
|
263 |
+
"""
|
264 |
+
|
265 |
+
# read audio and create pyaudio stream
|
266 |
+
with wave.open(filename, "rb") as wavfile:
|
267 |
+
self.stream = self.p.open(
|
268 |
+
format=self.p.get_format_from_width(wavfile.getsampwidth()),
|
269 |
+
channels=wavfile.getnchannels(),
|
270 |
+
rate=wavfile.getframerate(),
|
271 |
+
input=True,
|
272 |
+
output=True,
|
273 |
+
frames_per_buffer=self.chunk,
|
274 |
+
)
|
275 |
+
try:
|
276 |
+
while self.recording:
|
277 |
+
data = wavfile.readframes(self.chunk)
|
278 |
+
if data == b"":
|
279 |
+
break
|
280 |
+
|
281 |
+
audio_array = self.bytes_to_float_array(data)
|
282 |
+
self.send_packet_to_server(audio_array.tobytes())
|
283 |
+
self.stream.write(data)
|
284 |
+
|
285 |
+
wavfile.close()
|
286 |
+
|
287 |
+
assert self.last_response_recieved
|
288 |
+
while time.time() - self.last_response_recieved < self.disconnect_if_no_response_for:
|
289 |
+
continue
|
290 |
+
self.stream.close()
|
291 |
+
self.close_websocket()
|
292 |
+
|
293 |
+
except KeyboardInterrupt:
|
294 |
+
wavfile.close()
|
295 |
+
self.stream.stop_stream()
|
296 |
+
self.stream.close()
|
297 |
+
self.p.terminate()
|
298 |
+
self.close_websocket()
|
299 |
+
print("[INFO]: Keyboard interrupt.")
|
300 |
+
|
301 |
+
def close_websocket(self):
|
302 |
+
"""
|
303 |
+
Close the WebSocket connection and join the WebSocket thread.
|
304 |
+
|
305 |
+
First attempts to close the WebSocket connection using `self.client_socket.close()`. After
|
306 |
+
closing the connection, it joins the WebSocket thread to ensure proper termination.
|
307 |
+
|
308 |
+
"""
|
309 |
+
try:
|
310 |
+
self.client_socket.close()
|
311 |
+
except Exception as e:
|
312 |
+
print("[ERROR]: Error closing WebSocket:", e)
|
313 |
+
|
314 |
+
try:
|
315 |
+
self.ws_thread.join()
|
316 |
+
except Exception as e:
|
317 |
+
print("[ERROR:] Error joining WebSocket thread:", e)
|
318 |
+
|
319 |
+
def get_client_socket(self):
|
320 |
+
"""
|
321 |
+
Get the WebSocket client socket instance.
|
322 |
+
|
323 |
+
Returns:
|
324 |
+
WebSocketApp: The WebSocket client socket instance currently in use by the client.
|
325 |
+
"""
|
326 |
+
return self.client_socket
|
327 |
+
|
328 |
+
def write_audio_frames_to_file(self, frames, file_name):
|
329 |
+
"""
|
330 |
+
Write audio frames to a WAV file.
|
331 |
+
|
332 |
+
The WAV file is created or overwritten with the specified name. The audio frames should be
|
333 |
+
in the correct format and match the specified channel, sample width, and sample rate.
|
334 |
+
|
335 |
+
Args:
|
336 |
+
frames (bytes): The audio frames to be written to the file.
|
337 |
+
file_name (str): The name of the WAV file to which the frames will be written.
|
338 |
+
|
339 |
+
"""
|
340 |
+
with wave.open(file_name, "wb") as wavfile:
|
341 |
+
wavfile: wave.Wave_write
|
342 |
+
wavfile.setnchannels(self.channels)
|
343 |
+
wavfile.setsampwidth(2)
|
344 |
+
wavfile.setframerate(self.rate)
|
345 |
+
wavfile.writeframes(frames)
|
346 |
+
|
347 |
+
def process_hls_stream(self, hls_url):
|
348 |
+
"""
|
349 |
+
Connect to an HLS source, process the audio stream, and send it for transcription.
|
350 |
+
|
351 |
+
Args:
|
352 |
+
hls_url (str): The URL of the HLS stream source.
|
353 |
+
"""
|
354 |
+
print("[INFO]: Connecting to HLS stream...")
|
355 |
+
process = None # Initialize process to None
|
356 |
+
|
357 |
+
try:
|
358 |
+
# Connecting to the HLS stream using ffmpeg-python
|
359 |
+
process = (
|
360 |
+
ffmpeg
|
361 |
+
.input(hls_url, threads=0)
|
362 |
+
.output('-', format='s16le', acodec='pcm_s16le', ac=1, ar=self.rate)
|
363 |
+
.run_async(pipe_stdout=True, pipe_stderr=True)
|
364 |
+
)
|
365 |
+
|
366 |
+
# Process the stream
|
367 |
+
while True:
|
368 |
+
in_bytes = process.stdout.read(self.chunk * 2) # 2 bytes per sample
|
369 |
+
if not in_bytes:
|
370 |
+
break
|
371 |
+
audio_array = self.bytes_to_float_array(in_bytes)
|
372 |
+
self.send_packet_to_server(audio_array.tobytes())
|
373 |
+
|
374 |
+
except Exception as e:
|
375 |
+
print(f"[ERROR]: Failed to connect to HLS stream: {e}")
|
376 |
+
finally:
|
377 |
+
if process:
|
378 |
+
process.kill()
|
379 |
+
|
380 |
+
print("[INFO]: HLS stream processing finished.")
|
381 |
+
|
382 |
+
|
383 |
+
def record(self, out_file="output_recording.wav"):
|
384 |
+
"""
|
385 |
+
Record audio data from the input stream and save it to a WAV file.
|
386 |
+
|
387 |
+
Continuously records audio data from the input stream, sends it to the server via a WebSocket
|
388 |
+
connection, and simultaneously saves it to multiple WAV files in chunks. It stops recording when
|
389 |
+
the `RECORD_SECONDS` duration is reached or when the `RECORDING` flag is set to `False`.
|
390 |
+
|
391 |
+
Audio data is saved in chunks to the "chunks" directory. Each chunk is saved as a separate WAV file.
|
392 |
+
The recording will continue until the specified duration is reached or until the `RECORDING` flag is set to `False`.
|
393 |
+
The recording process can be interrupted by sending a KeyboardInterrupt (e.g., pressing Ctrl+C). After recording,
|
394 |
+
the method combines all the saved audio chunks into the specified `out_file`.
|
395 |
+
|
396 |
+
Args:
|
397 |
+
out_file (str, optional): The name of the output WAV file to save the entire recording. Default is "output_recording.wav".
|
398 |
+
|
399 |
+
"""
|
400 |
+
n_audio_file = 0
|
401 |
+
if not os.path.exists("chunks"):
|
402 |
+
os.makedirs("chunks", exist_ok=True)
|
403 |
+
try:
|
404 |
+
for _ in range(0, int(self.rate / self.chunk * self.record_seconds)):
|
405 |
+
if not self.recording:
|
406 |
+
break
|
407 |
+
data = self.stream.read(self.chunk)
|
408 |
+
self.frames += data
|
409 |
+
|
410 |
+
audio_array = Client.bytes_to_float_array(data)
|
411 |
+
|
412 |
+
self.send_packet_to_server(audio_array.tobytes())
|
413 |
+
|
414 |
+
# save frames if more than a minute
|
415 |
+
if len(self.frames) > 60 * self.rate:
|
416 |
+
t = threading.Thread(
|
417 |
+
target=self.write_audio_frames_to_file,
|
418 |
+
args=(
|
419 |
+
self.frames[:],
|
420 |
+
f"chunks/{n_audio_file}.wav",
|
421 |
+
),
|
422 |
+
)
|
423 |
+
t.start()
|
424 |
+
n_audio_file += 1
|
425 |
+
self.frames = b""
|
426 |
+
|
427 |
+
except KeyboardInterrupt:
|
428 |
+
if len(self.frames):
|
429 |
+
self.write_audio_frames_to_file(
|
430 |
+
self.frames[:], f"chunks/{n_audio_file}.wav"
|
431 |
+
)
|
432 |
+
n_audio_file += 1
|
433 |
+
self.stream.stop_stream()
|
434 |
+
self.stream.close()
|
435 |
+
self.p.terminate()
|
436 |
+
self.close_websocket()
|
437 |
+
|
438 |
+
self.write_output_recording(n_audio_file, out_file)
|
439 |
+
|
440 |
+
def write_output_recording(self, n_audio_file, out_file):
|
441 |
+
"""
|
442 |
+
Combine and save recorded audio chunks into a single WAV file.
|
443 |
+
|
444 |
+
The individual audio chunk files are expected to be located in the "chunks" directory. Reads each chunk
|
445 |
+
file, appends its audio data to the final recording, and then deletes the chunk file. After combining
|
446 |
+
and saving, the final recording is stored in the specified `out_file`.
|
447 |
+
|
448 |
+
|
449 |
+
Args:
|
450 |
+
n_audio_file (int): The number of audio chunk files to combine.
|
451 |
+
out_file (str): The name of the output WAV file to save the final recording.
|
452 |
+
|
453 |
+
"""
|
454 |
+
input_files = [
|
455 |
+
f"chunks/{i}.wav"
|
456 |
+
for i in range(n_audio_file)
|
457 |
+
if os.path.exists(f"chunks/{i}.wav")
|
458 |
+
]
|
459 |
+
with wave.open(out_file, "wb") as wavfile:
|
460 |
+
wavfile: wave.Wave_write
|
461 |
+
wavfile.setnchannels(self.channels)
|
462 |
+
wavfile.setsampwidth(2)
|
463 |
+
wavfile.setframerate(self.rate)
|
464 |
+
for in_file in input_files:
|
465 |
+
with wave.open(in_file, "rb") as wav_in:
|
466 |
+
while True:
|
467 |
+
data = wav_in.readframes(self.chunk)
|
468 |
+
if data == b"":
|
469 |
+
break
|
470 |
+
wavfile.writeframes(data)
|
471 |
+
# remove this file
|
472 |
+
os.remove(in_file)
|
473 |
+
wavfile.close()
|
474 |
+
|
475 |
+
|
476 |
+
class TranscriptionClient:
|
477 |
+
"""
|
478 |
+
Client for handling audio transcription tasks via a WebSocket connection.
|
479 |
+
|
480 |
+
Acts as a high-level client for audio transcription tasks using a WebSocket connection. It can be used
|
481 |
+
to send audio data for transcription to a server and receive transcribed text segments.
|
482 |
+
|
483 |
+
Args:
|
484 |
+
host (str): The hostname or IP address of the server.
|
485 |
+
port (int): The port number to connect to on the server.
|
486 |
+
is_multilingual (bool, optional): Indicates whether the transcription should support multiple languages (default is False).
|
487 |
+
lang (str, optional): The primary language for transcription (used if `is_multilingual` is False). Default is None, which defaults to English ('en').
|
488 |
+
translate (bool, optional): Indicates whether translation tasks are required (default is False).
|
489 |
+
|
490 |
+
Attributes:
|
491 |
+
client (Client): An instance of the underlying Client class responsible for handling the WebSocket connection.
|
492 |
+
|
493 |
+
Example:
|
494 |
+
To create a TranscriptionClient and start transcription on microphone audio:
|
495 |
+
```python
|
496 |
+
transcription_client = TranscriptionClient(host="localhost", port=9090, is_multilingual=True)
|
497 |
+
transcription_client()
|
498 |
+
```
|
499 |
+
"""
|
500 |
+
def __init__(self, host, port, is_multilingual=False, lang=None, translate=False):
|
501 |
+
self.client = Client(host, port, is_multilingual, lang, translate)
|
502 |
+
|
503 |
+
def __call__(self, audio=None, hls_url=None):
|
504 |
+
"""
|
505 |
+
Start the transcription process.
|
506 |
+
|
507 |
+
Initiates the transcription process by connecting to the server via a WebSocket. It waits for the server
|
508 |
+
to be ready to receive audio data and then sends audio for transcription. If an audio file is provided, it
|
509 |
+
will be played and streamed to the server; otherwise, it will perform live recording.
|
510 |
+
|
511 |
+
Args:
|
512 |
+
audio (str, optional): Path to an audio file for transcription. Default is None, which triggers live recording.
|
513 |
+
|
514 |
+
"""
|
515 |
+
print("[INFO]: Waiting for server ready ...")
|
516 |
+
while not self.client.recording:
|
517 |
+
if self.client.waiting:
|
518 |
+
self.client.close_websocket()
|
519 |
+
return
|
520 |
+
pass
|
521 |
+
print("[INFO]: Server Ready!")
|
522 |
+
if hls_url is not None:
|
523 |
+
self.client.process_hls_stream(hls_url)
|
524 |
+
elif audio is not None:
|
525 |
+
resampled_file = resample(audio)
|
526 |
+
self.client.play_file(resampled_file)
|
527 |
+
else:
|
528 |
+
self.client.record()
|
whisper_live/server.py
ADDED
@@ -0,0 +1,498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import websockets
|
2 |
+
import time
|
3 |
+
import threading
|
4 |
+
import json
|
5 |
+
import textwrap
|
6 |
+
|
7 |
+
import logging
|
8 |
+
logging.basicConfig(level = logging.INFO)
|
9 |
+
|
10 |
+
from websockets.sync.server import serve
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import numpy as np
|
14 |
+
import time
|
15 |
+
from whisper_live.transcriber import WhisperModel
|
16 |
+
|
17 |
+
|
18 |
+
class TranscriptionServer:
|
19 |
+
"""
|
20 |
+
Represents a transcription server that handles incoming audio from clients.
|
21 |
+
|
22 |
+
Attributes:
|
23 |
+
RATE (int): The audio sampling rate (constant) set to 16000.
|
24 |
+
vad_model (torch.Module): The voice activity detection model.
|
25 |
+
vad_threshold (float): The voice activity detection threshold.
|
26 |
+
clients (dict): A dictionary to store connected clients.
|
27 |
+
websockets (dict): A dictionary to store WebSocket connections.
|
28 |
+
clients_start_time (dict): A dictionary to track client start times.
|
29 |
+
max_clients (int): Maximum allowed connected clients.
|
30 |
+
max_connection_time (int): Maximum allowed connection time in seconds.
|
31 |
+
"""
|
32 |
+
|
33 |
+
RATE = 16000
|
34 |
+
|
35 |
+
def __init__(self):
|
36 |
+
# voice activity detection model
|
37 |
+
|
38 |
+
self.clients = {}
|
39 |
+
self.websockets = {}
|
40 |
+
self.clients_start_time = {}
|
41 |
+
self.max_clients = 4
|
42 |
+
self.max_connection_time = 600
|
43 |
+
|
44 |
+
def get_wait_time(self):
|
45 |
+
"""
|
46 |
+
Calculate and return the estimated wait time for clients.
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
float: The estimated wait time in minutes.
|
50 |
+
"""
|
51 |
+
wait_time = None
|
52 |
+
|
53 |
+
for k, v in self.clients_start_time.items():
|
54 |
+
current_client_time_remaining = self.max_connection_time - (time.time() - v)
|
55 |
+
|
56 |
+
if wait_time is None or current_client_time_remaining < wait_time:
|
57 |
+
wait_time = current_client_time_remaining
|
58 |
+
|
59 |
+
return wait_time / 60
|
60 |
+
|
61 |
+
def recv_audio(self, websocket):
|
62 |
+
"""
|
63 |
+
Receive audio chunks from a client in an infinite loop.
|
64 |
+
|
65 |
+
Continuously receives audio frames from a connected client
|
66 |
+
over a WebSocket connection. It processes the audio frames using a
|
67 |
+
voice activity detection (VAD) model to determine if they contain speech
|
68 |
+
or not. If the audio frame contains speech, it is added to the client's
|
69 |
+
audio data for ASR.
|
70 |
+
If the maximum number of clients is reached, the method sends a
|
71 |
+
"WAIT" status to the client, indicating that they should wait
|
72 |
+
until a slot is available.
|
73 |
+
If a client's connection exceeds the maximum allowed time, it will
|
74 |
+
be disconnected, and the client's resources will be cleaned up.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
websocket (WebSocket): The WebSocket connection for the client.
|
78 |
+
|
79 |
+
Raises:
|
80 |
+
Exception: If there is an error during the audio frame processing.
|
81 |
+
"""
|
82 |
+
logging.info("New client connected")
|
83 |
+
options = websocket.recv()
|
84 |
+
options = json.loads(options)
|
85 |
+
|
86 |
+
if len(self.clients) >= self.max_clients:
|
87 |
+
logging.warning("Client Queue Full. Asking client to wait ...")
|
88 |
+
wait_time = self.get_wait_time()
|
89 |
+
response = {
|
90 |
+
"uid": options["uid"],
|
91 |
+
"status": "WAIT",
|
92 |
+
"message": wait_time,
|
93 |
+
}
|
94 |
+
websocket.send(json.dumps(response))
|
95 |
+
websocket.close()
|
96 |
+
del websocket
|
97 |
+
return
|
98 |
+
|
99 |
+
client = ServeClient(
|
100 |
+
websocket,
|
101 |
+
multilingual=options["multilingual"],
|
102 |
+
language=options["language"],
|
103 |
+
task=options["task"],
|
104 |
+
client_uid=options["uid"]
|
105 |
+
)
|
106 |
+
|
107 |
+
self.clients[websocket] = client
|
108 |
+
self.clients_start_time[websocket] = time.time()
|
109 |
+
|
110 |
+
while True:
|
111 |
+
try:
|
112 |
+
frame_data = websocket.recv()
|
113 |
+
frame_np = np.frombuffer(frame_data, dtype=np.float32)
|
114 |
+
|
115 |
+
self.clients[websocket].add_frames(frame_np)
|
116 |
+
|
117 |
+
elapsed_time = time.time() - self.clients_start_time[websocket]
|
118 |
+
if elapsed_time >= self.max_connection_time:
|
119 |
+
self.clients[websocket].disconnect()
|
120 |
+
logging.warning(f"{self.clients[websocket]} Client disconnected due to overtime.")
|
121 |
+
self.clients[websocket].cleanup()
|
122 |
+
self.clients.pop(websocket)
|
123 |
+
self.clients_start_time.pop(websocket)
|
124 |
+
websocket.close()
|
125 |
+
del websocket
|
126 |
+
break
|
127 |
+
|
128 |
+
except Exception as e:
|
129 |
+
logging.error(e)
|
130 |
+
self.clients[websocket].cleanup()
|
131 |
+
self.clients.pop(websocket)
|
132 |
+
self.clients_start_time.pop(websocket)
|
133 |
+
logging.info("Connection Closed.")
|
134 |
+
logging.info(self.clients)
|
135 |
+
del websocket
|
136 |
+
break
|
137 |
+
|
138 |
+
def run(self, host, port=9090):
|
139 |
+
"""
|
140 |
+
Run the transcription server.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
host (str): The host address to bind the server.
|
144 |
+
port (int): The port number to bind the server.
|
145 |
+
"""
|
146 |
+
with serve(self.recv_audio, host, port) as server:
|
147 |
+
server.serve_forever()
|
148 |
+
|
149 |
+
|
150 |
+
class ServeClient:
|
151 |
+
"""
|
152 |
+
Attributes:
|
153 |
+
RATE (int): The audio sampling rate (constant) set to 16000.
|
154 |
+
SERVER_READY (str): A constant message indicating that the server is ready.
|
155 |
+
DISCONNECT (str): A constant message indicating that the client should disconnect.
|
156 |
+
client_uid (str): A unique identifier for the client.
|
157 |
+
data (bytes): Accumulated audio data.
|
158 |
+
frames (bytes): Accumulated audio frames.
|
159 |
+
language (str): The language for transcription.
|
160 |
+
task (str): The task type, e.g., "transcribe."
|
161 |
+
transcriber (WhisperModel): The Whisper model for speech-to-text.
|
162 |
+
timestamp_offset (float): The offset in audio timestamps.
|
163 |
+
frames_np (numpy.ndarray): NumPy array to store audio frames.
|
164 |
+
frames_offset (float): The offset in audio frames.
|
165 |
+
text (list): List of transcribed text segments.
|
166 |
+
current_out (str): The current incomplete transcription.
|
167 |
+
prev_out (str): The previous incomplete transcription.
|
168 |
+
t_start (float): Timestamp for the start of transcription.
|
169 |
+
exit (bool): A flag to exit the transcription thread.
|
170 |
+
same_output_threshold (int): Threshold for consecutive same output segments.
|
171 |
+
show_prev_out_thresh (int): Threshold for showing previous output segments.
|
172 |
+
add_pause_thresh (int): Threshold for adding a pause (blank) segment.
|
173 |
+
transcript (list): List of transcribed segments.
|
174 |
+
send_last_n_segments (int): Number of last segments to send to the client.
|
175 |
+
wrapper (textwrap.TextWrapper): Text wrapper for formatting text.
|
176 |
+
pick_previous_segments (int): Number of previous segments to include in the output.
|
177 |
+
websocket: The WebSocket connection for the client.
|
178 |
+
"""
|
179 |
+
RATE = 16000
|
180 |
+
SERVER_READY = "SERVER_READY"
|
181 |
+
DISCONNECT = "DISCONNECT"
|
182 |
+
|
183 |
+
def __init__(self, websocket, task="transcribe", device=None, multilingual=False, language=None, client_uid=None):
|
184 |
+
"""
|
185 |
+
Initialize a ServeClient instance.
|
186 |
+
The Whisper model is initialized based on the client's language and device availability.
|
187 |
+
The transcription thread is started upon initialization. A "SERVER_READY" message is sent
|
188 |
+
to the client to indicate that the server is ready.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
websocket (WebSocket): The WebSocket connection for the client.
|
192 |
+
task (str, optional): The task type, e.g., "transcribe." Defaults to "transcribe".
|
193 |
+
device (str, optional): The device type for Whisper, "cuda" or "cpu". Defaults to None.
|
194 |
+
multilingual (bool, optional): Whether the client supports multilingual transcription. Defaults to False.
|
195 |
+
language (str, optional): The language for transcription. Defaults to None.
|
196 |
+
client_uid (str, optional): A unique identifier for the client. Defaults to None.
|
197 |
+
|
198 |
+
"""
|
199 |
+
self.client_uid = client_uid
|
200 |
+
self.data = b""
|
201 |
+
self.frames = b""
|
202 |
+
self.language = language if multilingual else "en"
|
203 |
+
self.task = task
|
204 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
205 |
+
self.transcriber = WhisperModel(
|
206 |
+
"small" if multilingual else "small.en",
|
207 |
+
device=device,
|
208 |
+
compute_type="int8" if device=="cpu" else "float16",
|
209 |
+
local_files_only=False,
|
210 |
+
)
|
211 |
+
|
212 |
+
self.timestamp_offset = 0.0
|
213 |
+
self.frames_np = None
|
214 |
+
self.frames_offset = 0.0
|
215 |
+
self.text = []
|
216 |
+
self.current_out = ''
|
217 |
+
self.prev_out = ''
|
218 |
+
self.t_start=None
|
219 |
+
self.exit = False
|
220 |
+
self.same_output_threshold = 0
|
221 |
+
self.show_prev_out_thresh = 5 # if pause(no output from whisper) show previous output for 5 seconds
|
222 |
+
self.add_pause_thresh = 3 # add a blank to segment list as a pause(no speech) for 3 seconds
|
223 |
+
self.transcript = []
|
224 |
+
self.send_last_n_segments = 10
|
225 |
+
|
226 |
+
# text formatting
|
227 |
+
self.wrapper = textwrap.TextWrapper(width=50)
|
228 |
+
self.pick_previous_segments = 2
|
229 |
+
|
230 |
+
# threading
|
231 |
+
self.websocket = websocket
|
232 |
+
self.trans_thread = threading.Thread(target=self.speech_to_text)
|
233 |
+
self.trans_thread.start()
|
234 |
+
self.websocket.send(
|
235 |
+
json.dumps(
|
236 |
+
{
|
237 |
+
"uid": self.client_uid,
|
238 |
+
"message": self.SERVER_READY
|
239 |
+
}
|
240 |
+
)
|
241 |
+
)
|
242 |
+
|
243 |
+
def fill_output(self, output):
|
244 |
+
"""
|
245 |
+
Format the current incomplete transcription output by combining it with previous complete segments.
|
246 |
+
The resulting transcription is wrapped into two lines, each containing a maximum of 50 characters.
|
247 |
+
|
248 |
+
It ensures that the combined transcription fits within two lines, with a maximum of 50 characters per line.
|
249 |
+
Segments are concatenated in the order they exist in the list of previous segments, with the most
|
250 |
+
recent complete segment first and older segments prepended as needed to maintain the character limit.
|
251 |
+
If a 3-second pause is detected in the previous segments, any text preceding it is discarded to ensure
|
252 |
+
the transcription starts with the most recent complete content. The resulting transcription is returned
|
253 |
+
as a single string.
|
254 |
+
|
255 |
+
Args:
|
256 |
+
output(str): The current incomplete transcription segment.
|
257 |
+
|
258 |
+
Returns:
|
259 |
+
str: A formatted transcription wrapped in two lines.
|
260 |
+
"""
|
261 |
+
text = ''
|
262 |
+
pick_prev = min(len(self.text), self.pick_previous_segments)
|
263 |
+
for seg in self.text[-pick_prev:]:
|
264 |
+
# discard everything before a 3 second pause
|
265 |
+
if seg == '':
|
266 |
+
text = ''
|
267 |
+
else:
|
268 |
+
text += seg
|
269 |
+
wrapped = "".join(text + output)
|
270 |
+
return wrapped
|
271 |
+
|
272 |
+
def add_frames(self, frame_np):
|
273 |
+
"""
|
274 |
+
Add audio frames to the ongoing audio stream buffer.
|
275 |
+
|
276 |
+
This method is responsible for maintaining the audio stream buffer, allowing the continuous addition
|
277 |
+
of audio frames as they are received. It also ensures that the buffer does not exceed a specified size
|
278 |
+
to prevent excessive memory usage.
|
279 |
+
|
280 |
+
If the buffer size exceeds a threshold (45 seconds of audio data), it discards the oldest 30 seconds
|
281 |
+
of audio data to maintain a reasonable buffer size. If the buffer is empty, it initializes it with the provided
|
282 |
+
audio frame. The audio stream buffer is used for real-time processing of audio data for transcription.
|
283 |
+
|
284 |
+
Args:
|
285 |
+
frame_np (numpy.ndarray): The audio frame data as a NumPy array.
|
286 |
+
|
287 |
+
"""
|
288 |
+
if self.frames_np is not None and self.frames_np.shape[0] > 45*self.RATE:
|
289 |
+
self.frames_offset += 30.0
|
290 |
+
self.frames_np = self.frames_np[int(30*self.RATE):]
|
291 |
+
if self.frames_np is None:
|
292 |
+
self.frames_np = frame_np.copy()
|
293 |
+
else:
|
294 |
+
self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0)
|
295 |
+
|
296 |
+
def speech_to_text(self):
|
297 |
+
"""
|
298 |
+
Process an audio stream in an infinite loop, continuously transcribing the speech.
|
299 |
+
|
300 |
+
This method continuously receives audio frames, performs real-time transcription, and sends
|
301 |
+
transcribed segments to the client via a WebSocket connection.
|
302 |
+
|
303 |
+
If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction.
|
304 |
+
It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments
|
305 |
+
are sent to the client in real-time, and a history of segments is maintained to provide context.Pauses in speech
|
306 |
+
(no output from Whisper) are handled by showing the previous output for a set duration. A blank segment is added if
|
307 |
+
there is no speech for a specified duration to indicate a pause.
|
308 |
+
|
309 |
+
Raises:
|
310 |
+
Exception: If there is an issue with audio processing or WebSocket communication.
|
311 |
+
|
312 |
+
"""
|
313 |
+
while True:
|
314 |
+
if self.exit:
|
315 |
+
logging.info("Exiting speech to text thread")
|
316 |
+
break
|
317 |
+
|
318 |
+
if self.frames_np is None:
|
319 |
+
continue
|
320 |
+
|
321 |
+
# clip audio if the current chunk exceeds 30 seconds, this basically implies that
|
322 |
+
# no valid segment for the last 30 seconds from whisper
|
323 |
+
if self.frames_np[int((self.timestamp_offset - self.frames_offset)*self.RATE):].shape[0] > 25 * self.RATE:
|
324 |
+
duration = self.frames_np.shape[0] / self.RATE
|
325 |
+
self.timestamp_offset = self.frames_offset + duration - 5
|
326 |
+
|
327 |
+
samples_take = max(0, (self.timestamp_offset - self.frames_offset)*self.RATE)
|
328 |
+
input_bytes = self.frames_np[int(samples_take):].copy()
|
329 |
+
duration = input_bytes.shape[0] / self.RATE
|
330 |
+
if duration<1.0:
|
331 |
+
continue
|
332 |
+
try:
|
333 |
+
input_sample = input_bytes.copy()
|
334 |
+
|
335 |
+
# whisper transcribe with prompt
|
336 |
+
result, info = self.transcriber.transcribe(
|
337 |
+
input_sample,
|
338 |
+
initial_prompt=None,
|
339 |
+
language=self.language,
|
340 |
+
task=self.task,
|
341 |
+
vad_filter=True,
|
342 |
+
vad_parameters={"threshold": 0.5}
|
343 |
+
)
|
344 |
+
|
345 |
+
if self.language is None:
|
346 |
+
if info.language_probability > 0.5:
|
347 |
+
self.language = info.language
|
348 |
+
logging.info(f"Detected language {self.language} with probability {info.language_probability}")
|
349 |
+
self.websocket.send(json.dumps(
|
350 |
+
{"uid": self.client_uid, "language": self.language, "language_prob": info.language_probability}))
|
351 |
+
else:
|
352 |
+
# detect language again
|
353 |
+
continue
|
354 |
+
|
355 |
+
if len(result):
|
356 |
+
self.t_start = None
|
357 |
+
last_segment = self.update_segments(result, duration)
|
358 |
+
if len(self.transcript) < self.send_last_n_segments:
|
359 |
+
segments = self.transcript
|
360 |
+
else:
|
361 |
+
segments = self.transcript[-self.send_last_n_segments:]
|
362 |
+
if last_segment is not None:
|
363 |
+
segments = segments + [last_segment]
|
364 |
+
else:
|
365 |
+
# show previous output if there is pause i.e. no output from whisper
|
366 |
+
segments = []
|
367 |
+
if self.t_start is None: self.t_start = time.time()
|
368 |
+
if time.time() - self.t_start < self.show_prev_out_thresh:
|
369 |
+
if len(self.transcript) < self.send_last_n_segments:
|
370 |
+
segments = self.transcript
|
371 |
+
else:
|
372 |
+
segments = self.transcript[-self.send_last_n_segments:]
|
373 |
+
|
374 |
+
# add a blank if there is no speech for 3 seconds
|
375 |
+
if len(self.text) and self.text[-1] != '':
|
376 |
+
if time.time() - self.t_start > self.add_pause_thresh:
|
377 |
+
self.text.append('')
|
378 |
+
|
379 |
+
try:
|
380 |
+
self.websocket.send(
|
381 |
+
json.dumps({
|
382 |
+
"uid": self.client_uid,
|
383 |
+
"segments": segments
|
384 |
+
})
|
385 |
+
)
|
386 |
+
except Exception as e:
|
387 |
+
logging.error(f"[ERROR]: {e}")
|
388 |
+
|
389 |
+
except Exception as e:
|
390 |
+
logging.error(f"[ERROR]: {e}")
|
391 |
+
time.sleep(0.01)
|
392 |
+
|
393 |
+
def update_segments(self, segments, duration):
|
394 |
+
"""
|
395 |
+
Processes the segments from whisper. Appends all the segments to the list
|
396 |
+
except for the last segment assuming that it is incomplete.
|
397 |
+
|
398 |
+
Updates the ongoing transcript with transcribed segments, including their start and end times.
|
399 |
+
Complete segments are appended to the transcript in chronological order. Incomplete segments
|
400 |
+
(assumed to be the last one) are processed to identify repeated content. If the same incomplete
|
401 |
+
segment is seen multiple times, it updates the offset and appends the segment to the transcript.
|
402 |
+
A threshold is used to detect repeated content and ensure it is only included once in the transcript.
|
403 |
+
The timestamp offset is updated based on the duration of processed segments. The method returns the
|
404 |
+
last processed segment, allowing it to be sent to the client for real-time updates.
|
405 |
+
|
406 |
+
Args:
|
407 |
+
segments(dict) : dictionary of segments as returned by whisper
|
408 |
+
duration(float): duration of the current chunk
|
409 |
+
|
410 |
+
Returns:
|
411 |
+
dict or None: The last processed segment with its start time, end time, and transcribed text.
|
412 |
+
Returns None if there are no valid segments to process.
|
413 |
+
"""
|
414 |
+
offset = None
|
415 |
+
self.current_out = ''
|
416 |
+
last_segment = None
|
417 |
+
# process complete segments
|
418 |
+
if len(segments) > 1:
|
419 |
+
for i, s in enumerate(segments[:-1]):
|
420 |
+
text_ = s.text
|
421 |
+
self.text.append(text_)
|
422 |
+
start, end = self.timestamp_offset + s.start, self.timestamp_offset + min(duration, s.end)
|
423 |
+
self.transcript.append(
|
424 |
+
{
|
425 |
+
'start': start,
|
426 |
+
'end': end,
|
427 |
+
'text': text_
|
428 |
+
}
|
429 |
+
)
|
430 |
+
|
431 |
+
offset = min(duration, s.end)
|
432 |
+
|
433 |
+
self.current_out += segments[-1].text
|
434 |
+
last_segment = {
|
435 |
+
'start': self.timestamp_offset + segments[-1].start,
|
436 |
+
'end': self.timestamp_offset + min(duration, segments[-1].end),
|
437 |
+
'text': self.current_out
|
438 |
+
}
|
439 |
+
|
440 |
+
# if same incomplete segment is seen multiple times then update the offset
|
441 |
+
# and append the segment to the list
|
442 |
+
if self.current_out.strip() == self.prev_out.strip() and self.current_out != '':
|
443 |
+
self.same_output_threshold += 1
|
444 |
+
else:
|
445 |
+
self.same_output_threshold = 0
|
446 |
+
|
447 |
+
if self.same_output_threshold > 5:
|
448 |
+
if not len(self.text) or self.text[-1].strip().lower()!=self.current_out.strip().lower():
|
449 |
+
self.text.append(self.current_out)
|
450 |
+
self.transcript.append(
|
451 |
+
{
|
452 |
+
'start': self.timestamp_offset,
|
453 |
+
'end': self.timestamp_offset + duration,
|
454 |
+
'text': self.current_out
|
455 |
+
}
|
456 |
+
)
|
457 |
+
self.current_out = ''
|
458 |
+
offset = duration
|
459 |
+
self.same_output_threshold = 0
|
460 |
+
last_segment = None
|
461 |
+
else:
|
462 |
+
self.prev_out = self.current_out
|
463 |
+
|
464 |
+
# update offset
|
465 |
+
if offset is not None:
|
466 |
+
self.timestamp_offset += offset
|
467 |
+
|
468 |
+
return last_segment
|
469 |
+
|
470 |
+
def disconnect(self):
|
471 |
+
"""
|
472 |
+
Notify the client of disconnection and send a disconnect message.
|
473 |
+
|
474 |
+
This method sends a disconnect message to the client via the WebSocket connection to notify them
|
475 |
+
that the transcription service is disconnecting gracefully.
|
476 |
+
|
477 |
+
"""
|
478 |
+
self.websocket.send(
|
479 |
+
json.dumps(
|
480 |
+
{
|
481 |
+
"uid": self.client_uid,
|
482 |
+
"message": self.DISCONNECT
|
483 |
+
}
|
484 |
+
)
|
485 |
+
)
|
486 |
+
|
487 |
+
def cleanup(self):
|
488 |
+
"""
|
489 |
+
Perform cleanup tasks before exiting the transcription service.
|
490 |
+
|
491 |
+
This method performs necessary cleanup tasks, including stopping the transcription thread, marking
|
492 |
+
the exit flag to indicate the transcription thread should exit gracefully, and destroying resources
|
493 |
+
associated with the transcription process.
|
494 |
+
|
495 |
+
"""
|
496 |
+
logging.info("Cleaning up.")
|
497 |
+
self.exit = True
|
498 |
+
self.transcriber.destroy()
|
whisper_live/transcriber.py
ADDED
@@ -0,0 +1,1023 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# original https://github.com/guillaumekln/faster-whisper/blob/master/faster_whisper/transcribe.py
|
2 |
+
|
3 |
+
import itertools
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
import zlib
|
7 |
+
|
8 |
+
from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union
|
9 |
+
|
10 |
+
import ctranslate2
|
11 |
+
import numpy as np
|
12 |
+
import tokenizers
|
13 |
+
|
14 |
+
from faster_whisper.audio import decode_audio
|
15 |
+
from faster_whisper.feature_extractor import FeatureExtractor
|
16 |
+
from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer
|
17 |
+
from faster_whisper.utils import download_model, format_timestamp, get_logger
|
18 |
+
from faster_whisper.vad import (
|
19 |
+
SpeechTimestampsMap,
|
20 |
+
VadOptions,
|
21 |
+
collect_chunks,
|
22 |
+
get_speech_timestamps,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
class Word(NamedTuple):
|
27 |
+
start: float
|
28 |
+
end: float
|
29 |
+
word: str
|
30 |
+
probability: float
|
31 |
+
|
32 |
+
|
33 |
+
class Segment(NamedTuple):
|
34 |
+
id: int
|
35 |
+
seek: int
|
36 |
+
start: float
|
37 |
+
end: float
|
38 |
+
text: str
|
39 |
+
tokens: List[int]
|
40 |
+
temperature: float
|
41 |
+
avg_logprob: float
|
42 |
+
compression_ratio: float
|
43 |
+
no_speech_prob: float
|
44 |
+
words: Optional[List[Word]]
|
45 |
+
|
46 |
+
|
47 |
+
class TranscriptionOptions(NamedTuple):
|
48 |
+
beam_size: int
|
49 |
+
best_of: int
|
50 |
+
patience: float
|
51 |
+
length_penalty: float
|
52 |
+
repetition_penalty: float
|
53 |
+
no_repeat_ngram_size: int
|
54 |
+
log_prob_threshold: Optional[float]
|
55 |
+
no_speech_threshold: Optional[float]
|
56 |
+
compression_ratio_threshold: Optional[float]
|
57 |
+
condition_on_previous_text: bool
|
58 |
+
prompt_reset_on_temperature: float
|
59 |
+
temperatures: List[float]
|
60 |
+
initial_prompt: Optional[Union[str, Iterable[int]]]
|
61 |
+
prefix: Optional[str]
|
62 |
+
suppress_blank: bool
|
63 |
+
suppress_tokens: Optional[List[int]]
|
64 |
+
without_timestamps: bool
|
65 |
+
max_initial_timestamp: float
|
66 |
+
word_timestamps: bool
|
67 |
+
prepend_punctuations: str
|
68 |
+
append_punctuations: str
|
69 |
+
|
70 |
+
|
71 |
+
class TranscriptionInfo(NamedTuple):
|
72 |
+
language: str
|
73 |
+
language_probability: float
|
74 |
+
duration: float
|
75 |
+
duration_after_vad: float
|
76 |
+
all_language_probs: Optional[List[Tuple[str, float]]]
|
77 |
+
transcription_options: TranscriptionOptions
|
78 |
+
vad_options: VadOptions
|
79 |
+
|
80 |
+
|
81 |
+
class WhisperModel:
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
model_size_or_path: str,
|
85 |
+
device: str = "auto",
|
86 |
+
device_index: Union[int, List[int]] = 0,
|
87 |
+
compute_type: str = "default",
|
88 |
+
cpu_threads: int = 0,
|
89 |
+
num_workers: int = 1,
|
90 |
+
download_root: Optional[str] = None,
|
91 |
+
local_files_only: bool = False,
|
92 |
+
):
|
93 |
+
"""Initializes the Whisper model.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en,
|
97 |
+
small, small.en, medium, medium.en, large-v1, large-v2, or large), a path to a converted
|
98 |
+
model directory, or a CTranslate2-converted Whisper model ID from the Hugging Face Hub.
|
99 |
+
When a size or a model ID is configured, the converted model is downloaded
|
100 |
+
from the Hugging Face Hub.
|
101 |
+
device: Device to use for computation ("cpu", "cuda", "auto").
|
102 |
+
device_index: Device ID to use.
|
103 |
+
The model can also be loaded on multiple GPUs by passing a list of IDs
|
104 |
+
(e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can run in parallel
|
105 |
+
when transcribe() is called from multiple Python threads (see also num_workers).
|
106 |
+
compute_type: Type to use for computation.
|
107 |
+
See https://opennmt.net/CTranslate2/quantization.html.
|
108 |
+
cpu_threads: Number of threads to use when running on CPU (4 by default).
|
109 |
+
A non zero value overrides the OMP_NUM_THREADS environment variable.
|
110 |
+
num_workers: When transcribe() is called from multiple Python threads,
|
111 |
+
having multiple workers enables true parallelism when running the model
|
112 |
+
(concurrent calls to self.model.generate() will run in parallel).
|
113 |
+
This can improve the global throughput at the cost of increased memory usage.
|
114 |
+
download_root: Directory where the models should be saved. If not set, the models
|
115 |
+
are saved in the standard Hugging Face cache directory.
|
116 |
+
local_files_only: If True, avoid downloading the file and return the path to the
|
117 |
+
local cached file if it exists.
|
118 |
+
"""
|
119 |
+
self.logger = get_logger()
|
120 |
+
|
121 |
+
if os.path.isdir(model_size_or_path):
|
122 |
+
model_path = model_size_or_path
|
123 |
+
else:
|
124 |
+
model_path = download_model(
|
125 |
+
model_size_or_path,
|
126 |
+
local_files_only=local_files_only,
|
127 |
+
cache_dir=download_root,
|
128 |
+
)
|
129 |
+
|
130 |
+
self.model = ctranslate2.models.Whisper(
|
131 |
+
model_path,
|
132 |
+
device=device,
|
133 |
+
device_index=device_index,
|
134 |
+
compute_type=compute_type,
|
135 |
+
intra_threads=cpu_threads,
|
136 |
+
inter_threads=num_workers,
|
137 |
+
)
|
138 |
+
|
139 |
+
tokenizer_file = os.path.join(model_path, "tokenizer.json")
|
140 |
+
if os.path.isfile(tokenizer_file):
|
141 |
+
self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
|
142 |
+
else:
|
143 |
+
self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained(
|
144 |
+
"openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
|
145 |
+
)
|
146 |
+
|
147 |
+
self.feature_extractor = FeatureExtractor()
|
148 |
+
self.num_samples_per_token = self.feature_extractor.hop_length * 2
|
149 |
+
self.frames_per_second = (
|
150 |
+
self.feature_extractor.sampling_rate // self.feature_extractor.hop_length
|
151 |
+
)
|
152 |
+
self.tokens_per_second = (
|
153 |
+
self.feature_extractor.sampling_rate // self.num_samples_per_token
|
154 |
+
)
|
155 |
+
self.input_stride = 2
|
156 |
+
self.time_precision = 0.02
|
157 |
+
self.max_length = 448
|
158 |
+
|
159 |
+
@property
|
160 |
+
def supported_languages(self) -> List[str]:
|
161 |
+
"""The languages supported by the model."""
|
162 |
+
return list(_LANGUAGE_CODES) if self.model.is_multilingual else ["en"]
|
163 |
+
|
164 |
+
def transcribe(
|
165 |
+
self,
|
166 |
+
audio: Union[str, BinaryIO, np.ndarray],
|
167 |
+
language: Optional[str] = None,
|
168 |
+
task: str = "transcribe",
|
169 |
+
beam_size: int = 5,
|
170 |
+
best_of: int = 5,
|
171 |
+
patience: float = 1,
|
172 |
+
length_penalty: float = 1,
|
173 |
+
repetition_penalty: float = 1,
|
174 |
+
no_repeat_ngram_size: int = 0,
|
175 |
+
temperature: Union[float, List[float], Tuple[float, ...]] = [
|
176 |
+
0.0,
|
177 |
+
0.2,
|
178 |
+
0.4,
|
179 |
+
0.6,
|
180 |
+
0.8,
|
181 |
+
1.0,
|
182 |
+
],
|
183 |
+
compression_ratio_threshold: Optional[float] = 2.4,
|
184 |
+
log_prob_threshold: Optional[float] = -1.0,
|
185 |
+
no_speech_threshold: Optional[float] = 0.6,
|
186 |
+
condition_on_previous_text: bool = True,
|
187 |
+
prompt_reset_on_temperature: float = 0.5,
|
188 |
+
initial_prompt: Optional[Union[str, Iterable[int]]] = None,
|
189 |
+
prefix: Optional[str] = None,
|
190 |
+
suppress_blank: bool = True,
|
191 |
+
suppress_tokens: Optional[List[int]] = [-1],
|
192 |
+
without_timestamps: bool = False,
|
193 |
+
max_initial_timestamp: float = 1.0,
|
194 |
+
word_timestamps: bool = False,
|
195 |
+
prepend_punctuations: str = "\"'“¿([{-",
|
196 |
+
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
197 |
+
vad_filter: bool = False,
|
198 |
+
vad_parameters: Optional[Union[dict, VadOptions]] = None,
|
199 |
+
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
|
200 |
+
"""Transcribes an input file.
|
201 |
+
|
202 |
+
Arguments:
|
203 |
+
audio: Path to the input file (or a file-like object), or the audio waveform.
|
204 |
+
language: The language spoken in the audio. It should be a language code such
|
205 |
+
as "en" or "fr". If not set, the language will be detected in the first 30 seconds
|
206 |
+
of audio.
|
207 |
+
task: Task to execute (transcribe or translate).
|
208 |
+
beam_size: Beam size to use for decoding.
|
209 |
+
best_of: Number of candidates when sampling with non-zero temperature.
|
210 |
+
patience: Beam search patience factor.
|
211 |
+
length_penalty: Exponential length penalty constant.
|
212 |
+
repetition_penalty: Penalty applied to the score of previously generated tokens
|
213 |
+
(set > 1 to penalize).
|
214 |
+
no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable).
|
215 |
+
temperature: Temperature for sampling. It can be a tuple of temperatures,
|
216 |
+
which will be successively used upon failures according to either
|
217 |
+
`compression_ratio_threshold` or `log_prob_threshold`.
|
218 |
+
compression_ratio_threshold: If the gzip compression ratio is above this value,
|
219 |
+
treat as failed.
|
220 |
+
log_prob_threshold: If the average log probability over sampled tokens is
|
221 |
+
below this value, treat as failed.
|
222 |
+
no_speech_threshold: If the no_speech probability is higher than this value AND
|
223 |
+
the average log probability over sampled tokens is below `log_prob_threshold`,
|
224 |
+
consider the segment as silent.
|
225 |
+
condition_on_previous_text: If True, the previous output of the model is provided
|
226 |
+
as a prompt for the next window; disabling may make the text inconsistent across
|
227 |
+
windows, but the model becomes less prone to getting stuck in a failure loop,
|
228 |
+
such as repetition looping or timestamps going out of sync.
|
229 |
+
prompt_reset_on_temperature: Resets prompt if temperature is above this value.
|
230 |
+
Arg has effect only if condition_on_previous_text is True.
|
231 |
+
initial_prompt: Optional text string or iterable of token ids to provide as a
|
232 |
+
prompt for the first window.
|
233 |
+
prefix: Optional text to provide as a prefix for the first window.
|
234 |
+
suppress_blank: Suppress blank outputs at the beginning of the sampling.
|
235 |
+
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
|
236 |
+
of symbols as defined in the model config.json file.
|
237 |
+
without_timestamps: Only sample text tokens.
|
238 |
+
max_initial_timestamp: The initial timestamp cannot be later than this.
|
239 |
+
word_timestamps: Extract word-level timestamps using the cross-attention pattern
|
240 |
+
and dynamic time warping, and include the timestamps for each word in each segment.
|
241 |
+
prepend_punctuations: If word_timestamps is True, merge these punctuation symbols
|
242 |
+
with the next word
|
243 |
+
append_punctuations: If word_timestamps is True, merge these punctuation symbols
|
244 |
+
with the previous word
|
245 |
+
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
|
246 |
+
without speech. This step is using the Silero VAD model
|
247 |
+
https://github.com/snakers4/silero-vad.
|
248 |
+
vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
|
249 |
+
parameters and default values in the class `VadOptions`).
|
250 |
+
|
251 |
+
Returns:
|
252 |
+
A tuple with:
|
253 |
+
|
254 |
+
- a generator over transcribed segments
|
255 |
+
- an instance of TranscriptionInfo
|
256 |
+
"""
|
257 |
+
sampling_rate = self.feature_extractor.sampling_rate
|
258 |
+
|
259 |
+
if not isinstance(audio, np.ndarray):
|
260 |
+
audio = decode_audio(audio, sampling_rate=sampling_rate)
|
261 |
+
|
262 |
+
duration = audio.shape[0] / sampling_rate
|
263 |
+
duration_after_vad = duration
|
264 |
+
|
265 |
+
self.logger.info(
|
266 |
+
"Processing audio with duration %s", format_timestamp(duration)
|
267 |
+
)
|
268 |
+
|
269 |
+
if vad_filter:
|
270 |
+
if vad_parameters is None:
|
271 |
+
vad_parameters = VadOptions()
|
272 |
+
elif isinstance(vad_parameters, dict):
|
273 |
+
vad_parameters = VadOptions(**vad_parameters)
|
274 |
+
speech_chunks = get_speech_timestamps(audio, vad_parameters)
|
275 |
+
audio = collect_chunks(audio, speech_chunks)
|
276 |
+
duration_after_vad = audio.shape[0] / sampling_rate
|
277 |
+
|
278 |
+
self.logger.info(
|
279 |
+
"VAD filter removed %s of audio",
|
280 |
+
format_timestamp(duration - duration_after_vad),
|
281 |
+
)
|
282 |
+
|
283 |
+
if self.logger.isEnabledFor(logging.DEBUG):
|
284 |
+
self.logger.debug(
|
285 |
+
"VAD filter kept the following audio segments: %s",
|
286 |
+
", ".join(
|
287 |
+
"[%s -> %s]"
|
288 |
+
% (
|
289 |
+
format_timestamp(chunk["start"] / sampling_rate),
|
290 |
+
format_timestamp(chunk["end"] / sampling_rate),
|
291 |
+
)
|
292 |
+
for chunk in speech_chunks
|
293 |
+
),
|
294 |
+
)
|
295 |
+
|
296 |
+
else:
|
297 |
+
speech_chunks = None
|
298 |
+
|
299 |
+
features = self.feature_extractor(audio)
|
300 |
+
|
301 |
+
encoder_output = None
|
302 |
+
all_language_probs = None
|
303 |
+
|
304 |
+
if language is None:
|
305 |
+
if not self.model.is_multilingual:
|
306 |
+
language = "en"
|
307 |
+
language_probability = 1
|
308 |
+
else:
|
309 |
+
segment = features[:, : self.feature_extractor.nb_max_frames]
|
310 |
+
encoder_output = self.encode(segment)
|
311 |
+
# results is a list of tuple[str, float] with language names and
|
312 |
+
# probabilities.
|
313 |
+
results = self.model.detect_language(encoder_output)[0]
|
314 |
+
# Parse language names to strip out markers
|
315 |
+
all_language_probs = [(token[2:-2], prob) for (token, prob) in results]
|
316 |
+
# Get top language token and probability
|
317 |
+
language, language_probability = all_language_probs[0]
|
318 |
+
|
319 |
+
self.logger.info(
|
320 |
+
"Detected language '%s' with probability %.2f",
|
321 |
+
language,
|
322 |
+
language_probability,
|
323 |
+
)
|
324 |
+
else:
|
325 |
+
if not self.model.is_multilingual and language != "en":
|
326 |
+
self.logger.warning(
|
327 |
+
"The current model is English-only but the language parameter is set to '%s'; "
|
328 |
+
"using 'en' instead." % language
|
329 |
+
)
|
330 |
+
language = "en"
|
331 |
+
|
332 |
+
language_probability = 1
|
333 |
+
|
334 |
+
tokenizer = Tokenizer(
|
335 |
+
self.hf_tokenizer,
|
336 |
+
self.model.is_multilingual,
|
337 |
+
task=task,
|
338 |
+
language=language,
|
339 |
+
)
|
340 |
+
|
341 |
+
options = TranscriptionOptions(
|
342 |
+
beam_size=beam_size,
|
343 |
+
best_of=best_of,
|
344 |
+
patience=patience,
|
345 |
+
length_penalty=length_penalty,
|
346 |
+
repetition_penalty=repetition_penalty,
|
347 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
348 |
+
log_prob_threshold=log_prob_threshold,
|
349 |
+
no_speech_threshold=no_speech_threshold,
|
350 |
+
compression_ratio_threshold=compression_ratio_threshold,
|
351 |
+
condition_on_previous_text=condition_on_previous_text,
|
352 |
+
prompt_reset_on_temperature=prompt_reset_on_temperature,
|
353 |
+
temperatures=(
|
354 |
+
temperature if isinstance(temperature, (list, tuple)) else [temperature]
|
355 |
+
),
|
356 |
+
initial_prompt=initial_prompt,
|
357 |
+
prefix=prefix,
|
358 |
+
suppress_blank=suppress_blank,
|
359 |
+
suppress_tokens=get_suppressed_tokens(tokenizer, suppress_tokens),
|
360 |
+
without_timestamps=without_timestamps,
|
361 |
+
max_initial_timestamp=max_initial_timestamp,
|
362 |
+
word_timestamps=word_timestamps,
|
363 |
+
prepend_punctuations=prepend_punctuations,
|
364 |
+
append_punctuations=append_punctuations,
|
365 |
+
)
|
366 |
+
|
367 |
+
segments = self.generate_segments(features, tokenizer, options, encoder_output)
|
368 |
+
|
369 |
+
if speech_chunks:
|
370 |
+
segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate)
|
371 |
+
|
372 |
+
info = TranscriptionInfo(
|
373 |
+
language=language,
|
374 |
+
language_probability=language_probability,
|
375 |
+
duration=duration,
|
376 |
+
duration_after_vad=duration_after_vad,
|
377 |
+
transcription_options=options,
|
378 |
+
vad_options=vad_parameters,
|
379 |
+
all_language_probs=all_language_probs,
|
380 |
+
)
|
381 |
+
|
382 |
+
return segments, info
|
383 |
+
|
384 |
+
def generate_segments(
|
385 |
+
self,
|
386 |
+
features: np.ndarray,
|
387 |
+
tokenizer: Tokenizer,
|
388 |
+
options: TranscriptionOptions,
|
389 |
+
encoder_output: Optional[ctranslate2.StorageView] = None,
|
390 |
+
) -> Iterable[Segment]:
|
391 |
+
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
|
392 |
+
idx = 0
|
393 |
+
seek = 0
|
394 |
+
all_tokens = []
|
395 |
+
prompt_reset_since = 0
|
396 |
+
|
397 |
+
if options.initial_prompt is not None:
|
398 |
+
if isinstance(options.initial_prompt, str):
|
399 |
+
initial_prompt = " " + options.initial_prompt.strip()
|
400 |
+
initial_prompt_tokens = tokenizer.encode(initial_prompt)
|
401 |
+
all_tokens.extend(initial_prompt_tokens)
|
402 |
+
else:
|
403 |
+
all_tokens.extend(options.initial_prompt)
|
404 |
+
|
405 |
+
last_speech_timestamp = 0.0
|
406 |
+
all_segments = []
|
407 |
+
while seek < content_frames:
|
408 |
+
time_offset = seek * self.feature_extractor.time_per_frame
|
409 |
+
segment = features[:, seek : seek + self.feature_extractor.nb_max_frames]
|
410 |
+
segment_size = min(
|
411 |
+
self.feature_extractor.nb_max_frames, content_frames - seek
|
412 |
+
)
|
413 |
+
segment_duration = segment_size * self.feature_extractor.time_per_frame
|
414 |
+
|
415 |
+
if self.logger.isEnabledFor(logging.DEBUG):
|
416 |
+
self.logger.debug(
|
417 |
+
"Processing segment at %s", format_timestamp(time_offset)
|
418 |
+
)
|
419 |
+
|
420 |
+
previous_tokens = all_tokens[prompt_reset_since:]
|
421 |
+
prompt = self.get_prompt(
|
422 |
+
tokenizer,
|
423 |
+
previous_tokens,
|
424 |
+
without_timestamps=options.without_timestamps,
|
425 |
+
prefix=options.prefix if seek == 0 else None,
|
426 |
+
)
|
427 |
+
|
428 |
+
if seek > 0 or encoder_output is None:
|
429 |
+
encoder_output = self.encode(segment)
|
430 |
+
|
431 |
+
(
|
432 |
+
result,
|
433 |
+
avg_logprob,
|
434 |
+
temperature,
|
435 |
+
compression_ratio,
|
436 |
+
) = self.generate_with_fallback(encoder_output, prompt, tokenizer, options)
|
437 |
+
|
438 |
+
if options.no_speech_threshold is not None:
|
439 |
+
# no voice activity check
|
440 |
+
should_skip = result.no_speech_prob > options.no_speech_threshold
|
441 |
+
|
442 |
+
if (
|
443 |
+
options.log_prob_threshold is not None
|
444 |
+
and avg_logprob > options.log_prob_threshold
|
445 |
+
):
|
446 |
+
# don't skip if the logprob is high enough, despite the no_speech_prob
|
447 |
+
should_skip = False
|
448 |
+
|
449 |
+
if should_skip:
|
450 |
+
self.logger.debug(
|
451 |
+
"No speech threshold is met (%f > %f)",
|
452 |
+
result.no_speech_prob,
|
453 |
+
options.no_speech_threshold,
|
454 |
+
)
|
455 |
+
|
456 |
+
# fast-forward to the next segment boundary
|
457 |
+
seek += segment_size
|
458 |
+
continue
|
459 |
+
|
460 |
+
tokens = result.sequences_ids[0]
|
461 |
+
|
462 |
+
previous_seek = seek
|
463 |
+
current_segments = []
|
464 |
+
|
465 |
+
single_timestamp_ending = (
|
466 |
+
len(tokens) >= 2
|
467 |
+
and tokens[-2] < tokenizer.timestamp_begin
|
468 |
+
and tokens[-1] >= tokenizer.timestamp_begin
|
469 |
+
)
|
470 |
+
|
471 |
+
consecutive_timestamps = [
|
472 |
+
i
|
473 |
+
for i in range(len(tokens))
|
474 |
+
if i > 0
|
475 |
+
and tokens[i] >= tokenizer.timestamp_begin
|
476 |
+
and tokens[i - 1] >= tokenizer.timestamp_begin
|
477 |
+
]
|
478 |
+
|
479 |
+
if len(consecutive_timestamps) > 0:
|
480 |
+
slices = list(consecutive_timestamps)
|
481 |
+
if single_timestamp_ending:
|
482 |
+
slices.append(len(tokens))
|
483 |
+
|
484 |
+
last_slice = 0
|
485 |
+
for current_slice in slices:
|
486 |
+
sliced_tokens = tokens[last_slice:current_slice]
|
487 |
+
start_timestamp_position = (
|
488 |
+
sliced_tokens[0] - tokenizer.timestamp_begin
|
489 |
+
)
|
490 |
+
end_timestamp_position = (
|
491 |
+
sliced_tokens[-1] - tokenizer.timestamp_begin
|
492 |
+
)
|
493 |
+
start_time = (
|
494 |
+
time_offset + start_timestamp_position * self.time_precision
|
495 |
+
)
|
496 |
+
end_time = (
|
497 |
+
time_offset + end_timestamp_position * self.time_precision
|
498 |
+
)
|
499 |
+
|
500 |
+
current_segments.append(
|
501 |
+
dict(
|
502 |
+
seek=seek,
|
503 |
+
start=start_time,
|
504 |
+
end=end_time,
|
505 |
+
tokens=sliced_tokens,
|
506 |
+
)
|
507 |
+
)
|
508 |
+
last_slice = current_slice
|
509 |
+
|
510 |
+
if single_timestamp_ending:
|
511 |
+
# single timestamp at the end means no speech after the last timestamp.
|
512 |
+
seek += segment_size
|
513 |
+
else:
|
514 |
+
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
515 |
+
last_timestamp_position = (
|
516 |
+
tokens[last_slice - 1] - tokenizer.timestamp_begin
|
517 |
+
)
|
518 |
+
seek += last_timestamp_position * self.input_stride
|
519 |
+
|
520 |
+
else:
|
521 |
+
duration = segment_duration
|
522 |
+
timestamps = [
|
523 |
+
token for token in tokens if token >= tokenizer.timestamp_begin
|
524 |
+
]
|
525 |
+
if len(timestamps) > 0 and timestamps[-1] != tokenizer.timestamp_begin:
|
526 |
+
last_timestamp_position = timestamps[-1] - tokenizer.timestamp_begin
|
527 |
+
duration = last_timestamp_position * self.time_precision
|
528 |
+
|
529 |
+
current_segments.append(
|
530 |
+
dict(
|
531 |
+
seek=seek,
|
532 |
+
start=time_offset,
|
533 |
+
end=time_offset + duration,
|
534 |
+
tokens=tokens,
|
535 |
+
)
|
536 |
+
)
|
537 |
+
|
538 |
+
seek += segment_size
|
539 |
+
|
540 |
+
if options.word_timestamps:
|
541 |
+
self.add_word_timestamps(
|
542 |
+
current_segments,
|
543 |
+
tokenizer,
|
544 |
+
encoder_output,
|
545 |
+
segment_size,
|
546 |
+
options.prepend_punctuations,
|
547 |
+
options.append_punctuations,
|
548 |
+
last_speech_timestamp=last_speech_timestamp,
|
549 |
+
)
|
550 |
+
|
551 |
+
word_end_timestamps = [
|
552 |
+
w["end"] for s in current_segments for w in s["words"]
|
553 |
+
]
|
554 |
+
if len(word_end_timestamps) > 0:
|
555 |
+
last_speech_timestamp = word_end_timestamps[-1]
|
556 |
+
if not single_timestamp_ending and len(word_end_timestamps) > 0:
|
557 |
+
seek_shift = round(
|
558 |
+
(word_end_timestamps[-1] - time_offset) * self.frames_per_second
|
559 |
+
)
|
560 |
+
|
561 |
+
if seek_shift > 0:
|
562 |
+
seek = previous_seek + seek_shift
|
563 |
+
|
564 |
+
for segment in current_segments:
|
565 |
+
tokens = segment["tokens"]
|
566 |
+
text = tokenizer.decode(tokens)
|
567 |
+
|
568 |
+
if segment["start"] == segment["end"] or not text.strip():
|
569 |
+
continue
|
570 |
+
|
571 |
+
all_tokens.extend(tokens)
|
572 |
+
idx += 1
|
573 |
+
|
574 |
+
all_segments.append(Segment(
|
575 |
+
id=idx,
|
576 |
+
seek=seek,
|
577 |
+
start=segment["start"],
|
578 |
+
end=segment["end"],
|
579 |
+
text=text,
|
580 |
+
tokens=tokens,
|
581 |
+
temperature=temperature,
|
582 |
+
avg_logprob=avg_logprob,
|
583 |
+
compression_ratio=compression_ratio,
|
584 |
+
no_speech_prob=result.no_speech_prob,
|
585 |
+
words=(
|
586 |
+
[Word(**word) for word in segment["words"]]
|
587 |
+
if options.word_timestamps
|
588 |
+
else None
|
589 |
+
),
|
590 |
+
))
|
591 |
+
|
592 |
+
if (
|
593 |
+
not options.condition_on_previous_text
|
594 |
+
or temperature > options.prompt_reset_on_temperature
|
595 |
+
):
|
596 |
+
if options.condition_on_previous_text:
|
597 |
+
self.logger.debug(
|
598 |
+
"Reset prompt. prompt_reset_on_temperature threshold is met %f > %f",
|
599 |
+
temperature,
|
600 |
+
options.prompt_reset_on_temperature,
|
601 |
+
)
|
602 |
+
|
603 |
+
prompt_reset_since = len(all_tokens)
|
604 |
+
return all_segments
|
605 |
+
|
606 |
+
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
|
607 |
+
# When the model is running on multiple GPUs, the encoder output should be moved
|
608 |
+
# to the CPU since we don't know which GPU will handle the next job.
|
609 |
+
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
|
610 |
+
|
611 |
+
features = np.expand_dims(features, 0)
|
612 |
+
features = get_ctranslate2_storage(features)
|
613 |
+
|
614 |
+
return self.model.encode(features, to_cpu=to_cpu)
|
615 |
+
|
616 |
+
def generate_with_fallback(
|
617 |
+
self,
|
618 |
+
encoder_output: ctranslate2.StorageView,
|
619 |
+
prompt: List[int],
|
620 |
+
tokenizer: Tokenizer,
|
621 |
+
options: TranscriptionOptions,
|
622 |
+
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]:
|
623 |
+
decode_result = None
|
624 |
+
all_results = []
|
625 |
+
below_cr_threshold_results = []
|
626 |
+
|
627 |
+
max_initial_timestamp_index = int(
|
628 |
+
round(options.max_initial_timestamp / self.time_precision)
|
629 |
+
)
|
630 |
+
|
631 |
+
for temperature in options.temperatures:
|
632 |
+
if temperature > 0:
|
633 |
+
kwargs = {
|
634 |
+
"beam_size": 1,
|
635 |
+
"num_hypotheses": options.best_of,
|
636 |
+
"sampling_topk": 0,
|
637 |
+
"sampling_temperature": temperature,
|
638 |
+
}
|
639 |
+
else:
|
640 |
+
kwargs = {
|
641 |
+
"beam_size": options.beam_size,
|
642 |
+
"patience": options.patience,
|
643 |
+
}
|
644 |
+
|
645 |
+
result = self.model.generate(
|
646 |
+
encoder_output,
|
647 |
+
[prompt],
|
648 |
+
length_penalty=options.length_penalty,
|
649 |
+
repetition_penalty=options.repetition_penalty,
|
650 |
+
no_repeat_ngram_size=options.no_repeat_ngram_size,
|
651 |
+
max_length=self.max_length,
|
652 |
+
return_scores=True,
|
653 |
+
return_no_speech_prob=True,
|
654 |
+
suppress_blank=options.suppress_blank,
|
655 |
+
suppress_tokens=options.suppress_tokens,
|
656 |
+
max_initial_timestamp_index=max_initial_timestamp_index,
|
657 |
+
**kwargs,
|
658 |
+
)[0]
|
659 |
+
|
660 |
+
tokens = result.sequences_ids[0]
|
661 |
+
|
662 |
+
# Recover the average log prob from the returned score.
|
663 |
+
seq_len = len(tokens)
|
664 |
+
cum_logprob = result.scores[0] * (seq_len**options.length_penalty)
|
665 |
+
avg_logprob = cum_logprob / (seq_len + 1)
|
666 |
+
|
667 |
+
text = tokenizer.decode(tokens).strip()
|
668 |
+
compression_ratio = get_compression_ratio(text)
|
669 |
+
|
670 |
+
decode_result = (
|
671 |
+
result,
|
672 |
+
avg_logprob,
|
673 |
+
temperature,
|
674 |
+
compression_ratio,
|
675 |
+
)
|
676 |
+
all_results.append(decode_result)
|
677 |
+
|
678 |
+
needs_fallback = False
|
679 |
+
|
680 |
+
if options.compression_ratio_threshold is not None:
|
681 |
+
if compression_ratio > options.compression_ratio_threshold:
|
682 |
+
needs_fallback = True # too repetitive
|
683 |
+
|
684 |
+
self.logger.debug(
|
685 |
+
"Compression ratio threshold is not met with temperature %.1f (%f > %f)",
|
686 |
+
temperature,
|
687 |
+
compression_ratio,
|
688 |
+
options.compression_ratio_threshold,
|
689 |
+
)
|
690 |
+
else:
|
691 |
+
below_cr_threshold_results.append(decode_result)
|
692 |
+
|
693 |
+
if (
|
694 |
+
options.log_prob_threshold is not None
|
695 |
+
and avg_logprob < options.log_prob_threshold
|
696 |
+
):
|
697 |
+
needs_fallback = True # average log probability is too low
|
698 |
+
|
699 |
+
self.logger.debug(
|
700 |
+
"Log probability threshold is not met with temperature %.1f (%f < %f)",
|
701 |
+
temperature,
|
702 |
+
avg_logprob,
|
703 |
+
options.log_prob_threshold,
|
704 |
+
)
|
705 |
+
|
706 |
+
if (
|
707 |
+
options.no_speech_threshold is not None
|
708 |
+
and result.no_speech_prob > options.no_speech_threshold
|
709 |
+
):
|
710 |
+
needs_fallback = False # silence
|
711 |
+
|
712 |
+
if not needs_fallback:
|
713 |
+
break
|
714 |
+
else:
|
715 |
+
# all failed, select the result with the highest average log probability
|
716 |
+
decode_result = max(
|
717 |
+
below_cr_threshold_results or all_results, key=lambda x: x[1]
|
718 |
+
)
|
719 |
+
|
720 |
+
return decode_result
|
721 |
+
|
722 |
+
def get_prompt(
|
723 |
+
self,
|
724 |
+
tokenizer: Tokenizer,
|
725 |
+
previous_tokens: List[int],
|
726 |
+
without_timestamps: bool = False,
|
727 |
+
prefix: Optional[str] = None,
|
728 |
+
) -> List[int]:
|
729 |
+
prompt = []
|
730 |
+
|
731 |
+
if previous_tokens:
|
732 |
+
prompt.append(tokenizer.sot_prev)
|
733 |
+
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
|
734 |
+
|
735 |
+
prompt.extend(tokenizer.sot_sequence)
|
736 |
+
|
737 |
+
if without_timestamps:
|
738 |
+
prompt.append(tokenizer.no_timestamps)
|
739 |
+
|
740 |
+
if prefix:
|
741 |
+
prefix_tokens = tokenizer.encode(" " + prefix.strip())
|
742 |
+
if len(prefix_tokens) >= self.max_length // 2:
|
743 |
+
prefix_tokens = prefix_tokens[: self.max_length // 2 - 1]
|
744 |
+
if not without_timestamps:
|
745 |
+
prompt.append(tokenizer.timestamp_begin)
|
746 |
+
prompt.extend(prefix_tokens)
|
747 |
+
|
748 |
+
return prompt
|
749 |
+
|
750 |
+
def add_word_timestamps(
|
751 |
+
self,
|
752 |
+
segments: List[dict],
|
753 |
+
tokenizer: Tokenizer,
|
754 |
+
encoder_output: ctranslate2.StorageView,
|
755 |
+
num_frames: int,
|
756 |
+
prepend_punctuations: str,
|
757 |
+
append_punctuations: str,
|
758 |
+
last_speech_timestamp: float,
|
759 |
+
) -> None:
|
760 |
+
if len(segments) == 0:
|
761 |
+
return
|
762 |
+
|
763 |
+
text_tokens_per_segment = [
|
764 |
+
[token for token in segment["tokens"] if token < tokenizer.eot]
|
765 |
+
for segment in segments
|
766 |
+
]
|
767 |
+
|
768 |
+
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
|
769 |
+
alignment = self.find_alignment(
|
770 |
+
tokenizer, text_tokens, encoder_output, num_frames
|
771 |
+
)
|
772 |
+
word_durations = np.array([word["end"] - word["start"] for word in alignment])
|
773 |
+
word_durations = word_durations[word_durations.nonzero()]
|
774 |
+
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
|
775 |
+
max_duration = median_duration * 2
|
776 |
+
|
777 |
+
# hack: truncate long words at sentence boundaries.
|
778 |
+
# a better segmentation algorithm based on VAD should be able to replace this.
|
779 |
+
if len(word_durations) > 0:
|
780 |
+
sentence_end_marks = ".。!!??"
|
781 |
+
# ensure words at sentence boundaries
|
782 |
+
# are not longer than twice the median word duration.
|
783 |
+
for i in range(1, len(alignment)):
|
784 |
+
if alignment[i]["end"] - alignment[i]["start"] > max_duration:
|
785 |
+
if alignment[i]["word"] in sentence_end_marks:
|
786 |
+
alignment[i]["end"] = alignment[i]["start"] + max_duration
|
787 |
+
elif alignment[i - 1]["word"] in sentence_end_marks:
|
788 |
+
alignment[i]["start"] = alignment[i]["end"] - max_duration
|
789 |
+
|
790 |
+
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
791 |
+
|
792 |
+
time_offset = (
|
793 |
+
segments[0]["seek"]
|
794 |
+
* self.feature_extractor.hop_length
|
795 |
+
/ self.feature_extractor.sampling_rate
|
796 |
+
)
|
797 |
+
|
798 |
+
word_index = 0
|
799 |
+
|
800 |
+
for segment, text_tokens in zip(segments, text_tokens_per_segment):
|
801 |
+
saved_tokens = 0
|
802 |
+
words = []
|
803 |
+
|
804 |
+
while word_index < len(alignment) and saved_tokens < len(text_tokens):
|
805 |
+
timing = alignment[word_index]
|
806 |
+
|
807 |
+
if timing["word"]:
|
808 |
+
words.append(
|
809 |
+
dict(
|
810 |
+
word=timing["word"],
|
811 |
+
start=round(time_offset + timing["start"], 2),
|
812 |
+
end=round(time_offset + timing["end"], 2),
|
813 |
+
probability=timing["probability"],
|
814 |
+
)
|
815 |
+
)
|
816 |
+
|
817 |
+
saved_tokens += len(timing["tokens"])
|
818 |
+
word_index += 1
|
819 |
+
|
820 |
+
# hack: truncate long words at segment boundaries.
|
821 |
+
# a better segmentation algorithm based on VAD should be able to replace this.
|
822 |
+
if len(words) > 0:
|
823 |
+
# ensure the first and second word after a pause is not longer than
|
824 |
+
# twice the median word duration.
|
825 |
+
if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
|
826 |
+
words[0]["end"] - words[0]["start"] > max_duration
|
827 |
+
or (
|
828 |
+
len(words) > 1
|
829 |
+
and words[1]["end"] - words[0]["start"] > max_duration * 2
|
830 |
+
)
|
831 |
+
):
|
832 |
+
if (
|
833 |
+
len(words) > 1
|
834 |
+
and words[1]["end"] - words[1]["start"] > max_duration
|
835 |
+
):
|
836 |
+
boundary = max(
|
837 |
+
words[1]["end"] / 2, words[1]["end"] - max_duration
|
838 |
+
)
|
839 |
+
words[0]["end"] = words[1]["start"] = boundary
|
840 |
+
words[0]["start"] = max(0, words[0]["end"] - max_duration)
|
841 |
+
|
842 |
+
# prefer the segment-level start timestamp if the first word is too long.
|
843 |
+
if (
|
844 |
+
segment["start"] < words[0]["end"]
|
845 |
+
and segment["start"] - 0.5 > words[0]["start"]
|
846 |
+
):
|
847 |
+
words[0]["start"] = max(
|
848 |
+
0, min(words[0]["end"] - median_duration, segment["start"])
|
849 |
+
)
|
850 |
+
else:
|
851 |
+
segment["start"] = words[0]["start"]
|
852 |
+
|
853 |
+
# prefer the segment-level end timestamp if the last word is too long.
|
854 |
+
if (
|
855 |
+
segment["end"] > words[-1]["start"]
|
856 |
+
and segment["end"] + 0.5 < words[-1]["end"]
|
857 |
+
):
|
858 |
+
words[-1]["end"] = max(
|
859 |
+
words[-1]["start"] + median_duration, segment["end"]
|
860 |
+
)
|
861 |
+
else:
|
862 |
+
segment["end"] = words[-1]["end"]
|
863 |
+
|
864 |
+
last_speech_timestamp = segment["end"]
|
865 |
+
|
866 |
+
segment["words"] = words
|
867 |
+
|
868 |
+
def find_alignment(
|
869 |
+
self,
|
870 |
+
tokenizer: Tokenizer,
|
871 |
+
text_tokens: List[int],
|
872 |
+
encoder_output: ctranslate2.StorageView,
|
873 |
+
num_frames: int,
|
874 |
+
median_filter_width: int = 7,
|
875 |
+
) -> List[dict]:
|
876 |
+
if len(text_tokens) == 0:
|
877 |
+
return []
|
878 |
+
|
879 |
+
result = self.model.align(
|
880 |
+
encoder_output,
|
881 |
+
tokenizer.sot_sequence,
|
882 |
+
[text_tokens],
|
883 |
+
num_frames,
|
884 |
+
median_filter_width=median_filter_width,
|
885 |
+
)[0]
|
886 |
+
|
887 |
+
text_token_probs = result.text_token_probs
|
888 |
+
|
889 |
+
alignments = result.alignments
|
890 |
+
text_indices = np.array([pair[0] for pair in alignments])
|
891 |
+
time_indices = np.array([pair[1] for pair in alignments])
|
892 |
+
|
893 |
+
words, word_tokens = tokenizer.split_to_word_tokens(
|
894 |
+
text_tokens + [tokenizer.eot]
|
895 |
+
)
|
896 |
+
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
897 |
+
if len(word_boundaries) <= 1:
|
898 |
+
return []
|
899 |
+
|
900 |
+
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
901 |
+
jump_times = time_indices[jumps] / self.tokens_per_second
|
902 |
+
start_times = jump_times[word_boundaries[:-1]]
|
903 |
+
end_times = jump_times[word_boundaries[1:]]
|
904 |
+
word_probabilities = [
|
905 |
+
np.mean(text_token_probs[i:j])
|
906 |
+
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
|
907 |
+
]
|
908 |
+
|
909 |
+
return [
|
910 |
+
dict(
|
911 |
+
word=word, tokens=tokens, start=start, end=end, probability=probability
|
912 |
+
)
|
913 |
+
for word, tokens, start, end, probability in zip(
|
914 |
+
words, word_tokens, start_times, end_times, word_probabilities
|
915 |
+
)
|
916 |
+
]
|
917 |
+
|
918 |
+
def destroy(self):
|
919 |
+
del self.model
|
920 |
+
|
921 |
+
|
922 |
+
def restore_speech_timestamps(
|
923 |
+
segments: Iterable[Segment],
|
924 |
+
speech_chunks: List[dict],
|
925 |
+
sampling_rate: int,
|
926 |
+
) -> Iterable[Segment]:
|
927 |
+
ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
|
928 |
+
|
929 |
+
for segment in segments:
|
930 |
+
if segment.words:
|
931 |
+
words = []
|
932 |
+
for word in segment.words:
|
933 |
+
# Ensure the word start and end times are resolved to the same chunk.
|
934 |
+
middle = (word.start + word.end) / 2
|
935 |
+
chunk_index = ts_map.get_chunk_index(middle)
|
936 |
+
word = word._replace(
|
937 |
+
start=ts_map.get_original_time(word.start, chunk_index),
|
938 |
+
end=ts_map.get_original_time(word.end, chunk_index),
|
939 |
+
)
|
940 |
+
words.append(word)
|
941 |
+
|
942 |
+
segment = segment._replace(
|
943 |
+
start=words[0].start,
|
944 |
+
end=words[-1].end,
|
945 |
+
words=words,
|
946 |
+
)
|
947 |
+
|
948 |
+
else:
|
949 |
+
segment = segment._replace(
|
950 |
+
start=ts_map.get_original_time(segment.start),
|
951 |
+
end=ts_map.get_original_time(segment.end),
|
952 |
+
)
|
953 |
+
|
954 |
+
return segments
|
955 |
+
|
956 |
+
|
957 |
+
def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView:
|
958 |
+
segment = np.ascontiguousarray(segment)
|
959 |
+
segment = ctranslate2.StorageView.from_array(segment)
|
960 |
+
return segment
|
961 |
+
|
962 |
+
|
963 |
+
def get_compression_ratio(text: str) -> float:
|
964 |
+
text_bytes = text.encode("utf-8")
|
965 |
+
return len(text_bytes) / len(zlib.compress(text_bytes))
|
966 |
+
|
967 |
+
|
968 |
+
def get_suppressed_tokens(
|
969 |
+
tokenizer: Tokenizer,
|
970 |
+
suppress_tokens: Optional[List[int]],
|
971 |
+
) -> Optional[List[int]]:
|
972 |
+
if not suppress_tokens or -1 in suppress_tokens:
|
973 |
+
return suppress_tokens
|
974 |
+
|
975 |
+
suppress_tokens = list(suppress_tokens)
|
976 |
+
|
977 |
+
# Ensure the following special tokens are suppressed when the user does
|
978 |
+
# not use the default set (-1).
|
979 |
+
suppress_tokens.extend(
|
980 |
+
[
|
981 |
+
tokenizer.transcribe,
|
982 |
+
tokenizer.translate,
|
983 |
+
tokenizer.sot,
|
984 |
+
tokenizer.sot_prev,
|
985 |
+
tokenizer.sot_lm,
|
986 |
+
]
|
987 |
+
)
|
988 |
+
|
989 |
+
return sorted(set(suppress_tokens))
|
990 |
+
|
991 |
+
|
992 |
+
def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> None:
|
993 |
+
# merge prepended punctuations
|
994 |
+
i = len(alignment) - 2
|
995 |
+
j = len(alignment) - 1
|
996 |
+
while i >= 0:
|
997 |
+
previous = alignment[i]
|
998 |
+
following = alignment[j]
|
999 |
+
if previous["word"].startswith(" ") and previous["word"].strip() in prepended:
|
1000 |
+
# prepend it to the following word
|
1001 |
+
following["word"] = previous["word"] + following["word"]
|
1002 |
+
following["tokens"] = previous["tokens"] + following["tokens"]
|
1003 |
+
previous["word"] = ""
|
1004 |
+
previous["tokens"] = []
|
1005 |
+
else:
|
1006 |
+
j = i
|
1007 |
+
i -= 1
|
1008 |
+
|
1009 |
+
# merge appended punctuations
|
1010 |
+
i = 0
|
1011 |
+
j = 1
|
1012 |
+
while j < len(alignment):
|
1013 |
+
previous = alignment[i]
|
1014 |
+
following = alignment[j]
|
1015 |
+
if not previous["word"].endswith(" ") and following["word"] in appended:
|
1016 |
+
# append it to the previous word
|
1017 |
+
previous["word"] = previous["word"] + following["word"]
|
1018 |
+
previous["tokens"] = previous["tokens"] + following["tokens"]
|
1019 |
+
following["word"] = ""
|
1020 |
+
following["tokens"] = []
|
1021 |
+
else:
|
1022 |
+
i = j
|
1023 |
+
j += 1
|
whisper_live/trt_server.py
ADDED
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import websockets
|
2 |
+
import time
|
3 |
+
import threading
|
4 |
+
import json
|
5 |
+
import textwrap
|
6 |
+
|
7 |
+
import logging
|
8 |
+
logging.basicConfig(level = logging.INFO)
|
9 |
+
|
10 |
+
from websockets.sync.server import serve
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import numpy as np
|
14 |
+
import time
|
15 |
+
from whisper_live.vad import VoiceActivityDetection
|
16 |
+
from whisper_live.trt_transcriber import WhisperTRTLLM
|
17 |
+
|
18 |
+
|
19 |
+
from scipy.io.wavfile import write
|
20 |
+
import functools
|
21 |
+
|
22 |
+
save_counter = 0
|
23 |
+
def save_wav(normalized_float32):
|
24 |
+
global save_counter
|
25 |
+
scaled_int16 = (normalized_float32 * 32768).astype(np.int16)
|
26 |
+
write(f"outputs/output{save_counter}.wav", 16000, scaled_int16)
|
27 |
+
save_counter += 1
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
class TranscriptionServer:
|
32 |
+
"""
|
33 |
+
Represents a transcription server that handles incoming audio from clients.
|
34 |
+
|
35 |
+
Attributes:
|
36 |
+
RATE (int): The audio sampling rate (constant) set to 16000.
|
37 |
+
vad_model (torch.Module): The voice activity detection model.
|
38 |
+
vad_threshold (float): The voice activity detection threshold.
|
39 |
+
clients (dict): A dictionary to store connected clients.
|
40 |
+
websockets (dict): A dictionary to store WebSocket connections.
|
41 |
+
clients_start_time (dict): A dictionary to track client start times.
|
42 |
+
max_clients (int): Maximum allowed connected clients.
|
43 |
+
max_connection_time (int): Maximum allowed connection time in seconds.
|
44 |
+
"""
|
45 |
+
|
46 |
+
RATE = 16000
|
47 |
+
|
48 |
+
def __init__(self):
|
49 |
+
# voice activity detection model
|
50 |
+
self.vad_model = VoiceActivityDetection()
|
51 |
+
self.vad_threshold = 0.5
|
52 |
+
self.clients = {}
|
53 |
+
self.websockets = {}
|
54 |
+
self.clients_start_time = {}
|
55 |
+
self.max_clients = 4
|
56 |
+
self.max_connection_time = 600
|
57 |
+
|
58 |
+
def get_wait_time(self):
|
59 |
+
"""
|
60 |
+
Calculate and return the estimated wait time for clients.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
float: The estimated wait time in minutes.
|
64 |
+
"""
|
65 |
+
wait_time = None
|
66 |
+
|
67 |
+
for k, v in self.clients_start_time.items():
|
68 |
+
current_client_time_remaining = self.max_connection_time - (time.time() - v)
|
69 |
+
|
70 |
+
if wait_time is None or current_client_time_remaining < wait_time:
|
71 |
+
wait_time = current_client_time_remaining
|
72 |
+
|
73 |
+
return wait_time / 60
|
74 |
+
|
75 |
+
def recv_audio(self, websocket):
|
76 |
+
"""
|
77 |
+
Receive audio chunks from a client in an infinite loop.
|
78 |
+
|
79 |
+
Continuously receives audio frames from a connected client
|
80 |
+
over a WebSocket connection. It processes the audio frames using a
|
81 |
+
voice activity detection (VAD) model to determine if they contain speech
|
82 |
+
or not. If the audio frame contains speech, it is added to the client's
|
83 |
+
audio data for ASR.
|
84 |
+
If the maximum number of clients is reached, the method sends a
|
85 |
+
"WAIT" status to the client, indicating that they should wait
|
86 |
+
until a slot is available.
|
87 |
+
If a client's connection exceeds the maximum allowed time, it will
|
88 |
+
be disconnected, and the client's resources will be cleaned up.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
websocket (WebSocket): The WebSocket connection for the client.
|
92 |
+
|
93 |
+
Raises:
|
94 |
+
Exception: If there is an error during the audio frame processing.
|
95 |
+
"""
|
96 |
+
logging.info("New client connected")
|
97 |
+
options = websocket.recv()
|
98 |
+
options = json.loads(options)
|
99 |
+
|
100 |
+
if len(self.clients) >= self.max_clients:
|
101 |
+
logging.warning("Client Queue Full. Asking client to wait ...")
|
102 |
+
wait_time = self.get_wait_time()
|
103 |
+
response = {
|
104 |
+
"uid": options["uid"],
|
105 |
+
"status": "WAIT",
|
106 |
+
"message": wait_time,
|
107 |
+
}
|
108 |
+
websocket.send(json.dumps(response))
|
109 |
+
websocket.close()
|
110 |
+
del websocket
|
111 |
+
return
|
112 |
+
|
113 |
+
client = ServeClient(
|
114 |
+
websocket,
|
115 |
+
multilingual=options["multilingual"],
|
116 |
+
language=options["language"],
|
117 |
+
task=options["task"],
|
118 |
+
client_uid=options["uid"]
|
119 |
+
)
|
120 |
+
|
121 |
+
self.clients[websocket] = client
|
122 |
+
self.clients_start_time[websocket] = time.time()
|
123 |
+
no_voice_activity_chunks = 0
|
124 |
+
while True:
|
125 |
+
try:
|
126 |
+
frame_data = websocket.recv()
|
127 |
+
frame_np = np.frombuffer(frame_data, dtype=np.float32)
|
128 |
+
# VAD
|
129 |
+
try:
|
130 |
+
speech_prob = self.vad_model(torch.from_numpy(frame_np.copy()), self.RATE).item()
|
131 |
+
if speech_prob < self.vad_threshold:
|
132 |
+
no_voice_activity_chunks += 1
|
133 |
+
print("No speech", no_voice_activity_chunks)
|
134 |
+
if no_voice_activity_chunks > 2:
|
135 |
+
if not self.clients[websocket].eos:
|
136 |
+
self.clients[websocket].set_eos(True)
|
137 |
+
continue
|
138 |
+
no_voice_activity_chunks = 0
|
139 |
+
self.clients[websocket].set_eos(False)
|
140 |
+
|
141 |
+
except Exception as e:
|
142 |
+
logging.error(e)
|
143 |
+
return
|
144 |
+
self.clients[websocket].add_frames(frame_np)
|
145 |
+
|
146 |
+
elapsed_time = time.time() - self.clients_start_time[websocket]
|
147 |
+
if elapsed_time >= self.max_connection_time:
|
148 |
+
self.clients[websocket].disconnect()
|
149 |
+
logging.warning(f"{self.clients[websocket]} Client disconnected due to overtime.")
|
150 |
+
self.clients[websocket].cleanup()
|
151 |
+
self.clients.pop(websocket)
|
152 |
+
self.clients_start_time.pop(websocket)
|
153 |
+
websocket.close()
|
154 |
+
del websocket
|
155 |
+
break
|
156 |
+
|
157 |
+
except Exception as e:
|
158 |
+
logging.error(e)
|
159 |
+
self.clients[websocket].cleanup()
|
160 |
+
self.clients.pop(websocket)
|
161 |
+
self.clients_start_time.pop(websocket)
|
162 |
+
logging.info("Connection Closed.")
|
163 |
+
logging.info(self.clients)
|
164 |
+
del websocket
|
165 |
+
break
|
166 |
+
|
167 |
+
def run(self, host, port=9090):
|
168 |
+
"""
|
169 |
+
Run the transcription server.
|
170 |
+
|
171 |
+
Args:
|
172 |
+
host (str): The host address to bind the server.
|
173 |
+
port (int): The port number to bind the server.
|
174 |
+
"""
|
175 |
+
with serve(self.recv_audio, host, port) as server:
|
176 |
+
server.serve_forever()
|
177 |
+
|
178 |
+
|
179 |
+
class ServeClient:
|
180 |
+
"""
|
181 |
+
Attributes:
|
182 |
+
RATE (int): The audio sampling rate (constant) set to 16000.
|
183 |
+
SERVER_READY (str): A constant message indicating that the server is ready.
|
184 |
+
DISCONNECT (str): A constant message indicating that the client should disconnect.
|
185 |
+
client_uid (str): A unique identifier for the client.
|
186 |
+
data (bytes): Accumulated audio data.
|
187 |
+
frames (bytes): Accumulated audio frames.
|
188 |
+
language (str): The language for transcription.
|
189 |
+
task (str): The task type, e.g., "transcribe."
|
190 |
+
transcriber (WhisperModel): The Whisper model for speech-to-text.
|
191 |
+
timestamp_offset (float): The offset in audio timestamps.
|
192 |
+
frames_np (numpy.ndarray): NumPy array to store audio frames.
|
193 |
+
frames_offset (float): The offset in audio frames.
|
194 |
+
text (list): List of transcribed text segments.
|
195 |
+
current_out (str): The current incomplete transcription.
|
196 |
+
prev_out (str): The previous incomplete transcription.
|
197 |
+
t_start (float): Timestamp for the start of transcription.
|
198 |
+
exit (bool): A flag to exit the transcription thread.
|
199 |
+
same_output_threshold (int): Threshold for consecutive same output segments.
|
200 |
+
show_prev_out_thresh (int): Threshold for showing previous output segments.
|
201 |
+
add_pause_thresh (int): Threshold for adding a pause (blank) segment.
|
202 |
+
transcript (list): List of transcribed segments.
|
203 |
+
send_last_n_segments (int): Number of last segments to send to the client.
|
204 |
+
wrapper (textwrap.TextWrapper): Text wrapper for formatting text.
|
205 |
+
pick_previous_segments (int): Number of previous segments to include in the output.
|
206 |
+
websocket: The WebSocket connection for the client.
|
207 |
+
"""
|
208 |
+
RATE = 16000
|
209 |
+
SERVER_READY = "SERVER_READY"
|
210 |
+
DISCONNECT = "DISCONNECT"
|
211 |
+
|
212 |
+
def __init__(self, websocket, task="transcribe", device=None, multilingual=False, language=None, client_uid=None):
|
213 |
+
"""
|
214 |
+
Initialize a ServeClient instance.
|
215 |
+
The Whisper model is initialized based on the client's language and device availability.
|
216 |
+
The transcription thread is started upon initialization. A "SERVER_READY" message is sent
|
217 |
+
to the client to indicate that the server is ready.
|
218 |
+
|
219 |
+
Args:
|
220 |
+
websocket (WebSocket): The WebSocket connection for the client.
|
221 |
+
task (str, optional): The task type, e.g., "transcribe." Defaults to "transcribe".
|
222 |
+
device (str, optional): The device type for Whisper, "cuda" or "cpu". Defaults to None.
|
223 |
+
multilingual (bool, optional): Whether the client supports multilingual transcription. Defaults to False.
|
224 |
+
language (str, optional): The language for transcription. Defaults to None.
|
225 |
+
client_uid (str, optional): A unique identifier for the client. Defaults to None.
|
226 |
+
|
227 |
+
"""
|
228 |
+
self.client_uid = client_uid
|
229 |
+
self.data = b""
|
230 |
+
self.frames = b""
|
231 |
+
self.language = language if multilingual else "en"
|
232 |
+
self.task = task
|
233 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
234 |
+
self.transcriber = WhisperTRTLLM(
|
235 |
+
"whisper_small_en", False, "assets", device="cuda")
|
236 |
+
|
237 |
+
self.timestamp_offset = 0.0
|
238 |
+
self.frames_np = None
|
239 |
+
self.frames_offset = 0.0
|
240 |
+
self.text = []
|
241 |
+
self.current_out = ''
|
242 |
+
self.prev_out = ''
|
243 |
+
self.t_start=None
|
244 |
+
self.exit = False
|
245 |
+
self.same_output_threshold = 0
|
246 |
+
self.show_prev_out_thresh = 5 # if pause(no output from whisper) show previous output for 5 seconds
|
247 |
+
self.add_pause_thresh = 3 # add a blank to segment list as a pause(no speech) for 3 seconds
|
248 |
+
self.transcript = []
|
249 |
+
self.send_last_n_segments = 10
|
250 |
+
|
251 |
+
# text formatting
|
252 |
+
self.wrapper = textwrap.TextWrapper(width=50)
|
253 |
+
self.pick_previous_segments = 2
|
254 |
+
|
255 |
+
# threading
|
256 |
+
self.websocket = websocket
|
257 |
+
self.lock = threading.Lock()
|
258 |
+
self.eos = False
|
259 |
+
self.trans_thread = threading.Thread(target=self.speech_to_text)
|
260 |
+
self.trans_thread.start()
|
261 |
+
self.websocket.send(
|
262 |
+
json.dumps(
|
263 |
+
{
|
264 |
+
"uid": self.client_uid,
|
265 |
+
"message": self.SERVER_READY
|
266 |
+
}
|
267 |
+
)
|
268 |
+
)
|
269 |
+
|
270 |
+
def set_eos(self, eos):
|
271 |
+
self.lock.acquire()
|
272 |
+
# if self.eos != eos:
|
273 |
+
# logging.info(f"[WhisperLive:] setting eos: {eos}")
|
274 |
+
self.eos = eos
|
275 |
+
self.lock.release()
|
276 |
+
|
277 |
+
def add_frames(self, frame_np):
|
278 |
+
"""
|
279 |
+
Add audio frames to the ongoing audio stream buffer.
|
280 |
+
|
281 |
+
This method is responsible for maintaining the audio stream buffer, allowing the continuous addition
|
282 |
+
of audio frames as they are received. It also ensures that the buffer does not exceed a specified size
|
283 |
+
to prevent excessive memory usage.
|
284 |
+
|
285 |
+
If the buffer size exceeds a threshold (45 seconds of audio data), it discards the oldest 30 seconds
|
286 |
+
of audio data to maintain a reasonable buffer size. If the buffer is empty, it initializes it with the provided
|
287 |
+
audio frame. The audio stream buffer is used for real-time processing of audio data for transcription.
|
288 |
+
|
289 |
+
Args:
|
290 |
+
frame_np (numpy.ndarray): The audio frame data as a NumPy array.
|
291 |
+
|
292 |
+
"""
|
293 |
+
self.lock.acquire()
|
294 |
+
if self.frames_np is not None and self.frames_np.shape[0] > 45*self.RATE:
|
295 |
+
self.frames_offset += 30.0
|
296 |
+
self.frames_np = self.frames_np[int(30*self.RATE):]
|
297 |
+
if self.frames_np is None:
|
298 |
+
self.frames_np = frame_np.copy()
|
299 |
+
else:
|
300 |
+
self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0)
|
301 |
+
self.lock.release()
|
302 |
+
|
303 |
+
def speech_to_text(self):
|
304 |
+
"""
|
305 |
+
Process an audio stream in an infinite loop, continuously transcribing the speech.
|
306 |
+
|
307 |
+
This method continuously receives audio frames, performs real-time transcription, and sends
|
308 |
+
transcribed segments to the client via a WebSocket connection.
|
309 |
+
|
310 |
+
If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction.
|
311 |
+
It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments
|
312 |
+
are sent to the client in real-time, and a history of segments is maintained to provide context.Pauses in speech
|
313 |
+
(no output from Whisper) are handled by showing the previous output for a set duration. A blank segment is added if
|
314 |
+
there is no speech for a specified duration to indicate a pause.
|
315 |
+
|
316 |
+
Raises:
|
317 |
+
Exception: If there is an issue with audio processing or WebSocket communication.
|
318 |
+
|
319 |
+
"""
|
320 |
+
while True:
|
321 |
+
if self.exit:
|
322 |
+
logging.info("Exiting speech to text thread")
|
323 |
+
break
|
324 |
+
|
325 |
+
if self.frames_np is None:
|
326 |
+
continue
|
327 |
+
|
328 |
+
# clip audio if the current chunk exceeds 30 seconds, this basically implies that
|
329 |
+
# no valid segment for the last 30 seconds from whisper
|
330 |
+
if self.frames_np[int((self.timestamp_offset - self.frames_offset)*self.RATE):].shape[0] > 25 * self.RATE:
|
331 |
+
duration = self.frames_np.shape[0] / self.RATE
|
332 |
+
self.timestamp_offset = self.frames_offset + duration - 5
|
333 |
+
|
334 |
+
samples_take = max(0, (self.timestamp_offset - self.frames_offset)*self.RATE)
|
335 |
+
input_bytes = self.frames_np[int(samples_take):].copy()
|
336 |
+
duration = input_bytes.shape[0] / self.RATE
|
337 |
+
if duration<1.0 or not self.eos:
|
338 |
+
continue
|
339 |
+
|
340 |
+
try:
|
341 |
+
input_sample = input_bytes.copy()
|
342 |
+
save_wav(input_sample)
|
343 |
+
# whisper transcribe with prompt
|
344 |
+
mel, duration = self.transcriber.log_mel_spectrogram(input_sample)
|
345 |
+
print(mel.shape, duration)
|
346 |
+
result = self.transcriber.transcribe(mel)
|
347 |
+
self.append_segment(result)
|
348 |
+
self.set_eos(False)
|
349 |
+
self.timestamp_offset += duration
|
350 |
+
if len(result):
|
351 |
+
segments = self.transcript[-self.send_last_n_segments:]
|
352 |
+
try:
|
353 |
+
self.websocket.send(
|
354 |
+
json.dumps({
|
355 |
+
"uid": self.client_uid,
|
356 |
+
"segments": segments
|
357 |
+
})
|
358 |
+
)
|
359 |
+
except Exception as e:
|
360 |
+
logging.error(f"[ERROR]: {e}")
|
361 |
+
|
362 |
+
except Exception as e:
|
363 |
+
logging.error(f"[ERROR]: {e}")
|
364 |
+
time.sleep(0.01)
|
365 |
+
|
366 |
+
def append_segment(self, result):
|
367 |
+
if not len(self.transcript):
|
368 |
+
self.transcript.append({"text": result + " "})
|
369 |
+
else:
|
370 |
+
if self.transcript[-1]["text"].strip()[-1] == ".":
|
371 |
+
if result[0] >= "a" and result[0] <= "z":
|
372 |
+
self.transcript[-1]["text"] = replace_last_occurrence(
|
373 |
+
self.transcript[-1]["text"], ".", ","
|
374 |
+
)
|
375 |
+
elif self.transcript[-1]["text"].strip()[-1] == "?":
|
376 |
+
if result[0] >= "a" and result[0] <= "z":
|
377 |
+
self.transcript[-1]["text"] = replace_last_occurrence(
|
378 |
+
self.transcript[-1]["text"], "?", ","
|
379 |
+
)
|
380 |
+
|
381 |
+
self.transcript.append({"text": result + " "})
|
382 |
+
|
383 |
+
|
384 |
+
def update_segments(self, segments, duration):
|
385 |
+
"""
|
386 |
+
Processes the segments from whisper. Appends all the segments to the list
|
387 |
+
except for the last segment assuming that it is incomplete.
|
388 |
+
|
389 |
+
Updates the ongoing transcript with transcribed segments, including their start and end times.
|
390 |
+
Complete segments are appended to the transcript in chronological order. Incomplete segments
|
391 |
+
(assumed to be the last one) are processed to identify repeated content. If the same incomplete
|
392 |
+
segment is seen multiple times, it updates the offset and appends the segment to the transcript.
|
393 |
+
A threshold is used to detect repeated content and ensure it is only included once in the transcript.
|
394 |
+
The timestamp offset is updated based on the duration of processed segments. The method returns the
|
395 |
+
last processed segment, allowing it to be sent to the client for real-time updates.
|
396 |
+
|
397 |
+
Args:
|
398 |
+
segments(dict) : dictionary of segments as returned by whisper
|
399 |
+
duration(float): duration of the current chunk
|
400 |
+
|
401 |
+
Returns:
|
402 |
+
dict or None: The last processed segment with its start time, end time, and transcribed text.
|
403 |
+
Returns None if there are no valid segments to process.
|
404 |
+
"""
|
405 |
+
offset = None
|
406 |
+
self.current_out = ''
|
407 |
+
last_segment = None
|
408 |
+
# process complete segments
|
409 |
+
if len(segments) > 1:
|
410 |
+
for i, s in enumerate(segments[:-1]):
|
411 |
+
text_ = s.text
|
412 |
+
self.text.append(text_)
|
413 |
+
start, end = self.timestamp_offset + s.start, self.timestamp_offset + min(duration, s.end)
|
414 |
+
self.transcript.append(
|
415 |
+
{
|
416 |
+
'start': start,
|
417 |
+
'end': end,
|
418 |
+
'text': text_
|
419 |
+
}
|
420 |
+
)
|
421 |
+
|
422 |
+
offset = min(duration, s.end)
|
423 |
+
|
424 |
+
self.current_out += segments[-1].text
|
425 |
+
last_segment = {
|
426 |
+
'start': self.timestamp_offset + segments[-1].start,
|
427 |
+
'end': self.timestamp_offset + min(duration, segments[-1].end),
|
428 |
+
'text': self.current_out
|
429 |
+
}
|
430 |
+
|
431 |
+
# if same incomplete segment is seen multiple times then update the offset
|
432 |
+
# and append the segment to the list
|
433 |
+
if self.current_out.strip() == self.prev_out.strip() and self.current_out != '':
|
434 |
+
self.same_output_threshold += 1
|
435 |
+
else:
|
436 |
+
self.same_output_threshold = 0
|
437 |
+
|
438 |
+
if self.same_output_threshold > 5:
|
439 |
+
if not len(self.text) or self.text[-1].strip().lower()!=self.current_out.strip().lower():
|
440 |
+
self.text.append(self.current_out)
|
441 |
+
self.transcript.append(
|
442 |
+
{
|
443 |
+
'start': self.timestamp_offset,
|
444 |
+
'end': self.timestamp_offset + duration,
|
445 |
+
'text': self.current_out
|
446 |
+
}
|
447 |
+
)
|
448 |
+
self.current_out = ''
|
449 |
+
offset = duration
|
450 |
+
self.same_output_threshold = 0
|
451 |
+
last_segment = None
|
452 |
+
else:
|
453 |
+
self.prev_out = self.current_out
|
454 |
+
|
455 |
+
# update offset
|
456 |
+
if offset is not None:
|
457 |
+
self.timestamp_offset += offset
|
458 |
+
|
459 |
+
return last_segment
|
460 |
+
|
461 |
+
def disconnect(self):
|
462 |
+
"""
|
463 |
+
Notify the client of disconnection and send a disconnect message.
|
464 |
+
|
465 |
+
This method sends a disconnect message to the client via the WebSocket connection to notify them
|
466 |
+
that the transcription service is disconnecting gracefully.
|
467 |
+
|
468 |
+
"""
|
469 |
+
self.websocket.send(
|
470 |
+
json.dumps(
|
471 |
+
{
|
472 |
+
"uid": self.client_uid,
|
473 |
+
"message": self.DISCONNECT
|
474 |
+
}
|
475 |
+
)
|
476 |
+
)
|
477 |
+
|
478 |
+
def cleanup(self):
|
479 |
+
"""
|
480 |
+
Perform cleanup tasks before exiting the transcription service.
|
481 |
+
|
482 |
+
This method performs necessary cleanup tasks, including stopping the transcription thread, marking
|
483 |
+
the exit flag to indicate the transcription thread should exit gracefully, and destroying resources
|
484 |
+
associated with the transcription process.
|
485 |
+
|
486 |
+
"""
|
487 |
+
logging.info("Cleaning up.")
|
488 |
+
self.exit = True
|
489 |
+
self.transcriber.destroy()
|
490 |
+
|
491 |
+
def replace_last_occurrence(input_str, old_char, new_char):
|
492 |
+
parts = input_str.rsplit(old_char, 1)
|
493 |
+
if len(parts) == 2:
|
494 |
+
return parts[0] + new_char + parts[1]
|
495 |
+
else:
|
496 |
+
return input_str
|
whisper_live/trt_transcriber.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import re
|
4 |
+
import time
|
5 |
+
from collections import OrderedDict
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import numpy as np
|
11 |
+
from whisper.tokenizer import get_tokenizer
|
12 |
+
from whisper_live.whisper_utils import (mel_filters, store_transcripts,
|
13 |
+
write_error_stats, load_audio_wav_format,
|
14 |
+
pad_or_trim)
|
15 |
+
|
16 |
+
import tensorrt_llm
|
17 |
+
import tensorrt_llm.logger as logger
|
18 |
+
from tensorrt_llm._utils import (str_dtype_to_torch, str_dtype_to_trt,
|
19 |
+
trt_dtype_to_torch)
|
20 |
+
from tensorrt_llm.runtime import ModelConfig, SamplingConfig
|
21 |
+
from tensorrt_llm.runtime.session import Session, TensorInfo
|
22 |
+
|
23 |
+
|
24 |
+
SAMPLE_RATE = 16000
|
25 |
+
N_FFT = 400
|
26 |
+
HOP_LENGTH = 160
|
27 |
+
CHUNK_LENGTH = 30
|
28 |
+
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
29 |
+
|
30 |
+
|
31 |
+
class WhisperEncoding:
|
32 |
+
|
33 |
+
def __init__(self, engine_dir):
|
34 |
+
self.session = self.get_session(engine_dir)
|
35 |
+
|
36 |
+
def get_session(self, engine_dir):
|
37 |
+
config_path = engine_dir / 'encoder_config.json'
|
38 |
+
with open(config_path, 'r') as f:
|
39 |
+
config = json.load(f)
|
40 |
+
|
41 |
+
use_gpt_attention_plugin = config['plugin_config'][
|
42 |
+
'gpt_attention_plugin']
|
43 |
+
dtype = config['builder_config']['precision']
|
44 |
+
n_mels = config['builder_config']['n_mels']
|
45 |
+
num_languages = config['builder_config']['num_languages']
|
46 |
+
|
47 |
+
self.dtype = dtype
|
48 |
+
self.n_mels = n_mels
|
49 |
+
self.num_languages = num_languages
|
50 |
+
|
51 |
+
serialize_path = engine_dir / f'whisper_encoder_{self.dtype}_tp1_rank0.engine'
|
52 |
+
|
53 |
+
with open(serialize_path, 'rb') as f:
|
54 |
+
session = Session.from_serialized_engine(f.read())
|
55 |
+
|
56 |
+
return session
|
57 |
+
|
58 |
+
def get_audio_features(self, mel):
|
59 |
+
inputs = OrderedDict()
|
60 |
+
output_list = []
|
61 |
+
|
62 |
+
inputs.update({'x': mel})
|
63 |
+
output_list.append(
|
64 |
+
TensorInfo('x', str_dtype_to_trt(self.dtype), mel.shape))
|
65 |
+
|
66 |
+
output_info = (self.session).infer_shapes(output_list)
|
67 |
+
|
68 |
+
logger.debug(f'output info {output_info}')
|
69 |
+
outputs = {
|
70 |
+
t.name: torch.empty(tuple(t.shape),
|
71 |
+
dtype=trt_dtype_to_torch(t.dtype),
|
72 |
+
device='cuda')
|
73 |
+
for t in output_info
|
74 |
+
}
|
75 |
+
stream = torch.cuda.current_stream()
|
76 |
+
ok = self.session.run(inputs=inputs,
|
77 |
+
outputs=outputs,
|
78 |
+
stream=stream.cuda_stream)
|
79 |
+
assert ok, 'Engine execution failed'
|
80 |
+
stream.synchronize()
|
81 |
+
audio_features = outputs['output']
|
82 |
+
return audio_features
|
83 |
+
|
84 |
+
|
85 |
+
class WhisperDecoding:
|
86 |
+
|
87 |
+
def __init__(self, engine_dir, runtime_mapping, debug_mode=False):
|
88 |
+
|
89 |
+
self.decoder_config = self.get_config(engine_dir)
|
90 |
+
self.decoder_generation_session = self.get_session(
|
91 |
+
engine_dir, runtime_mapping, debug_mode)
|
92 |
+
|
93 |
+
def get_config(self, engine_dir):
|
94 |
+
config_path = engine_dir / 'decoder_config.json'
|
95 |
+
with open(config_path, 'r') as f:
|
96 |
+
config = json.load(f)
|
97 |
+
decoder_config = OrderedDict()
|
98 |
+
decoder_config.update(config['plugin_config'])
|
99 |
+
decoder_config.update(config['builder_config'])
|
100 |
+
return decoder_config
|
101 |
+
|
102 |
+
def get_session(self, engine_dir, runtime_mapping, debug_mode=False):
|
103 |
+
dtype = self.decoder_config['precision']
|
104 |
+
serialize_path = engine_dir / f'whisper_decoder_{dtype}_tp1_rank0.engine'
|
105 |
+
with open(serialize_path, "rb") as f:
|
106 |
+
decoder_engine_buffer = f.read()
|
107 |
+
|
108 |
+
decoder_model_config = ModelConfig(
|
109 |
+
num_heads=self.decoder_config['num_heads'],
|
110 |
+
num_kv_heads=self.decoder_config['num_heads'],
|
111 |
+
hidden_size=self.decoder_config['hidden_size'],
|
112 |
+
vocab_size=self.decoder_config['vocab_size'],
|
113 |
+
num_layers=self.decoder_config['num_layers'],
|
114 |
+
gpt_attention_plugin=self.decoder_config['gpt_attention_plugin'],
|
115 |
+
remove_input_padding=self.decoder_config['remove_input_padding'],
|
116 |
+
cross_attention=self.decoder_config['cross_attention'],
|
117 |
+
has_position_embedding=self.
|
118 |
+
decoder_config['has_position_embedding'],
|
119 |
+
has_token_type_embedding=self.
|
120 |
+
decoder_config['has_token_type_embedding'],
|
121 |
+
)
|
122 |
+
decoder_generation_session = tensorrt_llm.runtime.GenerationSession(
|
123 |
+
decoder_model_config,
|
124 |
+
decoder_engine_buffer,
|
125 |
+
runtime_mapping,
|
126 |
+
debug_mode=debug_mode)
|
127 |
+
|
128 |
+
return decoder_generation_session
|
129 |
+
|
130 |
+
def generate(self,
|
131 |
+
decoder_input_ids,
|
132 |
+
encoder_outputs,
|
133 |
+
eot_id,
|
134 |
+
max_new_tokens=40,
|
135 |
+
num_beams=1):
|
136 |
+
encoder_input_lengths = torch.tensor(
|
137 |
+
[encoder_outputs.shape[1] for x in range(encoder_outputs.shape[0])],
|
138 |
+
dtype=torch.int32,
|
139 |
+
device='cuda')
|
140 |
+
|
141 |
+
decoder_input_lengths = torch.tensor([
|
142 |
+
decoder_input_ids.shape[-1]
|
143 |
+
for _ in range(decoder_input_ids.shape[0])
|
144 |
+
],
|
145 |
+
dtype=torch.int32,
|
146 |
+
device='cuda')
|
147 |
+
decoder_max_input_length = torch.max(decoder_input_lengths).item()
|
148 |
+
|
149 |
+
# generation config
|
150 |
+
sampling_config = SamplingConfig(end_id=eot_id,
|
151 |
+
pad_id=eot_id,
|
152 |
+
num_beams=num_beams)
|
153 |
+
self.decoder_generation_session.setup(
|
154 |
+
decoder_input_lengths.size(0),
|
155 |
+
decoder_max_input_length,
|
156 |
+
max_new_tokens,
|
157 |
+
beam_width=num_beams,
|
158 |
+
encoder_max_input_length=encoder_outputs.shape[1])
|
159 |
+
|
160 |
+
torch.cuda.synchronize()
|
161 |
+
|
162 |
+
decoder_input_ids = decoder_input_ids.type(torch.int32).cuda()
|
163 |
+
output_ids = self.decoder_generation_session.decode(
|
164 |
+
decoder_input_ids,
|
165 |
+
decoder_input_lengths,
|
166 |
+
sampling_config,
|
167 |
+
encoder_output=encoder_outputs,
|
168 |
+
encoder_input_lengths=encoder_input_lengths,
|
169 |
+
)
|
170 |
+
torch.cuda.synchronize()
|
171 |
+
|
172 |
+
# get the list of int from output_ids tensor
|
173 |
+
output_ids = output_ids.cpu().numpy().tolist()
|
174 |
+
return output_ids
|
175 |
+
|
176 |
+
|
177 |
+
class WhisperTRTLLM(object):
|
178 |
+
|
179 |
+
def __init__(
|
180 |
+
self,
|
181 |
+
engine_dir,
|
182 |
+
debug_mode=False,
|
183 |
+
assets_dir=None,
|
184 |
+
device=None
|
185 |
+
):
|
186 |
+
world_size = 1
|
187 |
+
runtime_rank = tensorrt_llm.mpi_rank()
|
188 |
+
runtime_mapping = tensorrt_llm.Mapping(world_size, runtime_rank)
|
189 |
+
torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
|
190 |
+
engine_dir = Path(engine_dir)
|
191 |
+
|
192 |
+
self.encoder = WhisperEncoding(engine_dir)
|
193 |
+
self.decoder = WhisperDecoding(engine_dir,
|
194 |
+
runtime_mapping,
|
195 |
+
debug_mode=False)
|
196 |
+
self.n_mels = self.encoder.n_mels
|
197 |
+
# self.tokenizer = get_tokenizer(num_languages=self.encoder.num_languages,
|
198 |
+
# tokenizer_dir=assets_dir)
|
199 |
+
self.device = device
|
200 |
+
self.tokenizer = get_tokenizer(
|
201 |
+
False,
|
202 |
+
num_languages=self.encoder.num_languages,
|
203 |
+
language="en",
|
204 |
+
task="transcribe",
|
205 |
+
)
|
206 |
+
self.filters = mel_filters(self.device, self.encoder.n_mels, assets_dir)
|
207 |
+
|
208 |
+
def log_mel_spectrogram(
|
209 |
+
self,
|
210 |
+
audio: Union[str, np.ndarray, torch.Tensor],
|
211 |
+
padding: int = 0,
|
212 |
+
return_duration = True
|
213 |
+
):
|
214 |
+
"""
|
215 |
+
Compute the log-Mel spectrogram of
|
216 |
+
|
217 |
+
Parameters
|
218 |
+
----------
|
219 |
+
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
220 |
+
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
221 |
+
|
222 |
+
n_mels: int
|
223 |
+
The number of Mel-frequency filters, only 80 and 128 are supported
|
224 |
+
|
225 |
+
padding: int
|
226 |
+
Number of zero samples to pad to the right
|
227 |
+
|
228 |
+
device: Optional[Union[str, torch.device]]
|
229 |
+
If given, the audio tensor is moved to this device before STFT
|
230 |
+
|
231 |
+
Returns
|
232 |
+
-------
|
233 |
+
torch.Tensor, shape = (80 or 128, n_frames)
|
234 |
+
A Tensor that contains the Mel spectrogram
|
235 |
+
"""
|
236 |
+
if not torch.is_tensor(audio):
|
237 |
+
if isinstance(audio, str):
|
238 |
+
if audio.endswith('.wav'):
|
239 |
+
audio, _ = load_audio_wav_format(audio)
|
240 |
+
else:
|
241 |
+
audio = load_audio(audio)
|
242 |
+
assert isinstance(audio,
|
243 |
+
np.ndarray), f"Unsupported audio type: {type(audio)}"
|
244 |
+
duration = audio.shape[-1] / SAMPLE_RATE
|
245 |
+
audio = pad_or_trim(audio, N_SAMPLES)
|
246 |
+
audio = audio.astype(np.float32)
|
247 |
+
audio = torch.from_numpy(audio)
|
248 |
+
|
249 |
+
if self.device is not None:
|
250 |
+
audio = audio.to(self.device)
|
251 |
+
if padding > 0:
|
252 |
+
audio = F.pad(audio, (0, padding))
|
253 |
+
window = torch.hann_window(N_FFT).to(audio.device)
|
254 |
+
stft = torch.stft(audio,
|
255 |
+
N_FFT,
|
256 |
+
HOP_LENGTH,
|
257 |
+
window=window,
|
258 |
+
return_complex=True)
|
259 |
+
magnitudes = stft[..., :-1].abs()**2
|
260 |
+
|
261 |
+
|
262 |
+
mel_spec = self.filters @ magnitudes
|
263 |
+
|
264 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
265 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
266 |
+
log_spec = (log_spec + 4.0) / 4.0
|
267 |
+
if return_duration:
|
268 |
+
return log_spec, duration
|
269 |
+
else:
|
270 |
+
return log_spec
|
271 |
+
|
272 |
+
|
273 |
+
def process_batch(
|
274 |
+
self,
|
275 |
+
mel,
|
276 |
+
text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
|
277 |
+
num_beams=1):
|
278 |
+
prompt_id = self.tokenizer.encode(
|
279 |
+
text_prefix, allowed_special=set(self.tokenizer.special_tokens.keys()))
|
280 |
+
|
281 |
+
prompt_id = torch.tensor(prompt_id)
|
282 |
+
batch_size = mel.shape[0]
|
283 |
+
decoder_input_ids = prompt_id.repeat(batch_size, 1)
|
284 |
+
|
285 |
+
encoder_output = self.encoder.get_audio_features(mel)
|
286 |
+
output_ids = self.decoder.generate(decoder_input_ids,
|
287 |
+
encoder_output,
|
288 |
+
self.tokenizer.eot,
|
289 |
+
max_new_tokens=96,
|
290 |
+
num_beams=num_beams)
|
291 |
+
texts = []
|
292 |
+
for i in range(len(output_ids)):
|
293 |
+
text = self.tokenizer.decode(output_ids[i][0]).strip()
|
294 |
+
texts.append(text)
|
295 |
+
return texts
|
296 |
+
|
297 |
+
def transcribe(
|
298 |
+
self,
|
299 |
+
mel,
|
300 |
+
text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
|
301 |
+
dtype='float16',
|
302 |
+
batch_size=1,
|
303 |
+
num_beams=1,
|
304 |
+
):
|
305 |
+
mel = mel.type(str_dtype_to_torch(dtype))
|
306 |
+
mel = mel.unsqueeze(0)
|
307 |
+
predictions = self.process_batch(mel, text_prefix, num_beams)
|
308 |
+
prediction = predictions[0]
|
309 |
+
|
310 |
+
# remove all special tokens in the prediction
|
311 |
+
prediction = re.sub(r'<\|.*?\|>', '', prediction)
|
312 |
+
return prediction.strip()
|
313 |
+
|
314 |
+
|
315 |
+
def decode_wav_file(
|
316 |
+
model,
|
317 |
+
mel,
|
318 |
+
text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
|
319 |
+
dtype='float16',
|
320 |
+
batch_size=1,
|
321 |
+
num_beams=1,
|
322 |
+
normalizer=None,
|
323 |
+
mel_filters_dir=None):
|
324 |
+
|
325 |
+
mel = mel.type(str_dtype_to_torch(dtype))
|
326 |
+
mel = mel.unsqueeze(0)
|
327 |
+
# repeat the mel spectrogram to match the batch size
|
328 |
+
mel = mel.repeat(batch_size, 1, 1)
|
329 |
+
predictions = model.process_batch(mel, text_prefix, num_beams)
|
330 |
+
prediction = predictions[0]
|
331 |
+
|
332 |
+
# remove all special tokens in the prediction
|
333 |
+
prediction = re.sub(r'<\|.*?\|>', '', prediction)
|
334 |
+
if normalizer:
|
335 |
+
prediction = normalizer(prediction)
|
336 |
+
|
337 |
+
return prediction.strip()
|
338 |
+
|
339 |
+
|
340 |
+
if __name__=="__main__":
|
341 |
+
tensorrt_llm.logger.set_level("error")
|
342 |
+
model = WhisperTRTLLM("../whisper_small_en", False, "../assets", device="cuda")
|
343 |
+
mel, total_duration = model.log_mel_spectrogram(
|
344 |
+
"/root/Code/outputs/output3.wav",
|
345 |
+
)
|
346 |
+
results = model.transcribe(mel)
|
347 |
+
print(results, total_duration)
|
whisper_live/vad.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# original: https://github.com/snakers4/silero-vad/blob/master/utils_vad.py
|
2 |
+
|
3 |
+
import os
|
4 |
+
import subprocess
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import onnxruntime
|
8 |
+
|
9 |
+
|
10 |
+
class VoiceActivityDetection():
|
11 |
+
|
12 |
+
def __init__(self, force_onnx_cpu=True):
|
13 |
+
path = self.download()
|
14 |
+
opts = onnxruntime.SessionOptions()
|
15 |
+
opts.log_severity_level = 3
|
16 |
+
|
17 |
+
opts.inter_op_num_threads = 1
|
18 |
+
opts.intra_op_num_threads = 1
|
19 |
+
|
20 |
+
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
|
21 |
+
self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
|
22 |
+
else:
|
23 |
+
self.session = onnxruntime.InferenceSession(path, providers=['CUDAExecutionProvider'], sess_options=opts)
|
24 |
+
|
25 |
+
|
26 |
+
self.reset_states()
|
27 |
+
self.sample_rates = [8000, 16000]
|
28 |
+
|
29 |
+
def _validate_input(self, x, sr: int):
|
30 |
+
if x.dim() == 1:
|
31 |
+
x = x.unsqueeze(0)
|
32 |
+
if x.dim() > 2:
|
33 |
+
raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
|
34 |
+
|
35 |
+
if sr != 16000 and (sr % 16000 == 0):
|
36 |
+
step = sr // 16000
|
37 |
+
x = x[:,::step]
|
38 |
+
sr = 16000
|
39 |
+
|
40 |
+
if sr not in self.sample_rates:
|
41 |
+
raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
|
42 |
+
|
43 |
+
if sr / x.shape[1] > 31.25:
|
44 |
+
raise ValueError("Input audio chunk is too short")
|
45 |
+
|
46 |
+
return x, sr
|
47 |
+
|
48 |
+
def reset_states(self, batch_size=1):
|
49 |
+
self._h = np.zeros((2, batch_size, 64)).astype('float32')
|
50 |
+
self._c = np.zeros((2, batch_size, 64)).astype('float32')
|
51 |
+
self._last_sr = 0
|
52 |
+
self._last_batch_size = 0
|
53 |
+
|
54 |
+
def __call__(self, x, sr: int):
|
55 |
+
|
56 |
+
x, sr = self._validate_input(x, sr)
|
57 |
+
batch_size = x.shape[0]
|
58 |
+
|
59 |
+
if not self._last_batch_size:
|
60 |
+
self.reset_states(batch_size)
|
61 |
+
if (self._last_sr) and (self._last_sr != sr):
|
62 |
+
self.reset_states(batch_size)
|
63 |
+
if (self._last_batch_size) and (self._last_batch_size != batch_size):
|
64 |
+
self.reset_states(batch_size)
|
65 |
+
|
66 |
+
if sr in [8000, 16000]:
|
67 |
+
ort_inputs = {'input': x.numpy(), 'h': self._h, 'c': self._c, 'sr': np.array(sr, dtype='int64')}
|
68 |
+
ort_outs = self.session.run(None, ort_inputs)
|
69 |
+
out, self._h, self._c = ort_outs
|
70 |
+
else:
|
71 |
+
raise ValueError()
|
72 |
+
|
73 |
+
self._last_sr = sr
|
74 |
+
self._last_batch_size = batch_size
|
75 |
+
|
76 |
+
out = torch.tensor(out)
|
77 |
+
return out
|
78 |
+
|
79 |
+
def audio_forward(self, x, sr: int, num_samples: int = 512):
|
80 |
+
outs = []
|
81 |
+
x, sr = self._validate_input(x, sr)
|
82 |
+
|
83 |
+
if x.shape[1] % num_samples:
|
84 |
+
pad_num = num_samples - (x.shape[1] % num_samples)
|
85 |
+
x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0)
|
86 |
+
|
87 |
+
self.reset_states(x.shape[0])
|
88 |
+
for i in range(0, x.shape[1], num_samples):
|
89 |
+
wavs_batch = x[:, i:i+num_samples]
|
90 |
+
out_chunk = self.__call__(wavs_batch, sr)
|
91 |
+
outs.append(out_chunk)
|
92 |
+
|
93 |
+
stacked = torch.cat(outs, dim=1)
|
94 |
+
return stacked.cpu()
|
95 |
+
|
96 |
+
@staticmethod
|
97 |
+
def download(model_url="https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx"):
|
98 |
+
target_dir = os.path.expanduser("~/.cache/whisper-live/")
|
99 |
+
|
100 |
+
# Ensure the target directory exists
|
101 |
+
os.makedirs(target_dir, exist_ok=True)
|
102 |
+
|
103 |
+
# Define the target file path
|
104 |
+
model_filename = os.path.join(target_dir, "silero_vad.onnx")
|
105 |
+
|
106 |
+
# Check if the model file already exists
|
107 |
+
if not os.path.exists(model_filename):
|
108 |
+
# If it doesn't exist, download the model using wget
|
109 |
+
print("Downloading VAD ONNX model...")
|
110 |
+
try:
|
111 |
+
subprocess.run(["wget", "-O", model_filename, model_url], check=True)
|
112 |
+
except subprocess.CalledProcessError:
|
113 |
+
print("Failed to download the model using wget.")
|
114 |
+
return model_filename
|
whisper_live/whisper_utils.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import logging
|
16 |
+
import os
|
17 |
+
from collections import defaultdict
|
18 |
+
from functools import lru_cache
|
19 |
+
from pathlib import Path
|
20 |
+
from subprocess import CalledProcessError, run
|
21 |
+
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
|
22 |
+
|
23 |
+
import kaldialign
|
24 |
+
import numpy as np
|
25 |
+
import soundfile
|
26 |
+
import torch
|
27 |
+
import torch.nn.functional as F
|
28 |
+
|
29 |
+
Pathlike = Union[str, Path]
|
30 |
+
|
31 |
+
SAMPLE_RATE = 16000
|
32 |
+
N_FFT = 400
|
33 |
+
HOP_LENGTH = 160
|
34 |
+
CHUNK_LENGTH = 30
|
35 |
+
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
36 |
+
|
37 |
+
|
38 |
+
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
39 |
+
"""
|
40 |
+
Open an audio file and read as mono waveform, resampling as necessary
|
41 |
+
|
42 |
+
Parameters
|
43 |
+
----------
|
44 |
+
file: str
|
45 |
+
The audio file to open
|
46 |
+
|
47 |
+
sr: int
|
48 |
+
The sample rate to resample the audio if necessary
|
49 |
+
|
50 |
+
Returns
|
51 |
+
-------
|
52 |
+
A NumPy array containing the audio waveform, in float32 dtype.
|
53 |
+
"""
|
54 |
+
|
55 |
+
# This launches a subprocess to decode audio while down-mixing
|
56 |
+
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
57 |
+
# fmt: off
|
58 |
+
cmd = [
|
59 |
+
"ffmpeg", "-nostdin", "-threads", "0", "-i", file, "-f", "s16le", "-ac",
|
60 |
+
"1", "-acodec", "pcm_s16le", "-ar",
|
61 |
+
str(sr), "-"
|
62 |
+
]
|
63 |
+
# fmt: on
|
64 |
+
try:
|
65 |
+
out = run(cmd, capture_output=True, check=True).stdout
|
66 |
+
except CalledProcessError as e:
|
67 |
+
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
68 |
+
|
69 |
+
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
70 |
+
|
71 |
+
|
72 |
+
def load_audio_wav_format(wav_path):
|
73 |
+
# make sure audio in .wav format
|
74 |
+
assert wav_path.endswith(
|
75 |
+
'.wav'), f"Only support .wav format, but got {wav_path}"
|
76 |
+
waveform, sample_rate = soundfile.read(wav_path)
|
77 |
+
assert sample_rate == 16000, f"Only support 16k sample rate, but got {sample_rate}"
|
78 |
+
return waveform, sample_rate
|
79 |
+
|
80 |
+
|
81 |
+
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
82 |
+
"""
|
83 |
+
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
84 |
+
"""
|
85 |
+
if torch.is_tensor(array):
|
86 |
+
if array.shape[axis] > length:
|
87 |
+
array = array.index_select(dim=axis,
|
88 |
+
index=torch.arange(length,
|
89 |
+
device=array.device))
|
90 |
+
|
91 |
+
if array.shape[axis] < length:
|
92 |
+
pad_widths = [(0, 0)] * array.ndim
|
93 |
+
pad_widths[axis] = (0, length - array.shape[axis])
|
94 |
+
array = F.pad(array,
|
95 |
+
[pad for sizes in pad_widths[::-1] for pad in sizes])
|
96 |
+
else:
|
97 |
+
if array.shape[axis] > length:
|
98 |
+
array = array.take(indices=range(length), axis=axis)
|
99 |
+
|
100 |
+
if array.shape[axis] < length:
|
101 |
+
pad_widths = [(0, 0)] * array.ndim
|
102 |
+
pad_widths[axis] = (0, length - array.shape[axis])
|
103 |
+
array = np.pad(array, pad_widths)
|
104 |
+
|
105 |
+
return array
|
106 |
+
|
107 |
+
|
108 |
+
@lru_cache(maxsize=None)
|
109 |
+
def mel_filters(device,
|
110 |
+
n_mels: int,
|
111 |
+
mel_filters_dir: str = None) -> torch.Tensor:
|
112 |
+
"""
|
113 |
+
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
114 |
+
Allows decoupling librosa dependency; saved using:
|
115 |
+
|
116 |
+
np.savez_compressed(
|
117 |
+
"mel_filters.npz",
|
118 |
+
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
119 |
+
)
|
120 |
+
"""
|
121 |
+
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
|
122 |
+
if mel_filters_dir is None:
|
123 |
+
mel_filters_path = os.path.join(os.path.dirname(__file__), "assets",
|
124 |
+
"mel_filters.npz")
|
125 |
+
else:
|
126 |
+
mel_filters_path = os.path.join(mel_filters_dir, "mel_filters.npz")
|
127 |
+
with np.load(mel_filters_path) as f:
|
128 |
+
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
129 |
+
|
130 |
+
|
131 |
+
def log_mel_spectrogram(
|
132 |
+
audio: Union[str, np.ndarray, torch.Tensor],
|
133 |
+
n_mels: int,
|
134 |
+
padding: int = 0,
|
135 |
+
device: Optional[Union[str, torch.device]] = None,
|
136 |
+
return_duration: bool = False,
|
137 |
+
mel_filters_dir: str = None,
|
138 |
+
):
|
139 |
+
"""
|
140 |
+
Compute the log-Mel spectrogram of
|
141 |
+
|
142 |
+
Parameters
|
143 |
+
----------
|
144 |
+
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
145 |
+
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
146 |
+
|
147 |
+
n_mels: int
|
148 |
+
The number of Mel-frequency filters, only 80 and 128 are supported
|
149 |
+
|
150 |
+
padding: int
|
151 |
+
Number of zero samples to pad to the right
|
152 |
+
|
153 |
+
device: Optional[Union[str, torch.device]]
|
154 |
+
If given, the audio tensor is moved to this device before STFT
|
155 |
+
|
156 |
+
Returns
|
157 |
+
-------
|
158 |
+
torch.Tensor, shape = (80 or 128, n_frames)
|
159 |
+
A Tensor that contains the Mel spectrogram
|
160 |
+
"""
|
161 |
+
if not torch.is_tensor(audio):
|
162 |
+
if isinstance(audio, str):
|
163 |
+
if audio.endswith('.wav'):
|
164 |
+
audio, _ = load_audio_wav_format(audio)
|
165 |
+
else:
|
166 |
+
audio = load_audio(audio)
|
167 |
+
assert isinstance(audio,
|
168 |
+
np.ndarray), f"Unsupported audio type: {type(audio)}"
|
169 |
+
duration = audio.shape[-1] / SAMPLE_RATE
|
170 |
+
audio = pad_or_trim(audio, N_SAMPLES)
|
171 |
+
audio = audio.astype(np.float32)
|
172 |
+
audio = torch.from_numpy(audio)
|
173 |
+
|
174 |
+
if device is not None:
|
175 |
+
audio = audio.to(device)
|
176 |
+
if padding > 0:
|
177 |
+
audio = F.pad(audio, (0, padding))
|
178 |
+
window = torch.hann_window(N_FFT).to(audio.device)
|
179 |
+
stft = torch.stft(audio,
|
180 |
+
N_FFT,
|
181 |
+
HOP_LENGTH,
|
182 |
+
window=window,
|
183 |
+
return_complex=True)
|
184 |
+
magnitudes = stft[..., :-1].abs()**2
|
185 |
+
|
186 |
+
filters = mel_filters(audio.device, n_mels, mel_filters_dir)
|
187 |
+
mel_spec = filters @ magnitudes
|
188 |
+
|
189 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
190 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
191 |
+
log_spec = (log_spec + 4.0) / 4.0
|
192 |
+
if return_duration:
|
193 |
+
return log_spec, duration
|
194 |
+
else:
|
195 |
+
return log_spec
|
196 |
+
|
197 |
+
|
198 |
+
def store_transcripts(filename: Pathlike, texts: Iterable[Tuple[str, str,
|
199 |
+
str]]) -> None:
|
200 |
+
"""Save predicted results and reference transcripts to a file.
|
201 |
+
https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
|
202 |
+
Args:
|
203 |
+
filename:
|
204 |
+
File to save the results to.
|
205 |
+
texts:
|
206 |
+
An iterable of tuples. The first element is the cur_id, the second is
|
207 |
+
the reference transcript and the third element is the predicted result.
|
208 |
+
Returns:
|
209 |
+
Return None.
|
210 |
+
"""
|
211 |
+
with open(filename, "w") as f:
|
212 |
+
for cut_id, ref, hyp in texts:
|
213 |
+
print(f"{cut_id}:\tref={ref}", file=f)
|
214 |
+
print(f"{cut_id}:\thyp={hyp}", file=f)
|
215 |
+
|
216 |
+
|
217 |
+
def write_error_stats(
|
218 |
+
f: TextIO,
|
219 |
+
test_set_name: str,
|
220 |
+
results: List[Tuple[str, str]],
|
221 |
+
enable_log: bool = True,
|
222 |
+
) -> float:
|
223 |
+
"""Write statistics based on predicted results and reference transcripts.
|
224 |
+
https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
|
225 |
+
It will write the following to the given file:
|
226 |
+
|
227 |
+
- WER
|
228 |
+
- number of insertions, deletions, substitutions, corrects and total
|
229 |
+
reference words. For example::
|
230 |
+
|
231 |
+
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
|
232 |
+
reference words (2337 correct)
|
233 |
+
|
234 |
+
- The difference between the reference transcript and predicted result.
|
235 |
+
An instance is given below::
|
236 |
+
|
237 |
+
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
|
238 |
+
|
239 |
+
The above example shows that the reference word is `EDISON`,
|
240 |
+
but it is predicted to `ADDISON` (a substitution error).
|
241 |
+
|
242 |
+
Another example is::
|
243 |
+
|
244 |
+
FOR THE FIRST DAY (SIR->*) I THINK
|
245 |
+
|
246 |
+
The reference word `SIR` is missing in the predicted
|
247 |
+
results (a deletion error).
|
248 |
+
results:
|
249 |
+
An iterable of tuples. The first element is the cur_id, the second is
|
250 |
+
the reference transcript and the third element is the predicted result.
|
251 |
+
enable_log:
|
252 |
+
If True, also print detailed WER to the console.
|
253 |
+
Otherwise, it is written only to the given file.
|
254 |
+
Returns:
|
255 |
+
Return None.
|
256 |
+
"""
|
257 |
+
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
258 |
+
ins: Dict[str, int] = defaultdict(int)
|
259 |
+
dels: Dict[str, int] = defaultdict(int)
|
260 |
+
|
261 |
+
# `words` stores counts per word, as follows:
|
262 |
+
# corr, ref_sub, hyp_sub, ins, dels
|
263 |
+
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
264 |
+
num_corr = 0
|
265 |
+
ERR = "*"
|
266 |
+
for cut_id, ref, hyp in results:
|
267 |
+
ali = kaldialign.align(ref, hyp, ERR)
|
268 |
+
for ref_word, hyp_word in ali:
|
269 |
+
if ref_word == ERR:
|
270 |
+
ins[hyp_word] += 1
|
271 |
+
words[hyp_word][3] += 1
|
272 |
+
elif hyp_word == ERR:
|
273 |
+
dels[ref_word] += 1
|
274 |
+
words[ref_word][4] += 1
|
275 |
+
elif hyp_word != ref_word:
|
276 |
+
subs[(ref_word, hyp_word)] += 1
|
277 |
+
words[ref_word][1] += 1
|
278 |
+
words[hyp_word][2] += 1
|
279 |
+
else:
|
280 |
+
words[ref_word][0] += 1
|
281 |
+
num_corr += 1
|
282 |
+
ref_len = sum([len(r) for _, r, _ in results])
|
283 |
+
sub_errs = sum(subs.values())
|
284 |
+
ins_errs = sum(ins.values())
|
285 |
+
del_errs = sum(dels.values())
|
286 |
+
tot_errs = sub_errs + ins_errs + del_errs
|
287 |
+
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
288 |
+
|
289 |
+
if enable_log:
|
290 |
+
logging.info(f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
291 |
+
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
292 |
+
f"{del_errs} del, {sub_errs} sub ]")
|
293 |
+
|
294 |
+
print(f"%WER = {tot_err_rate}", file=f)
|
295 |
+
print(
|
296 |
+
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
|
297 |
+
f"{sub_errs} substitutions, over {ref_len} reference "
|
298 |
+
f"words ({num_corr} correct)",
|
299 |
+
file=f,
|
300 |
+
)
|
301 |
+
print(
|
302 |
+
"Search below for sections starting with PER-UTT DETAILS:, "
|
303 |
+
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
|
304 |
+
file=f,
|
305 |
+
)
|
306 |
+
|
307 |
+
print("", file=f)
|
308 |
+
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
309 |
+
for cut_id, ref, hyp in results:
|
310 |
+
ali = kaldialign.align(ref, hyp, ERR)
|
311 |
+
combine_successive_errors = True
|
312 |
+
if combine_successive_errors:
|
313 |
+
ali = [[[x], [y]] for x, y in ali]
|
314 |
+
for i in range(len(ali) - 1):
|
315 |
+
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
|
316 |
+
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
|
317 |
+
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
|
318 |
+
ali[i] = [[], []]
|
319 |
+
ali = [[
|
320 |
+
list(filter(lambda a: a != ERR, x)),
|
321 |
+
list(filter(lambda a: a != ERR, y)),
|
322 |
+
] for x, y in ali]
|
323 |
+
ali = list(filter(lambda x: x != [[], []], ali))
|
324 |
+
ali = [[
|
325 |
+
ERR if x == [] else " ".join(x),
|
326 |
+
ERR if y == [] else " ".join(y),
|
327 |
+
] for x, y in ali]
|
328 |
+
|
329 |
+
print(
|
330 |
+
f"{cut_id}:\t" + " ".join((ref_word if ref_word == hyp_word else
|
331 |
+
f"({ref_word}->{hyp_word})"
|
332 |
+
for ref_word, hyp_word in ali)),
|
333 |
+
file=f,
|
334 |
+
)
|
335 |
+
|
336 |
+
print("", file=f)
|
337 |
+
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
338 |
+
|
339 |
+
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()],
|
340 |
+
reverse=True):
|
341 |
+
print(f"{count} {ref} -> {hyp}", file=f)
|
342 |
+
|
343 |
+
print("", file=f)
|
344 |
+
print("DELETIONS: count ref", file=f)
|
345 |
+
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
|
346 |
+
print(f"{count} {ref}", file=f)
|
347 |
+
|
348 |
+
print("", file=f)
|
349 |
+
print("INSERTIONS: count hyp", file=f)
|
350 |
+
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
|
351 |
+
print(f"{count} {hyp}", file=f)
|
352 |
+
|
353 |
+
print("", file=f)
|
354 |
+
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp",
|
355 |
+
file=f)
|
356 |
+
for _, word, counts in sorted([(sum(v[1:]), k, v)
|
357 |
+
for k, v in words.items()],
|
358 |
+
reverse=True):
|
359 |
+
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
360 |
+
tot_errs = ref_sub + hyp_sub + ins + dels
|
361 |
+
ref_count = corr + ref_sub + dels
|
362 |
+
hyp_count = corr + hyp_sub + ins
|
363 |
+
|
364 |
+
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
365 |
+
return float(tot_err_rate)
|