aadnk commited on
Commit
43189ac
1 Parent(s): b84f4f2

Fix threads argument

Browse files
Files changed (2) hide show
  1. app.py +6 -1
  2. cli.py +3 -0
app.py CHANGED
@@ -29,7 +29,7 @@ import ffmpeg
29
  import gradio as gr
30
 
31
  from src.download import ExceededMaximumDuration, download_url
32
- from src.utils import slugify, write_srt, write_vtt
33
  from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
34
  from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
35
  from src.whisper.whisperFactory import create_whisper_container
@@ -596,9 +596,14 @@ if __name__ == '__main__':
596
  help="the Whisper implementation to use")
597
  parser.add_argument("--compute_type", type=str, default=default_app_config.compute_type, choices=["default", "auto", "int8", "int8_float16", "int16", "float16", "float32"], \
598
  help="the compute type to use for inference")
 
 
599
 
600
  args = parser.parse_args().__dict__
601
 
602
  updated_config = default_app_config.update(**args)
603
 
 
 
 
604
  create_ui(app_config=updated_config)
 
29
  import gradio as gr
30
 
31
  from src.download import ExceededMaximumDuration, download_url
32
+ from src.utils import optional_int, slugify, write_srt, write_vtt
33
  from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
34
  from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
35
  from src.whisper.whisperFactory import create_whisper_container
 
596
  help="the Whisper implementation to use")
597
  parser.add_argument("--compute_type", type=str, default=default_app_config.compute_type, choices=["default", "auto", "int8", "int8_float16", "int16", "float16", "float32"], \
598
  help="the compute type to use for inference")
599
+ parser.add_argument("--threads", type=optional_int, default=0,
600
+ help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
601
 
602
  args = parser.parse_args().__dict__
603
 
604
  updated_config = default_app_config.update(**args)
605
 
606
+ if (threads := args.pop("threads")) > 0:
607
+ torch.set_num_threads(threads)
608
+
609
  create_ui(app_config=updated_config)
cli.py CHANGED
@@ -113,6 +113,9 @@ def cli():
113
  device: str = args.pop("device")
114
  os.makedirs(output_dir, exist_ok=True)
115
 
 
 
 
116
  whisper_implementation = args.pop("whisper_implementation")
117
  print(f"Using {whisper_implementation} for Whisper")
118
 
 
113
  device: str = args.pop("device")
114
  os.makedirs(output_dir, exist_ok=True)
115
 
116
+ if (threads := args.pop("threads")) > 0:
117
+ torch.set_num_threads(threads)
118
+
119
  whisper_implementation = args.pop("whisper_implementation")
120
  print(f"Using {whisper_implementation} for Whisper")
121