Marcus Edel commited on
Commit
a7b86dc
·
unverified ·
2 Parent(s): a330b66 80a0363

Merge pull request #8 from makaveli10/reuse_loaded_model

Browse files
Files changed (1) hide show
  1. whisper_live/trt_server.py +9 -7
whisper_live/trt_server.py CHANGED
@@ -54,7 +54,7 @@ class TranscriptionServer:
54
  self.clients_start_time = {}
55
  self.max_clients = 4
56
  self.max_connection_time = 600
57
- print("done loading")
58
 
59
  def get_wait_time(self):
60
  """
@@ -113,6 +113,9 @@ class TranscriptionServer:
113
  websocket.close()
114
  del websocket
115
  return
 
 
 
116
 
117
  client = ServeClient(
118
  websocket,
@@ -122,7 +125,7 @@ class TranscriptionServer:
122
  client_uid=options["uid"],
123
  transcription_queue=transcription_queue,
124
  llm_queue=llm_queue,
125
- model_path=whisper_tensorrt_path
126
  )
127
 
128
  self.clients[websocket] = client
@@ -237,7 +240,7 @@ class ServeClient:
237
  client_uid=None,
238
  transcription_queue=None,
239
  llm_queue=None,
240
- model_path=None
241
  ):
242
  """
243
  Initialize a ServeClient instance.
@@ -254,16 +257,15 @@ class ServeClient:
254
  client_uid (str, optional): A unique identifier for the client. Defaults to None.
255
 
256
  """
 
 
 
257
  self.client_uid = client_uid
258
  self.transcription_queue = transcription_queue
259
  self.llm_queue = llm_queue
260
  self.data = b""
261
  self.frames = b""
262
- self.language = language if multilingual else "en"
263
  self.task = task
264
- self.transcriber = WhisperTRTLLM(model_path, False, "assets", device="cuda")
265
-
266
-
267
  self.last_prompt = None
268
 
269
  self.timestamp_offset = 0.0
 
54
  self.clients_start_time = {}
55
  self.max_clients = 4
56
  self.max_connection_time = 600
57
+ self.transcriber = None
58
 
59
  def get_wait_time(self):
60
  """
 
113
  websocket.close()
114
  del websocket
115
  return
116
+
117
+ if self.transcriber is None:
118
+ self.transcriber = WhisperTRTLLM(whisper_tensorrt_path, assets_dir="assets", device="cuda")
119
 
120
  client = ServeClient(
121
  websocket,
 
125
  client_uid=options["uid"],
126
  transcription_queue=transcription_queue,
127
  llm_queue=llm_queue,
128
+ transcriber=self.transcriber
129
  )
130
 
131
  self.clients[websocket] = client
 
240
  client_uid=None,
241
  transcription_queue=None,
242
  llm_queue=None,
243
+ transcriber=None
244
  ):
245
  """
246
  Initialize a ServeClient instance.
 
257
  client_uid (str, optional): A unique identifier for the client. Defaults to None.
258
 
259
  """
260
+ if transcriber is None:
261
+ raise ValueError("Transcriber is None.")
262
+ self.transcriber = transcriber
263
  self.client_uid = client_uid
264
  self.transcription_queue = transcription_queue
265
  self.llm_queue = llm_queue
266
  self.data = b""
267
  self.frames = b""
 
268
  self.task = task
 
 
 
269
  self.last_prompt = None
270
 
271
  self.timestamp_offset = 0.0