aadnk commited on
Commit
295de00
·
1 Parent(s): c90f138

Adding support for faster_whisper

Browse files

This is a re-implementation of Whisper in CTranslate2 that can be 4x faster
and use much less memory than OpenAI's Whisper.

app.py CHANGED
@@ -11,8 +11,11 @@ import zipfile
11
  import numpy as np
12
 
13
  import torch
 
14
  from src.config import ApplicationConfig
15
- from src.hooks.whisperProgressHook import ProgressListener, SubTaskProgressListener, create_progress_listener_handle
 
 
16
  from src.modelCache import ModelCache
17
  from src.source import get_audio_source_collection
18
  from src.vadParallel import ParallelContext, ParallelTranscription
@@ -26,7 +29,8 @@ import gradio as gr
26
  from src.download import ExceededMaximumDuration, download_url
27
  from src.utils import slugify, write_srt, write_vtt
28
  from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
29
- from src.whisperContainer import WhisperContainer
 
30
 
31
  # Configure more application defaults in config.json5
32
 
@@ -121,7 +125,8 @@ class WhisperTranscriber:
121
  selectedLanguage = languageName.lower() if len(languageName) > 0 else None
122
  selectedModel = modelName if modelName is not None else "base"
123
 
124
- model = WhisperContainer(model_name=selectedModel, cache=self.model_cache, models=self.app_config.models)
 
125
 
126
  # Result
127
  download = []
@@ -223,7 +228,7 @@ class WhisperTranscriber:
223
  except ExceededMaximumDuration as e:
224
  return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
225
 
226
- def transcribe_file(self, model: WhisperContainer, audio_path: str, language: str, task: str = None, vad: str = None,
227
  vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1,
228
  progressListener: ProgressListener = None, **decodeOptions: dict):
229
 
