Spaces:
Paused
Paused
Merge pull request #8 from makaveli10/reuse_loaded_model
Browse files
whisper_live/trt_server.py
CHANGED
@@ -54,7 +54,7 @@ class TranscriptionServer:
|
|
54 |
self.clients_start_time = {}
|
55 |
self.max_clients = 4
|
56 |
self.max_connection_time = 600
|
57 |
-
|
58 |
|
59 |
def get_wait_time(self):
|
60 |
"""
|
@@ -113,6 +113,9 @@ class TranscriptionServer:
|
|
113 |
websocket.close()
|
114 |
del websocket
|
115 |
return
|
|
|
|
|
|
|
116 |
|
117 |
client = ServeClient(
|
118 |
websocket,
|
@@ -122,7 +125,7 @@ class TranscriptionServer:
|
|
122 |
client_uid=options["uid"],
|
123 |
transcription_queue=transcription_queue,
|
124 |
llm_queue=llm_queue,
|
125 |
-
|
126 |
)
|
127 |
|
128 |
self.clients[websocket] = client
|
@@ -237,7 +240,7 @@ class ServeClient:
|
|
237 |
client_uid=None,
|
238 |
transcription_queue=None,
|
239 |
llm_queue=None,
|
240 |
-
|
241 |
):
|
242 |
"""
|
243 |
Initialize a ServeClient instance.
|
@@ -254,16 +257,15 @@ class ServeClient:
|
|
254 |
client_uid (str, optional): A unique identifier for the client. Defaults to None.
|
255 |
|
256 |
"""
|
|
|
|
|
|
|
257 |
self.client_uid = client_uid
|
258 |
self.transcription_queue = transcription_queue
|
259 |
self.llm_queue = llm_queue
|
260 |
self.data = b""
|
261 |
self.frames = b""
|
262 |
-
self.language = language if multilingual else "en"
|
263 |
self.task = task
|
264 |
-
self.transcriber = WhisperTRTLLM(model_path, False, "assets", device="cuda")
|
265 |
-
|
266 |
-
|
267 |
self.last_prompt = None
|
268 |
|
269 |
self.timestamp_offset = 0.0
|
|
|
54 |
self.clients_start_time = {}
|
55 |
self.max_clients = 4
|
56 |
self.max_connection_time = 600
|
57 |
+
self.transcriber = None
|
58 |
|
59 |
def get_wait_time(self):
|
60 |
"""
|
|
|
113 |
websocket.close()
|
114 |
del websocket
|
115 |
return
|
116 |
+
|
117 |
+
if self.transcriber is None:
|
118 |
+
self.transcriber = WhisperTRTLLM(whisper_tensorrt_path, assets_dir="assets", device="cuda")
|
119 |
|
120 |
client = ServeClient(
|
121 |
websocket,
|
|
|
125 |
client_uid=options["uid"],
|
126 |
transcription_queue=transcription_queue,
|
127 |
llm_queue=llm_queue,
|
128 |
+
transcriber=self.transcriber
|
129 |
)
|
130 |
|
131 |
self.clients[websocket] = client
|
|
|
240 |
client_uid=None,
|
241 |
transcription_queue=None,
|
242 |
llm_queue=None,
|
243 |
+
transcriber=None
|
244 |
):
|
245 |
"""
|
246 |
Initialize a ServeClient instance.
|
|
|
257 |
client_uid (str, optional): A unique identifier for the client. Defaults to None.
|
258 |
|
259 |
"""
|
260 |
+
if transcriber is None:
|
261 |
+
raise ValueError("Transcriber is None.")
|
262 |
+
self.transcriber = transcriber
|
263 |
self.client_uid = client_uid
|
264 |
self.transcription_queue = transcription_queue
|
265 |
self.llm_queue = llm_queue
|
266 |
self.data = b""
|
267 |
self.frames = b""
|
|
|
268 |
self.task = task
|
|
|
|
|
|
|
269 |
self.last_prompt = None
|
270 |
|
271 |
self.timestamp_offset = 0.0
|