aadnk commited on
Commit
aa33666
1 Parent(s): 18bb72f

Fix diarization in CLI

Browse files
Files changed (2) hide show
  1. app.py +13 -13
  2. cli.py +7 -7
app.py CHANGED
@@ -240,19 +240,6 @@ class WhisperTranscriber:
240
  # Update progress
241
  current_progress += source_audio_duration
242
 
243
- # Diarization
244
- if self.diarization and self.diarization_kwargs:
245
- print("Diarizing ", source.source_path)
246
- diarization_result = list(self.diarization.run(source.source_path, **self.diarization_kwargs))
247
-
248
- # Print result
249
- print("Diarization result: ")
250
- for entry in diarization_result:
251
- print(f" start={entry.start:.1f}s stop={entry.end:.1f}s speaker_{entry.speaker}")
252
-
253
- # Add speakers to result
254
- result = self.diarization.mark_speakers(diarization_result, result)
255
-
256
  source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory, highlight_words)
257
 
258
  if len(sources) > 1:
@@ -373,6 +360,19 @@ class WhisperTranscriber:
373
  else:
374
  # Default VAD
375
  result = whisperCallable.invoke(audio_path, 0, None, None, progress_listener=progressListener)
 
 
 
 
 
 
 
 
 
 
 
 
 
376
 
377
  return result
378
 
 
240
  # Update progress
241
  current_progress += source_audio_duration
242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory, highlight_words)
244
 
245
  if len(sources) > 1:
 
360
  else:
361
  # Default VAD
362
  result = whisperCallable.invoke(audio_path, 0, None, None, progress_listener=progressListener)
363
+
364
+ # Diarization
365
+ if self.diarization and self.diarization_kwargs:
366
+ print("Diarizing ", audio_path)
367
+ diarization_result = list(self.diarization.run(audio_path, **self.diarization_kwargs))
368
+
369
+ # Print result
370
+ print("Diarization result: ")
371
+ for entry in diarization_result:
372
+ print(f" start={entry.start:.1f}s stop={entry.end:.1f}s speaker_{entry.speaker}")
373
+
374
+ # Add speakers to result
375
+ result = self.diarization.mark_speakers(diarization_result, result)
376
 
377
  return result
378
 
cli.py CHANGED
@@ -111,9 +111,9 @@ def cli():
111
  parser.add_argument('--auth_token', type=str, default=None, help='HuggingFace API Token (optional)')
112
  parser.add_argument("--diarization", type=str2bool, default=app_config.diarization, \
113
  help="whether to perform speaker diarization")
114
- parser.add_argument("--num_speakers", type=int, default=None, help="Number of speakers")
115
- parser.add_argument("--min_speakers", type=int, default=None, help="Minimum number of speakers")
116
- parser.add_argument("--max_speakers", type=int, default=None, help="Maximum number of speakers")
117
 
118
  args = parser.parse_args().__dict__
119
  model_name: str = args.pop("model")
@@ -151,11 +151,11 @@ def cli():
151
  compute_type = args.pop("compute_type")
152
  highlight_words = args.pop("highlight_words")
153
 
154
- diarization = args.pop("diarization")
155
  auth_token = args.pop("auth_token")
156
- num_speakers = args.pop("num_speakers")
157
- min_speakers = args.pop("min_speakers")
158
- max_speakers = args.pop("max_speakers")
 
159
 
160
  transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
161
  transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
 
111
  parser.add_argument('--auth_token', type=str, default=None, help='HuggingFace API Token (optional)')
112
  parser.add_argument("--diarization", type=str2bool, default=app_config.diarization, \
113
  help="whether to perform speaker diarization")
114
+ parser.add_argument("--diarization_num_speakers", type=int, default=None, help="Number of speakers")
115
+ parser.add_argument("--diarization_min_speakers", type=int, default=None, help="Minimum number of speakers")
116
+ parser.add_argument("--diarization_max_speakers", type=int, default=None, help="Maximum number of speakers")
117
 
118
  args = parser.parse_args().__dict__
119
  model_name: str = args.pop("model")
 
151
  compute_type = args.pop("compute_type")
152
  highlight_words = args.pop("highlight_words")
153
 
 
154
  auth_token = args.pop("auth_token")
155
+ diarization = args.pop("diarization")
156
+ num_speakers = args.pop("diarization_num_speakers")
157
+ min_speakers = args.pop("diarization_min_speakers")
158
+ max_speakers = args.pop("diarization_max_speakers")
159
 
160
  transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
161
  transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))