@@ -507,7 +512,9 @@ if __name__ == '__main__':
507
  parser.add_argument("--auto_parallel", type=bool, default=app_config.auto_parallel, \
508
  help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.") # False
509
  parser.add_argument("--output_dir", "-o", type=str, default=app_config.output_dir, \
510
- help="directory to save the outputs") # None
 
 
511
 
512
  args = parser.parse_args().__dict__
513
 
 
11
  import numpy as np
12
 
13
  import torch
14
+
15
  from src.config import ApplicationConfig
16
+ from src.hooks.progressListener import ProgressListener
17
+ from src.hooks.subTaskProgressListener import SubTaskProgressListener
18
+ from src.hooks.whisperProgressHook import create_progress_listener_handle
19
  from src.modelCache import ModelCache
20
  from src.source import get_audio_source_collection
21
  from src.vadParallel import ParallelContext, ParallelTranscription
 
29
  from src.download import ExceededMaximumDuration, download_url
30
  from src.utils import slugify, write_srt, write_vtt
31
  from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
32
+ from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
33
+ from src.whisper.whisperFactory import create_whisper_container
34
 
35
  # Configure more application defaults in config.json5
36
 
 
125
  selectedLanguage = languageName.lower() if len(languageName) > 0 else None
126
  selectedModel = modelName if modelName is not None else "base"
127
 
128
+ model = create_whisper_container(whisper_implementation=app_config.whisper_implementation,
129
+ model_name=selectedModel, cache=self.model_cache, models=self.app_config.models)
130
 
131
  # Result
132
  download = []
 
228
  except ExceededMaximumDuration as e:
229
  return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
230
 
231
+ def transcribe_file(self, model: AbstractWhisperContainer, audio_path: str, language: str, task: str = None, vad: str = None,
232
  vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1,
233
  progressListener: ProgressListener = None, **decodeOptions: dict):
234
 
 
512
  parser.add_argument("--auto_parallel", type=bool, default=app_config.auto_parallel, \
513
  help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.") # False
514
  parser.add_argument("--output_dir", "-o", type=str, default=app_config.output_dir, \
515
+ help="directory to save the outputs"), \
516
+ parser.add_argument("--whisper_implementation", type=str, default=app_config.whisper_implementation, choices=["whisper", "faster-whisper"],\
517
+ help="the Whisper implementation to use"), \
518
 
519
  args = parser.parse_args().__dict__
520
 
cli.py CHANGED
@@ -11,7 +11,7 @@ from src.config import ApplicationConfig
11
  from src.download import download_url
12
 
13
  from src.utils import optional_float, optional_int, str2bool
14
- from src.whisperContainer import WhisperContainer
15
 
16
  def cli():
17
  app_config = ApplicationConfig.create_default()
@@ -32,8 +32,10 @@ def cli():
32
  parser.add_argument("--output_dir", "-o", type=str, default=output_dir, \
33
  help="directory to save the outputs")
34
  parser.add_argument("--verbose", type=str2bool, default=app_config.verbose, \
35
- help="whether to print out the progress and debug messages")
36
-
 
 
37
  parser.add_argument("--task", type=str, default=app_config.task, choices=["transcribe", "translate"], \
38
  help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
39
  parser.add_argument("--language", type=str, default=app_config.language, choices=sorted(LANGUAGES), \
@@ -92,6 +94,8 @@ def cli():
92
  device: str = args.pop("device")
93
  os.makedirs(output_dir, exist_ok=True)
94
 
 
 
95
  if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
96
  warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
97
  args["language"] = "en"
@@ -115,7 +119,8 @@ def cli():
115
  transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
116
  transcriber.set_auto_parallel(auto_parallel)
117
 
118
- model = WhisperContainer(model_name, device=device, download_root=model_dir, models=app_config.models)
 
119
 
120
  if (transcriber._has_parallel_devices()):
121
  print("Using parallel devices:", transcriber.parallel_device_list)
 
11
  from src.download import download_url
12
 
13
  from src.utils import optional_float, optional_int, str2bool
14
+ from src.whisper.whisperFactory import create_whisper_container
15
 
16
  def cli():
17
  app_config = ApplicationConfig.create_default()
 
32
  parser.add_argument("--output_dir", "-o", type=str, default=output_dir, \
33
  help="directory to save the outputs")
34
  parser.add_argument("--verbose", type=str2bool, default=app_config.verbose, \
35
+ help="whether to print out the progress and debug messages"), \
36
+ parser.add_argument("--whisper_implementation", type=str, default=app_config.whisper_implementation, choices=["whisper", "faster-whisper"],\
37
+ help="the Whisper implementation to use"), \
38
+
39
  parser.add_argument("--task", type=str, default=app_config.task, choices=["transcribe", "translate"], \
40
  help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
41
  parser.add_argument("--language", type=str, default=app_config.language, choices=sorted(LANGUAGES), \
 
94
  device: str = args.pop("device")
95
  os.makedirs(output_dir, exist_ok=True)
96
 
97
+ whisper_implementation = args.pop("whisper_implementation")
98
+
99
  if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
100
  warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
101
  args["language"] = "en"
 
119
  transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
120
  transcriber.set_auto_parallel(auto_parallel)
121
 
122
+ model = create_whisper_container(whisper_implementation=whisper_implementation,
123
+ device=device, download_root=model_dir, models=app_config.models)
124
 
125
  if (transcriber._has_parallel_devices()):
126
  print("Using parallel devices:", transcriber.parallel_device_list)
config.json5 CHANGED
@@ -62,6 +62,9 @@
62
 
63
  // * General options *
64
 
 
 
 
65
  // The default model name.
66
  "default_model_name": "medium",
67
  // The default VAD.
 
62
 
63
  // * General options *
64
 
65
+ // The default implementation to use for Whisper. Can be "whisper" or "faster-whisper".
66
+ "whisper_implementation": "whisper",
67
+
68
  // The default model name.
69
  "default_model_name": "medium",
70
  // The default VAD.
requirements-fastWhisper.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ ctranslate2
2
+ faster-whisper
3
+ ffmpeg-python==0.2.0
4
+ gradio==3.23.0
5
+ yt-dlp
6
+ json5
7
+ torch
8
+ torchaudio
src/config.py CHANGED
@@ -8,8 +8,6 @@ import torch
8
 
9
  from tqdm import tqdm
10
 
11
- from src.conversion.hf_converter import convert_hf_whisper
12
-
13
  class ModelConfig:
14
  def __init__(self, name: str, url: str, path: str = None, type: str = "whisper"):
15
  """
@@ -25,86 +23,11 @@ class ModelConfig:
25
  self.path = path
26
  self.type = type
27
 
28
- def download_url(self, root_dir: str):
29
- import whisper
30
-
31
- # See if path is already set
32
- if self.path is not None:
33
- return self.path
34
-
35
- if root_dir is None:
36
- root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
37
-
38
- model_type = self.type.lower() if self.type is not None else "whisper"
39
-
40
- if model_type in ["huggingface", "hf"]:
41
- self.path = self.url
42
- destination_target = os.path.join(root_dir, self.name + ".pt")
43
-
44
- # Convert from HuggingFace format to Whisper format
45
- if os.path.exists(destination_target):
46
- print(f"File {destination_target} already exists, skipping conversion")
47
- else:
48
- print("Saving HuggingFace model in Whisper format to " + destination_target)
49
- convert_hf_whisper(self.url, destination_target)
50
-
51
- self.path = destination_target
52
-
53
- elif model_type in ["whisper", "w"]:
54
- self.path = self.url
55
-
56
- # See if URL is just a file
57
- if self.url in whisper._MODELS:
58
- # No need to download anything - Whisper will handle it
59
- self.path = self.url
60
- elif self.url.startswith("file://"):
61
- # Get file path
62
- self.path = urlparse(self.url).path
63
- # See if it is an URL
64
- elif self.url.startswith("http://") or self.url.startswith("https://"):
65
- # Extension (or file name)
66
- extension = os.path.splitext(self.url)[-1]
67
- download_target = os.path.join(root_dir, self.name + extension)
68
-
69
- if os.path.exists(download_target) and not os.path.isfile(download_target):
70
- raise RuntimeError(f"{download_target} exists and is not a regular file")
71
-
72
- if not os.path.isfile(download_target):
73
- self._download_file(self.url, download_target)
74
- else:
75
- print(f"File {download_target} already exists, skipping download")
76
-
77
- self.path = download_target
78
- # Must be a local file
79
- else:
80
- self.path = self.url
81
-
82
- else:
83
- raise ValueError(f"Unknown model type {model_type}")
84
-
85
- return self.path
86
-
87
- def _download_file(self, url: str, destination: str):
88
- with urllib.request.urlopen(url) as source, open(destination, "wb") as output:
89
- with tqdm(
90
- total=int(source.info().get("Content-Length")),
91
- ncols=80,
92
- unit="iB",
93
- unit_scale=True,
94
- unit_divisor=1024,
95
- ) as loop:
96
- while True:
97
- buffer = source.read(8192)
98
- if not buffer:
99
- break
100
-
101
- output.write(buffer)
102
- loop.update(len(buffer))
103
-
104
  class ApplicationConfig:
105
  def __init__(self, models: List[ModelConfig] = [], input_audio_max_duration: int = 600,
106
  share: bool = False, server_name: str = None, server_port: int = 7860,
107
  queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
 
108
  default_model_name: str = "medium", default_vad: str = "silero-vad",
109
  vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800,
110
  auto_parallel: bool = False, output_dir: str = None,
@@ -132,6 +55,7 @@ class ApplicationConfig:
132
  self.queue_concurrency_count = queue_concurrency_count
133
  self.delete_uploaded_files = delete_uploaded_files
134
 
 
135
  self.default_model_name = default_model_name
136
  self.default_vad = default_vad
137
  self.vad_parallel_devices = vad_parallel_devices
 
8
 
9
  from tqdm import tqdm
10
 
 
 
11
  class ModelConfig:
12
  def __init__(self, name: str, url: str, path: str = None, type: str = "whisper"):
13
  """
 
23
  self.path = path
24
  self.type = type
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  class ApplicationConfig:
27
  def __init__(self, models: List[ModelConfig] = [], input_audio_max_duration: int = 600,
28
  share: bool = False, server_name: str = None, server_port: int = 7860,
29
  queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
30
+ whisper_implementation: str = "whisper",
31
  default_model_name: str = "medium", default_vad: str = "silero-vad",
32
  vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800,
33
  auto_parallel: bool = False, output_dir: str = None,
 
55
  self.queue_concurrency_count = queue_concurrency_count
56
  self.delete_uploaded_files = delete_uploaded_files
57
 
58
+ self.whisper_implementation = whisper_implementation
59
  self.default_model_name = default_model_name
60
  self.default_vad = default_vad
61
  self.vad_parallel_devices = vad_parallel_devices
src/conversion/hf_converter.py CHANGED
@@ -2,7 +2,6 @@
2
 
3
  from copy import deepcopy
4
  import torch
5
- from transformers import WhisperForConditionalGeneration
6
 
7
  WHISPER_MAPPING = {
8
  "layers": "blocks",
@@ -43,7 +42,8 @@ def rename_keys(s_dict):
43
  return s_dict
44
 
45
 
46
- def convert_hf_whisper(hf_model_name_or_path: str, whisper_state_path: str):
 
47
  transformer_model = WhisperForConditionalGeneration.from_pretrained(hf_model_name_or_path)
48
  config = transformer_model.config
49
 
 
2
 
3
  from copy import deepcopy
4
  import torch
 
5
 
6
  WHISPER_MAPPING = {
7
  "layers": "blocks",
 
42
  return s_dict
43
 
44
 
45
+ def convert_hf_whisper(hf_model_name_or_path: str, whisper_state_path: str)
46
+ from transformers import WhisperForConditionalGeneration
47
  transformer_model = WhisperForConditionalGeneration.from_pretrained(hf_model_name_or_path)
48
  config = transformer_model.config
49
 
src/hooks/progressListener.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ class ProgressListener:
4
+ def on_progress(self, current: Union[int, float], total: Union[int, float]):
5
+ self.total = total
6
+
7
+ def on_finished(self):
8
+ pass
src/hooks/subTaskProgressListener.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.hooks.progressListener import ProgressListener
2
+
3
+ from typing import Union
4
+
5
+ class SubTaskProgressListener(ProgressListener):
6
+ """
7
+ A sub task listener that reports the progress of a sub task to a base task listener
8
+ Parameters
9
+ ----------
10
+ base_task_listener : ProgressListener
11
+ The base progress listener to accumulate overall progress in.
12
+ base_task_total : float
13
+ The maximum total progress that will be reported to the base progress listener.
14
+ sub_task_start : float
15
+ The starting progress of a sub task, in respect to the base progress listener.
16
+ sub_task_total : float
17
+ The total amount of progress a sub task will report to the base progress listener.
18
+ """
19
+ def __init__(
20
+ self,
21
+ base_task_listener: ProgressListener,
22
+ base_task_total: float,
23
+ sub_task_start: float,
24
+ sub_task_total: float,
25
+ ):
26
+ self.base_task_listener = base_task_listener
27
+ self.base_task_total = base_task_total
28
+ self.sub_task_start = sub_task_start
29
+ self.sub_task_total = sub_task_total
30
+
31
+ def on_progress(self, current: Union[int, float], total: Union[int, float]):
32
+ sub_task_progress_frac = current / total
33
+ sub_task_progress = self.sub_task_start + self.sub_task_total * sub_task_progress_frac
34
+ self.base_task_listener.on_progress(sub_task_progress, self.base_task_total)
35
+
36
+ def on_finished(self):
37
+ self.base_task_listener.on_progress(self.sub_task_start + self.sub_task_total, self.base_task_total)
src/hooks/whisperProgressHook.py CHANGED
@@ -3,12 +3,7 @@ import threading
3
  from typing import List, Union
4
  import tqdm
5
 
6
- class ProgressListener:
7
- def on_progress(self, current: Union[int, float], total: Union[int, float]):
8
- self.total = total
9
-
10
- def on_finished(self):
11
- pass
12
 
13
  class ProgressListenerHandle:
14
  def __init__(self, listener: ProgressListener):
@@ -23,41 +18,6 @@ class ProgressListenerHandle:
23
  if exc_type is None:
24
  self.listener.on_finished()
25
 
26
- class SubTaskProgressListener(ProgressListener):
27
- """
28
- A sub task listener that reports the progress of a sub task to a base task listener
29
-
30
- Parameters
31
- ----------
32
- base_task_listener : ProgressListener
33
- The base progress listener to accumulate overall progress in.
34
- base_task_total : float
35
- The maximum total progress that will be reported to the base progress listener.
36
- sub_task_start : float
37
- The starting progress of a sub task, in respect to the base progress listener.
38
- sub_task_total : float
39
- The total amount of progress a sub task will report to the base progress listener.
40
- """
41
- def __init__(
42
- self,
43
- base_task_listener: ProgressListener,
44
- base_task_total: float,
45
- sub_task_start: float,
46
- sub_task_total: float,
47
- ):
48
- self.base_task_listener = base_task_listener
49
- self.base_task_total = base_task_total
50
- self.sub_task_start = sub_task_start
51
- self.sub_task_total = sub_task_total
52
-
53
- def on_progress(self, current: Union[int, float], total: Union[int, float]):
54
- sub_task_progress_frac = current / total
55
- sub_task_progress = self.sub_task_start + self.sub_task_total * sub_task_progress_frac
56
- self.base_task_listener.on_progress(sub_task_progress, self.base_task_total)
57
-
58
- def on_finished(self):
59
- self.base_task_listener.on_progress(self.sub_task_start + self.sub_task_total, self.base_task_total)
60
-
61
  class _CustomProgressBar(tqdm.tqdm):
62
  def __init__(self, *args, **kwargs):
63
  super().__init__(*args, **kwargs)
 
3
  from typing import List, Union
4
  import tqdm
5
 
6
+ from src.hooks.progressListener import ProgressListener
 
 
 
 
 
7
 
8
  class ProgressListenerHandle:
9
  def __init__(self, listener: ProgressListener):
 
18
  if exc_type is None:
19
  self.listener.on_finished()
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  class _CustomProgressBar(tqdm.tqdm):
22
  def __init__(self, *args, **kwargs):
23
  super().__init__(*args, **kwargs)
src/utils.py CHANGED
@@ -4,6 +4,9 @@ import re
4
 
5
  import zlib
6
  from typing import Iterator, TextIO
 
 
 
7
 
8
 
9
  def exact_div(x, y):
@@ -112,4 +115,21 @@ def slugify(value, allow_unicode=False):
112
  else:
113
  value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
114
  value = re.sub(r'[^\w\s-]', '', value.lower())
115
- return re.sub(r'[-\s]+', '-', value).strip('-_')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  import zlib
6
  from typing import Iterator, TextIO
7
+ import tqdm
8
+
9
+ import urllib3
10
 
11
 
12
  def exact_div(x, y):
 
115
  else:
116
  value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
117
  value = re.sub(r'[^\w\s-]', '', value.lower())
118
+ return re.sub(r'[-\s]+', '-', value).strip('-_')
119
+
120
+ def download_file(url: str, destination: str):
121
+ with urllib3.request.urlopen(url) as source, open(destination, "wb") as output:
122
+ with tqdm(
123
+ total=int(source.info().get("Content-Length")),
124
+ ncols=80,
125
+ unit="iB",
126
+ unit_scale=True,
127
+ unit_divisor=1024,
128
+ ) as loop:
129
+ while True:
130
+ buffer = source.read(8192)
131
+ if not buffer:
132
+ break
133
+
134
+ output.write(buffer)
135
+ loop.update(len(buffer))
src/vad.py CHANGED
@@ -5,11 +5,13 @@ import time
5
  from typing import Any, Deque, Iterator, List, Dict
6
 
7
  from pprint import pprint
8
- from src.hooks.whisperProgressHook import ProgressListener, SubTaskProgressListener, create_progress_listener_handle
 
 
9
  from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
10
 
11
  from src.segments import merge_timestamps
12
- from src.whisperContainer import WhisperCallback
13
 
14
  # Workaround for https://github.com/tensorflow/tensorflow/issues/48797
15
  try:
@@ -136,7 +138,7 @@ class AbstractTranscription(ABC):
136
  pprint(merged)
137
  return merged
138
 
139
- def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig,
140
  progressListener: ProgressListener = None):
141
  """
142
  Transcribe the given audo file.
 
5
  from typing import Any, Deque, Iterator, List, Dict
6
 
7
  from pprint import pprint
8
+ from src.hooks.progressListener import ProgressListener
9
+ from src.hooks.subTaskProgressListener import SubTaskProgressListener
10
+ from src.hooks.whisperProgressHook import create_progress_listener_handle
11
  from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
12
 
13
  from src.segments import merge_timestamps
14
+ from src.whisper.abstractWhisperContainer import AbstractWhisperCallback
15
 
16
  # Workaround for https://github.com/tensorflow/tensorflow/issues/48797
17
  try:
 
138
  pprint(merged)
139
  return merged
140
 
141
+ def transcribe(self, audio: str, whisperCallable: AbstractWhisperCallback, config: TranscriptionConfig,
142
  progressListener: ProgressListener = None):
143
  """
144
  Transcribe the given audo file.
src/vadParallel.py CHANGED
@@ -2,15 +2,16 @@ import multiprocessing
2
  from queue import Empty
3
  import threading
4
  import time
5
- from src.hooks.whisperProgressHook import ProgressListener
6
  from src.vad import AbstractTranscription, TranscriptionConfig, get_audio_duration
7
- from src.whisperContainer import WhisperCallback
8
 
9
  from multiprocessing import Pool, Queue
10
 
11
  from typing import Any, Dict, List, Union
12
  import os
13
 
 
 
14
  class _ProgressListenerToQueue(ProgressListener):
15
  def __init__(self, progress_queue: Queue):
16
  self.progress_queue = progress_queue
@@ -104,7 +105,7 @@ class ParallelTranscription(AbstractTranscription):
104
  def __init__(self, sampling_rate: int = 16000):
105
  super().__init__(sampling_rate=sampling_rate)
106
 
107
- def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig,
108
  cpu_device_count: int, gpu_devices: List[str], cpu_parallel_context: ParallelContext = None, gpu_parallel_context: ParallelContext = None,
109
  progress_listener: ProgressListener = None):
110
  total_duration = get_audio_duration(audio)
@@ -276,7 +277,7 @@ class ParallelTranscription(AbstractTranscription):
276
  return config.override_timestamps
277
  return super().get_merged_timestamps(timestamps, config, total_duration)
278
 
279
- def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: ParallelTranscriptionConfig,
280
  progressListener: ProgressListener = None):
281
  # Override device ID the first time
282
  if (os.environ.get("INITIALIZED", None) is None):
 
2
  from queue import Empty
3
  import threading
4
  import time
5
+ from src.hooks.progressListener import ProgressListener
6
  from src.vad import AbstractTranscription, TranscriptionConfig, get_audio_duration
 
7
 
8
  from multiprocessing import Pool, Queue
9
 
10
  from typing import Any, Dict, List, Union
11
  import os
12
 
13
+ from src.whisper.abstractWhisperContainer import AbstractWhisperCallback
14
+
15
  class _ProgressListenerToQueue(ProgressListener):
16
  def __init__(self, progress_queue: Queue):
17
  self.progress_queue = progress_queue
 
105
  def __init__(self, sampling_rate: int = 16000):
106
  super().__init__(sampling_rate=sampling_rate)
107
 
108
+ def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable: AbstractWhisperCallback, config: TranscriptionConfig,
109
  cpu_device_count: int, gpu_devices: List[str], cpu_parallel_context: ParallelContext = None, gpu_parallel_context: ParallelContext = None,
110
  progress_listener: ProgressListener = None):
111
  total_duration = get_audio_duration(audio)
 
277
  return config.override_timestamps
278
  return super().get_merged_timestamps(timestamps, config, total_duration)
279
 
280
+ def transcribe(self, audio: str, whisperCallable: AbstractWhisperCallback, config: ParallelTranscriptionConfig,
281
  progressListener: ProgressListener = None):
282
  # Override device ID the first time
283
  if (os.environ.get("INITIALIZED", None) is None):
src/whisper/abstractWhisperContainer.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import List
3
+ from src.config import ModelConfig
4
+
5
+ from src.hooks.progressListener import ProgressListener
6
+ from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
7
+
8
+ class AbstractWhisperCallback:
9
+ @abc.abstractmethod
10
+ def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
11
+ """
12
+ Peform the transcription of the given audio file or data.
13
+
14
+ Parameters
15
+ ----------
16
+ audio: Union[str, np.ndarray, torch.Tensor]
17
+ The audio file to transcribe, or the audio data as a numpy array or torch tensor.
18
+ segment_index: int
19
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
20
+ task: str
21
+ The task - either translate or transcribe.
22
+ progress_listener: ProgressListener
23
+ A callback to receive progress updates.
24
+ """
25
+ raise NotImplementedError()
26
+
27
+ def _concat_prompt(self, prompt1, prompt2):
28
+ if (prompt1 is None):
29
+ return prompt2
30
+ elif (prompt2 is None):
31
+ return prompt1
32
+ else:
33
+ return prompt1 + " " + prompt2
34
+
35
+ class AbstractWhisperContainer:
36
+ def __init__(self, model_name: str, device: str = None, download_root: str = None,
37
+ cache: ModelCache = None, models: List[ModelConfig] = []):
38
+ self.model_name = model_name
39
+ self.device = device
40
+ self.download_root = download_root
41
+ self.cache = cache
42
+
43
+ # Will be created on demand
44
+ self.model = None
45
+
46
+ # List of known models
47
+ self.models = models
48
+
49
+ def get_model(self):
50
+ if self.model is None:
51
+
52
+ if (self.cache is None):
53
+ self.model = self._create_model()
54
+ else:
55
+ model_key = "WhisperContainer." + self.model_name + ":" + (self.device if self.device else '')
56
+ self.model = self.cache.get(model_key, self._create_model)
57
+ return self.model
58
+
59
+ @abc.abstractmethod
60
+ def _create_model(self):
61
+ raise NotImplementedError()
62
+
63
+ def ensure_downloaded(self):
64
+ pass
65
+
66
+ @abc.abstractmethod
67
+ def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict) -> AbstractWhisperCallback:
68
+ """
69
+ Create a WhisperCallback object that can be used to transcript audio files.
70
+
71
+ Parameters
72
+ ----------
73
+ language: str
74
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
75
+ task: str
76
+ The task - either translate or transcribe.
77
+ initial_prompt: str
78
+ The initial prompt to use for the transcription.
79
+ decodeOptions: dict
80
+ Additional options to pass to the decoder. Must be pickleable.
81
+
82
+ Returns
83
+ -------
84
+ A WhisperCallback object.
85
+ """
86
+ raise NotImplementedError()
87
+
88
+ # This is required for multiprocessing
89
+ def __getstate__(self):
90
+ return { "model_name": self.model_name, "device": self.device, "download_root": self.download_root, "models": self.models }
91
+
92
+ def __setstate__(self, state):
93
+ self.model_name = state["model_name"]
94
+ self.device = state["device"]
95
+ self.download_root = state["download_root"]
96
+ self.models = state["models"]
97
+ self.model = None
98
+ # Depickled objects must use the global cache
99
+ self.cache = GLOBAL_MODEL_CACHE
src/whisper/fasterWhisperContainer.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ from faster_whisper import WhisperModel, download_model
5
+ from src.config import ModelConfig
6
+ from src.hooks.progressListener import ProgressListener
7
+ from src.modelCache import ModelCache
8
+ from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
9
+
10
+ class FasterWhisperContainer(AbstractWhisperContainer):
11
+ def __init__(self, model_name: str, device: str = None, download_root: str = None,
12
+ cache: ModelCache = None,
13
+ models: List[ModelConfig] = []):
14
+ super().__init__(model_name, device, download_root, cache, models)
15
+
16
+ def ensure_downloaded(self):
17
+ """
18
+ Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before
19
+ passing the container to a subprocess.
20
+ """
21
+ model_config = self._get_model_config()
22
+
23
+ if os.path.isdir(model_config.url):
24
+ model_config.path = model_config.url
25
+ else:
26
+ model_config.path = download_model(model_config.url, output_dir=self.download_root)
27
+
28
+ def _get_model_config(self) -> ModelConfig:
29
+ """
30
+ Get the model configuration for the model.
31
+ """
32
+ for model in self.models:
33
+ if model.name == self.model_name:
34
+ return model
35
+ return None
36
+
37
+ def _create_model(self):
38
+ print("Loading faster whisper model " + self.model_name)
39
+ model_config = self._get_model_config()
40
+
41
+ if model_config.type == "whisper" and model_config.url not in ["tiny", "base", "small", "medium", "large", "large-v2"]:
42
+ raise Exception("FasterWhisperContainer does not yet support Whisper models. Use ct2-transformers-converter to convert the model to a faster-whisper model.")
43
+
44
+ device = self.device
45
+
46
+ if (device is None):
47
+ device = "auto"
48
+
49
+ model = WhisperModel(model_config.url, device=device, compute_type="float16")
50
+ return model
51
+
52
+ def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
53
+ """
54
+ Create a WhisperCallback object that can be used to transcript audio files.
55
+
56
+ Parameters
57
+ ----------
58
+ language: str
59
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
60
+ task: str
61
+ The task - either translate or transcribe.
62
+ initial_prompt: str
63
+ The initial prompt to use for the transcription.
64
+ decodeOptions: dict
65
+ Additional options to pass to the decoder. Must be pickleable.
66
+
67
+ Returns
68
+ -------
69
+ A WhisperCallback object.
70
+ """
71
+ return FasterWhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, **decodeOptions)
72
+
73
+ class FasterWhisperCallback(AbstractWhisperCallback):
74
+ def __init__(self, model_container: FasterWhisperContainer, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
75
+ self.model_container = model_container
76
+ self.language = language
77
+ self.task = task
78
+ self.initial_prompt = initial_prompt
79
+ self.decodeOptions = decodeOptions
80
+
81
+ def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
82
+ """
83
+ Peform the transcription of the given audio file or data.
84
+
85
+ Parameters
86
+ ----------
87
+ audio: Union[str, np.ndarray, torch.Tensor]
88
+ The audio file to transcribe, or the audio data as a numpy array or torch tensor.
89
+ segment_index: int
90
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
91
+ task: str
92
+ The task - either translate or transcribe.
93
+ progress_listener: ProgressListener
94
+ A callback to receive progress updates.
95
+ """
96
+ model: WhisperModel = self.model_container.get_model()
97
+ language_code = self._lookup_language_code(self.language) if self.language else None
98
+
99
+ segments_generator, info = model.transcribe(audio, \
100
+ language=language_code if language_code else detected_language, task=self.task, \
101
+ initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
102
+ **self.decodeOptions
103
+ )
104
+
105
+ segments = []
106
+
107
+ for segment in segments_generator:
108
+ segments.append(segment)
109
+
110
+ if progress_listener is not None:
111
+ progress_listener.on_progress(segment.end, info.duration)
112
+
113
+ text = " ".join([segment.text for segment in segments])
114
+
115
+ # Convert the segments to a format that is easier to serialize
116
+ whisper_segments = [{
117
+ "text": segment.text,
118
+ "start": segment.start,
119
+ "end": segment.end,
120
+
121
+ # Extra fields added by faster-whisper
122
+ "words": [{
123
+ "start": word.start,
124
+ "end": word.end,
125
+ "word": word.word,
126
+ "probability": word.probability
127
+ } for word in (segment.words if segment.words is not None else []) ]
128
+ } for segment in segments]
129
+
130
+ result = {
131
+ "segments": whisper_segments,
132
+ "text": text,
133
+ "language": info.language if info else None,
134
+
135
+ # Extra fields added by faster-whisper
136
+ "language_probability": info.language_probability if info else None,
137
+ "duration": info.duration if info else None
138
+ }
139
+
140
+ if progress_listener is not None:
141
+ progress_listener.on_finished()
142
+ return result
143
+
144
+ def _lookup_language_code(self, language: str):
145
+ lookup = {
146
+ "english": "en", "chinese": "zh-cn", "german": "de", "spanish": "es", "russian": "ru", "korean": "ko",
147
+ "french": "fr", "japanese": "ja", "portuguese": "pt", "turkish": "tr", "polish": "pl", "catalan": "ca",
148
+ "dutch": "nl", "arabic": "ar", "swedish": "sv", "italian": "it", "indonesian": "id", "hindi": "hi",
149
+ "finnish": "fi", "vietnamese": "vi", "hebrew": "he", "ukrainian": "uk", "greek": "el", "malay": "ms",
150
+ "czech": "cs", "romanian": "ro", "danish": "da", "hungarian": "hu", "tamil": "ta", "norwegian": "no",
151
+ "thai": "th", "urdu": "ur", "croatian": "hr", "bulgarian": "bg", "lithuanian": "lt", "latin": "la",
152
+ "maori": "mi", "malayalam": "ml", "welsh": "cy", "slovak": "sk", "telugu": "te", "persian": "fa",
153
+ "latvian": "lv", "bengali": "bn", "serbian": "sr", "azerbaijani": "az", "slovenian": "sl",
154
+ "kannada": "kn", "estonian": "et", "macedonian": "mk", "breton": "br", "basque": "eu", "icelandic": "is",
155
+ "armenian": "hy", "nepali": "ne", "mongolian": "mn", "bosnian": "bs", "kazakh": "kk", "albanian": "sq",
156
+ "swahili": "sw", "galician": "gl", "marathi": "mr", "punjabi": "pa", "sinhala": "si", "khmer": "km",
157
+ "shona": "sn", "yoruba": "yo", "somali": "so", "afrikaans": "af", "occitan": "oc", "georgian": "ka",
158
+ "belarusian": "be", "tajik": "tg", "sindhi": "sd", "gujarati": "gu", "amharic": "am", "yiddish": "yi",
159
+ "lao": "lo", "uzbek": "uz", "faroese": "fo", "haitian creole": "ht", "pashto": "ps", "turkmen": "tk",
160
+ "nynorsk": "nn", "maltese": "mt", "sanskrit": "sa", "luxembourgish": "lb", "myanmar": "my", "tibetan": "bo",
161
+ "tagalog": "tl", "malagasy": "mg", "assamese": "as", "tatar": "tt", "hawaiian": "haw", "lingala": "ln",
162
+ "hausa": "ha", "bashkir": "ba", "javanese": "jv", "sundanese": "su"
163
+ }
164
+
165
+ return lookup.get(language.lower() if language is not None else None, language)
src/{whisperContainer.py → whisper/whisperContainer.py} RENAMED
@@ -1,40 +1,27 @@
1
  # External programs
 
2
  import os
3
  import sys
4
  from typing import List
 
 
 
5
 
6
  import whisper
7
  from whisper import Whisper
8
 
9
  from src.config import ModelConfig
10
- from src.hooks.whisperProgressHook import ProgressListener, create_progress_listener_handle
11
 
12
  from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
 
 
13
 
14
- class WhisperContainer:
15
- def __init__(self, model_name: str, device: str = None, download_root: str = None,
16
- cache: ModelCache = None, models: List[ModelConfig] = []):
17
- self.model_name = model_name
18
- self.device = device
19
- self.download_root = download_root
20
- self.cache = cache
21
-
22
- # Will be created on demand
23
- self.model = None
24
-
25
- # List of known models
26
- self.models = models
27
 
28
- def get_model(self):
29
- if self.model is None:
30
-
31
- if (self.cache is None):
32
- self.model = self._create_model()
33
- else:
34
- model_key = "WhisperContainer." + self.model_name + ":" + (self.device if self.device else '')
35
- self.model = self.cache.get(model_key, self._create_model)
36
- return self.model
37
-
38
  def ensure_downloaded(self):
39
  """
40
  Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before
@@ -43,7 +30,7 @@ class WhisperContainer:
43
  # Warning: Using private API here
44
  try:
45
  root_dir = self.download_root
46
- model_config = self.get_model_config()
47
 
48
  if root_dir is None:
49
  root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
@@ -60,7 +47,7 @@ class WhisperContainer:
60
  print("Error pre-downloading model: " + str(e))
61
  return False
62
 
63
- def get_model_config(self) -> ModelConfig:
64
  """
65
  Get the model configuration for the model.
66
  """
@@ -71,10 +58,10 @@ class WhisperContainer:
71
 
72
  def _create_model(self):
73
  print("Loading whisper model " + self.model_name)
74
-
75
- model_config = self.get_model_config()
76
  # Note that the model will not be downloaded in the case of an official Whisper model
77
- model_path = model_config.download_url(self.download_root)
78
 
79
  return whisper.load_model(model_path, device=self.device, download_root=self.download_root)
80
 
@@ -99,21 +86,73 @@ class WhisperContainer:
99
  """
100
  return WhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, **decodeOptions)
101
 
102
- # This is required for multiprocessing
103
- def __getstate__(self):
104
- return { "model_name": self.model_name, "device": self.device, "download_root": self.download_root, "models": self.models }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
- def __setstate__(self, state):
107
- self.model_name = state["model_name"]
108
- self.device = state["device"]
109
- self.download_root = state["download_root"]
110
- self.models = state["models"]
111
- self.model = None
112
- # Depickled objects must use the global cache
113
- self.cache = GLOBAL_MODEL_CACHE
114
 
 
115
 
116
- class WhisperCallback:
117
  def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
118
  self.model_container = model_container
119
  self.language = language
@@ -133,14 +172,8 @@ class WhisperCallback:
133
  The target language of the transcription. If not specified, the language will be inferred from the audio content.
134
  task: str
135
  The task - either translate or transcribe.
136
- prompt: str
137
- The prompt to use for the transcription.
138
- detected_language: str
139
- The detected language of the audio file.
140
-
141
- Returns
142
- -------
143
- The result of the Whisper call.
144
  """
145
  model = self.model_container.get_model()
146
 
@@ -155,12 +188,4 @@ class WhisperCallback:
155
  language=self.language if self.language else detected_language, task=self.task, \
156
  initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
157
  **self.decodeOptions
158
- )
159
-
160
- def _concat_prompt(self, prompt1, prompt2):
161
- if (prompt1 is None):
162
- return prompt2
163
- elif (prompt2 is None):
164
- return prompt1
165
- else:
166
- return prompt1 + " " + prompt2
 
1
  # External programs
2
+ import abc
3
  import os
4
  import sys
5
  from typing import List
6
+ from urllib.parse import urlparse
7
+ import urllib3
8
+ from src.hooks.progressListener import ProgressListener
9
 
10
  import whisper
11
  from whisper import Whisper
12
 
13
  from src.config import ModelConfig
14
+ from src.hooks.whisperProgressHook import create_progress_listener_handle
15
 
16
  from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
17
+ from src.utils import download_file
18
+ from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
19
 
20
+ class WhisperContainer(AbstractWhisperContainer):
21
+ def __init__(self, model_name: str, device: str = None, download_root: str = None,
22
+ cache: ModelCache = None, models: List[ModelConfig] = []):
23
+ super().__init__(model_name, device, download_root, cache, models)
 
 
 
 
 
 
 
 
 
24
 
 
 
 
 
 
 
 
 
 
 
25
  def ensure_downloaded(self):
26
  """
27
  Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before
 
30
  # Warning: Using private API here
31
  try:
32
  root_dir = self.download_root
33
+ model_config = self._get_model_config()
34
 
35
  if root_dir is None:
36
  root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
 
47
  print("Error pre-downloading model: " + str(e))
48
  return False
49
 
50
+ def _get_model_config(self) -> ModelConfig:
51
  """
52
  Get the model configuration for the model.
53
  """
 
58
 
59
  def _create_model(self):
60
  print("Loading whisper model " + self.model_name)
61
+ model_config = self._get_model_config()
62
+
63
  # Note that the model will not be downloaded in the case of an official Whisper model
64
+ model_path = self._get_model_path(model_config, self.download_root)
65
 
66
  return whisper.load_model(model_path, device=self.device, download_root=self.download_root)
67
 
 
86
  """
87
  return WhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, **decodeOptions)
88
 
89
+ def _get_model_path(self, model_config: ModelConfig, root_dir: str = None):
90
+ from src.conversion.hf_converter import convert_hf_whisper
91
+ """
92
+ Download the model.
93
+
94
+ Parameters
95
+ ----------
96
+ model_config: ModelConfig
97
+ The model configuration.
98
+ """
99
+ # See if path is already set
100
+ if model_config.path is not None:
101
+ return model_config.path
102
+
103
+ if root_dir is None:
104
+ root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
105
+
106
+ model_type = model_config.type.lower() if model_config.type is not None else "whisper"
107
+
108
+ if model_type in ["huggingface", "hf"]:
109
+ model_config.path = model_config.url
110
+ destination_target = os.path.join(root_dir, model_config.name + ".pt")
111
+
112
+ # Convert from HuggingFace format to Whisper format
113
+ if os.path.exists(destination_target):
114
+ print(f"File {destination_target} already exists, skipping conversion")
115
+ else:
116
+ print("Saving HuggingFace model in Whisper format to " + destination_target)
117
+ convert_hf_whisper(model_config.url, destination_target)
118
+
119
+ model_config.path = destination_target
120
+
121
+ elif model_type in ["whisper", "w"]:
122
+ model_config.path = model_config.url
123
+
124
+ # See if URL is just a file
125
+ if model_config.url in whisper._MODELS:
126
+ # No need to download anything - Whisper will handle it
127
+ model_config.path = model_config.url
128
+ elif model_config.url.startswith("file://"):
129
+ # Get file path
130
+ model_config.path = urlparse(model_config.url).path
131
+ # See if it is an URL
132
+ elif model_config.url.startswith("http://") or model_config.url.startswith("https://"):
133
+ # Extension (or file name)
134
+ extension = os.path.splitext(model_config.url)[-1]
135
+ download_target = os.path.join(root_dir, model_config.name + extension)
136
+
137
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
138
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
139
+
140
+ if not os.path.isfile(download_target):
141
+ download_file(model_config.url, download_target)
142
+ else:
143
+ print(f"File {download_target} already exists, skipping download")
144
+
145
+ model_config.path = download_target
146
+ # Must be a local file
147
+ else:
148
+ model_config.path = model_config.url
149
 
150
+ else:
151
+ raise ValueError(f"Unknown model type {model_type}")
 
 
 
 
 
 
152
 
153
+ return model_config.path
154
 
155
+ class WhisperCallback(AbstractWhisperCallback):
156
  def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
157
  self.model_container = model_container
158
  self.language = language
 
172
  The target language of the transcription. If not specified, the language will be inferred from the audio content.
173
  task: str
174
  The task - either translate or transcribe.
175
+ progress_listener: ProgressListener
176
+ A callback to receive progress updates.
 
 
 
 
 
 
177
  """
178
  model = self.model_container.get_model()
179
 
 
188
  language=self.language if self.language else detected_language, task=self.task, \
189
  initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
190
  **self.decodeOptions
191
+ )
 
 
 
 
 
 
 
 
src/whisper/whisperFactory.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from src import modelCache
3
+ from src.config import ModelConfig
4
+ from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
5
+
6
+ def create_whisper_container(whisper_implementation: str,
7
+ model_name: str, device: str = None, download_root: str = None,
8
+ cache: modelCache = None, models: List[ModelConfig] = []) -> AbstractWhisperContainer:
9
+ if (whisper_implementation == "whisper"):
10
+ from src.whisper.whisperContainer import WhisperContainer
11
+ return WhisperContainer(model_name, device, download_root, cache, models)
12
+ elif (whisper_implementation == "faster-whisper" or whisper_implementation == "faster_whisper"):
13
+ from src.whisper.fasterWhisperContainer import FasterWhisperContainer
14
+ return FasterWhisperContainer(model_name, device, download_root, cache, models)
15
+ else:
16
+ raise ValueError("Unknown Whisper implementation: " + whisper_implementation)