avans06 commited on
Commit
1e744c4
·
1 Parent(s): 28514b1

Enhanced Translation Model translation capabilities and optimized the Web UI interface.

Browse files

1.Translation Model Enhancements:

* Added support for the M2M100 model.
* Added three new options for the translation model: Batch Size, No Repeat Ngram Size, Num Beams.
* When using the Translation Model for translation, it will now generate additional subtitle (srt) files for the original language (*-original.srt) and bilingual (*-bilingual.srt).
* In response to adjustments in the Translation Model functionality, nllbLangs has been renamed to translationLangs, and nllbModel has been renamed to translationModel.

2.Web UI Enhancements:

* Placed the translation model under tabs, with tabs for M2M100, NLLB, MT5.
* Organized the audio input under tabs for URL, Upload, Microphone.
* Categorized VAD options under tabs for VAD, Merge Window, Max Merge Size, Padding, Prompt Window, Initial Prompt Mode.
* Grouped Word Timestamps options under tabs for Word Timestamps, Highlight Words, Prepend Punctuations, Append Punctuations.
* On the Full page, the Whisper Advanced options have been organized into tabs, including Initial Prompt, Temperature, Best Of, Beam Size, Patience, Length Penalty, Suppress Tokens, Condition on previous text, FP16, Temperature increment on fallback, Compression ratio threshold, Logprob threshold, and No speech threshold.

3.New advanced options and program adjustments for Whisper:

* In the Whisper Advanced options on the Full page, Repetition Penalty and No Repeat Ngram Size options have been added for use with faster-whisper.
* Merged languages into translationLangs.

app.py CHANGED
@@ -1,7 +1,7 @@
1
  from datetime import datetime
2
  import json
3
  import math
4
- from typing import Iterator, Union
5
  import argparse
6
 
7
  from io import StringIO
@@ -20,7 +20,6 @@ from src.diarization.diarizationContainer import DiarizationContainer
20
  from src.hooks.progressListener import ProgressListener
21
  from src.hooks.subTaskProgressListener import SubTaskProgressListener
22
  from src.hooks.whisperProgressHook import create_progress_listener_handle
23
- from src.languages import _TO_LANGUAGE_CODE, get_language_names, get_language_from_name, get_language_from_code
24
  from src.modelCache import ModelCache
25
  from src.prompts.jsonPromptStrategy import JsonPromptStrategy
26
  from src.prompts.prependPromptStrategy import PrependPromptStrategy
@@ -34,18 +33,18 @@ import ffmpeg
34
  import gradio as gr
35
 
36
  from src.download import ExceededMaximumDuration, download_url
37
- from src.utils import optional_int, slugify, str2bool, write_srt, write_vtt
38
  from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
39
  from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
40
  from src.whisper.whisperFactory import create_whisper_container
41
- from src.nllb.nllbModel import NllbModel
42
- from src.nllb.nllbLangs import _TO_NLLB_LANG_CODE
43
- from src.nllb.nllbLangs import get_nllb_lang_names
44
- from src.nllb.nllbLangs import get_nllb_lang_from_name
45
-
46
  import shutil
47
  import zhconv
48
  import tqdm
 
49
 
50
  # Configure more application defaults in config.json5
51
 
@@ -114,120 +113,231 @@ class WhisperTranscriber:
114
  self.diarization.cleanup()
115
  self.diarization_kwargs = None
116
 
117
- # Entry function for the simple tab
118
- def transcribe_webui_simple(self, modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
119
- vad, vadMergeWindow, vadMaxMergeSize,
120
- word_timestamps: bool = False, highlight_words: bool = False,
121
- diarization: bool = False, diarization_speakers: int = 2,
122
- diarization_min_speakers = 1, diarization_max_speakers = 8):
123
- return self.transcribe_webui_simple_progress(modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
124
- vad, vadMergeWindow, vadMaxMergeSize,
125
- word_timestamps, highlight_words,
126
- diarization, diarization_speakers,
127
- diarization_min_speakers, diarization_max_speakers)
128
 
129
- # Entry function for the simple tab progress
130
- def transcribe_webui_simple_progress(self, modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
131
- vad, vadMergeWindow, vadMaxMergeSize,
132
- word_timestamps: bool = False, highlight_words: bool = False,
133
- diarization: bool = False, diarization_speakers: int = 2,
134
- diarization_min_speakers = 1, diarization_max_speakers = 8,
135
- progress=gr.Progress()):
136
-
137
- vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, self.app_config.vad_padding, self.app_config.vad_prompt_window, self.app_config.vad_initial_prompt_mode)
138
-
139
- if diarization:
140
- if diarization_speakers is not None and diarization_speakers < 1:
141
- self.set_diarization(auth_token=self.app_config.auth_token, min_speakers=diarization_min_speakers, max_speakers=diarization_max_speakers)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  else:
143
- self.set_diarization(auth_token=self.app_config.auth_token, num_speakers=diarization_speakers, min_speakers=diarization_min_speakers, max_speakers=diarization_max_speakers)
144
- else:
145
- self.unset_diarization()
146
-
147
- return self.transcribe_webui(modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task, vadOptions,
148
- word_timestamps=word_timestamps, highlight_words=highlight_words, progress=progress)
149
-
150
- # Entry function for the full tab
151
- def transcribe_webui_full(self, modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
152
- vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
153
- # Word timestamps
154
- word_timestamps: bool, highlight_words: bool, prepend_punctuations: str, append_punctuations: str,
155
- initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
156
- condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
157
- compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
158
- diarization: bool = False, diarization_speakers: int = 2,
159
- diarization_min_speakers = 1, diarization_max_speakers = 8):
160
-
161
- return self.transcribe_webui_full_progress(modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
162
- vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
163
- word_timestamps, highlight_words, prepend_punctuations, append_punctuations,
164
- initial_prompt, temperature, best_of, beam_size, patience, length_penalty, suppress_tokens,
165
- condition_on_previous_text, fp16, temperature_increment_on_fallback,
166
- compression_ratio_threshold, logprob_threshold, no_speech_threshold,
167
- diarization, diarization_speakers,
168
- diarization_min_speakers, diarization_max_speakers)
169
-
170
- # Entry function for the full tab with progress
171
- def transcribe_webui_full_progress(self, modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
172
- vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
173
- # Word timestamps
174
- word_timestamps: bool, highlight_words: bool, prepend_punctuations: str, append_punctuations: str,
175
- initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
176
- condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
177
- compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
178
- diarization: bool = False, diarization_speakers: int = 2,
179
- diarization_min_speakers = 1, diarization_max_speakers = 8,
180
- progress=gr.Progress()):
181
-
182
- # Handle temperature_increment_on_fallback
183
- if temperature_increment_on_fallback is not None:
184
- temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
185
- else:
186
- temperature = [temperature]
187
-
188
- vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode)
189
 
190
- # Set diarization
191
- if diarization:
192
- if diarization_speakers is not None and diarization_speakers < 1:
193
- self.set_diarization(auth_token=self.app_config.auth_token, min_speakers=diarization_min_speakers, max_speakers=diarization_max_speakers)
194
- else:
195
- self.set_diarization(auth_token=self.app_config.auth_token, num_speakers=diarization_speakers, min_speakers=diarization_min_speakers, max_speakers=diarization_max_speakers)
196
- else:
197
- self.unset_diarization()
198
-
199
- return self.transcribe_webui(modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task, vadOptions,
200
- initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
201
- condition_on_previous_text=condition_on_previous_text, fp16=fp16,
202
- compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold,
203
- word_timestamps=word_timestamps, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, highlight_words=highlight_words,
204
- progress=progress)
205
-
206
- def transcribe_webui(self, modelName: str, languageName: str, nllbModelName: str, nllbLangName: str, urlData: str, multipleFiles, microphoneData: str, task: str,
207
- vadOptions: VadOptions, progress: gr.Progress = None, highlight_words: bool = False,
208
- **decodeOptions: dict):
209
- try:
210
  progress(0, desc="init audio sources")
211
- sources = self.__get_source(urlData, multipleFiles, microphoneData)
 
 
 
 
 
 
 
212
  if (len(sources) == 0):
213
  raise Exception("init audio sources failed...")
 
214
  try:
215
  progress(0, desc="init whisper model")
216
- whisper_lang = get_language_from_name(languageName)
217
- selectedLanguage = languageName.lower() if languageName is not None and len(languageName) > 0 else None
218
- selectedModel = modelName if modelName is not None else "base"
219
 
220
  model = create_whisper_container(whisper_implementation=self.app_config.whisper_implementation,
221
  model_name=selectedModel, compute_type=self.app_config.compute_type,
222
- cache=self.model_cache, models=self.app_config.models)
223
 
224
  progress(0, desc="init translate model")
225
- nllb_lang = get_nllb_lang_from_name(nllbLangName)
226
- selectedNllbModelName = nllbModelName if nllbModelName is not None and len(nllbModelName) > 0 else "nllb-200-distilled-600M/facebook"
227
- selectedNllbModel = next((modelConfig for modelConfig in self.app_config.nllb_models if modelConfig.name == selectedNllbModelName), None)
228
-
229
- nllb_model = NllbModel(model_config=selectedNllbModel, whisper_lang=whisper_lang, nllb_lang=nllb_lang) # load_model=True
230
-
 
 
 
 
 
 
 
 
 
 
 
 
231
  progress(0, desc="init transcribe")
232
  # Result
233
  download = []
@@ -238,7 +348,7 @@ class WhisperTranscriber:
238
  # Write result
239
  downloadDirectory = tempfile.mkdtemp()
240
  source_index = 0
241
- extra_tasks_count = 1 if nllb_lang is not None else 0
242
 
243
  outputDirectory = self.output_dir if self.output_dir is not None else downloadDirectory
244
 
@@ -267,10 +377,10 @@ class WhisperTranscriber:
267
  sub_task_total=sub_task_total)
268
 
269
  # Transcribe
270
- result = self.transcribe_file(model, source.source_path, selectedLanguage, task, vadOptions, scaled_progress_listener, **decodeOptions)
271
- if whisper_lang is None and result["language"] is not None and len(result["language"]) > 0:
272
- whisper_lang = get_language_from_code(result["language"])
273
- nllb_model.whisper_lang = whisper_lang
274
 
275
  short_name, suffix = source.get_short_name_suffix(max_length=self.app_config.input_max_file_name_length)
276
  filePrefix = slugify(source_prefix + short_name, allow_unicode=True)
@@ -278,7 +388,7 @@ class WhisperTranscriber:
278
  # Update progress
279
  current_progress += source_audio_duration
280
 
281
- source_download, source_text, source_vtt = self.write_result(result, nllb_model, filePrefix + suffix.replace(".", "_"), outputDirectory, highlight_words, scaled_progress_listener)
282
 
283
  if self.app_config.merge_subtitle_with_sources and self.app_config.output_dir is not None:
284
  print("\nmerge subtitle(srt) with source file [" + source.source_name + "]\n")
@@ -287,8 +397,8 @@ class WhisperTranscriber:
287
  srt_path = source_download[0]
288
  save_path = os.path.join(self.app_config.output_dir, filePrefix)
289
  # save_without_ext, ext = os.path.splitext(save_path)
290
- source_lang = "." + whisper_lang.code if whisper_lang is not None else ""
291
- translate_lang = "." + nllb_lang.code if nllb_lang is not None else ""
292
  output_with_srt = save_path + source_lang + translate_lang + suffix
293
 
294
  #ffmpeg -i "input.mp4" -i "input.srt" -c copy -c:s mov_text output.mp4
@@ -363,12 +473,11 @@ class WhisperTranscriber:
363
  except ExceededMaximumDuration as e:
364
  return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
365
  except Exception as e:
366
- import traceback
367
  print(traceback.format_exc())
368
- return [], ("Error occurred during transcribe: " + str(e)), ""
369
 
370
 
371
- def transcribe_file(self, model: AbstractWhisperContainer, audio_path: str, language: str, task: str = None,
372
  vadOptions: VadOptions = VadOptions(),
373
  progressListener: ProgressListener = None, **decodeOptions: dict):
374
 
@@ -398,7 +507,7 @@ class WhisperTranscriber:
398
  raise ValueError("Invalid vadInitialPromptMode: " + initial_prompt_mode)
399
 
400
  # Callable for processing an audio file
401
- whisperCallable = model.create_callback(language, task, prompt_strategy=prompt_strategy, **decodeOptions)
402
 
403
  # The results
404
  if (vadOptions.vad == 'silero-vad'):
@@ -513,7 +622,7 @@ class WhisperTranscriber:
513
 
514
  return config
515
 
516
- def write_result(self, result: dict, nllb_model: NllbModel, source_name: str, output_dir: str, highlight_words: bool = False, progressListener: ProgressListener = None):
517
  if not os.path.exists(output_dir):
518
  os.makedirs(output_dir)
519
 
@@ -522,7 +631,7 @@ class WhisperTranscriber:
522
  language = result["language"]
523
  languageMaxLineWidth = self.__get_max_line_width(language)
524
 
525
- if nllb_model.nllb_lang is not None:
526
  try:
527
  segments_progress_listener = SubTaskProgressListener(progressListener,
528
  base_task_total=progressListener.sub_task_total,
@@ -530,17 +639,15 @@ class WhisperTranscriber:
530
  sub_task_total=1)
531
  pbar = tqdm.tqdm(total=len(segments))
532
  perf_start_time = time.perf_counter()
533
- nllb_model.load_model()
534
  for idx, segment in enumerate(segments):
535
  seg_text = segment["text"]
536
- if language == "zh":
537
- segment["text"] = zhconv.convert(seg_text, "zh-tw")
538
- if nllb_model.nllb_lang is not None:
539
- segment["text"] = nllb_model.translation(seg_text)
540
  pbar.update(1)
541
  segments_progress_listener.on_progress(idx+1, len(segments), desc=f"Process segments: {idx}/{len(segments)}")
542
 
543
- nllb_model.release_vram()
544
  perf_end_time = time.perf_counter()
545
  # Call the finished callback
546
  if segments_progress_listener is not None:
@@ -549,24 +656,57 @@ class WhisperTranscriber:
549
  print("\n\nprocess segments took {} seconds.\n\n".format(perf_end_time - perf_start_time))
550
  except Exception as e:
551
  # Ignore error - it's just a cleanup
 
552
  print("Error process segments: " + str(e))
553
 
554
  print("Max line width " + str(languageMaxLineWidth) + " for language:" + language)
555
  vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth, highlight_words=highlight_words)
556
  srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth, highlight_words=highlight_words)
557
  json_result = json.dumps(result, indent=4, ensure_ascii=False)
558
-
559
- if language == "zh" or (nllb_model.nllb_lang is not None and nllb_model.nllb_lang.code == "zho_Hant"):
560
- vtt = zhconv.convert(vtt, "zh-tw")
561
- srt = zhconv.convert(srt, "zh-tw")
562
- text = zhconv.convert(text, "zh-tw")
563
- json_result = zhconv.convert(json_result, "zh-tw")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
564
 
565
  output_files = []
566
  output_files.append(self.__create_file(srt, output_dir, source_name + "-subs.srt"));
567
  output_files.append(self.__create_file(vtt, output_dir, source_name + "-subs.vtt"));
568
  output_files.append(self.__create_file(text, output_dir, source_name + "-transcript.txt"));
569
  output_files.append(self.__create_file(json_result, output_dir, source_name + "-result.json"));
 
 
 
 
570
 
571
  return output_files, text, vtt
572
 
@@ -593,6 +733,10 @@ class WhisperTranscriber:
593
  write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth, highlight_words=highlight_words)
594
  elif format == 'srt':
595
  write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth, highlight_words=highlight_words)
 
 
 
 
596
  else:
597
  raise Exception("Unknown format " + format)
598
 
@@ -621,6 +765,16 @@ class WhisperTranscriber:
621
  self.diarization = None
622
 
623
  def create_ui(app_config: ApplicationConfig):
 
 
 
 
 
 
 
 
 
 
624
  ui = WhisperTranscriber(app_config.input_audio_max_duration, app_config.vad_process_timeout, app_config.vad_cpu_cores,
625
  app_config.delete_uploaded_files, app_config.output_dir, app_config)
626
 
@@ -639,59 +793,69 @@ def create_ui(app_config: ApplicationConfig):
639
  # Try to convert from camel-case to title-case
640
  implementation_name = app_config.whisper_implementation.title().replace("_", " ").replace("-", " ")
641
 
642
- ui_description = implementation_name + " is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
643
- ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
644
- ui_description += " as well as speech translation and language identification. "
645
 
646
- ui_description += "\n\n\n\nFor longer audio files (>10 minutes) not in English, it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option."
647
 
648
  # Recommend faster-whisper
649
  if is_whisper:
650
- ui_description += "\n\n\n\nFor faster inference on GPU, try [faster-whisper](https://huggingface.co/spaces/aadnk/faster-whisper-webui)."
651
 
652
  if app_config.input_audio_max_duration > 0:
653
- ui_description += "\n\n" + "Max audio file length: " + str(app_config.input_audio_max_duration) + " s"
654
-
655
- ui_article = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)."
656
- ui_article += "\n\nWhisper's Task 'translate' only implements the functionality of translating other languages into English. "
657
- ui_article += "OpenAI does not guarantee translations between arbitrary languages. In such cases, you can choose to use the NLLB Model to implement the translation task. "
658
- ui_article += "However, it's important to note that the NLLB Model runs slowly, and the completion time may be twice as long as usual. "
659
- ui_article += "\n\nThe larger the parameters of the NLLB model, the better its performance is expected to be. "
660
- ui_article += "However, it also requires higher computational resources, making it slower to operate. "
661
- ui_article += "On the other hand, the version converted from ct2 (CTranslate2) requires lower resources and operates at a faster speed."
662
- ui_article += "\n\nCurrently, enabling word-level timestamps cannot be used in conjunction with NLLB Model translation "
663
- ui_article += "because Word Timestamps will split the source text, and after translation, it becomes a non-word-level string. "
664
- ui_article += "\n\nThe 'mt5-zh-ja-en-trimmed' model is finetuned from Google's 'mt5-base' model. "
665
- ui_article += "This model has a relatively good translation speed, but it only supports three languages: Chinese, Japanese, and English. "
666
-
667
- whisper_models = app_config.get_model_names()
668
- nllb_models = app_config.get_nllb_model_names()
 
 
669
 
670
- common_whisper_inputs = lambda : [
671
- gr.Dropdown(label="Whisper Model (for audio)", choices=whisper_models, value=app_config.default_model_name),
672
- gr.Dropdown(label="Whisper Language", choices=sorted(get_language_names()), value=app_config.language),
673
- ]
674
- common_nllb_inputs = lambda : [
675
- gr.Dropdown(label="NLLB Model (for translate)", choices=nllb_models),
676
- gr.Dropdown(label="NLLB Language", choices=sorted(get_nllb_lang_names())),
677
- ]
678
- common_audio_inputs = lambda : [
679
- gr.Text(label="URL (YouTube, etc.)"),
680
- gr.File(label="Upload Files", file_count="multiple"),
681
- gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
682
- gr.Dropdown(choices=["transcribe", "translate"], label="Task", value=app_config.task),
683
- ]
684
-
685
- common_vad_inputs = lambda : [
686
- gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=app_config.default_vad, label="VAD"),
687
- gr.Number(label="VAD - Merge Window (s)", precision=0, value=app_config.vad_merge_window),
688
- gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=app_config.vad_max_merge_size),
689
- ]
690
 
691
- common_word_timestamps_inputs = lambda : [
692
- gr.Checkbox(label="Word Timestamps", value=app_config.word_timestamps),
693
- gr.Checkbox(label="Word Timestamps - Highlight Words", value=app_config.highlight_words),
694
- ]
 
 
 
 
 
 
 
 
 
 
 
 
695
 
696
  has_diarization_libs = Diarization.has_libraries()
697
 
@@ -699,12 +863,12 @@ def create_ui(app_config: ApplicationConfig):
699
  print("Diarization libraries not found - disabling diarization")
700
  app_config.diarization = False
701
 
702
- common_diarization_inputs = lambda : [
703
- gr.Checkbox(label="Diarization", value=app_config.diarization, interactive=has_diarization_libs),
704
- gr.Number(label="Diarization - Speakers", precision=0, value=app_config.diarization_speakers, interactive=has_diarization_libs),
705
- gr.Number(label="Diarization - Min Speakers", precision=0, value=app_config.diarization_min_speakers, interactive=has_diarization_libs),
706
- gr.Number(label="Diarization - Max Speakers", precision=0, value=app_config.diarization_max_speakers, interactive=has_diarization_libs)
707
- ]
708
 
709
  common_output = lambda : [
710
  gr.File(label="Download"),
@@ -714,84 +878,152 @@ def create_ui(app_config: ApplicationConfig):
714
 
715
  is_queue_mode = app_config.queue_concurrency_count is not None and app_config.queue_concurrency_count > 0
716
 
717
- simple_callback = gr.CSVLogger()
718
-
719
- with gr.Blocks() as simple_transcribe:
720
- gr.Markdown(ui_description)
 
 
721
  with gr.Row():
722
  with gr.Column():
723
- simple_submit = gr.Button("Submit", variant="primary")
724
  with gr.Column():
725
  with gr.Row():
726
- simple_input = common_whisper_inputs()
727
- with gr.Row():
728
- simple_input += common_nllb_inputs()
 
 
 
 
 
 
 
 
 
 
729
  with gr.Column():
730
- simple_input += common_audio_inputs() + common_vad_inputs() + common_word_timestamps_inputs() + common_diarization_inputs()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
731
  with gr.Column():
732
- simple_output = common_output()
733
- simple_flag = gr.Button("Flag")
734
- gr.Markdown(ui_article)
735
-
736
- # This needs to be called at some point prior to the first call to callback.flag()
737
- simple_callback.setup(simple_input + simple_output, "flagged")
738
-
739
- simple_submit.click(fn=ui.transcribe_webui_simple_progress if is_queue_mode else ui.transcribe_webui_simple,
740
- inputs=simple_input, outputs=simple_output)
741
- # We can choose which components to flag -- in this case, we'll flag all of them
742
- simple_flag.click(lambda *args: print("simple_callback.flag...") or simple_callback.flag(args), simple_input + simple_output, None, preprocess=False)
743
-
744
- full_description = ui_description + "\n\n\n\n" + "Be careful when changing some of the options in the full interface - this can cause the model to crash."
745
 
746
- full_callback = gr.CSVLogger()
 
747
 
748
- with gr.Blocks() as full_transcribe:
749
- gr.Markdown(full_description)
 
 
750
  with gr.Row():
751
  with gr.Column():
752
- full_submit = gr.Button("Submit", variant="primary")
753
  with gr.Column():
754
  with gr.Row():
755
- full_input1 = common_whisper_inputs()
756
- with gr.Row():
757
- full_input1 += common_nllb_inputs()
 
 
 
 
 
 
 
 
 
 
758
  with gr.Column():
759
- full_input1 += common_audio_inputs() + common_vad_inputs() + [
760
- gr.Number(label="VAD - Padding (s)", precision=None, value=app_config.vad_padding),
761
- gr.Number(label="VAD - Prompt Window (s)", precision=None, value=app_config.vad_prompt_window),
762
- gr.Dropdown(choices=VAD_INITIAL_PROMPT_MODE_VALUES, label="VAD - Initial Prompt Mode")]
763
-
764
- full_input2 = common_word_timestamps_inputs() + [
765
- gr.Text(label="Word Timestamps - Prepend Punctuations", value=app_config.prepend_punctuations),
766
- gr.Text(label="Word Timestamps - Append Punctuations", value=app_config.append_punctuations),
767
- gr.TextArea(label="Initial Prompt"),
768
- gr.Number(label="Temperature", value=app_config.temperature),
769
- gr.Number(label="Best Of - Non-zero temperature", value=app_config.best_of, precision=0),
770
- gr.Number(label="Beam Size - Zero temperature", value=app_config.beam_size, precision=0),
771
- gr.Number(label="Patience - Zero temperature", value=app_config.patience),
772
- gr.Number(label="Length Penalty - Any temperature", value=app_config.length_penalty),
773
- gr.Text(label="Suppress Tokens - Comma-separated list of token IDs", value=app_config.suppress_tokens),
774
- gr.Checkbox(label="Condition on previous text", value=app_config.condition_on_previous_text),
775
- gr.Checkbox(label="FP16", value=app_config.fp16),
776
- gr.Number(label="Temperature increment on fallback", value=app_config.temperature_increment_on_fallback),
777
- gr.Number(label="Compression ratio threshold", value=app_config.compression_ratio_threshold),
778
- gr.Number(label="Logprob threshold", value=app_config.logprob_threshold),
779
- gr.Number(label="No speech threshold", value=app_config.no_speech_threshold)] + common_diarization_inputs()
780
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
781
  with gr.Column():
782
- full_output = common_output()
783
- full_flag = gr.Button("Flag")
784
- gr.Markdown(ui_article)
785
-
786
- # This needs to be called at some point prior to the first call to callback.flag()
787
- full_callback.setup(full_input1 + full_input2 + full_output, "flagged")
788
-
789
- full_submit.click(fn=ui.transcribe_webui_full_progress if is_queue_mode else ui.transcribe_webui_full,
790
- inputs=full_input1+full_input2, outputs=full_output)
791
- # We can choose which components to flag -- in this case, we'll flag all of them
792
- full_flag.click(lambda *args: print("full_callback.flag...") or full_callback.flag(args), full_input1 + full_input2 + full_output, None, preprocess=False)
 
 
793
 
794
- demo = gr.TabbedInterface([simple_transcribe, full_transcribe], tab_names=["Simple", "Full"])
795
 
796
  # Queue up the demo
797
  if is_queue_mode:
@@ -807,8 +1039,7 @@ def create_ui(app_config: ApplicationConfig):
807
 
808
  if __name__ == '__main__':
809
  default_app_config = ApplicationConfig.create_default()
810
- whisper_models = default_app_config.get_model_names()
811
- nllb_models = default_app_config.get_nllb_model_names()
812
 
813
  # Environment variable overrides
814
  default_whisper_implementation = os.environ.get("WHISPER_IMPLEMENTATION", default_app_config.whisper_implementation)
@@ -846,9 +1077,10 @@ if __name__ == '__main__':
846
  help="the compute type to use for inference")
847
  parser.add_argument("--threads", type=optional_int, default=0,
848
  help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
 
849
  parser.add_argument("--vad_max_merge_size", type=int, default=default_app_config.vad_max_merge_size, \
850
  help="The number of VAD - Max Merge Size (s).") # 30
851
- parser.add_argument("--language", type=str, default=None, choices=sorted(get_language_names()) + sorted([k.title() for k in _TO_LANGUAGE_CODE.keys()]),
852
  help="language spoken in the audio, specify None to perform language detection")
853
  parser.add_argument("--save_downloaded_files", action='store_true', \
854
  help="True to move downloaded files to outputs directory. This argument will take effect only after output_dir is set.")
@@ -858,6 +1090,7 @@ if __name__ == '__main__':
858
  help="Maximum length of a file name.")
859
  parser.add_argument("--autolaunch", action='store_true', \
860
  help="open the webui URL in the system's default browser upon launch")
 
861
  parser.add_argument('--auth_token', type=str, default=default_app_config.auth_token, help='HuggingFace API Token (optional)')
862
  parser.add_argument("--diarization", type=str2bool, default=default_app_config.diarization, \
863
  help="whether to perform speaker diarization")
 
1
  from datetime import datetime
2
  import json
3
  import math
4
+ from typing import Iterator, Union, List
5
  import argparse
6
 
7
  from io import StringIO
 
20
  from src.hooks.progressListener import ProgressListener
21
  from src.hooks.subTaskProgressListener import SubTaskProgressListener
22
  from src.hooks.whisperProgressHook import create_progress_listener_handle
 
23
  from src.modelCache import ModelCache
24
  from src.prompts.jsonPromptStrategy import JsonPromptStrategy
25
  from src.prompts.prependPromptStrategy import PrependPromptStrategy
 
33
  import gradio as gr
34
 
35
  from src.download import ExceededMaximumDuration, download_url
36
+ from src.utils import optional_int, slugify, str2bool, write_srt, write_srt_original, write_vtt
37
  from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
38
  from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
39
  from src.whisper.whisperFactory import create_whisper_container
40
+ from src.translation.translationModel import TranslationModel
41
+ from src.translation.translationLangs import (TranslationLang,
42
+ _TO_LANG_CODE_WHISPER, get_lang_whisper_names, get_lang_from_whisper_name, get_lang_from_whisper_code,
43
+ get_lang_nllb_names, get_lang_from_nllb_name, get_lang_m2m100_names, get_lang_from_m2m100_name)
 
44
  import shutil
45
  import zhconv
46
  import tqdm
47
+ import traceback
48
 
49
  # Configure more application defaults in config.json5
50
 
 
113
  self.diarization.cleanup()
114
  self.diarization_kwargs = None
115
 
116
+ # Entry function for the simple tab, Queue mode disabled: progress bars will not be shown
117
+ def transcribe_webui_simple(self, data: dict): return self.transcribe_webui_simple_progress(data)
 
 
 
 
 
 
 
 
 
118
 
119
+ # Entry function for the simple tab progress, Progress tracking requires queuing to be enabled
120
+ def transcribe_webui_simple_progress(self, data: dict, progress=gr.Progress()):
121
+ dataDict = {}
122
+ for key, value in data.items():
123
+ dataDict.update({key.elem_id: value})
124
+
125
+ return self.transcribe_webui(dataDict, progress=progress)
126
+
127
+ # Entry function for the full tab, Queue mode disabled: progress bars will not be shown
128
+ def transcribe_webui_full(self, data: dict): return self.transcribe_webui_full_progress(data)
129
+
130
+ # Entry function for the full tab with progress, Progress tracking requires queuing to be enabled
131
+ def transcribe_webui_full_progress(self, data: dict, progress=gr.Progress()):
132
+ dataDict = {}
133
+ for key, value in data.items():
134
+ dataDict.update({key.elem_id: value})
135
+
136
+ return self.transcribe_webui(dataDict, progress=progress)
137
+
138
+ def transcribe_webui(self, decodeOptions: dict, progress: gr.Progress = None):
139
+ """
140
+ Transcribe an audio file using Whisper
141
+ https://github.com/openai/whisper/blob/main/whisper/transcribe.py#L37
142
+ Parameters
143
+ ----------
144
+ model: Whisper
145
+ The Whisper model instance
146
+
147
+ temperature: Union[float, Tuple[float, ...]]
148
+ Temperature for sampling. It can be a tuple of temperatures, which will be successively used
149
+ upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
150
+
151
+ compression_ratio_threshold: float
152
+ If the gzip compression ratio is above this value, treat as failed
153
+
154
+ logprob_threshold: float
155
+ If the average log probability over sampled tokens is below this value, treat as failed
156
+
157
+ no_speech_threshold: float
158
+ If the no_speech probability is higher than this value AND the average log probability
159
+ over sampled tokens is below `logprob_threshold`, consider the segment as silent
160
+
161
+ condition_on_previous_text: bool
162
+ if True, the previous output of the model is provided as a prompt for the next window;
163
+ disabling may make the text inconsistent across windows, but the model becomes less prone to
164
+ getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
165
+
166
+ word_timestamps: bool
167
+ Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
168
+ and include the timestamps for each word in each segment.
169
+
170
+ prepend_punctuations: str
171
+ If word_timestamps is True, merge these punctuation symbols with the next word
172
+
173
+ append_punctuations: str
174
+ If word_timestamps is True, merge these punctuation symbols with the previous word
175
+
176
+ initial_prompt: Optional[str]
177
+ Optional text to provide as a prompt for the first window. This can be used to provide, or
178
+ "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
179
+ to make it more likely to predict those word correctly.
180
+
181
+ decode_options: dict
182
+ Keyword arguments to construct `DecodingOptions` instances
183
+ https://github.com/openai/whisper/blob/main/whisper/decoding.py#L81
184
+
185
+ task: str = "transcribe"
186
+ whether to perform X->X "transcribe" or X->English "translate"
187
+
188
+ language: Optional[str] = None
189
+ language that the audio is in; uses detected language if None
190
+
191
+ temperature: float = 0.0
192
+ sample_len: Optional[int] = None # maximum number of tokens to sample
193
+ best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
194
+ beam_size: Optional[int] = None # number of beams in beam search, if t == 0
195
+ patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
196
+ sampling-related options
197
+
198
+ length_penalty: Optional[float] = None
199
+ "alpha" in Google NMT, or None for length norm, when ranking generations
200
+ to select which to return among the beams or best-of-N samples
201
+
202
+ prompt: Optional[Union[str, List[int]]] = None # for the previous context
203
+ prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
204
+ text or tokens to feed as the prompt or the prefix; for more info:
205
+ https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
206
+
207
+ suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
208
+ suppress_blank: bool = True # this will suppress blank outputs
209
+ list of tokens ids (or comma-separated token ids) to suppress
210
+ "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
211
+
212
+ without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
213
+ max_initial_timestamp: Optional[float] = 1.0
214
+ timestamp sampling options
215
+
216
+ fp16: bool = True # use fp16 for most of the calculation
217
+ implementation details
218
+ repetition_penalty: float
219
+ The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
220
+ no_repeat_ngram_size: int
221
+ The model ensures that a sequence of words of no_repeat_ngram_size isn’t repeated in the output sequence. If specified, it must be a positive integer greater than 1.
222
+ """
223
+ try:
224
+ whisperModelName: str = decodeOptions.pop("whisperModelName")
225
+ whisperLangName: str = decodeOptions.pop("whisperLangName")
226
+
227
+ translateInput: str = decodeOptions.pop("translateInput")
228
+ m2m100ModelName: str = decodeOptions.pop("m2m100ModelName")
229
+ m2m100LangName: str = decodeOptions.pop("m2m100LangName")
230
+ nllbModelName: str = decodeOptions.pop("nllbModelName")
231
+ nllbLangName: str = decodeOptions.pop("nllbLangName")
232
+ mt5ModelName: str = decodeOptions.pop("mt5ModelName")
233
+ mt5LangName: str = decodeOptions.pop("mt5LangName")
234
+
235
+ translationBatchSize: int = decodeOptions.pop("translationBatchSize")
236
+ translationNoRepeatNgramSize: int = decodeOptions.pop("translationNoRepeatNgramSize")
237
+ translationNumBeams: int = decodeOptions.pop("translationNumBeams")
238
+
239
+ sourceInput: str = decodeOptions.pop("sourceInput")
240
+ urlData: str = decodeOptions.pop("urlData")
241
+ multipleFiles: List = decodeOptions.pop("multipleFiles")
242
+ microphoneData: str = decodeOptions.pop("microphoneData")
243
+ task: str = decodeOptions.pop("task")
244
+
245
+ vad: str = decodeOptions.pop("vad")
246
+ vadMergeWindow: float = decodeOptions.pop("vadMergeWindow")
247
+ vadMaxMergeSize: float = decodeOptions.pop("vadMaxMergeSize")
248
+ vadPadding: float = decodeOptions.pop("vadPadding", self.app_config.vad_padding)
249
+ vadPromptWindow: float = decodeOptions.pop("vadPromptWindow", self.app_config.vad_prompt_window)
250
+ vadInitialPromptMode: str = decodeOptions.pop("vadInitialPromptMode", self.app_config.vad_initial_prompt_mode)
251
+
252
+ diarization: bool = decodeOptions.pop("diarization", False)
253
+ diarization_speakers: int = decodeOptions.pop("diarization_speakers", 2)
254
+ diarization_min_speakers: int = decodeOptions.pop("diarization_min_speakers", 1)
255
+ diarization_max_speakers: int = decodeOptions.pop("diarization_max_speakers", 8)
256
+ highlight_words: bool = decodeOptions.pop("highlight_words", False)
257
+
258
+ temperature: float = decodeOptions.pop("temperature", None)
259
+ temperature_increment_on_fallback: float = decodeOptions.pop("temperature_increment_on_fallback", None)
260
+
261
+ whisperRepetitionPenalty: float = decodeOptions.get("repetition_penalty", None)
262
+ whisperNoRepeatNgramSize: int = decodeOptions.get("no_repeat_ngram_size", None)
263
+ if whisperRepetitionPenalty is not None and whisperRepetitionPenalty <= 1.0:
264
+ decodeOptions.pop("repetition_penalty")
265
+ if whisperNoRepeatNgramSize is not None and whisperNoRepeatNgramSize <= 1:
266
+ decodeOptions.pop("no_repeat_ngram_size")
267
+
268
+ # word_timestamps = options.get("word_timestamps", False)
269
+ # condition_on_previous_text = options.get("condition_on_previous_text", False)
270
+
271
+ # prepend_punctuations = options.get("prepend_punctuations", None)
272
+ # append_punctuations = options.get("append_punctuations", None)
273
+ # initial_prompt = options.get("initial_prompt", None)
274
+ # best_of = options.get("best_of", None)
275
+ # beam_size = options.get("beam_size", None)
276
+ # patience = options.get("patience", None)
277
+ # length_penalty = options.get("length_penalty", None)
278
+ # suppress_tokens = options.get("suppress_tokens", None)
279
+ # compression_ratio_threshold = options.get("compression_ratio_threshold", None)
280
+ # logprob_threshold = options.get("logprob_threshold", None)
281
+
282
+ vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode)
283
+
284
+ if diarization:
285
+ if diarization_speakers is not None and diarization_speakers < 1:
286
+ self.set_diarization(auth_token=self.app_config.auth_token, min_speakers=diarization_min_speakers, max_speakers=diarization_max_speakers)
287
+ else:
288
+ self.set_diarization(auth_token=self.app_config.auth_token, num_speakers=diarization_speakers, min_speakers=diarization_min_speakers, max_speakers=diarization_max_speakers)
289
  else:
290
+ self.unset_diarization()
291
+
292
+ # Handle temperature_increment_on_fallback
293
+ if temperature is not None:
294
+ if temperature_increment_on_fallback is not None:
295
+ temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
296
+ else:
297
+ temperature = [temperature]
298
+ decodeOptions["temperature"] = temperature
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  progress(0, desc="init audio sources")
301
+
302
+ if sourceInput == "urlData":
303
+ sources = self.__get_source(urlData, None, None)
304
+ elif sourceInput == "multipleFiles":
305
+ sources = self.__get_source(None, multipleFiles, None)
306
+ elif sourceInput == "microphoneData":
307
+ sources = self.__get_source(None, None, microphoneData)
308
+
309
  if (len(sources) == 0):
310
  raise Exception("init audio sources failed...")
311
+
312
  try:
313
  progress(0, desc="init whisper model")
314
+ whisperLang: TranslationLang = get_lang_from_whisper_name(whisperLangName)
315
+ whisperLangCode = whisperLang.whisper.code if whisperLang is not None and whisperLang.whisper is not None else None
316
+ selectedModel = whisperModelName if whisperModelName is not None else "base"
317
 
318
  model = create_whisper_container(whisper_implementation=self.app_config.whisper_implementation,
319
  model_name=selectedModel, compute_type=self.app_config.compute_type,
320
+ cache=self.model_cache, models=self.app_config.models["whisper"])
321
 
322
  progress(0, desc="init translate model")
323
+ translationLang = None
324
+ translationModel = None
325
+ if translateInput == "m2m100" and m2m100LangName is not None and len(m2m100LangName) > 0:
326
+ selectedModelName = m2m100ModelName if m2m100ModelName is not None and len(m2m100ModelName) > 0 else "m2m100_418M/facebook"
327
+ selectedModel = next((modelConfig for modelConfig in self.app_config.models["m2m100"] if modelConfig.name == selectedModelName), None)
328
+ translationLang = get_lang_from_m2m100_name(m2m100LangName)
329
+ elif translateInput == "nllb" and nllbLangName is not None and len(nllbLangName) > 0:
330
+ selectedModelName = nllbModelName if nllbModelName is not None and len(nllbModelName) > 0 else "nllb-200-distilled-600M/facebook"
331
+ selectedModel = next((modelConfig for modelConfig in self.app_config.models["nllb"] if modelConfig.name == selectedModelName), None)
332
+ translationLang = get_lang_from_nllb_name(nllbLangName)
333
+ elif translateInput == "mt5" and mt5LangName is not None and len(mt5LangName) > 0:
334
+ selectedModelName = mt5ModelName if mt5ModelName is not None and len(mt5ModelName) > 0 else "mt5-zh-ja-en-trimmed/K024"
335
+ selectedModel = next((modelConfig for modelConfig in self.app_config.models["mt5"] if modelConfig.name == selectedModelName), None)
336
+ translationLang = get_lang_from_m2m100_name(mt5LangName)
337
+
338
+ if translationLang is not None:
339
+ translationModel = TranslationModel(modelConfig=selectedModel, whisperLang=whisperLang, translationLang=translationLang, batchSize=translationBatchSize, noRepeatNgramSize=translationNoRepeatNgramSize, numBeams=translationNumBeams)
340
+
341
  progress(0, desc="init transcribe")
342
  # Result
343
  download = []
 
348
  # Write result
349
  downloadDirectory = tempfile.mkdtemp()
350
  source_index = 0
351
+ extra_tasks_count = 1 if translationLang is not None else 0
352
 
353
  outputDirectory = self.output_dir if self.output_dir is not None else downloadDirectory
354
 
 
377
  sub_task_total=sub_task_total)
378
 
379
  # Transcribe
380
+ result = self.transcribe_file(model, source.source_path, whisperLangCode, task, vadOptions, scaled_progress_listener, **decodeOptions)
381
+ if whisperLang is None and result["language"] is not None and len(result["language"]) > 0:
382
+ whisperLang = get_lang_from_whisper_code(result["language"])
383
+ translationModel.whisperLang = whisperLang
384
 
385
  short_name, suffix = source.get_short_name_suffix(max_length=self.app_config.input_max_file_name_length)
386
  filePrefix = slugify(source_prefix + short_name, allow_unicode=True)
 
388
  # Update progress
389
  current_progress += source_audio_duration
390
 
391
+ source_download, source_text, source_vtt = self.write_result(result, whisperLang, translationModel, filePrefix + suffix.replace(".", "_"), outputDirectory, highlight_words, scaled_progress_listener)
392
 
393
  if self.app_config.merge_subtitle_with_sources and self.app_config.output_dir is not None:
394
  print("\nmerge subtitle(srt) with source file [" + source.source_name + "]\n")
 
397
  srt_path = source_download[0]
398
  save_path = os.path.join(self.app_config.output_dir, filePrefix)
399
  # save_without_ext, ext = os.path.splitext(save_path)
400
+ source_lang = "." + whisperLang.whisper.code if whisperLang is not None and whisperLang.whisper is not None else ""
401
+ translate_lang = "." + translationLang.nllb.code if translationLang is not None else ""
402
  output_with_srt = save_path + source_lang + translate_lang + suffix
403
 
404
  #ffmpeg -i "input.mp4" -i "input.srt" -c copy -c:s mov_text output.mp4
 
473
  except ExceededMaximumDuration as e:
474
  return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
475
  except Exception as e:
 
476
  print(traceback.format_exc())
477
+ return [], ("Error occurred during transcribe: " + str(e)), traceback.format_exc()
478
 
479
 
480
+ def transcribe_file(self, model: AbstractWhisperContainer, audio_path: str, languageCode: str, task: str = None,
481
  vadOptions: VadOptions = VadOptions(),
482
  progressListener: ProgressListener = None, **decodeOptions: dict):
483
 
 
507
  raise ValueError("Invalid vadInitialPromptMode: " + initial_prompt_mode)
508
 
509
  # Callable for processing an audio file
510
+ whisperCallable = model.create_callback(languageCode, task, prompt_strategy=prompt_strategy, **decodeOptions)
511
 
512
  # The results
513
  if (vadOptions.vad == 'silero-vad'):
 
622
 
623
  return config
624
 
625
+ def write_result(self, result: dict, whisperLang: TranslationLang, translationModel: TranslationModel, source_name: str, output_dir: str, highlight_words: bool = False, progressListener: ProgressListener = None):
626
  if not os.path.exists(output_dir):
627
  os.makedirs(output_dir)
628
 
 
631
  language = result["language"]
632
  languageMaxLineWidth = self.__get_max_line_width(language)
633
 
634
+ if translationModel is not None and translationModel.translationLang is not None:
635
  try:
636
  segments_progress_listener = SubTaskProgressListener(progressListener,
637
  base_task_total=progressListener.sub_task_total,
 
639
  sub_task_total=1)
640
  pbar = tqdm.tqdm(total=len(segments))
641
  perf_start_time = time.perf_counter()
642
+ translationModel.load_model()
643
  for idx, segment in enumerate(segments):
644
  seg_text = segment["text"]
645
+ segment["original"] = seg_text
646
+ segment["text"] = translationModel.translation(seg_text)
 
 
647
  pbar.update(1)
648
  segments_progress_listener.on_progress(idx+1, len(segments), desc=f"Process segments: {idx}/{len(segments)}")
649
 
650
+ translationModel.release_vram()
651
  perf_end_time = time.perf_counter()
652
  # Call the finished callback
653
  if segments_progress_listener is not None:
 
656
  print("\n\nprocess segments took {} seconds.\n\n".format(perf_end_time - perf_start_time))
657
  except Exception as e:
658
  # Ignore error - it's just a cleanup
659
+ print(traceback.format_exc())
660
  print("Error process segments: " + str(e))
661
 
662
  print("Max line width " + str(languageMaxLineWidth) + " for language:" + language)
663
  vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth, highlight_words=highlight_words)
664
  srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth, highlight_words=highlight_words)
665
  json_result = json.dumps(result, indent=4, ensure_ascii=False)
666
+ srt_original = None
667
+ srt_bilingual = None
668
+ if translationModel is not None and translationModel.translationLang is not None:
669
+ srt_original = self.__get_subs(result["segments"], "srt_original", languageMaxLineWidth, highlight_words=highlight_words)
670
+ srt_bilingual = self.__get_subs(result["segments"], "srt_bilingual", languageMaxLineWidth, highlight_words=highlight_words)
671
+
672
+ whisperLangZho: bool = whisperLang is not None and whisperLang.nllb is not None and whisperLang.nllb.code in ["zho_Hant", "zho_Hans", "yue_Hant"]
673
+ translationZho: bool = translationModel is not None and translationModel.translationLang is not None and translationModel.translationLang.nllb is not None and translationModel.translationLang.nllb.code in ["zho_Hant", "zho_Hans", "yue_Hant"]
674
+ if whisperLangZho or translationZho:
675
+ locale = None
676
+ if whisperLangZho:
677
+ if whisperLang.nllb.code == "zho_Hant":
678
+ locale = "zh-tw"
679
+ elif whisperLang.nllb.code == "zho_Hans":
680
+ locale = "zh-cn"
681
+ elif whisperLang.nllb.code == "yue_Hant":
682
+ locale = "zh-hk"
683
+ if translationZho:
684
+ if translationModel.translationLang.nllb.code == "zho_Hant":
685
+ locale = "zh-tw"
686
+ elif translationModel.translationLang.nllb.code == "zho_Hans":
687
+ locale = "zh-cn"
688
+ elif translationModel.translationLang.nllb.code == "yue_Hant":
689
+ locale = "zh-hk"
690
+ if locale is not None:
691
+ vtt = zhconv.convert(vtt, locale)
692
+ srt = zhconv.convert(srt, locale)
693
+ text = zhconv.convert(text, locale)
694
+ json_result = zhconv.convert(json_result, locale)
695
+ if translationModel is not None and translationModel.translationLang is not None:
696
+ if srt_original is not None and len(srt_original) > 0:
697
+ srt_original = zhconv.convert(srt_original, locale)
698
+ if srt_bilingual is not None and len(srt_bilingual) > 0:
699
+ srt_bilingual = zhconv.convert(srt_bilingual, locale)
700
 
701
  output_files = []
702
  output_files.append(self.__create_file(srt, output_dir, source_name + "-subs.srt"));
703
  output_files.append(self.__create_file(vtt, output_dir, source_name + "-subs.vtt"));
704
  output_files.append(self.__create_file(text, output_dir, source_name + "-transcript.txt"));
705
  output_files.append(self.__create_file(json_result, output_dir, source_name + "-result.json"));
706
+ if srt_original is not None and len(srt_original) > 0:
707
+ output_files.append(self.__create_file(srt_original, output_dir, source_name + "-original.srt"));
708
+ if srt_bilingual is not None and len(srt_bilingual) > 0:
709
+ output_files.append(self.__create_file(srt_bilingual, output_dir, source_name + "-bilingual.srt"));
710
 
711
  return output_files, text, vtt
712
 
 
733
  write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth, highlight_words=highlight_words)
734
  elif format == 'srt':
735
  write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth, highlight_words=highlight_words)
736
+ elif format == 'srt_original':
737
+ write_srt_original(segments, file=segmentStream, maxLineWidth=maxLineWidth, highlight_words=highlight_words)
738
+ elif format == 'srt_bilingual':
739
+ write_srt_original(segments, file=segmentStream, maxLineWidth=maxLineWidth, highlight_words=highlight_words, bilingual=True)
740
  else:
741
  raise Exception("Unknown format " + format)
742
 
 
765
  self.diarization = None
766
 
767
  def create_ui(app_config: ApplicationConfig):
768
+ optionsMd: str = None
769
+ readmeMd: str = None
770
+ try:
771
+ with open("docs\options.md", "r", encoding="utf-8") as optionsFile:
772
+ optionsMd = optionsFile.read()
773
+ with open("README.md", "r", encoding="utf-8") as readmeFile:
774
+ readmeMd = readmeFile.read()
775
+ except Exception as e:
776
+ print("Error occurred during read options.md file: ", str(e))
777
+
778
  ui = WhisperTranscriber(app_config.input_audio_max_duration, app_config.vad_process_timeout, app_config.vad_cpu_cores,
779
  app_config.delete_uploaded_files, app_config.output_dir, app_config)
780
 
 
793
  # Try to convert from camel-case to title-case
794
  implementation_name = app_config.whisper_implementation.title().replace("_", " ").replace("-", " ")
795
 
796
+ uiDescription = implementation_name + " is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
797
+ uiDescription += " audio and is also a multi-task model that can perform multilingual speech recognition "
798
+ uiDescription += " as well as speech translation and language identification. "
799
 
800
+ uiDescription += "\n\n\n\nFor longer audio files (>10 minutes) not in English, it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option."
801
 
802
  # Recommend faster-whisper
803
  if is_whisper:
804
+ uiDescription += "\n\n\n\nFor faster inference on GPU, try [faster-whisper](https://huggingface.co/spaces/aadnk/faster-whisper-webui)."
805
 
806
  if app_config.input_audio_max_duration > 0:
807
+ uiDescription += "\n\n" + "Max audio file length: " + str(app_config.input_audio_max_duration) + " s"
808
+
809
+ uiArticle = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)."
810
+ uiArticle += "\n\nWhisper's Task 'translate' only implements the functionality of translating other languages into English. "
811
+ uiArticle += "OpenAI does not guarantee translations between arbitrary languages. In such cases, you can choose to use the NLLB Model to implement the translation task. "
812
+ uiArticle += "However, it's important to note that the NLLB Model runs slowly, and the completion time may be twice as long as usual. "
813
+ uiArticle += "\n\nThe larger the parameters of the NLLB model, the better its performance is expected to be. "
814
+ uiArticle += "However, it also requires higher computational resources, making it slower to operate. "
815
+ uiArticle += "On the other hand, the version converted from ct2 (CTranslate2) requires lower resources and operates at a faster speed."
816
+ uiArticle += "\n\nCurrently, enabling word-level timestamps cannot be used in conjunction with NLLB Model translation "
817
+ uiArticle += "because Word Timestamps will split the source text, and after translation, it becomes a non-word-level string. "
818
+ uiArticle += "\n\nThe 'mt5-zh-ja-en-trimmed' model is finetuned from Google's 'mt5-base' model. "
819
+ uiArticle += "This model has a relatively good translation speed, but it only supports three languages: Chinese, Japanese, and English. "
820
+
821
+ whisper_models = app_config.get_model_names("whisper")
822
+ nllb_models = app_config.get_model_names("nllb")
823
+ m2m100_models = app_config.get_model_names("m2m100")
824
+ mt5_models = app_config.get_model_names("mt5")
825
 
826
+ common_whisper_inputs = lambda : {
827
+ gr.Dropdown(label="Whisper - Model (for audio)", choices=whisper_models, value=app_config.default_model_name, elem_id="whisperModelName"),
828
+ gr.Dropdown(label="Whisper - Language", choices=sorted(get_lang_whisper_names()), value=app_config.language, elem_id="whisperLangName"),
829
+ }
830
+ common_m2m100_inputs = lambda : {
831
+ gr.Dropdown(label="M2M100 - Model (for translate)", choices=m2m100_models, elem_id="m2m100ModelName"),
832
+ gr.Dropdown(label="M2M100 - Language", choices=sorted(get_lang_m2m100_names()), elem_id="m2m100LangName"),
833
+ }
834
+ common_nllb_inputs = lambda : {
835
+ gr.Dropdown(label="NLLB - Model (for translate)", choices=nllb_models, elem_id="nllbModelName"),
836
+ gr.Dropdown(label="NLLB - Language", choices=sorted(get_lang_nllb_names()), elem_id="nllbLangName"),
837
+ }
838
+ common_mt5_inputs = lambda : {
839
+ gr.Dropdown(label="MT5 - Model (for translate)", choices=mt5_models, elem_id="mt5ModelName"),
840
+ gr.Dropdown(label="MT5 - Language", choices=sorted(get_lang_m2m100_names(["en", "ja", "zh"])), elem_id="mt5LangName"),
841
+ }
 
 
 
 
842
 
843
+ common_translation_inputs = lambda : {
844
+ gr.Number(label="Translation - Batch Size", precision=0, value=app_config.translation_batch_size, elem_id="translationBatchSize"),
845
+ gr.Number(label="Translation - No Repeat Ngram Size", precision=0, value=app_config.translation_no_repeat_ngram_size, elem_id="translationNoRepeatNgramSize"),
846
+ gr.Number(label="Translation - Num Beams", precision=0, value=app_config.translation_num_beams, elem_id="translationNumBeams")
847
+ }
848
+
849
+ common_vad_inputs = lambda : {
850
+ gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=app_config.default_vad, label="VAD", elem_id="vad"),
851
+ gr.Number(label="VAD - Merge Window (s)", precision=0, value=app_config.vad_merge_window, elem_id="vadMergeWindow"),
852
+ gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=app_config.vad_max_merge_size, elem_id="vadMaxMergeSize"),
853
+ }
854
+
855
+ common_word_timestamps_inputs = lambda : {
856
+ gr.Checkbox(label="Word Timestamps", value=app_config.word_timestamps, elem_id="word_timestamps"),
857
+ gr.Checkbox(label="Word Timestamps - Highlight Words", value=app_config.highlight_words, elem_id="highlight_words"),
858
+ }
859
 
860
  has_diarization_libs = Diarization.has_libraries()
861
 
 
863
  print("Diarization libraries not found - disabling diarization")
864
  app_config.diarization = False
865
 
866
+ common_diarization_inputs = lambda : {
867
+ gr.Checkbox(label="Diarization", value=app_config.diarization, interactive=has_diarization_libs, elem_id="diarization"),
868
+ gr.Number(label="Diarization - Speakers", precision=0, value=app_config.diarization_speakers, interactive=has_diarization_libs, elem_id="diarization_speakers"),
869
+ gr.Number(label="Diarization - Min Speakers", precision=0, value=app_config.diarization_min_speakers, interactive=has_diarization_libs, elem_id="diarization_min_speakers"),
870
+ gr.Number(label="Diarization - Max Speakers", precision=0, value=app_config.diarization_max_speakers, interactive=has_diarization_libs, elem_id="diarization_max_speakers")
871
+ }
872
 
873
  common_output = lambda : [
874
  gr.File(label="Download"),
 
878
 
879
  is_queue_mode = app_config.queue_concurrency_count is not None and app_config.queue_concurrency_count > 0
880
 
881
+ simpleInputDict = {}
882
+
883
+ with gr.Blocks() as simpleTranscribe:
884
+ simpleTranslateInput = gr.State(value="m2m100", elem_id = "translateInput")
885
+ simpleSourceInput = gr.State(value="urlData", elem_id = "sourceInput")
886
+ gr.Markdown(uiDescription)
887
  with gr.Row():
888
  with gr.Column():
889
+ simpleSubmit = gr.Button("Submit", variant="primary")
890
  with gr.Column():
891
  with gr.Row():
892
+ simpleInputDict = common_whisper_inputs()
893
+ with gr.Tab(label="M2M100") as simpleM2M100Tab:
894
+ with gr.Row():
895
+ simpleInputDict.update(common_m2m100_inputs())
896
+ with gr.Tab(label="NLLB") as simpleNllbTab:
897
+ with gr.Row():
898
+ simpleInputDict.update(common_nllb_inputs())
899
+ with gr.Tab(label="MT5") as simpleMT5Tab:
900
+ with gr.Row():
901
+ simpleInputDict.update(common_mt5_inputs())
902
+ simpleM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [simpleTranslateInput] )
903
+ simpleNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [simpleTranslateInput] )
904
+ simpleMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [simpleTranslateInput] )
905
  with gr.Column():
906
+ with gr.Tab(label="URL") as simpleUrlTab:
907
+ simpleInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
908
+ with gr.Tab(label="Upload") as simpleUploadTab:
909
+ simpleInputDict.update({gr.File(label="Upload Files", file_count="multiple", elem_id = "multipleFiles")})
910
+ with gr.Tab(label="Microphone") as simpleMicTab:
911
+ simpleInputDict.update({gr.Audio(source="microphone", type="filepath", label="Microphone Input", elem_id = "microphoneData")})
912
+ simpleUrlTab.select(fn=lambda: "urlData", inputs = [], outputs= [simpleSourceInput] )
913
+ simpleUploadTab.select(fn=lambda: "multipleFiles", inputs = [], outputs= [simpleSourceInput] )
914
+ simpleMicTab.select(fn=lambda: "microphoneData", inputs = [], outputs= [simpleSourceInput] )
915
+ simpleInputDict.update({gr.Dropdown(choices=["transcribe", "translate"], label="Task", value=app_config.task, elem_id = "task")})
916
+ with gr.Accordion("VAD options", open=False):
917
+ simpleInputDict.update(common_vad_inputs())
918
+ with gr.Accordion("Word Timestamps options", open=False):
919
+ simpleInputDict.update(common_word_timestamps_inputs())
920
+ with gr.Accordion("Diarization options", open=False):
921
+ simpleInputDict.update(common_diarization_inputs())
922
+ with gr.Accordion("Translation options", open=False):
923
+ simpleInputDict.update(common_translation_inputs())
924
  with gr.Column():
925
+ simpleOutput = common_output()
926
+ with gr.Accordion("Article"):
927
+ gr.Markdown(uiArticle)
928
+ if optionsMd is not None:
929
+ with gr.Accordion("docs/options.md", open=False):
930
+ gr.Markdown(optionsMd)
931
+ if readmeMd is not None:
932
+ with gr.Accordion("README.md", open=False):
933
+ gr.Markdown(readmeMd)
934
+
935
+ simpleInputDict.update({simpleTranslateInput, simpleSourceInput})
936
+ simpleSubmit.click(fn=ui.transcribe_webui_simple_progress if is_queue_mode else ui.transcribe_webui_simple,
937
+ inputs=simpleInputDict, outputs=simpleOutput)
938
 
939
+ fullInputDict = {}
940
+ fullDescription = uiDescription + "\n\n\n\n" + "Be careful when changing some of the options in the full interface - this can cause the model to crash."
941
 
942
+ with gr.Blocks() as fullTranscribe:
943
+ fullTranslateInput = gr.State(value="m2m100", elem_id = "translateInput")
944
+ fullSourceInput = gr.State(value="urlData", elem_id = "sourceInput")
945
+ gr.Markdown(fullDescription)
946
  with gr.Row():
947
  with gr.Column():
948
+ fullSubmit = gr.Button("Submit", variant="primary")
949
  with gr.Column():
950
  with gr.Row():
951
+ fullInputDict = common_whisper_inputs()
952
+ with gr.Tab(label="M2M100") as fullM2M100Tab:
953
+ with gr.Row():
954
+ fullInputDict.update(common_m2m100_inputs())
955
+ with gr.Tab(label="NLLB") as fullNllbTab:
956
+ with gr.Row():
957
+ fullInputDict.update(common_nllb_inputs())
958
+ with gr.Tab(label="MT5") as fullMT5Tab:
959
+ with gr.Row():
960
+ fullInputDict.update(common_mt5_inputs())
961
+ fullM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [fullTranslateInput] )
962
+ fullNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [fullTranslateInput] )
963
+ fullMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [fullTranslateInput] )
964
  with gr.Column():
965
+ with gr.Tab(label="URL") as fullUrlTab:
966
+ fullInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
967
+ with gr.Tab(label="Upload") as fullUploadTab:
968
+ fullInputDict.update({gr.File(label="Upload Files", file_count="multiple", elem_id = "multipleFiles")})
969
+ with gr.Tab(label="Microphone") as fullMicTab:
970
+ fullInputDict.update({gr.Audio(source="microphone", type="filepath", label="Microphone Input", elem_id = "microphoneData")})
971
+ fullUrlTab.select(fn=lambda: "urlData", inputs = [], outputs= [fullSourceInput] )
972
+ fullUploadTab.select(fn=lambda: "multipleFiles", inputs = [], outputs= [fullSourceInput] )
973
+ fullMicTab.select(fn=lambda: "microphoneData", inputs = [], outputs= [fullSourceInput] )
974
+ fullInputDict.update({gr.Dropdown(choices=["transcribe", "translate"], label="Task", value=app_config.task, elem_id = "task")})
975
+ with gr.Accordion("VAD options", open=False):
976
+ fullInputDict.update(common_vad_inputs())
977
+ fullInputDict.update({
978
+ gr.Number(label="VAD - Padding (s)", precision=None, value=app_config.vad_padding, elem_id = "vadPadding"),
979
+ gr.Number(label="VAD - Prompt Window (s)", precision=None, value=app_config.vad_prompt_window, elem_id = "vadPromptWindow"),
980
+ gr.Dropdown(choices=VAD_INITIAL_PROMPT_MODE_VALUES, label="VAD - Initial Prompt Mode", value=app_config.vad_initial_prompt_mode, elem_id = "vadInitialPromptMode")})
981
+ with gr.Accordion("Word Timestamps options", open=False):
982
+ fullInputDict.update(common_word_timestamps_inputs())
983
+ fullInputDict.update({
984
+ gr.Text(label="Word Timestamps - Prepend Punctuations", value=app_config.prepend_punctuations, elem_id = "prepend_punctuations"),
985
+ gr.Text(label="Word Timestamps - Append Punctuations", value=app_config.append_punctuations, elem_id = "append_punctuations")})
986
+ with gr.Accordion("Whisper Advanced options", open=False):
987
+ fullInputDict.update({
988
+ gr.TextArea(label="Initial Prompt", elem_id = "initial_prompt"),
989
+ gr.Number(label="Temperature", value=app_config.temperature, elem_id = "temperature"),
990
+ gr.Number(label="Best Of - Non-zero temperature", value=app_config.best_of, precision=0, elem_id = "best_of"),
991
+ gr.Number(label="Beam Size - Zero temperature", value=app_config.beam_size, precision=0, elem_id = "beam_size"),
992
+ gr.Number(label="Patience - Zero temperature", value=app_config.patience, elem_id = "patience"),
993
+ gr.Number(label="Length Penalty - Any temperature", value=app_config.length_penalty, elem_id = "length_penalty"),
994
+ gr.Text(label="Suppress Tokens - Comma-separated list of token IDs", value=app_config.suppress_tokens, elem_id = "suppress_tokens"),
995
+ gr.Checkbox(label="Condition on previous text", value=app_config.condition_on_previous_text, elem_id = "condition_on_previous_text"),
996
+ gr.Checkbox(label="FP16", value=app_config.fp16, elem_id = "fp16"),
997
+ gr.Number(label="Temperature increment on fallback", value=app_config.temperature_increment_on_fallback, elem_id = "temperature_increment_on_fallback"),
998
+ gr.Number(label="Compression ratio threshold", value=app_config.compression_ratio_threshold, elem_id = "compression_ratio_threshold"),
999
+ gr.Number(label="Logprob threshold", value=app_config.logprob_threshold, elem_id = "logprob_threshold"),
1000
+ gr.Number(label="No speech threshold", value=app_config.no_speech_threshold, elem_id = "no_speech_threshold"),
1001
+ })
1002
+ if app_config.whisper_implementation == "faster-whisper":
1003
+ fullInputDict.update({
1004
+ gr.Number(label="Repetition Penalty", value=app_config.repetition_penalty, elem_id = "repetition_penalty"),
1005
+ gr.Number(label="No Repeat Ngram Size", value=app_config.no_repeat_ngram_size, precision=0, elem_id = "no_repeat_ngram_size")
1006
+ })
1007
+ with gr.Accordion("Diarization options", open=False):
1008
+ fullInputDict.update(common_diarization_inputs())
1009
+ with gr.Accordion("Translation options", open=False):
1010
+ fullInputDict.update(common_translation_inputs())
1011
  with gr.Column():
1012
+ fullOutput = common_output()
1013
+ with gr.Accordion("Article"):
1014
+ gr.Markdown(uiArticle)
1015
+ if optionsMd is not None:
1016
+ with gr.Accordion("docs/options.md", open=False):
1017
+ gr.Markdown(optionsMd)
1018
+ if readmeMd is not None:
1019
+ with gr.Accordion("README.md", open=False):
1020
+ gr.Markdown(readmeMd)
1021
+
1022
+ fullInputDict.update({fullTranslateInput, fullSourceInput})
1023
+ fullSubmit.click(fn=ui.transcribe_webui_full_progress if is_queue_mode else ui.transcribe_webui_full,
1024
+ inputs=fullInputDict, outputs=fullOutput)
1025
 
1026
+ demo = gr.TabbedInterface([simpleTranscribe, fullTranscribe], tab_names=["Simple", "Full"])
1027
 
1028
  # Queue up the demo
1029
  if is_queue_mode:
 
1039
 
1040
  if __name__ == '__main__':
1041
  default_app_config = ApplicationConfig.create_default()
1042
+ whisper_models = default_app_config.get_model_names("whisper")
 
1043
 
1044
  # Environment variable overrides
1045
  default_whisper_implementation = os.environ.get("WHISPER_IMPLEMENTATION", default_app_config.whisper_implementation)
 
1077
  help="the compute type to use for inference")
1078
  parser.add_argument("--threads", type=optional_int, default=0,
1079
  help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
1080
+
1081
  parser.add_argument("--vad_max_merge_size", type=int, default=default_app_config.vad_max_merge_size, \
1082
  help="The number of VAD - Max Merge Size (s).") # 30
1083
+ parser.add_argument("--language", type=str, default=None, choices=sorted(get_lang_whisper_names()) + sorted([k.title() for k in _TO_LANG_CODE_WHISPER.keys()]),
1084
  help="language spoken in the audio, specify None to perform language detection")
1085
  parser.add_argument("--save_downloaded_files", action='store_true', \
1086
  help="True to move downloaded files to outputs directory. This argument will take effect only after output_dir is set.")
 
1090
  help="Maximum length of a file name.")
1091
  parser.add_argument("--autolaunch", action='store_true', \
1092
  help="open the webui URL in the system's default browser upon launch")
1093
+
1094
  parser.add_argument('--auth_token', type=str, default=default_app_config.auth_token, help='HuggingFace API Token (optional)')
1095
  parser.add_argument("--diarization", type=str2bool, default=default_app_config.diarization, \
1096
  help="whether to perform speaker diarization")
cli.py CHANGED
@@ -10,7 +10,7 @@ from app import VadOptions, WhisperTranscriber
10
  from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
11
  from src.diarization.diarization import Diarization
12
  from src.download import download_url
13
- from src.languages import get_language_names
14
 
15
  from src.utils import optional_float, optional_int, str2bool
16
  from src.whisper.whisperFactory import create_whisper_container
@@ -43,7 +43,7 @@ def cli():
43
 
44
  parser.add_argument("--task", type=str, default=app_config.task, choices=["transcribe", "translate"], \
45
  help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
46
- parser.add_argument("--language", type=str, default=app_config.language, choices=sorted(get_language_names()), \
47
  help="language spoken in the audio, specify None to perform language detection")
48
 
49
  parser.add_argument("--vad", type=str, default=app_config.default_vad, choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], \
 
10
  from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
11
  from src.diarization.diarization import Diarization
12
  from src.download import download_url
13
+ from src.translation.translationLangs import get_lang_whisper_names # from src.languages import get_language_names
14
 
15
  from src.utils import optional_float, optional_int, str2bool
16
  from src.whisper.whisperFactory import create_whisper_container
 
43
 
44
  parser.add_argument("--task", type=str, default=app_config.task, choices=["transcribe", "translate"], \
45
  help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
46
+ parser.add_argument("--language", type=str, default=app_config.language, choices=sorted(get_lang_whisper_names()), \
47
  help="language spoken in the audio, specify None to perform language detection")
48
 
49
  parser.add_argument("--vad", type=str, default=app_config.default_vad, choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], \
config.json5 CHANGED
@@ -1,254 +1,300 @@
1
  {
2
- "models": [
3
- // Configuration for the built-in models. You can remove any of these
4
- // if you don't want to use the default models.
5
- {
6
- "name": "tiny",
7
- "url": "tiny"
8
- },
9
- {
10
- "name": "base",
11
- "url": "base"
12
- },
13
- {
14
- "name": "small",
15
- "url": "small"
16
- },
17
- {
18
- "name": "medium",
19
- "url": "medium"
20
- },
21
- {
22
- "name": "large",
23
- "url": "large"
24
- },
25
- {
26
- "name": "large-v2",
27
- "url": "large-v2"
28
- },
29
- {
30
- "name": "large-v3",
31
- "url": "large-v3"
32
- },
33
- // Uncomment to add custom Japanese models
34
- //{
35
- // "name": "whisper-large-v2-mix-jp",
36
- // "url": "vumichien/whisper-large-v2-mix-jp",
37
- // // The type of the model. Can be "huggingface" or "whisper" - "whisper" is the default.
38
- // // HuggingFace models are loaded using the HuggingFace transformers library and then converted to Whisper models.
39
- // "type": "huggingface",
40
- //},
41
- //{
42
- // "name": "local-model",
43
- // "url": "path/to/local/model",
44
- //},
45
- //{
46
- // "name": "remote-model",
47
- // "url": "https://example.com/path/to/model",
48
- //}
 
49
  ],
50
- "nllb_models": [
51
- {
52
- "name": "nllb-200-distilled-1.3B-ct2fast:int8_float16/michaelfeil",
53
- "url": "michaelfeil/ct2fast-nllb-200-distilled-1.3B",
54
- "type": "huggingface"
55
- },
56
- {
57
- "name": "nllb-200-3.3B-ct2fast:int8_float16/michaelfeil",
58
- "url": "michaelfeil/ct2fast-nllb-200-3.3B",
59
- "type": "huggingface"
60
- },
61
- {
62
- "name": "nllb-200-1.3B-ct2:float16/JustFrederik",
63
- "url": "JustFrederik/nllb-200-1.3B-ct2-float16",
64
- "type": "huggingface"
65
- },
66
- {
67
- "name": "nllb-200-distilled-1.3B-ct2:float16/JustFrederik",
68
- "url": "JustFrederik/nllb-200-distilled-1.3B-ct2-float16",
69
- "type": "huggingface"
70
- },
71
- {
72
- "name": "nllb-200-1.3B-ct2:int8/JustFrederik",
73
- "url": "JustFrederik/nllb-200-1.3B-ct2-int8",
74
- "type": "huggingface"
75
- },
76
- {
77
- "name": "nllb-200-distilled-1.3B-ct2:int8/JustFrederik",
78
- "url": "JustFrederik/nllb-200-distilled-1.3B-ct2-int8",
79
- "type": "huggingface"
80
- },
81
- {
82
- "name": "mt5-zh-ja-en-trimmed/K024",
83
- "url": "K024/mt5-zh-ja-en-trimmed",
84
- "type": "huggingface"
85
- },
86
- {
87
- "name": "mt5-zh-ja-en-trimmed-fine-tuned-v1/engmatic-earth",
88
- "url": "engmatic-earth/mt5-zh-ja-en-trimmed-fine-tuned-v1",
89
- "type": "huggingface"
90
- },
91
- {
92
- "name": "nllb-200-distilled-600M/facebook",
93
- "url": "facebook/nllb-200-distilled-600M",
94
- "type": "huggingface"
95
- },
96
- {
97
- "name": "nllb-200-distilled-600M-ct2/JustFrederik",
98
- "url": "JustFrederik/nllb-200-distilled-600M-ct2",
99
- "type": "huggingface"
100
- },
101
- {
102
- "name": "nllb-200-distilled-600M-ct2:float16/JustFrederik",
103
- "url": "JustFrederik/nllb-200-distilled-600M-ct2-float16",
104
- "type": "huggingface"
105
- },
106
- {
107
- "name": "nllb-200-distilled-600M-ct2:int8/JustFrederik",
108
- "url": "JustFrederik/nllb-200-distilled-600M-ct2-int8",
109
- "type": "huggingface"
110
- },
111
- // Uncomment to add official Facebook 1.3B and 3.3B model
112
- // The official Facebook 1.3B and 3.3B model files are too large,
113
- // and to avoid occupying too much disk space on Hugging Face's free spaces,
114
- // these models are not included in the config.
115
- //{
116
- // "name": "nllb-200-distilled-1.3B/facebook",
117
- // "url": "facebook/nllb-200-distilled-1.3B",
118
- // "type": "huggingface"
119
- //},
120
- //{
121
- // "name": "nllb-200-1.3B/facebook",
122
- // "url": "facebook/nllb-200-1.3B",
123
- // "type": "huggingface"
124
- //},
125
- //{
126
- // "name": "nllb-200-3.3B/facebook",
127
- // "url": "facebook/nllb-200-3.3B",
128
- // "type": "huggingface"
129
- //},
130
- //{
131
- // "name": "nllb-200-distilled-1.3B-ct2/JustFrederik",
132
- // "url": "JustFrederik/nllb-200-distilled-1.3B-ct2",
133
- // "type": "huggingface"
134
- //},
135
- //{
136
- // "name": "nllb-200-1.3B-ct2/JustFrederik",
137
- // "url": "JustFrederik/nllb-200-1.3B-ct2",
138
- // "type": "huggingface"
139
- //},
140
- //{
141
- // "name": "nllb-200-3.3B-ct2:float16/JustFrederik",
142
- // "url": "JustFrederik/nllb-200-3.3B-ct2-float16",
143
- // "type": "huggingface"
144
- //},
145
  ],
146
- // Configuration options that will be used if they are not specified in the command line arguments.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
- // * WEBUI options *
149
 
150
- // Maximum audio file length in seconds, or -1 for no limit. Ignored by CLI.
151
- "input_audio_max_duration": 1800,
152
- // True to share the app on HuggingFace.
153
- "share": false,
154
- // The host or IP to bind to. If None, bind to localhost.
155
- "server_name": null,
156
- // The port to bind to.
157
- "server_port": 7860,
158
- // The number of workers to use for the web server. Use -1 to disable queueing.
159
- "queue_concurrency_count": 1,
160
- // Whether or not to automatically delete all uploaded files, to save disk space
161
- "delete_uploaded_files": true,
162
 
163
- // * General options *
164
 
165
- // The default implementation to use for Whisper. Can be "whisper" or "faster-whisper".
166
- // Note that you must either install the requirements for faster-whisper (requirements-fasterWhisper.txt)
167
- // or whisper (requirements.txt)
168
- "whisper_implementation": "faster-whisper",
169
 
170
- // The default model name.
171
- "default_model_name": "large-v2",
172
- // The default VAD.
173
- "default_vad": "silero-vad",
174
- // A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.
175
- "vad_parallel_devices": "",
176
- // The number of CPU cores to use for VAD pre-processing.
177
- "vad_cpu_cores": 1,
178
- // The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.
179
- "vad_process_timeout": 1800,
180
- // True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.
181
- "auto_parallel": false,
182
- // Directory to save the outputs (CLI will use the current directory if not specified)
183
- "output_dir": null,
184
- // The path to save model files; uses ~/.cache/whisper by default
185
- "model_dir": null,
186
- // Device to use for PyTorch inference, or Null to use the default device
187
- "device": null,
188
- // Whether to print out the progress and debug messages
189
- "verbose": true,
190
- // Whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')
191
- "task": "transcribe",
192
- // Language spoken in the audio, specify None to perform language detection
193
- "language": null,
194
- // The window size (in seconds) to merge voice segments
195
- "vad_merge_window": 5,
196
- // The maximum size (in seconds) of a voice segment
197
- "vad_max_merge_size": 90,
198
- // The padding (in seconds) to add to each voice segment
199
- "vad_padding": 1,
200
- // Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)
201
- "vad_initial_prompt_mode": "prepend_first_segment",
202
- // The window size of the prompt to pass to Whisper
203
- "vad_prompt_window": 3,
204
- // Temperature to use for sampling
205
- "temperature": 0,
206
- // Number of candidates when sampling with non-zero temperature
207
- "best_of": 5,
208
- // Number of beams in beam search, only applicable when temperature is zero
209
- "beam_size": 5,
210
- // Optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search
211
- "patience": 1,
212
- // Optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default
213
- "length_penalty": null,
214
- // Comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations
215
- "suppress_tokens": "-1",
216
- // Optional text to provide as a prompt for the first window
217
- "initial_prompt": null,
218
- // If True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop
219
- "condition_on_previous_text": true,
220
- // Whether to perform inference in fp16; True by default
221
- "fp16": true,
222
- // The compute type used by faster-whisper. Can be "int8". "int16" or "float16".
223
- "compute_type": "auto",
224
- // Temperature to increase when falling back when the decoding fails to meet either of the thresholds below
225
- "temperature_increment_on_fallback": 0.2,
226
- // If the gzip compression ratio is higher than this value, treat the decoding as failed
227
- "compression_ratio_threshold": 2.4,
228
- // If the average log probability is lower than this value, treat the decoding as failed
229
- "logprob_threshold": -1.0,
230
- // If the probability of the <no-speech> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence
231
- "no_speech_threshold": 0.6,
232
 
233
- // (experimental) extract word-level timestamps and refine the results based on them
234
- "word_timestamps": false,
235
- // if word_timestamps is True, merge these punctuation symbols with the next word
236
- "prepend_punctuations": "\"\'“¿([{-",
237
- // if word_timestamps is True, merge these punctuation symbols with the previous word
238
- "append_punctuations": "\"\'.。,,!!??::”)]}、",
239
- // (requires --word_timestamps True) underline each word as it is spoken in srt and vtt
240
- "highlight_words": false,
241
 
242
- // Diarization settings
243
- "auth_token": null,
244
- // Whether to perform speaker diarization
245
- "diarization": false,
246
- // The number of speakers to detect
247
- "diarization_speakers": 2,
248
- // The minimum number of speakers to detect
249
- "diarization_min_speakers": 1,
250
- // The maximum number of speakers to detect
251
- "diarization_max_speakers": 8,
252
- // The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.
253
- "diarization_process_timeout": 60,
254
  }
 
1
  {
2
+ "models": {
3
+ "whisper": [
4
+ // Configuration for the built-in models. You can remove any of these
5
+ // if you don't want to use the default models.
6
+ {
7
+ "name": "tiny",
8
+ "url": "tiny"
9
+ },
10
+ {
11
+ "name": "base",
12
+ "url": "base"
13
+ },
14
+ {
15
+ "name": "small",
16
+ "url": "small"
17
+ },
18
+ {
19
+ "name": "medium",
20
+ "url": "medium"
21
+ },
22
+ {
23
+ "name": "large",
24
+ "url": "large"
25
+ },
26
+ {
27
+ "name": "large-v2",
28
+ "url": "large-v2"
29
+ },
30
+ {
31
+ "name": "large-v3",
32
+ "url": "large-v3"
33
+ }
34
+ // Uncomment to add custom Japanese models
35
+ //{
36
+ // "name": "whisper-large-v2-mix-jp",
37
+ // "url": "vumichien/whisper-large-v2-mix-jp",
38
+ // // The type of the model. Can be "huggingface" or "whisper" - "whisper" is the default.
39
+ // // HuggingFace models are loaded using the HuggingFace transformers library and then converted to Whisper models.
40
+ // "type": "huggingface",
41
+ //},
42
+ //{
43
+ // "name": "local-model",
44
+ // "url": "path/to/local/model",
45
+ //},
46
+ //{
47
+ // "name": "remote-model",
48
+ // "url": "https://example.com/path/to/model",
49
+ //}
50
  ],
51
+ "m2m100": [
52
+ {
53
+ "name": "m2m100_1.2B-ct2fast/michaelfeil",
54
+ "url": "michaelfeil/ct2fast-m2m100_1.2B",
55
+ "type": "huggingface",
56
+ "tokenizer_url": "facebook/m2m100_1.2B"
57
+ },
58
+ {
59
+ "name": "m2m100_418M-ct2fast/michaelfeil",
60
+ "url": "michaelfeil/ct2fast-m2m100_418M",
61
+ "type": "huggingface",
62
+ "tokenizer_url": "facebook/m2m100_418M"
63
+ },
64
+ //{
65
+ // "name": "m2m100-12B-ct2fast/michaelfeil",
66
+ // "url": "michaelfeil/ct2fast-m2m100-12B-last-ckpt",
67
+ // "type": "huggingface",
68
+ // "tokenizer_url": "facebook/m2m100-12B-last-ckpt"
69
+ //},
70
+ {
71
+ "name": "m2m100_1.2B/facebook",
72
+ "url": "facebook/m2m100_1.2B",
73
+ "type": "huggingface"
74
+ },
75
+ {
76
+ "name": "m2m100_418M/facebook",
77
+ "url": "facebook/m2m100_418M",
78
+ "type": "huggingface"
79
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  ],
81
+ "nllb": [
82
+ {
83
+ "name": "nllb-200-distilled-1.3B-ct2fast:int8_float16/michaelfeil",
84
+ "url": "michaelfeil/ct2fast-nllb-200-distilled-1.3B",
85
+ "type": "huggingface",
86
+ "tokenizer_url": "facebook/nllb-200-distilled-1.3B"
87
+ },
88
+ {
89
+ "name": "nllb-200-3.3B-ct2fast:int8_float16/michaelfeil",
90
+ "url": "michaelfeil/ct2fast-nllb-200-3.3B",
91
+ "type": "huggingface",
92
+ "tokenizer_url": "facebook/nllb-200-3.3B"
93
+ },
94
+ {
95
+ "name": "nllb-200-1.3B-ct2:float16/JustFrederik",
96
+ "url": "JustFrederik/nllb-200-1.3B-ct2-float16",
97
+ "type": "huggingface",
98
+ "tokenizer_url": "facebook/nllb-200-1.3B"
99
+ },
100
+ {
101
+ "name": "nllb-200-distilled-1.3B-ct2:float16/JustFrederik",
102
+ "url": "JustFrederik/nllb-200-distilled-1.3B-ct2-float16",
103
+ "type": "huggingface",
104
+ "tokenizer_url": "facebook/nllb-200-distilled-1.3B"
105
+ },
106
+ {
107
+ "name": "nllb-200-1.3B-ct2:int8/JustFrederik",
108
+ "url": "JustFrederik/nllb-200-1.3B-ct2-int8",
109
+ "type": "huggingface",
110
+ "tokenizer_url": "facebook/nllb-200-1.3B"
111
+ },
112
+ {
113
+ "name": "nllb-200-distilled-1.3B-ct2:int8/JustFrederik",
114
+ "url": "JustFrederik/nllb-200-distilled-1.3B-ct2-int8",
115
+ "type": "huggingface",
116
+ "tokenizer_url": "facebook/nllb-200-distilled-1.3B"
117
+ },
118
+ {
119
+ "name": "nllb-200-distilled-600M/facebook",
120
+ "url": "facebook/nllb-200-distilled-600M",
121
+ "type": "huggingface"
122
+ },
123
+ {
124
+ "name": "nllb-200-distilled-600M-ct2/JustFrederik",
125
+ "url": "JustFrederik/nllb-200-distilled-600M-ct2",
126
+ "type": "huggingface",
127
+ "tokenizer_url": "facebook/nllb-200-distilled-600M"
128
+ },
129
+ {
130
+ "name": "nllb-200-distilled-600M-ct2:float16/JustFrederik",
131
+ "url": "JustFrederik/nllb-200-distilled-600M-ct2-float16",
132
+ "type": "huggingface",
133
+ "tokenizer_url": "facebook/nllb-200-distilled-600M"
134
+ },
135
+ {
136
+ "name": "nllb-200-distilled-600M-ct2:int8/JustFrederik",
137
+ "url": "JustFrederik/nllb-200-distilled-600M-ct2-int8",
138
+ "type": "huggingface",
139
+ "tokenizer_url": "facebook/nllb-200-distilled-600M"
140
+ }
141
+ // Uncomment to add official Facebook 1.3B and 3.3B model
142
+ // The official Facebook 1.3B and 3.3B model files are too large,
143
+ // and to avoid occupying too much disk space on Hugging Face's free spaces,
144
+ // these models are not included in the config.
145
+ //{
146
+ // "name": "nllb-200-distilled-1.3B/facebook",
147
+ // "url": "facebook/nllb-200-distilled-1.3B",
148
+ // "type": "huggingface"
149
+ //},
150
+ //{
151
+ // "name": "nllb-200-1.3B/facebook",
152
+ // "url": "facebook/nllb-200-1.3B",
153
+ // "type": "huggingface"
154
+ //},
155
+ //{
156
+ // "name": "nllb-200-3.3B/facebook",
157
+ // "url": "facebook/nllb-200-3.3B",
158
+ // "type": "huggingface"
159
+ //},
160
+ //{
161
+ // "name": "nllb-200-distilled-1.3B-ct2/JustFrederik",
162
+ // "url": "JustFrederik/nllb-200-distilled-1.3B-ct2",
163
+ // "type": "huggingface",
164
+ // "tokenizer_url": "facebook/nllb-200-distilled-1.3B"
165
+ //},
166
+ //{
167
+ // "name": "nllb-200-1.3B-ct2/JustFrederik",
168
+ // "url": "JustFrederik/nllb-200-1.3B-ct2",
169
+ // "type": "huggingface",
170
+ // "tokenizer_url": "facebook/nllb-200-1.3B"
171
+ //},
172
+ //{
173
+ // "name": "nllb-200-3.3B-ct2:float16/JustFrederik",
174
+ // "url": "JustFrederik/nllb-200-3.3B-ct2-float16",
175
+ // "type": "huggingface",
176
+ // "tokenizer_url": "facebook/nllb-200-3.3B"
177
+ //},
178
+ ],
179
+ "mt5": [
180
+ {
181
+ "name": "mt5-zh-ja-en-trimmed/K024",
182
+ "url": "K024/mt5-zh-ja-en-trimmed",
183
+ "type": "huggingface"
184
+ },
185
+ {
186
+ "name": "mt5-zh-ja-en-trimmed-fine-tuned-v1/engmatic-earth",
187
+ "url": "engmatic-earth/mt5-zh-ja-en-trimmed-fine-tuned-v1",
188
+ "type": "huggingface"
189
+ }
190
+ ]
191
+ },
192
+ // Configuration options that will be used if they are not specified in the command line arguments.
193
 
194
+ // * WEBUI options *
195
 
196
+ // Maximum audio file length in seconds, or -1 for no limit. Ignored by CLI.
197
+ "input_audio_max_duration": 1800,
198
+ // True to share the app on HuggingFace.
199
+ "share": false,
200
+ // The host or IP to bind to. If None, bind to localhost.
201
+ "server_name": null,
202
+ // The port to bind to.
203
+ "server_port": 7860,
204
+ // The number of workers to use for the web server. Use -1 to disable queueing.
205
+ "queue_concurrency_count": 1,
206
+ // Whether or not to automatically delete all uploaded files, to save disk space
207
+ "delete_uploaded_files": true,
208
 
209
+ // * General options *
210
 
211
+ // The default implementation to use for Whisper. Can be "whisper" or "faster-whisper".
212
+ // Note that you must either install the requirements for faster-whisper (requirements-fasterWhisper.txt)
213
+ // or whisper (requirements.txt)
214
+ "whisper_implementation": "faster-whisper",
215
 
216
+ // The default model name.
217
+ "default_model_name": "large-v2",
218
+ // The default VAD.
219
+ "default_vad": "silero-vad",
220
+ // A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.
221
+ "vad_parallel_devices": "",
222
+ // The number of CPU cores to use for VAD pre-processing.
223
+ "vad_cpu_cores": 1,
224
+ // The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.
225
+ "vad_process_timeout": 1800,
226
+ // True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.
227
+ "auto_parallel": false,
228
+ // Directory to save the outputs (CLI will use the current directory if not specified)
229
+ "output_dir": null,
230
+ // The path to save model files; uses ~/.cache/whisper by default
231
+ "model_dir": null,
232
+ // Device to use for PyTorch inference, or Null to use the default device
233
+ "device": null,
234
+ // Whether to print out the progress and debug messages
235
+ "verbose": true,
236
+ // Whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')
237
+ "task": "transcribe",
238
+ // Language spoken in the audio, specify None to perform language detection
239
+ "language": null,
240
+ // The window size (in seconds) to merge voice segments
241
+ "vad_merge_window": 5,
242
+ // The maximum size (in seconds) of a voice segment
243
+ "vad_max_merge_size": 90,
244
+ // The padding (in seconds) to add to each voice segment
245
+ "vad_padding": 1,
246
+ // Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)
247
+ "vad_initial_prompt_mode": "prepend_first_segment",
248
+ // The window size of the prompt to pass to Whisper
249
+ "vad_prompt_window": 3,
250
+ // Temperature to use for sampling
251
+ "temperature": 0,
252
+ // Number of candidates when sampling with non-zero temperature
253
+ "best_of": 5,
254
+ // Number of beams in beam search, only applicable when temperature is zero
255
+ "beam_size": 5,
256
+ // Optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search
257
+ "patience": 1,
258
+ // Optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default
259
+ "length_penalty": null,
260
+ // Comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations
261
+ "suppress_tokens": "-1",
262
+ // Optional text to provide as a prompt for the first window
263
+ "initial_prompt": null,
264
+ // If True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop
265
+ "condition_on_previous_text": true,
266
+ // Whether to perform inference in fp16; True by default
267
+ "fp16": true,
268
+ // The compute type used by faster-whisper. Can be "int8". "int16" or "float16".
269
+ "compute_type": "auto",
270
+ // Temperature to increase when falling back when the decoding fails to meet either of the thresholds below
271
+ "temperature_increment_on_fallback": 0.2,
272
+ // If the gzip compression ratio is higher than this value, treat the decoding as failed
273
+ "compression_ratio_threshold": 2.4,
274
+ // If the average log probability is lower than this value, treat the decoding as failed
275
+ "logprob_threshold": -1.0,
276
+ // If the probability of the <no-speech> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence
277
+ "no_speech_threshold": 0.6,
278
 
279
+ // (experimental) extract word-level timestamps and refine the results based on them
280
+ "word_timestamps": false,
281
+ // if word_timestamps is True, merge these punctuation symbols with the next word
282
+ "prepend_punctuations": "\"\'“¿([{-",
283
+ // if word_timestamps is True, merge these punctuation symbols with the previous word
284
+ "append_punctuations": "\"\'.。,,!!??::”)]}、",
285
+ // (requires --word_timestamps True) underline each word as it is spoken in srt and vtt
286
+ "highlight_words": false,
287
 
288
+ // Diarization settings
289
+ "auth_token": null,
290
+ // Whether to perform speaker diarization
291
+ "diarization": false,
292
+ // The number of speakers to detect
293
+ "diarization_speakers": 2,
294
+ // The minimum number of speakers to detect
295
+ "diarization_min_speakers": 1,
296
+ // The maximum number of speakers to detect
297
+ "diarization_max_speakers": 8,
298
+ // The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.
299
+ "diarization_process_timeout": 60,
300
  }
requirements-whisper.txt CHANGED
@@ -1,6 +1,5 @@
1
  git+https://github.com/huggingface/transformers
2
  git+https://github.com/openai/whisper.git
3
- transformers
4
  ffmpeg-python==0.2.0
5
  gradio==3.50.2
6
  yt-dlp
 
1
  git+https://github.com/huggingface/transformers
2
  git+https://github.com/openai/whisper.git
 
3
  ffmpeg-python==0.2.0
4
  gradio==3.50.2
5
  yt-dlp
src/config.py CHANGED
@@ -1,16 +1,11 @@
1
  from enum import Enum
2
- import urllib
3
 
4
  import os
5
- from typing import List
6
- from urllib.parse import urlparse
7
- import json5
8
- import torch
9
 
10
- from tqdm import tqdm
11
 
12
  class ModelConfig:
13
- def __init__(self, name: str, url: str, path: str = None, type: str = "whisper"):
14
  """
15
  Initialize a model configuration.
16
 
@@ -23,6 +18,7 @@ class ModelConfig:
23
  self.url = url
24
  self.path = path
25
  self.type = type
 
26
 
27
  VAD_INITIAL_PROMPT_MODE_VALUES=["prepend_all_segments", "prepend_first_segment", "json_prompt_mode"]
28
 
@@ -33,7 +29,7 @@ class VadInitialPromptMode(Enum):
33
 
34
  @staticmethod
35
  def from_string(s: str):
36
- normalized = s.lower() if s is not None else None
37
 
38
  if normalized == "prepend_all_segments":
39
  return VadInitialPromptMode.PREPEND_ALL_SEGMENTS
@@ -47,11 +43,11 @@ class VadInitialPromptMode(Enum):
47
  return None
48
 
49
  class ApplicationConfig:
50
- def __init__(self, models: List[ModelConfig] = [], nllb_models: List[ModelConfig] = [], input_audio_max_duration: int = 600,
51
- share: bool = False, server_name: str = None, server_port: int = 7860,
52
  queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
53
- whisper_implementation: str = "whisper",
54
- default_model_name: str = "medium", default_nllb_model_name: str = "distilled-600M", default_vad: str = "silero-vad",
55
  vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800,
56
  auto_parallel: bool = False, output_dir: str = None,
57
  model_dir: str = None, device: str = None,
@@ -66,6 +62,7 @@ class ApplicationConfig:
66
  compute_type: str = "float16",
67
  temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
68
  logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6,
 
69
  # Word timestamp settings
70
  word_timestamps: bool = False, prepend_punctuations: str = "\"\'“¿([{-",
71
  append_punctuations: str = "\"\'.。,,!!??::”)]}、",
@@ -73,10 +70,14 @@ class ApplicationConfig:
73
  # Diarization
74
  auth_token: str = None, diarization: bool = False, diarization_speakers: int = 2,
75
  diarization_min_speakers: int = 1, diarization_max_speakers: int = 5,
76
- diarization_process_timeout: int = 60):
 
 
 
 
 
77
 
78
  self.models = models
79
- self.nllb_models = nllb_models
80
 
81
  # WebUI settings
82
  self.input_audio_max_duration = input_audio_max_duration
@@ -120,6 +121,8 @@ class ApplicationConfig:
120
  self.compression_ratio_threshold = compression_ratio_threshold
121
  self.logprob_threshold = logprob_threshold
122
  self.no_speech_threshold = no_speech_threshold
 
 
123
 
124
  # Word timestamp settings
125
  self.word_timestamps = word_timestamps
@@ -134,12 +137,13 @@ class ApplicationConfig:
134
  self.diarization_min_speakers = diarization_min_speakers
135
  self.diarization_max_speakers = diarization_max_speakers
136
  self.diarization_process_timeout = diarization_process_timeout
 
 
 
 
137
 
138
- def get_model_names(self):
139
- return [ x.name for x in self.models ]
140
-
141
- def get_nllb_model_names(self):
142
- return [ x.name for x in self.nllb_models ]
143
 
144
  def update(self, **new_values):
145
  result = ApplicationConfig(**self.__dict__)
@@ -165,9 +169,9 @@ class ApplicationConfig:
165
  # Load using json5
166
  data = json5.load(f)
167
  data_models = data.pop("models", [])
168
- data_nllb_models = data.pop("nllb_models", [])
169
-
170
- models = [ ModelConfig(**x) for x in data_models ]
171
- nllb_models = [ ModelConfig(**x) for x in data_nllb_models ]
172
 
173
- return ApplicationConfig(models, nllb_models, **data)
 
1
  from enum import Enum
 
2
 
3
  import os
4
+ from typing import List, Dict, Literal
 
 
 
5
 
 
6
 
7
  class ModelConfig:
8
+ def __init__(self, name: str, url: str, path: str = None, type: str = "whisper", tokenizer_url: str = None):
9
  """
10
  Initialize a model configuration.
11
 
 
18
  self.url = url
19
  self.path = path
20
  self.type = type
21
+ self.tokenizer_url = tokenizer_url
22
 
23
  VAD_INITIAL_PROMPT_MODE_VALUES=["prepend_all_segments", "prepend_first_segment", "json_prompt_mode"]
24
 
 
29
 
30
  @staticmethod
31
  def from_string(s: str):
32
+ normalized = s.lower() if s is not None and len(s) > 0 else None
33
 
34
  if normalized == "prepend_all_segments":
35
  return VadInitialPromptMode.PREPEND_ALL_SEGMENTS
 
43
  return None
44
 
45
  class ApplicationConfig:
46
+ def __init__(self, models: Dict[Literal["whisper", "m2m100", "nllb", "mt5"], List[ModelConfig]],
47
+ input_audio_max_duration: int = 600, share: bool = False, server_name: str = None, server_port: int = 7860,
48
  queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
49
+ whisper_implementation: str = "whisper", default_model_name: str = "medium",
50
+ default_nllb_model_name: str = "distilled-600M", default_vad: str = "silero-vad",
51
  vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800,
52
  auto_parallel: bool = False, output_dir: str = None,
53
  model_dir: str = None, device: str = None,
 
62
  compute_type: str = "float16",
63
  temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
64
  logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6,
65
+ repetition_penalty: float = 1.0, no_repeat_ngram_size: int = 0,
66
  # Word timestamp settings
67
  word_timestamps: bool = False, prepend_punctuations: str = "\"\'“¿([{-",
68
  append_punctuations: str = "\"\'.。,,!!??::”)]}、",
 
70
  # Diarization
71
  auth_token: str = None, diarization: bool = False, diarization_speakers: int = 2,
72
  diarization_min_speakers: int = 1, diarization_max_speakers: int = 5,
73
+ diarization_process_timeout: int = 60,
74
+ # Translation
75
+ translation_batch_size: int = 2,
76
+ translation_no_repeat_ngram_size: int = 3,
77
+ translation_num_beams: int = 2,
78
+ ):
79
 
80
  self.models = models
 
81
 
82
  # WebUI settings
83
  self.input_audio_max_duration = input_audio_max_duration
 
121
  self.compression_ratio_threshold = compression_ratio_threshold
122
  self.logprob_threshold = logprob_threshold
123
  self.no_speech_threshold = no_speech_threshold
124
+ self.repetition_penalty = repetition_penalty
125
+ self.no_repeat_ngram_size = no_repeat_ngram_size
126
 
127
  # Word timestamp settings
128
  self.word_timestamps = word_timestamps
 
137
  self.diarization_min_speakers = diarization_min_speakers
138
  self.diarization_max_speakers = diarization_max_speakers
139
  self.diarization_process_timeout = diarization_process_timeout
140
+ # Translation
141
+ self.translation_batch_size = translation_batch_size
142
+ self.translation_no_repeat_ngram_size = translation_no_repeat_ngram_size
143
+ self.translation_num_beams = translation_num_beams
144
 
145
+ def get_model_names(self, name: str):
146
+ return [ x.name for x in self.models[name] ]
 
 
 
147
 
148
  def update(self, **new_values):
149
  result = ApplicationConfig(**self.__dict__)
 
169
  # Load using json5
170
  data = json5.load(f)
171
  data_models = data.pop("models", [])
172
+ models: Dict[Literal["whisper", "m2m100", "nllb", "mt5"], List[ModelConfig]] = {
173
+ key: [ModelConfig(**item) for item in value]
174
+ for key, value in data_models.items()
175
+ }
176
 
177
+ return ApplicationConfig(models, **data)
src/languages.py DELETED
@@ -1,147 +0,0 @@
1
- class Language():
2
- def __init__(self, code, name):
3
- self.code = code
4
- self.name = name
5
-
6
- def __str__(self):
7
- return "Language(code={}, name={})".format(self.code, self.name)
8
-
9
- LANGUAGES = [
10
- Language('en', 'English'),
11
- Language('zh', 'Chinese'),
12
- Language('de', 'German'),
13
- Language('es', 'Spanish'),
14
- Language('ru', 'Russian'),
15
- Language('ko', 'Korean'),
16
- Language('fr', 'French'),
17
- Language('ja', 'Japanese'),
18
- Language('pt', 'Portuguese'),
19
- Language('tr', 'Turkish'),
20
- Language('pl', 'Polish'),
21
- Language('ca', 'Catalan'),
22
- Language('nl', 'Dutch'),
23
- Language('ar', 'Arabic'),
24
- Language('sv', 'Swedish'),
25
- Language('it', 'Italian'),
26
- Language('id', 'Indonesian'),
27
- Language('hi', 'Hindi'),
28
- Language('fi', 'Finnish'),
29
- Language('vi', 'Vietnamese'),
30
- Language('he', 'Hebrew'),
31
- Language('uk', 'Ukrainian'),
32
- Language('el', 'Greek'),
33
- Language('ms', 'Malay'),
34
- Language('cs', 'Czech'),
35
- Language('ro', 'Romanian'),
36
- Language('da', 'Danish'),
37
- Language('hu', 'Hungarian'),
38
- Language('ta', 'Tamil'),
39
- Language('no', 'Norwegian'),
40
- Language('th', 'Thai'),
41
- Language('ur', 'Urdu'),
42
- Language('hr', 'Croatian'),
43
- Language('bg', 'Bulgarian'),
44
- Language('lt', 'Lithuanian'),
45
- Language('la', 'Latin'),
46
- Language('mi', 'Maori'),
47
- Language('ml', 'Malayalam'),
48
- Language('cy', 'Welsh'),
49
- Language('sk', 'Slovak'),
50
- Language('te', 'Telugu'),
51
- Language('fa', 'Persian'),
52
- Language('lv', 'Latvian'),
53
- Language('bn', 'Bengali'),
54
- Language('sr', 'Serbian'),
55
- Language('az', 'Azerbaijani'),
56
- Language('sl', 'Slovenian'),
57
- Language('kn', 'Kannada'),
58
- Language('et', 'Estonian'),
59
- Language('mk', 'Macedonian'),
60
- Language('br', 'Breton'),
61
- Language('eu', 'Basque'),
62
- Language('is', 'Icelandic'),
63
- Language('hy', 'Armenian'),
64
- Language('ne', 'Nepali'),
65
- Language('mn', 'Mongolian'),
66
- Language('bs', 'Bosnian'),
67
- Language('kk', 'Kazakh'),
68
- Language('sq', 'Albanian'),
69
- Language('sw', 'Swahili'),
70
- Language('gl', 'Galician'),
71
- Language('mr', 'Marathi'),
72
- Language('pa', 'Punjabi'),
73
- Language('si', 'Sinhala'),
74
- Language('km', 'Khmer'),
75
- Language('sn', 'Shona'),
76
- Language('yo', 'Yoruba'),
77
- Language('so', 'Somali'),
78
- Language('af', 'Afrikaans'),
79
- Language('oc', 'Occitan'),
80
- Language('ka', 'Georgian'),
81
- Language('be', 'Belarusian'),
82
- Language('tg', 'Tajik'),
83
- Language('sd', 'Sindhi'),
84
- Language('gu', 'Gujarati'),
85
- Language('am', 'Amharic'),
86
- Language('yi', 'Yiddish'),
87
- Language('lo', 'Lao'),
88
- Language('uz', 'Uzbek'),
89
- Language('fo', 'Faroese'),
90
- Language('ht', 'Haitian creole'),
91
- Language('ps', 'Pashto'),
92
- Language('tk', 'Turkmen'),
93
- Language('nn', 'Nynorsk'),
94
- Language('mt', 'Maltese'),
95
- Language('sa', 'Sanskrit'),
96
- Language('lb', 'Luxembourgish'),
97
- Language('my', 'Myanmar'),
98
- Language('bo', 'Tibetan'),
99
- Language('tl', 'Tagalog'),
100
- Language('mg', 'Malagasy'),
101
- Language('as', 'Assamese'),
102
- Language('tt', 'Tatar'),
103
- Language('haw', 'Hawaiian'),
104
- Language('ln', 'Lingala'),
105
- Language('ha', 'Hausa'),
106
- Language('ba', 'Bashkir'),
107
- Language('jw', 'Javanese'),
108
- Language('su', 'Sundanese')
109
- ]
110
-
111
- _TO_LANGUAGE_CODE = {
112
- **{language.code: language for language in LANGUAGES},
113
- "burmese": "my",
114
- "valencian": "ca",
115
- "flemish": "nl",
116
- "haitian": "ht",
117
- "letzeburgesch": "lb",
118
- "pushto": "ps",
119
- "panjabi": "pa",
120
- "moldavian": "ro",
121
- "moldovan": "ro",
122
- "sinhalese": "si",
123
- "castilian": "es",
124
- }
125
-
126
- _FROM_LANGUAGE_NAME = {
127
- **{language.name.lower(): language for language in LANGUAGES}
128
- }
129
-
130
- def get_language_from_code(language_code, default=None) -> Language:
131
- """Return the language name from the language code."""
132
- return _TO_LANGUAGE_CODE.get(language_code, default)
133
-
134
- def get_language_from_name(language, default=None) -> Language:
135
- """Return the language code from the language name."""
136
- return _FROM_LANGUAGE_NAME.get(language.lower() if language else None, default)
137
-
138
- def get_language_names():
139
- """Return a list of language names."""
140
- return [language.name for language in LANGUAGES]
141
-
142
- if __name__ == "__main__":
143
- # Test lookup
144
- print(get_language_from_code('en'))
145
- print(get_language_from_name('English'))
146
-
147
- print(get_language_names())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/nllb/nllbLangs.py DELETED
@@ -1,251 +0,0 @@
1
- class NllbLang():
2
- def __init__(self, code, name, code_whisper=None, name_whisper=None):
3
- self.code = code
4
- self.name = name
5
- self.code_whisper = code_whisper
6
- self.name_whisper = name_whisper
7
-
8
- def __str__(self):
9
- return "Language(code={}, name={})".format(self.code, self.name)
10
-
11
- NLLB_LANGS = [
12
- NllbLang('ace_Arab', 'Acehnese (Arabic script)'),
13
- NllbLang('ace_Latn', 'Acehnese (Latin script)'),
14
- NllbLang('acm_Arab', 'Mesopotamian Arabic', 'ar', 'Arabic'),
15
- NllbLang('acq_Arab', 'Ta’izzi-Adeni Arabic', 'ar', 'Arabic'),
16
- NllbLang('aeb_Arab', 'Tunisian Arabic'),
17
- NllbLang('afr_Latn', 'Afrikaans', 'am', 'Amharic'),
18
- NllbLang('ajp_Arab', 'South Levantine Arabic', 'ar', 'Arabic'),
19
- NllbLang('aka_Latn', 'Akan'),
20
- NllbLang('amh_Ethi', 'Amharic'),
21
- NllbLang('apc_Arab', 'North Levantine Arabic', 'ar', 'Arabic'),
22
- NllbLang('arb_Arab', 'Modern Standard Arabic', 'ar', 'Arabic'),
23
- NllbLang('arb_Latn', 'Modern Standard Arabic (Romanized)'),
24
- NllbLang('ars_Arab', 'Najdi Arabic', 'ar', 'Arabic'),
25
- NllbLang('ary_Arab', 'Moroccan Arabic', 'ar', 'Arabic'),
26
- NllbLang('arz_Arab', 'Egyptian Arabic', 'ar', 'Arabic'),
27
- NllbLang('asm_Beng', 'Assamese', 'as', 'Assamese'),
28
- NllbLang('ast_Latn', 'Asturian'),
29
- NllbLang('awa_Deva', 'Awadhi'),
30
- NllbLang('ayr_Latn', 'Central Aymara'),
31
- NllbLang('azb_Arab', 'South Azerbaijani', 'az', 'Azerbaijani'),
32
- NllbLang('azj_Latn', 'North Azerbaijani', 'az', 'Azerbaijani'),
33
- NllbLang('bak_Cyrl', 'Bashkir', 'ba', 'Bashkir'),
34
- NllbLang('bam_Latn', 'Bambara'),
35
- NllbLang('ban_Latn', 'Balinese'),
36
- NllbLang('bel_Cyrl', 'Belarusian', 'be', 'Belarusian'),
37
- NllbLang('bem_Latn', 'Bemba'),
38
- NllbLang('ben_Beng', 'Bengali', 'bn', 'Bengali'),
39
- NllbLang('bho_Deva', 'Bhojpuri'),
40
- NllbLang('bjn_Arab', 'Banjar (Arabic script)'),
41
- NllbLang('bjn_Latn', 'Banjar (Latin script)'),
42
- NllbLang('bod_Tibt', 'Standard Tibetan', 'bo', 'Tibetan'),
43
- NllbLang('bos_Latn', 'Bosnian', 'bs', 'Bosnian'),
44
- NllbLang('bug_Latn', 'Buginese'),
45
- NllbLang('bul_Cyrl', 'Bulgarian', 'bg', 'Bulgarian'),
46
- NllbLang('cat_Latn', 'Catalan', 'ca', 'Catalan'),
47
- NllbLang('ceb_Latn', 'Cebuano'),
48
- NllbLang('ces_Latn', 'Czech', 'cs', 'Czech'),
49
- NllbLang('cjk_Latn', 'Chokwe'),
50
- NllbLang('ckb_Arab', 'Central Kurdish'),
51
- NllbLang('crh_Latn', 'Crimean Tatar'),
52
- NllbLang('cym_Latn', 'Welsh', 'cy', 'Welsh'),
53
- NllbLang('dan_Latn', 'Danish', 'da', 'Danish'),
54
- NllbLang('deu_Latn', 'German', 'de', 'German'),
55
- NllbLang('dik_Latn', 'Southwestern Dinka'),
56
- NllbLang('dyu_Latn', 'Dyula'),
57
- NllbLang('dzo_Tibt', 'Dzongkha'),
58
- NllbLang('ell_Grek', 'Greek', 'el', 'Greek'),
59
- NllbLang('eng_Latn', 'English', 'en', 'English'),
60
- NllbLang('epo_Latn', 'Esperanto'),
61
- NllbLang('est_Latn', 'Estonian', 'et', 'Estonian'),
62
- NllbLang('eus_Latn', 'Basque', 'eu', 'Basque'),
63
- NllbLang('ewe_Latn', 'Ewe'),
64
- NllbLang('fao_Latn', 'Faroese', 'fo', 'Faroese'),
65
- NllbLang('fij_Latn', 'Fijian'),
66
- NllbLang('fin_Latn', 'Finnish', 'fi', 'Finnish'),
67
- NllbLang('fon_Latn', 'Fon'),
68
- NllbLang('fra_Latn', 'French', 'fr', 'French'),
69
- NllbLang('fur_Latn', 'Friulian'),
70
- NllbLang('fuv_Latn', 'Nigerian Fulfulde'),
71
- NllbLang('gla_Latn', 'Scottish Gaelic'),
72
- NllbLang('gle_Latn', 'Irish'),
73
- NllbLang('glg_Latn', 'Galician', 'gl', 'Galician'),
74
- NllbLang('grn_Latn', 'Guarani'),
75
- NllbLang('guj_Gujr', 'Gujarati', 'gu', 'Gujarati'),
76
- NllbLang('hat_Latn', 'Haitian Creole', 'ht', 'Haitian creole'),
77
- NllbLang('hau_Latn', 'Hausa', 'ha', 'Hausa'),
78
- NllbLang('heb_Hebr', 'Hebrew', 'he', 'Hebrew'),
79
- NllbLang('hin_Deva', 'Hindi', 'hi', 'Hindi'),
80
- NllbLang('hne_Deva', 'Chhattisgarhi'),
81
- NllbLang('hrv_Latn', 'Croatian', 'hr', 'Croatian'),
82
- NllbLang('hun_Latn', 'Hungarian', 'hu', 'Hungarian'),
83
- NllbLang('hye_Armn', 'Armenian', 'hy', 'Armenian'),
84
- NllbLang('ibo_Latn', 'Igbo'),
85
- NllbLang('ilo_Latn', 'Ilocano'),
86
- NllbLang('ind_Latn', 'Indonesian', 'id', 'Indonesian'),
87
- NllbLang('isl_Latn', 'Icelandic', 'is', 'Icelandic'),
88
- NllbLang('ita_Latn', 'Italian', 'it', 'Italian'),
89
- NllbLang('jav_Latn', 'Javanese', 'jw', 'Javanese'),
90
- NllbLang('jpn_Jpan', 'Japanese', 'ja', 'Japanese'),
91
- NllbLang('kab_Latn', 'Kabyle'),
92
- NllbLang('kac_Latn', 'Jingpho'),
93
- NllbLang('kam_Latn', 'Kamba'),
94
- NllbLang('kan_Knda', 'Kannada', 'kn', 'Kannada'),
95
- NllbLang('kas_Arab', 'Kashmiri (Arabic script)'),
96
- NllbLang('kas_Deva', 'Kashmiri (Devanagari script)'),
97
- NllbLang('kat_Geor', 'Georgian', 'ka', 'Georgian'),
98
- NllbLang('knc_Arab', 'Central Kanuri (Arabic script)'),
99
- NllbLang('knc_Latn', 'Central Kanuri (Latin script)'),
100
- NllbLang('kaz_Cyrl', 'Kazakh', 'kk', 'Kazakh'),
101
- NllbLang('kbp_Latn', 'Kabiyè'),
102
- NllbLang('kea_Latn', 'Kabuverdianu'),
103
- NllbLang('khm_Khmr', 'Khmer', 'km', 'Khmer'),
104
- NllbLang('kik_Latn', 'Kikuyu'),
105
- NllbLang('kin_Latn', 'Kinyarwanda'),
106
- NllbLang('kir_Cyrl', 'Kyrgyz'),
107
- NllbLang('kmb_Latn', 'Kimbundu'),
108
- NllbLang('kmr_Latn', 'Northern Kurdish'),
109
- NllbLang('kon_Latn', 'Kikongo'),
110
- NllbLang('kor_Hang', 'Korean', 'ko', 'Korean'),
111
- NllbLang('lao_Laoo', 'Lao', 'lo', 'Lao'),
112
- NllbLang('lij_Latn', 'Ligurian'),
113
- NllbLang('lim_Latn', 'Limburgish'),
114
- NllbLang('lin_Latn', 'Lingala', 'ln', 'Lingala'),
115
- NllbLang('lit_Latn', 'Lithuanian', 'lt', 'Lithuanian'),
116
- NllbLang('lmo_Latn', 'Lombard'),
117
- NllbLang('ltg_Latn', 'Latgalian'),
118
- NllbLang('ltz_Latn', 'Luxembourgish', 'lb', 'Luxembourgish'),
119
- NllbLang('lua_Latn', 'Luba-Kasai'),
120
- NllbLang('lug_Latn', 'Ganda'),
121
- NllbLang('luo_Latn', 'Luo'),
122
- NllbLang('lus_Latn', 'Mizo'),
123
- NllbLang('lvs_Latn', 'Standard Latvian', 'lv', 'Latvian'),
124
- NllbLang('mag_Deva', 'Magahi'),
125
- NllbLang('mai_Deva', 'Maithili'),
126
- NllbLang('mal_Mlym', 'Malayalam', 'ml', 'Malayalam'),
127
- NllbLang('mar_Deva', 'Marathi', 'mr', 'Marathi'),
128
- NllbLang('min_Arab', 'Minangkabau (Arabic script)'),
129
- NllbLang('min_Latn', 'Minangkabau (Latin script)'),
130
- NllbLang('mkd_Cyrl', 'Macedonian', 'mk', 'Macedonian'),
131
- NllbLang('plt_Latn', 'Plateau Malagasy', 'mg', 'Malagasy'),
132
- NllbLang('mlt_Latn', 'Maltese', 'mt', 'Maltese'),
133
- NllbLang('mni_Beng', 'Meitei (Bengali script)'),
134
- NllbLang('khk_Cyrl', 'Halh Mongolian', 'mn', 'Mongolian'),
135
- NllbLang('mos_Latn', 'Mossi'),
136
- NllbLang('mri_Latn', 'Maori', 'mi', 'Maori'),
137
- NllbLang('mya_Mymr', 'Burmese', 'my', 'Myanmar'),
138
- NllbLang('nld_Latn', 'Dutch', 'nl', 'Dutch'),
139
- NllbLang('nno_Latn', 'Norwegian Nynorsk', 'nn', 'Nynorsk'),
140
- NllbLang('nob_Latn', 'Norwegian Bokmål', 'no', 'Norwegian'),
141
- NllbLang('npi_Deva', 'Nepali', 'ne', 'Nepali'),
142
- NllbLang('nso_Latn', 'Northern Sotho'),
143
- NllbLang('nus_Latn', 'Nuer'),
144
- NllbLang('nya_Latn', 'Nyanja'),
145
- NllbLang('oci_Latn', 'Occitan', 'oc', 'Occitan'),
146
- NllbLang('gaz_Latn', 'West Central Oromo'),
147
- NllbLang('ory_Orya', 'Odia'),
148
- NllbLang('pag_Latn', 'Pangasinan'),
149
- NllbLang('pan_Guru', 'Eastern Panjabi', 'pa', 'Punjabi'),
150
- NllbLang('pap_Latn', 'Papiamento'),
151
- NllbLang('pes_Arab', 'Western Persian', 'fa', 'Persian'),
152
- NllbLang('pol_Latn', 'Polish', 'pl', 'Polish'),
153
- NllbLang('por_Latn', 'Portuguese', 'pt', 'Portuguese'),
154
- NllbLang('prs_Arab', 'Dari'),
155
- NllbLang('pbt_Arab', 'Southern Pashto', 'ps', 'Pashto'),
156
- NllbLang('quy_Latn', 'Ayacucho Quechua'),
157
- NllbLang('ron_Latn', 'Romanian', 'ro', 'Romanian'),
158
- NllbLang('run_Latn', 'Rundi'),
159
- NllbLang('rus_Cyrl', 'Russian', 'ru', 'Russian'),
160
- NllbLang('sag_Latn', 'Sango'),
161
- NllbLang('san_Deva', 'Sanskrit', 'sa', 'Sanskrit'),
162
- NllbLang('sat_Olck', 'Santali'),
163
- NllbLang('scn_Latn', 'Sicilian'),
164
- NllbLang('shn_Mymr', 'Shan'),
165
- NllbLang('sin_Sinh', 'Sinhala', 'si', 'Sinhala'),
166
- NllbLang('slk_Latn', 'Slovak', 'sk', 'Slovak'),
167
- NllbLang('slv_Latn', 'Slovenian', 'sl', 'Slovenian'),
168
- NllbLang('smo_Latn', 'Samoan'),
169
- NllbLang('sna_Latn', 'Shona', 'sn', 'Shona'),
170
- NllbLang('snd_Arab', 'Sindhi', 'sd', 'Sindhi'),
171
- NllbLang('som_Latn', 'Somali', 'so', 'Somali'),
172
- NllbLang('sot_Latn', 'Southern Sotho'),
173
- NllbLang('spa_Latn', 'Spanish', 'es', 'Spanish'),
174
- NllbLang('als_Latn', 'Tosk Albanian', 'sq', 'Albanian'),
175
- NllbLang('srd_Latn', 'Sardinian'),
176
- NllbLang('srp_Cyrl', 'Serbian', 'sr', 'Serbian'),
177
- NllbLang('ssw_Latn', 'Swati'),
178
- NllbLang('sun_Latn', 'Sundanese', 'su', 'Sundanese'),
179
- NllbLang('swe_Latn', 'Swedish', 'sv', 'Swedish'),
180
- NllbLang('swh_Latn', 'Swahili', 'sw', 'Swahili'),
181
- NllbLang('szl_Latn', 'Silesian'),
182
- NllbLang('tam_Taml', 'Tamil', 'ta', 'Tamil'),
183
- NllbLang('tat_Cyrl', 'Tatar', 'tt', 'Tatar'),
184
- NllbLang('tel_Telu', 'Telugu', 'te', 'Telugu'),
185
- NllbLang('tgk_Cyrl', 'Tajik', 'tg', 'Tajik'),
186
- NllbLang('tgl_Latn', 'Tagalog', 'tl', 'Tagalog'),
187
- NllbLang('tha_Thai', 'Thai', 'th', 'Thai'),
188
- NllbLang('tir_Ethi', 'Tigrinya'),
189
- NllbLang('taq_Latn', 'Tamasheq (Latin script)'),
190
- NllbLang('taq_Tfng', 'Tamasheq (Tifinagh script)'),
191
- NllbLang('tpi_Latn', 'Tok Pisin'),
192
- NllbLang('tsn_Latn', 'Tswana'),
193
- NllbLang('tso_Latn', 'Tsonga'),
194
- NllbLang('tuk_Latn', 'Turkmen', 'tk', 'Turkmen'),
195
- NllbLang('tum_Latn', 'Tumbuka'),
196
- NllbLang('tur_Latn', 'Turkish', 'tr', 'Turkish'),
197
- NllbLang('twi_Latn', 'Twi'),
198
- NllbLang('tzm_Tfng', 'Central Atlas Tamazight'),
199
- NllbLang('uig_Arab', 'Uyghur'),
200
- NllbLang('ukr_Cyrl', 'Ukrainian', 'uk', 'Ukrainian'),
201
- NllbLang('umb_Latn', 'Umbundu'),
202
- NllbLang('urd_Arab', 'Urdu', 'ur', 'Urdu'),
203
- NllbLang('uzn_Latn', 'Northern Uzbek', 'uz', 'Uzbek'),
204
- NllbLang('vec_Latn', 'Venetian'),
205
- NllbLang('vie_Latn', 'Vietnamese', 'vi', 'Vietnamese'),
206
- NllbLang('war_Latn', 'Waray'),
207
- NllbLang('wol_Latn', 'Wolof'),
208
- NllbLang('xho_Latn', 'Xhosa'),
209
- NllbLang('ydd_Hebr', 'Eastern Yiddish', 'yi', 'Yiddish'),
210
- NllbLang('yor_Latn', 'Yoruba', 'yo', 'Yoruba'),
211
- NllbLang('yue_Hant', 'Yue Chinese', 'zh', 'Chinese'),
212
- NllbLang('zho_Hans', 'Chinese (Simplified)', 'zh', 'Chinese'),
213
- NllbLang('zho_Hant', 'Chinese (Traditional)', 'zh', 'Chinese'),
214
- NllbLang('zsm_Latn', 'Standard Malay', 'ms', 'Malay'),
215
- NllbLang('zul_Latn', 'Zulu'),
216
- ]
217
-
218
- _TO_NLLB_LANG_CODE = {language.code.lower(): language for language in NLLB_LANGS if language.code is not None}
219
-
220
- _TO_NLLB_LANG_NAME = {language.name.lower(): language for language in NLLB_LANGS if language.name is not None}
221
-
222
- _TO_NLLB_LANG_WHISPER_CODE = {language.code_whisper.lower(): language for language in NLLB_LANGS if language.code_whisper is not None}
223
-
224
- _TO_NLLB_LANG_WHISPER_NAME = {language.name_whisper.lower(): language for language in NLLB_LANGS if language.name_whisper is not None}
225
-
226
- def get_nllb_lang_from_code(lang_code, default=None) -> NllbLang:
227
- """Return the language from the language code."""
228
- return _TO_NLLB_LANG_CODE.get(lang_code, default)
229
-
230
- def get_nllb_lang_from_name(lang_name, default=None) -> NllbLang:
231
- """Return the language from the language name."""
232
- return _TO_NLLB_LANG_NAME.get(lang_name.lower() if lang_name else None, default)
233
-
234
- def get_nllb_lang_from_code_whisper(lang_code_whisper, default=None) -> NllbLang:
235
- """Return the language from the language code."""
236
- return _TO_NLLB_LANG_WHISPER_CODE.get(lang_code_whisper, default)
237
-
238
- def get_nllb_lang_from_name_whisper(lang_name_whisper, default=None) -> NllbLang:
239
- """Return the language from the language name."""
240
- return _TO_NLLB_LANG_WHISPER_NAME.get(lang_name_whisper.lower() if lang_name_whisper else None, default)
241
-
242
- def get_nllb_lang_names():
243
- """Return a list of language names."""
244
- return [language.name for language in NLLB_LANGS]
245
-
246
- if __name__ == "__main__":
247
- # Test lookup
248
- print(get_nllb_lang_from_code('eng_Latn'))
249
- print(get_nllb_lang_from_name('English'))
250
-
251
- print(get_nllb_lang_names())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/translation/translationLangs.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Lang():
2
+ def __init__(self, code: str, *names: str):
3
+ self.code = code
4
+ self.names = names
5
+
6
+ def __repr__(self):
7
+ return f"code:{self.code}, name:{self.names}"
8
+
9
+ class TranslationLang():
10
+ def __init__(self, nllb: Lang, whisper: Lang = None, m2m100: Lang = None):
11
+ self.nllb = nllb
12
+ self.whisper = whisper
13
+ self.m2m100 = None
14
+
15
+ if m2m100 is None: m2m100 = whisper
16
+ if m2m100 is not None and len(m2m100.names) > 0:
17
+ self.m2m100 = m2m100
18
+
19
+ def __repr__(self):
20
+ result = ""
21
+ if self.nllb is not None:
22
+ result += f"NLLB={self.nllb} "
23
+ if self.whisper is not None:
24
+ result += f"WHISPER={self.whisper} "
25
+ if self.m2m100 is not None:
26
+ result += f"M@M100={self.m2m100} "
27
+ return f"Language {result}"
28
+
29
+ """
30
+ Model available Languages
31
+
32
+ [NLLB]
33
+ ace_Latn:Acehnese (Latin script), aka_Latn:Akan, als_Latn:Tosk Albanian, amh_Ethi:Amharic, asm_Beng:Assamese, awa_Deva:Awadhi, ayr_Latn:Central Aymara, azb_Arab:South Azerbaijani, azj_Latn:North Azerbaijani, bak_Cyrl:Bashkir, bam_Latn:Bambara, ban_Latn:Balinese, bel_Cyrl:Belarusian, bem_Latn:Bemba, ben_Beng:Bengali, bho_Deva:Bhojpuri, bjn_Latn:Banjar (Latin script), bod_Tibt:Standard Tibetan, bug_Latn:Buginese, ceb_Latn:Cebuano, cjk_Latn:Chokwe, ckb_Arab:Central Kurdish, crh_Latn:Crimean Tatar, cym_Latn:Welsh, dik_Latn:Southwestern Dinka, diq_Latn:Southern Zaza, dyu_Latn:Dyula, dzo_Tibt:Dzongkha, ewe_Latn:Ewe, fao_Latn:Faroese, fij_Latn:Fijian, fon_Latn:Fon, fur_Latn:Friulian, fuv_Latn:Nigerian Fulfulde, gaz_Latn:West Central Oromo, gla_Latn:Scottish Gaelic, gle_Latn:Irish, grn_Latn:Guarani, guj_Gujr:Gujarati, hat_Latn:Haitian Creole, hau_Latn:Hausa, hin_Deva:Hindi, hne_Deva:Chhattisgarhi, hye_Armn:Armenian, ibo_Latn:Igbo, ilo_Latn:Ilocano, ind_Latn:Indonesian, jav_Latn:Javanese, kab_Latn:Kabyle, kac_Latn:Jingpho, kam_Latn:Kamba, kan_Knda:Kannada, kas_Arab:Kashmiri (Arabic script), kas_Deva:Kashmiri (Devanagari script), kat_Geor:Georgian, kaz_Cyrl:Kazakh, kbp_Latn:Kabiyè, kea_Latn:Kabuverdianu, khk_Cyrl:Halh Mongolian, khm_Khmr:Khmer, kik_Latn:Kikuyu, kin_Latn:Kinyarwanda, kir_Cyrl:Kyrgyz, kmb_Latn:Kimbundu, kmr_Latn:Northern Kurdish, knc_Arab:Central Kanuri (Arabic script), knc_Latn:Central Kanuri (Latin script), kon_Latn:Kikongo, lao_Laoo:Lao, lij_Latn:Ligurian, lim_Latn:Limburgish, lin_Latn:Lingala, lmo_Latn:Lombard, ltg_Latn:Latgalian, ltz_Latn:Luxembourgish, lua_Latn:Luba-Kasai, lug_Latn:Ganda, luo_Latn:Luo, lus_Latn:Mizo, mag_Deva:Magahi, mai_Deva:Maithili, mal_Mlym:Malayalam, mar_Deva:Marathi, min_Latn:Minangkabau (Latin script), mlt_Latn:Maltese, mni_Beng:Meitei (Bengali script), mos_Latn:Mossi, mri_Latn:Maori, mya_Mymr:Burmese, npi_Deva:Nepali, nso_Latn:Northern Sotho, nus_Latn:Nuer, nya_Latn:Nyanja, ory_Orya:Odia, pag_Latn:Pangasinan, pan_Guru:Eastern Panjabi, pap_Latn:Papiamento, pbt_Arab:Southern Pashto, pes_Arab:Western Persian, plt_Latn:Plateau Malagasy, prs_Arab:Dari, quy_Latn:Ayacucho Quechua, run_Latn:Rundi, sag_Latn:Sango, san_Deva:Sanskrit, sat_Beng:Santali, scn_Latn:Sicilian, shn_Mymr:Shan, sin_Sinh:Sinhala, smo_Latn:Samoan, sna_Latn:Shona, snd_Arab:Sindhi, som_Latn:Somali, sot_Latn:Southern Sotho, srd_Latn:Sardinian, ssw_Latn:Swati, sun_Latn:Sundanese, swh_Latn:Swahili, szl_Latn:Silesian, tam_Taml:Tamil, taq_Latn:Tamasheq (Latin script), tat_Cyrl:Tatar, tel_Telu:Telugu, tgk_Cyrl:Tajik, tgl_Latn:Tagalog, tha_Thai:Thai, tir_Ethi:Tigrinya, tpi_Latn:Tok Pisin, tsn_Latn:Tswana, tso_Latn:Tsonga, tuk_Latn:Turkmen, tum_Latn:Tumbuka, tur_Latn:Turkish, twi_Latn:Twi, tzm_Tfng:Central Atlas Tamazight, uig_Arab:Uyghur, umb_Latn:Umbundu, urd_Arab:Urdu, uzn_Latn:Northern Uzbek, vec_Latn:Venetian, war_Latn:Waray, wol_Latn:Wolof, xho_Latn:Xhosa, ydd_Hebr:Eastern Yiddish, yor_Latn:Yoruba, zsm_Latn:Standard Malay, zul_Latn:Zulu
34
+ https://github.com/facebookresearch/LASER/blob/main/nllb/README.md
35
+
36
+ In the NLLB model, languages are identified by a FLORES-200 code of the form {language}_{script}, where the language is an ISO 639-3 code and the script is an ISO 15924 code.
37
+ https://github.com/sillsdev/serval/wiki/FLORES%E2%80%90200-Language-Code-Resolution-for-NMT-Engine
38
+
39
+ [whisper]
40
+ en:english, zh:chinese, de:german, es:spanish, ru:russian, ko:korean, fr:french, ja:japanese, pt:portuguese, tr:turkish, pl:polish, ca:catalan, nl:dutch, ar:arabic, sv:swedish, it:italian, id:indonesian, hi:hindi, fi:finnish, vi:vietnamese, he:hebrew, uk:ukrainian, el:greek, ms:malay, cs:czech, ro:romanian, da:danish, hu:hungarian, ta:tamil, no:norwegian, th:thai, ur:urdu, hr:croatian, bg:bulgarian, lt:lithuanian, la:latin, mi:maori, ml:malayalam, cy:welsh, sk:slovak, te:telugu, fa:persian, lv:latvian, bn:bengali, sr:serbian, az:azerbaijani, sl:slovenian, kn:kannada, et:estonian, mk:macedonian, br:breton, eu:basque, is:icelandic, hy:armenian, ne:nepali, mn:mongolian, bs:bosnian, kk:kazakh, sq:albanian, sw:swahili, gl:galician, mr:marathi, pa:punjabi, si:sinhala, km:khmer, sn:shona, yo:yoruba, so:somali, af:afrikaans, oc:occitan, ka:georgian, be:belarusian, tg:tajik, sd:sindhi, gu:gujarati, am:amharic, yi:yiddish, lo:lao, uz:uzbek, fo:faroese, ht:haitian creole, ps:pashto, tk:turkmen, nn:nynorsk, mt:maltese, sa:sanskrit, lb:luxembourgish, my:myanmar, bo:tibetan, tl:tagalog, mg:malagasy, as:assamese, tt:tatar, haw:hawaiian, ln:lingala, ha:hausa, ba:bashkir, jw:javanese, su:sundanese, yue:cantonese,
41
+ https://github.com/openai/whisper/blob/main/whisper/tokenizer.py
42
+
43
+ [m2m100]
44
+ af:Afrikaans, am:Amharic, ar:Arabic, ast:Asturian, az:Azerbaijani, ba:Bashkir, be:Belarusian, bg:Bulgarian, bn:Bengali, br:Breton, bs:Bosnian, ca:Catalan; Valencian, ceb:Cebuano, cs:Czech, cy:Welsh, da:Danish, de:German, el:Greek, en:English, es:Spanish, et:Estonian, fa:Persian, ff:Fulah, fi:Finnish, fr:French, fy:Western Frisian, ga:Irish, gd:Gaelic; Scottish Gaelic, gl:Galician, gu:Gujarati, ha:Hausa, he:Hebrew, hi:Hindi, hr:Croatian, ht:Haitian; Haitian Creole, hu:Hungarian, hy:Armenian, id:Indonesian, ig:Igbo, ilo:Iloko, is:Icelandic, it:Italian, ja:Japanese, jv:Javanese, ka:Georgian, kk:Kazakh, km:Central Khmer, kn:Kannada, ko:Korean, lb:Luxembourgish; Letzeburgesch, lg:Ganda, ln:Lingala, lo:Lao, lt:Lithuanian, lv:Latvian, mg:Malagasy, mk:Macedonian, ml:Malayalam, mn:Mongolian, mr:Marathi, ms:Malay, my:Burmese, ne:Nepali, nl:Dutch; Flemish, no:Norwegian, ns:Northern Sotho, Occitan (oc:post 1500), or:Oriya, pa:Panjabi; Punjabi, pl:Polish, ps:Pushto; Pashto, pt:Portuguese, ro:Romanian; Moldavian; Moldovan, ru:Russian, sd:Sindhi, si:Sinhala; Sinhalese, sk:Slovak, sl:Slovenian, so:Somali, sq:Albanian, sr:Serbian, ss:Swati, su:Sundanese, sv:Swedish, sw:Swahili, ta:Tamil, th:Thai, tl:Tagalog, tn:Tswana, tr:Turkish, uk:Ukrainian, ur:Urdu, uz:Uzbek, vi:Vietnamese, wo:Wolof, xh:Xhosa, yi:Yiddish, yo:Yoruba, zh:Chinese, zu:Zulu
45
+ https://huggingface.co/facebook/m2m100_1.2B
46
+
47
+ The available languages for m2m100 and whisper are almost identical. Most of the codes correspond to the ISO 639-1 standard. For detailed information, please refer to the official documentation provided.
48
+ """
49
+ TranslationLangs = [
50
+ TranslationLang(Lang("ace_Arab", "Acehnese (Arabic script)")),
51
+ TranslationLang(Lang("ace_Latn", "Acehnese (Latin script)")),
52
+ TranslationLang(Lang("acm_Arab", "Mesopotamian Arabic"), Lang("ar", "Arabic")),
53
+ TranslationLang(Lang("acq_Arab", "Ta’izzi-Adeni Arabic"), Lang("ar", "Arabic")),
54
+ TranslationLang(Lang("aeb_Arab", "Tunisian Arabic")),
55
+ TranslationLang(Lang("afr_Latn", "Afrikaans"), Lang("af", "Afrikaans")),
56
+ TranslationLang(Lang("ajp_Arab", "South Levantine Arabic"), Lang("ar", "Arabic")),
57
+ TranslationLang(Lang("aka_Latn", "Akan")),
58
+ TranslationLang(Lang("amh_Ethi", "Amharic"), Lang("am", "Amharic")),
59
+ TranslationLang(Lang("apc_Arab", "North Levantine Arabic"), Lang("ar", "Arabic")),
60
+ TranslationLang(Lang("arb_Arab", "Modern Standard Arabic"), Lang("ar", "Arabic")),
61
+ TranslationLang(Lang("arb_Latn", "Modern Standard Arabic (Romanized)")),
62
+ TranslationLang(Lang("ars_Arab", "Najdi Arabic"), Lang("ar", "Arabic")),
63
+ TranslationLang(Lang("ary_Arab", "Moroccan Arabic"), Lang("ar", "Arabic")),
64
+ TranslationLang(Lang("arz_Arab", "Egyptian Arabic"), Lang("ar", "Arabic")),
65
+ TranslationLang(Lang("asm_Beng", "Assamese"), Lang("as", "Assamese")),
66
+ TranslationLang(Lang("ast_Latn", "Asturian"), None, Lang("ast", "Asturian")),
67
+ TranslationLang(Lang("awa_Deva", "Awadhi")),
68
+ TranslationLang(Lang("ayr_Latn", "Central Aymara")),
69
+ TranslationLang(Lang("azb_Arab", "South Azerbaijani"), Lang("az", "Azerbaijani")),
70
+ TranslationLang(Lang("azj_Latn", "North Azerbaijani"), Lang("az", "Azerbaijani")),
71
+ TranslationLang(Lang("bak_Cyrl", "Bashkir"), Lang("ba", "Bashkir")),
72
+ TranslationLang(Lang("bam_Latn", "Bambara")),
73
+ TranslationLang(Lang("ban_Latn", "Balinese")),
74
+ TranslationLang(Lang("bel_Cyrl", "Belarusian"), Lang("be", "Belarusian")),
75
+ TranslationLang(Lang("bem_Latn", "Bemba")),
76
+ TranslationLang(Lang("ben_Beng", "Bengali"), Lang("bn", "Bengali")),
77
+ TranslationLang(Lang("bho_Deva", "Bhojpuri")),
78
+ TranslationLang(Lang("bjn_Arab", "Banjar (Arabic script)")),
79
+ TranslationLang(Lang("bjn_Latn", "Banjar (Latin script)")),
80
+ TranslationLang(Lang("bod_Tibt", "Standard Tibetan"), Lang("bo", "Tibetan")),
81
+ TranslationLang(Lang("bos_Latn", "Bosnian"), Lang("bs", "Bosnian")),
82
+ TranslationLang(Lang("bug_Latn", "Buginese")),
83
+ TranslationLang(Lang("bul_Cyrl", "Bulgarian"), Lang("bg", "Bulgarian")),
84
+ TranslationLang(Lang("cat_Latn", "Catalan"), Lang("ca", "Catalan", "valencian")),
85
+ TranslationLang(Lang("ceb_Latn", "Cebuano"), None, Lang("ceb", "Cebuano")),
86
+ TranslationLang(Lang("ces_Latn", "Czech"), Lang("cs", "Czech")),
87
+ TranslationLang(Lang("cjk_Latn", "Chokwe")),
88
+ TranslationLang(Lang("ckb_Arab", "Central Kurdish")),
89
+ TranslationLang(Lang("crh_Latn", "Crimean Tatar")),
90
+ TranslationLang(Lang("cym_Latn", "Welsh"), Lang("cy", "Welsh")),
91
+ TranslationLang(Lang("dan_Latn", "Danish"), Lang("da", "Danish")),
92
+ TranslationLang(Lang("deu_Latn", "German"), Lang("de", "German")),
93
+ TranslationLang(Lang("dik_Latn", "Southwestern Dinka")),
94
+ TranslationLang(Lang("dyu_Latn", "Dyula")),
95
+ TranslationLang(Lang("dzo_Tibt", "Dzongkha")),
96
+ TranslationLang(Lang("ell_Grek", "Greek"), Lang("el", "Greek")),
97
+ TranslationLang(Lang("eng_Latn", "English"), Lang("en", "English")),
98
+ TranslationLang(Lang("epo_Latn", "Esperanto")),
99
+ TranslationLang(Lang("est_Latn", "Estonian"), Lang("et", "Estonian")),
100
+ TranslationLang(Lang("eus_Latn", "Basque"), Lang("eu", "Basque")),
101
+ TranslationLang(Lang("ewe_Latn", "Ewe")),
102
+ TranslationLang(Lang("fao_Latn", "Faroese"), Lang("fo", "Faroese")),
103
+ TranslationLang(Lang("fij_Latn", "Fijian")),
104
+ TranslationLang(Lang("fin_Latn", "Finnish"), Lang("fi", "Finnish")),
105
+ TranslationLang(Lang("fon_Latn", "Fon")),
106
+ TranslationLang(Lang("fra_Latn", "French"), Lang("fr", "French")),
107
+ TranslationLang(Lang("fur_Latn", "Friulian")),
108
+ TranslationLang(Lang("fuv_Latn", "Nigerian Fulfulde"), None, Lang("ff", "Fulah")),
109
+ TranslationLang(Lang("gla_Latn", "Scottish Gaelic"), None, Lang("gd", "Scottish Gaelic")),
110
+ TranslationLang(Lang("gle_Latn", "Irish"), None, Lang("ga", "Irish")),
111
+ TranslationLang(Lang("glg_Latn", "Galician"), Lang("gl", "Galician")),
112
+ TranslationLang(Lang("grn_Latn", "Guarani")),
113
+ TranslationLang(Lang("guj_Gujr", "Gujarati"), Lang("gu", "Gujarati")),
114
+ TranslationLang(Lang("hat_Latn", "Haitian Creole"), Lang("ht", "Haitian creole", "haitian")),
115
+ TranslationLang(Lang("hau_Latn", "Hausa"), Lang("ha", "Hausa")),
116
+ TranslationLang(Lang("heb_Hebr", "Hebrew"), Lang("he", "Hebrew")),
117
+ TranslationLang(Lang("hin_Deva", "Hindi"), Lang("hi", "Hindi")),
118
+ TranslationLang(Lang("hne_Deva", "Chhattisgarhi")),
119
+ TranslationLang(Lang("hrv_Latn", "Croatian"), Lang("hr", "Croatian")),
120
+ TranslationLang(Lang("hun_Latn", "Hungarian"), Lang("hu", "Hungarian")),
121
+ TranslationLang(Lang("hye_Armn", "Armenian"), Lang("hy", "Armenian")),
122
+ TranslationLang(Lang("ibo_Latn", "Igbo"), None, Lang("ig", "Igbo")),
123
+ TranslationLang(Lang("ilo_Latn", "Ilocano"), None, Lang("ilo", "Iloko")),
124
+ TranslationLang(Lang("ind_Latn", "Indonesian"), Lang("id", "Indonesian")),
125
+ TranslationLang(Lang("isl_Latn", "Icelandic"), Lang("is", "Icelandic")),
126
+ TranslationLang(Lang("ita_Latn", "Italian"), Lang("it", "Italian")),
127
+ TranslationLang(Lang("jav_Latn", "Javanese"), Lang("jw", "Javanese"), Lang("jv", "Javanese")),
128
+ TranslationLang(Lang("jpn_Jpan", "Japanese"), Lang("ja", "Japanese")),
129
+ TranslationLang(Lang("kab_Latn", "Kabyle")),
130
+ TranslationLang(Lang("kac_Latn", "Jingpho")),
131
+ TranslationLang(Lang("kam_Latn", "Kamba")),
132
+ TranslationLang(Lang("kan_Knda", "Kannada"), Lang("kn", "Kannada")),
133
+ TranslationLang(Lang("kas_Arab", "Kashmiri (Arabic script)")),
134
+ TranslationLang(Lang("kas_Deva", "Kashmiri (Devanagari script)")),
135
+ TranslationLang(Lang("kat_Geor", "Georgian"), Lang("ka", "Georgian")),
136
+ TranslationLang(Lang("knc_Arab", "Central Kanuri (Arabic script)")),
137
+ TranslationLang(Lang("knc_Latn", "Central Kanuri (Latin script)")),
138
+ TranslationLang(Lang("kaz_Cyrl", "Kazakh"), Lang("kk", "Kazakh")),
139
+ TranslationLang(Lang("kbp_Latn", "Kabiyè")),
140
+ TranslationLang(Lang("kea_Latn", "Kabuverdianu")),
141
+ TranslationLang(Lang("khm_Khmr", "Khmer"), Lang("km", "Khmer")),
142
+ TranslationLang(Lang("kik_Latn", "Kikuyu")),
143
+ TranslationLang(Lang("kin_Latn", "Kinyarwanda")),
144
+ TranslationLang(Lang("kir_Cyrl", "Kyrgyz")),
145
+ TranslationLang(Lang("kmb_Latn", "Kimbundu")),
146
+ TranslationLang(Lang("kmr_Latn", "Northern Kurdish")),
147
+ TranslationLang(Lang("kon_Latn", "Kikongo")),
148
+ TranslationLang(Lang("kor_Hang", "Korean"), Lang("ko", "Korean")),
149
+ TranslationLang(Lang("lao_Laoo", "Lao"), Lang("lo", "Lao")),
150
+ TranslationLang(Lang("lij_Latn", "Ligurian")),
151
+ TranslationLang(Lang("lim_Latn", "Limburgish")),
152
+ TranslationLang(Lang("lin_Latn", "Lingala"), Lang("ln", "Lingala")),
153
+ TranslationLang(Lang("lit_Latn", "Lithuanian"), Lang("lt", "Lithuanian")),
154
+ TranslationLang(Lang("lmo_Latn", "Lombard")),
155
+ TranslationLang(Lang("ltg_Latn", "Latgalian")),
156
+ TranslationLang(Lang("ltz_Latn", "Luxembourgish"), Lang("lb", "Luxembourgish", "letzeburgesch")),
157
+ TranslationLang(Lang("lua_Latn", "Luba-Kasai")),
158
+ TranslationLang(Lang("lug_Latn", "Ganda"), None, Lang("lg", "Ganda")),
159
+ TranslationLang(Lang("luo_Latn", "Luo")),
160
+ TranslationLang(Lang("lus_Latn", "Mizo")),
161
+ TranslationLang(Lang("lvs_Latn", "Standard Latvian"), Lang("lv", "Latvian")),
162
+ TranslationLang(Lang("mag_Deva", "Magahi")),
163
+ TranslationLang(Lang("mai_Deva", "Maithili")),
164
+ TranslationLang(Lang("mal_Mlym", "Malayalam"), Lang("ml", "Malayalam")),
165
+ TranslationLang(Lang("mar_Deva", "Marathi"), Lang("mr", "Marathi")),
166
+ TranslationLang(Lang("min_Arab", "Minangkabau (Arabic script)")),
167
+ TranslationLang(Lang("min_Latn", "Minangkabau (Latin script)")),
168
+ TranslationLang(Lang("mkd_Cyrl", "Macedonian"), Lang("mk", "Macedonian")),
169
+ TranslationLang(Lang("plt_Latn", "Plateau Malagasy"), Lang("mg", "Malagasy")),
170
+ TranslationLang(Lang("mlt_Latn", "Maltese"), Lang("mt", "Maltese")),
171
+ TranslationLang(Lang("mni_Beng", "Meitei (Bengali script)")),
172
+ TranslationLang(Lang("khk_Cyrl", "Halh Mongolian"), Lang("mn", "Mongolian")),
173
+ TranslationLang(Lang("mos_Latn", "Mossi")),
174
+ TranslationLang(Lang("mri_Latn", "Maori"), Lang("mi", "Maori")),
175
+ TranslationLang(Lang("mya_Mymr", "Burmese"), Lang("my", "Myanmar", "burmese")),
176
+ TranslationLang(Lang("nld_Latn", "Dutch"), Lang("nl", "Dutch", "flemish")),
177
+ TranslationLang(Lang("nno_Latn", "Norwegian Nynorsk"), Lang("nn", "Nynorsk")),
178
+ TranslationLang(Lang("nob_Latn", "Norwegian Bokmål"), Lang("no", "Norwegian")),
179
+ TranslationLang(Lang("npi_Deva", "Nepali"), Lang("ne", "Nepali")),
180
+ TranslationLang(Lang("nso_Latn", "Northern Sotho"), None, Lang("ns", "Northern Sotho")),
181
+ TranslationLang(Lang("nus_Latn", "Nuer")),
182
+ TranslationLang(Lang("nya_Latn", "Nyanja")),
183
+ TranslationLang(Lang("oci_Latn", "Occitan"), Lang("oc", "Occitan")),
184
+ TranslationLang(Lang("gaz_Latn", "West Central Oromo")),
185
+ TranslationLang(Lang("ory_Orya", "Odia"), None, Lang("or", "Oriya")),
186
+ TranslationLang(Lang("pag_Latn", "Pangasinan")),
187
+ TranslationLang(Lang("pan_Guru", "Eastern Panjabi"), Lang("pa", "Punjabi", "panjabi")),
188
+ TranslationLang(Lang("pap_Latn", "Papiamento")),
189
+ TranslationLang(Lang("pes_Arab", "Western Persian"), Lang("fa", "Persian")),
190
+ TranslationLang(Lang("pol_Latn", "Polish"), Lang("pl", "Polish")),
191
+ TranslationLang(Lang("por_Latn", "Portuguese"), Lang("pt", "Portuguese")),
192
+ TranslationLang(Lang("prs_Arab", "Dari")),
193
+ TranslationLang(Lang("pbt_Arab", "Southern Pashto"), Lang("ps", "Pashto", "pushto")),
194
+ TranslationLang(Lang("quy_Latn", "Ayacucho Quechua")),
195
+ TranslationLang(Lang("ron_Latn", "Romanian"), Lang("ro", "Romanian", "moldavian", "moldovan")),
196
+ TranslationLang(Lang("run_Latn", "Rundi")),
197
+ TranslationLang(Lang("rus_Cyrl", "Russian"), Lang("ru", "Russian")),
198
+ TranslationLang(Lang("sag_Latn", "Sango")),
199
+ TranslationLang(Lang("san_Deva", "Sanskrit"), Lang("sa", "Sanskrit")),
200
+ TranslationLang(Lang("sat_Olck", "Santali")),
201
+ TranslationLang(Lang("scn_Latn", "Sicilian")),
202
+ TranslationLang(Lang("shn_Mymr", "Shan")),
203
+ TranslationLang(Lang("sin_Sinh", "Sinhala"), Lang("si", "Sinhala", "sinhalese")),
204
+ TranslationLang(Lang("slk_Latn", "Slovak"), Lang("sk", "Slovak")),
205
+ TranslationLang(Lang("slv_Latn", "Slovenian"), Lang("sl", "Slovenian")),
206
+ TranslationLang(Lang("smo_Latn", "Samoan")),
207
+ TranslationLang(Lang("sna_Latn", "Shona"), Lang("sn", "Shona")),
208
+ TranslationLang(Lang("snd_Arab", "Sindhi"), Lang("sd", "Sindhi")),
209
+ TranslationLang(Lang("som_Latn", "Somali"), Lang("so", "Somali")),
210
+ TranslationLang(Lang("sot_Latn", "Southern Sotho")),
211
+ TranslationLang(Lang("spa_Latn", "Spanish"), Lang("es", "Spanish", "castilian")),
212
+ TranslationLang(Lang("als_Latn", "Tosk Albanian"), Lang("sq", "Albanian")),
213
+ TranslationLang(Lang("srd_Latn", "Sardinian")),
214
+ TranslationLang(Lang("srp_Cyrl", "Serbian"), Lang("sr", "Serbian")),
215
+ TranslationLang(Lang("ssw_Latn", "Swati"), None, Lang("ss", "Swati")),
216
+ TranslationLang(Lang("sun_Latn", "Sundanese"), Lang("su", "Sundanese")),
217
+ TranslationLang(Lang("swe_Latn", "Swedish"), Lang("sv", "Swedish")),
218
+ TranslationLang(Lang("swh_Latn", "Swahili"), Lang("sw", "Swahili")),
219
+ TranslationLang(Lang("szl_Latn", "Silesian")),
220
+ TranslationLang(Lang("tam_Taml", "Tamil"), Lang("ta", "Tamil")),
221
+ TranslationLang(Lang("tat_Cyrl", "Tatar"), Lang("tt", "Tatar")),
222
+ TranslationLang(Lang("tel_Telu", "Telugu"), Lang("te", "Telugu")),
223
+ TranslationLang(Lang("tgk_Cyrl", "Tajik"), Lang("tg", "Tajik")),
224
+ TranslationLang(Lang("tgl_Latn", "Tagalog"), Lang("tl", "Tagalog")),
225
+ TranslationLang(Lang("tha_Thai", "Thai"), Lang("th", "Thai")),
226
+ TranslationLang(Lang("tir_Ethi", "Tigrinya")),
227
+ TranslationLang(Lang("taq_Latn", "Tamasheq (Latin script)")),
228
+ TranslationLang(Lang("taq_Tfng", "Tamasheq (Tifinagh script)")),
229
+ TranslationLang(Lang("tpi_Latn", "Tok Pisin")),
230
+ TranslationLang(Lang("tsn_Latn", "Tswana"), None, Lang("tn", "Tswana")),
231
+ TranslationLang(Lang("tso_Latn", "Tsonga")),
232
+ TranslationLang(Lang("tuk_Latn", "Turkmen"), Lang("tk", "Turkmen")),
233
+ TranslationLang(Lang("tum_Latn", "Tumbuka")),
234
+ TranslationLang(Lang("tur_Latn", "Turkish"), Lang("tr", "Turkish")),
235
+ TranslationLang(Lang("twi_Latn", "Twi")),
236
+ TranslationLang(Lang("tzm_Tfng", "Central Atlas Tamazight")),
237
+ TranslationLang(Lang("uig_Arab", "Uyghur")),
238
+ TranslationLang(Lang("ukr_Cyrl", "Ukrainian"), Lang("uk", "Ukrainian")),
239
+ TranslationLang(Lang("umb_Latn", "Umbundu")),
240
+ TranslationLang(Lang("urd_Arab", "Urdu"), Lang("ur", "Urdu")),
241
+ TranslationLang(Lang("uzn_Latn", "Northern Uzbek"), Lang("uz", "Uzbek")),
242
+ TranslationLang(Lang("vec_Latn", "Venetian")),
243
+ TranslationLang(Lang("vie_Latn", "Vietnamese"), Lang("vi", "Vietnamese")),
244
+ TranslationLang(Lang("war_Latn", "Waray")),
245
+ TranslationLang(Lang("wol_Latn", "Wolof"), None, Lang("wo", "Wolof")),
246
+ TranslationLang(Lang("xho_Latn", "Xhosa"), None, Lang("xh", "Xhosa")),
247
+ TranslationLang(Lang("ydd_Hebr", "Eastern Yiddish"), Lang("yi", "Yiddish")),
248
+ TranslationLang(Lang("yor_Latn", "Yoruba"), Lang("yo", "Yoruba")),
249
+ TranslationLang(Lang("yue_Hant", "Yue Chinese"), Lang("yue", "cantonese"), Lang("zh", "Chinese (zh-yue)")),
250
+ TranslationLang(Lang("zho_Hans", "Chinese (Simplified)"), Lang("zh", "Chinese (Simplified)", "Chinese", "mandarin")),
251
+ TranslationLang(Lang("zho_Hant", "Chinese (Traditional)"), Lang("zh", "Chinese (Traditional)")),
252
+ TranslationLang(Lang("zsm_Latn", "Standard Malay"), Lang("ms", "Malay")),
253
+ TranslationLang(Lang("zul_Latn", "Zulu"), None, Lang("zu", "Zulu")),
254
+ TranslationLang(None, Lang("br", "Breton")), # Both whisper and m2m100 support the Breton language, but nllb does not have this language.
255
+ ]
256
+
257
+
258
+ _TO_LANG_NAME_NLLB = {name.lower(): language for language in TranslationLangs if language.nllb is not None for name in language.nllb.names}
259
+
260
+ _TO_LANG_NAME_M2M100 = {name.lower(): language for language in TranslationLangs if language.m2m100 is not None for name in language.m2m100.names}
261
+
262
+ _TO_LANG_NAME_WHISPER = {name.lower(): language for language in TranslationLangs if language.whisper is not None for name in language.whisper.names}
263
+
264
+ _TO_LANG_CODE_WHISPER = {language.whisper.code.lower(): language for language in TranslationLangs if language.whisper is not None and len(language.whisper.code) > 0}
265
+
266
+
267
+ def get_lang_from_nllb_name(nllbName, default=None) -> TranslationLang:
268
+ """Return the TranslationLang from the lang_name_nllb."""
269
+ return _TO_LANG_NAME_NLLB.get(nllbName.lower() if nllbName else None, default)
270
+
271
+ def get_lang_from_m2m100_name(m2m100Name, default=None) -> TranslationLang:
272
+ """Return the TranslationLang from the lang_name_m2m100 name."""
273
+ return _TO_LANG_NAME_M2M100.get(m2m100Name.lower() if m2m100Name else None, default)
274
+
275
+ def get_lang_from_whisper_name(whisperName, default=None) -> TranslationLang:
276
+ """Return the TranslationLang from the lang_name_whisper name."""
277
+ return _TO_LANG_NAME_WHISPER.get(whisperName.lower() if whisperName else None, default)
278
+
279
+ def get_lang_from_whisper_code(whisperCode, default=None) -> TranslationLang:
280
+ """Return the TranslationLang from the lang_code_whisper."""
281
+ return _TO_LANG_CODE_WHISPER.get(whisperCode, default)
282
+
283
+ def get_lang_nllb_names():
284
+ """Return a list of nllb language names."""
285
+ return list(_TO_LANG_NAME_NLLB.keys())
286
+
287
+ def get_lang_m2m100_names(codes = []):
288
+ """Return a list of m2m100 language names."""
289
+ return list({name.lower(): None for language in TranslationLangs if language.m2m100 is not None and (len(codes) == 0 or any(code in language.m2m100.code for code in codes)) for name in language.m2m100.names}.keys())
290
+
291
+ def get_lang_whisper_names():
292
+ """Return a list of whisper language names."""
293
+ return list(_TO_LANG_NAME_WHISPER.keys())
294
+
295
+ if __name__ == "__main__":
296
+ # Test lookup
297
+ print("name:Chinese (Traditional)", get_lang_from_nllb_name("Chinese (Traditional)"))
298
+ print("name:moldavian", get_lang_from_m2m100_name("moldavian"))
299
+ print("code:ja", get_lang_from_whisper_code("ja"))
300
+ print("name:English", get_lang_from_nllb_name('English'))
301
+
302
+ print(get_lang_m2m100_names(["en", "ja", "zh"]))
303
+ print(get_lang_nllb_names())
src/{nllb/nllbModel.py → translation/translationModel.py} RENAMED
@@ -9,24 +9,26 @@ import transformers
9
 
10
  from typing import Optional
11
  from src.config import ModelConfig
12
- from src.languages import Language
13
- from src.nllb.nllbLangs import NllbLang, get_nllb_lang_from_code_whisper
14
 
15
- class NllbModel:
16
  def __init__(
17
  self,
18
- model_config: ModelConfig,
19
  device: str = None,
20
- whisper_lang: Language = None,
21
- nllb_lang: NllbLang = None,
22
- download_root: Optional[str] = None,
23
- local_files_only: bool = False,
24
- load_model: bool = False,
 
 
 
25
  ):
26
- """Initializes the Nllb-200 model.
27
 
28
  Args:
29
- model_config: Config of the model to use (distilled-600M, distilled-1.3B,
30
  1.3B, 3.3B...) or a path to a converted
31
  model directory. When a size is configured, the converted model is downloaded
32
  from the Hugging Face Hub.
@@ -44,62 +46,72 @@ class NllbModel:
44
  having multiple workers enables true parallelism when running the model
45
  (concurrent calls to self.model.generate() will run in parallel).
46
  This can improve the global throughput at the cost of increased memory usage.
47
- download_root: Directory where the models should be saved. If not set, the models
48
  are saved in the standard Hugging Face cache directory.
49
- local_files_only: If True, avoid downloading the file and return the path to the
50
  local cached file if it exists.
51
  """
52
- self.whisper_lang = whisper_lang
53
- self.nllb_whisper_lang = get_nllb_lang_from_code_whisper(whisper_lang.code.lower() if whisper_lang is not None else "en")
54
- self.nllb_lang = nllb_lang
55
- self.model_config = model_config
56
 
57
- if nllb_lang is None:
58
  return
 
 
 
 
59
 
60
- if os.path.isdir(model_config.url):
61
- self.model_path = model_config.url
62
  else:
63
- self.model_path = download_model(
64
- model_config,
65
- local_files_only=local_files_only,
66
- cache_dir=download_root,
67
  )
68
 
69
  if device is None:
70
  if torch.cuda.is_available():
71
- device = "cuda" if "ct2" in self.model_path else "cuda:0"
72
  else:
73
  device = "cpu"
74
 
75
  self.device = device
76
 
77
- if load_model:
78
  self.load_model()
79
 
80
  def load_model(self):
81
- print('\n\nLoading model: %s\n\n' % self.model_path)
82
- if "ct2" in self.model_path:
83
- self.target_prefix = [self.nllb_lang.code]
84
- self.trans_tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_path, src_lang=self.nllb_whisper_lang.code)
85
- self.trans_model = ctranslate2.Translator(self.model_path, compute_type="auto", device=self.device)
86
- elif "mt5" in self.model_path:
87
- self.mt5_prefix = self.whisper_lang.code + "2" + self.nllb_lang.code_whisper + ": "
88
- self.trans_tokenizer = transformers.T5Tokenizer.from_pretrained(self.model_path, legacy=False) #requires spiece.model
89
- self.trans_model = transformers.MT5ForConditionalGeneration.from_pretrained(self.model_path)
90
- self.trans_translator = transformers.pipeline('text2text-generation', model=self.trans_model, device=self.device, tokenizer=self.trans_tokenizer)
91
- else: #NLLB
92
- self.trans_tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_path)
93
- self.trans_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.model_path)
94
- self.trans_translator = transformers.pipeline('translation', model=self.trans_model, device=self.device, tokenizer=self.trans_tokenizer, src_lang=self.nllb_whisper_lang.code, tgt_lang=self.nllb_lang.code)
 
 
 
 
 
 
 
95
 
96
  def release_vram(self):
97
  try:
98
  if torch.cuda.is_available():
99
- if "ct2" not in self.model_path:
100
  device = torch.device("cpu")
101
- self.trans_model.to(device)
102
- del self.trans_model
103
  torch.cuda.empty_cache()
104
  print("release vram end.")
105
  except Exception as e:
@@ -110,16 +122,16 @@ class NllbModel:
110
  output = None
111
  result = None
112
  try:
113
- if "ct2" in self.model_path:
114
- source = self.trans_tokenizer.convert_ids_to_tokens(self.trans_tokenizer.encode(text))
115
- output = self.trans_model.translate_batch([source], target_prefix=[self.target_prefix])
116
  target = output[0].hypotheses[0][1:]
117
- result = self.trans_tokenizer.decode(self.trans_tokenizer.convert_tokens_to_ids(target))
118
- elif "mt5" in self.model_path:
119
- output = self.trans_translator(self.mt5_prefix + text, max_length=max_length, num_beams=4)
120
  result = output[0]['generated_text']
121
- else: #NLLB
122
- output = self.trans_translator(text, max_length=max_length)
123
  result = output[0]['translation_text']
124
  except Exception as e:
125
  print("Error translation text: " + str(e))
@@ -133,6 +145,8 @@ _MODELS = ["distilled-600M", "distilled-1.3B", "1.3B", "3.3B",
133
  "nllb-200-3.3B-ct2-float16", "nllb-200-1.3B-ct2", "nllb-200-1.3B-ct2-int8", "nllb-200-1.3B-ct2-float16",
134
  "nllb-200-distilled-1.3B-ct2", "nllb-200-distilled-1.3B-ct2-int8", "nllb-200-distilled-1.3B-ct2-float16",
135
  "nllb-200-distilled-600M-ct2", "nllb-200-distilled-600M-ct2-int8", "nllb-200-distilled-600M-ct2-float16",
 
 
136
  "mt5-zh-ja-en-trimmed",
137
  "mt5-zh-ja-en-trimmed-fine-tuned-v1"]
138
 
@@ -140,10 +154,10 @@ def check_model_name(name):
140
  return any(allowed_name in name for allowed_name in _MODELS)
141
 
142
  def download_model(
143
- model_config: ModelConfig,
144
- output_dir: Optional[str] = None,
145
- local_files_only: bool = False,
146
- cache_dir: Optional[str] = None,
147
  ):
148
  """"download_model" is referenced from the "utils.py" script
149
  of the "faster_whisper" project, authored by guillaumekln.
@@ -153,13 +167,13 @@ def download_model(
153
  The model is downloaded from https://huggingface.co/facebook.
154
 
155
  Args:
156
- model_config: config of the model to download (facebook/nllb-distilled-600M,
157
  facebook/nllb-distilled-1.3B, facebook/nllb-1.3B, facebook/nllb-3.3B...).
158
- output_dir: Directory where the model should be saved. If not set, the model is saved in
159
  the cache directory.
160
- local_files_only: If True, avoid downloading the file and return the path to the local
161
  cached file if it exists.
162
- cache_dir: Path to the folder where cached files are stored.
163
 
164
  Returns:
165
  The path to the downloaded model.
@@ -167,19 +181,20 @@ def download_model(
167
  Raises:
168
  ValueError: if the model size is invalid.
169
  """
170
- if not check_model_name(model_config.name):
171
  raise ValueError(
172
- "Invalid model name '%s', expected one of: %s" % (model_config.name, ", ".join(_MODELS))
173
  )
174
 
175
- repo_id = model_config.url #"facebook/nllb-200-%s" %
176
 
177
- allow_patterns = [
178
  "config.json",
179
  "generation_config.json",
180
  "model.bin",
181
  "pytorch_model.bin",
182
  "pytorch_model.bin.index.json",
 
183
  "pytorch_model-00001-of-00003.bin",
184
  "pytorch_model-00002-of-00003.bin",
185
  "pytorch_model-00003-of-00003.bin",
@@ -190,30 +205,31 @@ def download_model(
190
  "shared_vocabulary.json",
191
  "special_tokens_map.json",
192
  "spiece.model",
 
193
  ]
194
 
195
  kwargs = {
196
- "local_files_only": local_files_only,
197
- "allow_patterns": allow_patterns,
198
  #"tqdm_class": disabled_tqdm,
199
  }
200
 
201
- if output_dir is not None:
202
- kwargs["local_dir"] = output_dir
203
  kwargs["local_dir_use_symlinks"] = False
204
 
205
- if cache_dir is not None:
206
- kwargs["cache_dir"] = cache_dir
207
 
208
  try:
209
- return huggingface_hub.snapshot_download(repo_id, **kwargs)
210
  except (
211
  huggingface_hub.utils.HfHubHTTPError,
212
  requests.exceptions.ConnectionError,
213
  ) as exception:
214
  warnings.warn(
215
  "An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s",
216
- repo_id,
217
  exception,
218
  )
219
  warnings.warn(
@@ -221,4 +237,4 @@ def download_model(
221
  )
222
 
223
  kwargs["local_files_only"] = True
224
- return huggingface_hub.snapshot_download(repo_id, **kwargs)
 
9
 
10
  from typing import Optional
11
  from src.config import ModelConfig
12
+ from src.translation.translationLangs import TranslationLang, get_lang_from_whisper_code
 
13
 
14
+ class TranslationModel:
15
  def __init__(
16
  self,
17
+ modelConfig: ModelConfig,
18
  device: str = None,
19
+ whisperLang: TranslationLang = None,
20
+ translationLang: TranslationLang = None,
21
+ batchSize: int = 2,
22
+ noRepeatNgramSize: int = 3,
23
+ numBeams: int = 2,
24
+ downloadRoot: Optional[str] = None,
25
+ localFilesOnly: bool = False,
26
+ loadModel: bool = False,
27
  ):
28
+ """Initializes the M2M100 / Nllb-200 / mt5 model.
29
 
30
  Args:
31
+ modelConfig: Config of the model to use (distilled-600M, distilled-1.3B,
32
  1.3B, 3.3B...) or a path to a converted
33
  model directory. When a size is configured, the converted model is downloaded
34
  from the Hugging Face Hub.
 
46
  having multiple workers enables true parallelism when running the model
47
  (concurrent calls to self.model.generate() will run in parallel).
48
  This can improve the global throughput at the cost of increased memory usage.
49
+ downloadRoot: Directory where the models should be saved. If not set, the models
50
  are saved in the standard Hugging Face cache directory.
51
+ localFilesOnly: If True, avoid downloading the file and return the path to the
52
  local cached file if it exists.
53
  """
54
+ self.modelConfig = modelConfig
55
+ self.whisperLang = whisperLang # self.translationLangWhisper = get_lang_from_whisper_code(whisperLang.code.lower() if whisperLang is not None else "en")
56
+ self.translationLang = translationLang
 
57
 
58
+ if translationLang is None:
59
  return
60
+
61
+ self.batchSize = batchSize
62
+ self.noRepeatNgramSize = noRepeatNgramSize
63
+ self.numBeams = numBeams
64
 
65
+ if os.path.isdir(modelConfig.url):
66
+ self.modelPath = modelConfig.url
67
  else:
68
+ self.modelPath = download_model(
69
+ modelConfig,
70
+ localFilesOnly=localFilesOnly,
71
+ cacheDir=downloadRoot,
72
  )
73
 
74
  if device is None:
75
  if torch.cuda.is_available():
76
+ device = "cuda" if "ct2" in self.modelPath else "cuda:0"
77
  else:
78
  device = "cpu"
79
 
80
  self.device = device
81
 
82
+ if loadModel:
83
  self.load_model()
84
 
85
  def load_model(self):
86
+ print('\n\nLoading model: %s\n\n' % self.modelPath)
87
+ if "ct2" in self.modelPath:
88
+ if "nllb" in self.modelPath:
89
+ self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.nllb.code)
90
+ self.targetPrefix = [self.translationLang.nllb.code]
91
+ elif "m2m100" in self.modelPath:
92
+ self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.m2m100.code)
93
+ self.targetPrefix = [self.transTokenizer.lang_code_to_token[self.translationLang.m2m100.code]]
94
+ self.transModel = ctranslate2.Translator(self.modelPath, compute_type="auto", device=self.device)
95
+ elif "mt5" in self.modelPath:
96
+ self.mt5Prefix = self.whisperLang.whisper.code + "2" + self.translationLang.whisper.code + ": "
97
+ self.transTokenizer = transformers.T5Tokenizer.from_pretrained(self.modelPath, legacy=False) #requires spiece.model
98
+ self.transModel = transformers.MT5ForConditionalGeneration.from_pretrained(self.modelPath)
99
+ self.transTranslator = transformers.pipeline('text2text-generation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer)
100
+ else:
101
+ self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath)
102
+ self.transModel = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.modelPath)
103
+ if "m2m100" in self.modelPath:
104
+ self.transTranslator = transformers.pipeline('translation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer, src_lang=self.whisperLang.m2m100.code, tgt_lang=self.translationLang.m2m100.code)
105
+ else: #NLLB
106
+ self.transTranslator = transformers.pipeline('translation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer, src_lang=self.whisperLang.nllb.code, tgt_lang=self.translationLang.nllb.code)
107
 
108
  def release_vram(self):
109
  try:
110
  if torch.cuda.is_available():
111
+ if "ct2" not in self.modelPath:
112
  device = torch.device("cpu")
113
+ self.transModel.to(device)
114
+ del self.transModel
115
  torch.cuda.empty_cache()
116
  print("release vram end.")
117
  except Exception as e:
 
122
  output = None
123
  result = None
124
  try:
125
+ if "ct2" in self.modelPath:
126
+ source = self.transTokenizer.convert_ids_to_tokens(self.transTokenizer.encode(text))
127
+ output = self.transModel.translate_batch([source], target_prefix=[self.targetPrefix], max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams)
128
  target = output[0].hypotheses[0][1:]
129
+ result = self.transTokenizer.decode(self.transTokenizer.convert_tokens_to_ids(target))
130
+ elif "mt5" in self.modelPath:
131
+ output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) #, num_return_sequences=2
132
  result = output[0]['generated_text']
133
+ else: #M2M100 & NLLB
134
+ output = self.transTranslator(text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
135
  result = output[0]['translation_text']
136
  except Exception as e:
137
  print("Error translation text: " + str(e))
 
145
  "nllb-200-3.3B-ct2-float16", "nllb-200-1.3B-ct2", "nllb-200-1.3B-ct2-int8", "nllb-200-1.3B-ct2-float16",
146
  "nllb-200-distilled-1.3B-ct2", "nllb-200-distilled-1.3B-ct2-int8", "nllb-200-distilled-1.3B-ct2-float16",
147
  "nllb-200-distilled-600M-ct2", "nllb-200-distilled-600M-ct2-int8", "nllb-200-distilled-600M-ct2-float16",
148
+ "m2m100_1.2B-ct2", "m2m100_418M-ct2", "m2m100-12B-ct2",
149
+ "m2m100_1.2B", "m2m100_418M",
150
  "mt5-zh-ja-en-trimmed",
151
  "mt5-zh-ja-en-trimmed-fine-tuned-v1"]
152
 
 
154
  return any(allowed_name in name for allowed_name in _MODELS)
155
 
156
  def download_model(
157
+ modelConfig: ModelConfig,
158
+ outputDir: Optional[str] = None,
159
+ localFilesOnly: bool = False,
160
+ cacheDir: Optional[str] = None,
161
  ):
162
  """"download_model" is referenced from the "utils.py" script
163
  of the "faster_whisper" project, authored by guillaumekln.
 
167
  The model is downloaded from https://huggingface.co/facebook.
168
 
169
  Args:
170
+ modelConfig: config of the model to download (facebook/nllb-distilled-600M,
171
  facebook/nllb-distilled-1.3B, facebook/nllb-1.3B, facebook/nllb-3.3B...).
172
+ outputDir: Directory where the model should be saved. If not set, the model is saved in
173
  the cache directory.
174
+ localFilesOnly: If True, avoid downloading the file and return the path to the local
175
  cached file if it exists.
176
+ cacheDir: Path to the folder where cached files are stored.
177
 
178
  Returns:
179
  The path to the downloaded model.
 
181
  Raises:
182
  ValueError: if the model size is invalid.
183
  """
184
+ if not check_model_name(modelConfig.name):
185
  raise ValueError(
186
+ "Invalid model name '%s', expected one of: %s" % (modelConfig.name, ", ".join(_MODELS))
187
  )
188
 
189
+ repoId = modelConfig.url #"facebook/nllb-200-%s" %
190
 
191
+ allowPatterns = [
192
  "config.json",
193
  "generation_config.json",
194
  "model.bin",
195
  "pytorch_model.bin",
196
  "pytorch_model.bin.index.json",
197
+ "pytorch_model-*.bin",
198
  "pytorch_model-00001-of-00003.bin",
199
  "pytorch_model-00002-of-00003.bin",
200
  "pytorch_model-00003-of-00003.bin",
 
205
  "shared_vocabulary.json",
206
  "special_tokens_map.json",
207
  "spiece.model",
208
+ "vocab.json", #m2m100
209
  ]
210
 
211
  kwargs = {
212
+ "local_files_only": localFilesOnly,
213
+ "allow_patterns": allowPatterns,
214
  #"tqdm_class": disabled_tqdm,
215
  }
216
 
217
+ if outputDir is not None:
218
+ kwargs["local_dir"] = outputDir
219
  kwargs["local_dir_use_symlinks"] = False
220
 
221
+ if cacheDir is not None:
222
+ kwargs["cache_dir"] = cacheDir
223
 
224
  try:
225
+ return huggingface_hub.snapshot_download(repoId, **kwargs)
226
  except (
227
  huggingface_hub.utils.HfHubHTTPError,
228
  requests.exceptions.ConnectionError,
229
  ) as exception:
230
  warnings.warn(
231
  "An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s",
232
+ repoId,
233
  exception,
234
  )
235
  warnings.warn(
 
237
  )
238
 
239
  kwargs["local_files_only"] = True
240
+ return huggingface_hub.snapshot_download(repoId, **kwargs)
src/utils.py CHANGED
@@ -100,46 +100,91 @@ def write_srt(transcript: Iterator[dict], file: TextIO,
100
  flush=True,
101
  )
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: int = None, highlight_words: bool = False):
104
  for segment in transcript:
105
  words: list = segment.get('words', [])
106
 
107
  # Append longest speaker ID if available
108
  segment_longest_speaker = segment.get('longest_speaker', None)
 
 
 
 
 
109
  if segment_longest_speaker is not None:
110
  segment_longest_speaker = segment_longest_speaker.replace("SPEAKER", "S")
111
-
 
 
 
 
 
112
  if len(words) == 0:
113
- # Yield the segment as-is or processed
114
- if (maxLineWidth is None or maxLineWidth < 0) and segment_longest_speaker is None:
115
- yield segment
116
- else:
117
- text = segment['text'].strip()
118
-
119
- # Prepend the longest speaker ID if available
120
- if segment_longest_speaker is not None:
121
- text = f"({segment_longest_speaker}) {text}"
122
-
123
- yield {
124
- 'start': segment['start'],
125
- 'end': segment['end'],
126
- 'text': process_text(text, maxLineWidth)
127
- }
128
  # We are done
129
  continue
130
 
131
- subtitle_start = segment['start']
132
- subtitle_end = segment['end']
133
-
134
  if segment_longest_speaker is not None:
135
  # Add the beginning
136
  words.insert(0, {
137
  'start': subtitle_start,
138
- 'end': subtitle_start,
139
- 'word': f"({segment_longest_speaker})"
140
  })
141
 
142
- text_words = [ this_word["word"] for this_word in words ]
143
  subtitle_text = __join_words(text_words, maxLineWidth)
144
 
145
  # Iterate over the words in the segment
@@ -154,15 +199,15 @@ def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: i
154
  # Display the text up to this point
155
  yield {
156
  'start': last,
157
- 'end': start,
158
- 'text': subtitle_text
159
  }
160
 
161
  # Display the text with the current word highlighted
162
  yield {
163
  'start': start,
164
- 'end': end,
165
- 'text': __join_words(
166
  [
167
  {
168
  "word": re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
@@ -180,17 +225,20 @@ def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: i
180
  # Display the last part of the text
181
  yield {
182
  'start': last,
183
- 'end': subtitle_end,
184
- 'text': subtitle_text
185
  }
186
 
187
  # Just return the subtitle text
188
  else:
189
- yield {
190
  'start': subtitle_start,
191
- 'end': subtitle_end,
192
- 'text': subtitle_text
193
  }
 
 
 
194
 
195
  def __join_words(words: Iterator[Union[str, dict]], maxLineWidth: int = None):
196
  if maxLineWidth is None or maxLineWidth < 0:
 
100
  flush=True,
101
  )
102
 
103
+ def write_srt_original(transcript: Iterator[dict], file: TextIO,
104
+ maxLineWidth=None, highlight_words: bool = False, bilingual: bool = False):
105
+ """
106
+ Write a transcript to a file in SRT format.
107
+ Example usage:
108
+ from pathlib import Path
109
+ from whisper.utils import write_srt
110
+ result = transcribe(model, audio_path, temperature=temperature, **args)
111
+ # save SRT
112
+ audio_basename = Path(audio_path).stem
113
+ with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
114
+ write_srt(result["segments"], file=srt)
115
+ """
116
+ iterator = __subtitle_preprocessor_iterator(transcript, maxLineWidth, highlight_words)
117
+
118
+ for i, segment in enumerate(iterator, start=1):
119
+ if "original" not in segment:
120
+ continue
121
+
122
+ original = segment['original'].replace('-->', '->')
123
+
124
+ # write srt lines
125
+ print(
126
+ f"{i}\n"
127
+ f"{format_timestamp(segment['start'], always_include_hours=True, fractionalSeperator=',')} --> "
128
+ f"{format_timestamp(segment['end'], always_include_hours=True, fractionalSeperator=',')}",
129
+ file=file,
130
+ flush=True,
131
+ )
132
+
133
+ if original is not None: print(f"{original}",
134
+ file=file,
135
+ flush=True)
136
+
137
+ if bilingual:
138
+ text = segment['text'].replace('-->', '->')
139
+ print(f"{text}\n",
140
+ file=file,
141
+ flush=True)
142
+
143
  def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: int = None, highlight_words: bool = False):
144
  for segment in transcript:
145
  words: list = segment.get('words', [])
146
 
147
  # Append longest speaker ID if available
148
  segment_longest_speaker = segment.get('longest_speaker', None)
149
+
150
+ # Yield the segment as-is or processed
151
+ if len(words) == 0 and (maxLineWidth is None or maxLineWidth < 0) and segment_longest_speaker is None:
152
+ yield segment
153
+
154
  if segment_longest_speaker is not None:
155
  segment_longest_speaker = segment_longest_speaker.replace("SPEAKER", "S")
156
+
157
+ subtitle_start = segment['start']
158
+ subtitle_end = segment['end']
159
+ text = segment['text'].strip()
160
+ original_text = segment['original'].strip() if 'original' in segment else None
161
+
162
  if len(words) == 0:
163
+ # Prepend the longest speaker ID if available
164
+ if segment_longest_speaker is not None:
165
+ text = f"({segment_longest_speaker}) {text}"
166
+
167
+ result = {
168
+ 'start': subtitle_start,
169
+ 'end' : subtitle_end,
170
+ 'text' : process_text(text, maxLineWidth)
171
+ }
172
+ if original_text is not None and len(original_text) > 0:
173
+ result.update({'original': process_text(original_text, maxLineWidth)})
174
+ yield result
175
+
 
 
176
  # We are done
177
  continue
178
 
 
 
 
179
  if segment_longest_speaker is not None:
180
  # Add the beginning
181
  words.insert(0, {
182
  'start': subtitle_start,
183
+ 'end' : subtitle_start,
184
+ 'word' : f"({segment_longest_speaker})"
185
  })
186
 
187
+ text_words = [text] if not highlight_words and original_text is not None and len(original_text) > 0 else [ this_word["word"] for this_word in words ]
188
  subtitle_text = __join_words(text_words, maxLineWidth)
189
 
190
  # Iterate over the words in the segment
 
199
  # Display the text up to this point
200
  yield {
201
  'start': last,
202
+ 'end' : start,
203
+ 'text' : subtitle_text
204
  }
205
 
206
  # Display the text with the current word highlighted
207
  yield {
208
  'start': start,
209
+ 'end' : end,
210
+ 'text' : __join_words(
211
  [
212
  {
213
  "word": re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
 
225
  # Display the last part of the text
226
  yield {
227
  'start': last,
228
+ 'end' : subtitle_end,
229
+ 'text' : subtitle_text
230
  }
231
 
232
  # Just return the subtitle text
233
  else:
234
+ result = {
235
  'start': subtitle_start,
236
+ 'end' : subtitle_end,
237
+ 'text' : subtitle_text
238
  }
239
+ if original_text is not None and len(original_text) > 0:
240
+ result.update({'original': original_text})
241
+ yield result
242
 
243
  def __join_words(words: Iterator[Union[str, dict]], maxLineWidth: int = None):
244
  if maxLineWidth is None or maxLineWidth < 0:
src/vad.py CHANGED
@@ -242,9 +242,8 @@ class AbstractTranscription(ABC):
242
 
243
  # Update prompt window
244
  self.__update_prompt_window(prompt_window, adjusted_segments, segment_end, segment_gap, config)
245
-
246
- if detected_language is not None:
247
- result['language'] = detected_language
248
  finally:
249
  # Notify progress listener that we are done
250
  if progressListener is not None:
 
242
 
243
  # Update prompt window
244
  self.__update_prompt_window(prompt_window, adjusted_segments, segment_end, segment_gap, config)
245
+
246
+ result['language'] = detected_language if detected_language is not None else segment_result['language']
 
247
  finally:
248
  # Notify progress listener that we are done
249
  if progressListener is not None:
src/whisper/abstractWhisperContainer.py CHANGED
@@ -71,7 +71,7 @@ class AbstractWhisperContainer:
71
  pass
72
 
73
  @abc.abstractmethod
74
- def create_callback(self, language: str = None, task: str = None,
75
  prompt_strategy: AbstractPromptStrategy = None,
76
  **decodeOptions: dict) -> AbstractWhisperCallback:
77
  """
@@ -79,8 +79,8 @@ class AbstractWhisperContainer:
79
 
80
  Parameters
81
  ----------
82
- language: str
83
- The target language of the transcription. If not specified, the language will be inferred from the audio content.
84
  task: str
85
  The task - either translate or transcribe.
86
  prompt_strategy: AbstractPromptStrategy
 
71
  pass
72
 
73
  @abc.abstractmethod
74
+ def create_callback(self, languageCode: str = None, task: str = None,
75
  prompt_strategy: AbstractPromptStrategy = None,
76
  **decodeOptions: dict) -> AbstractWhisperCallback:
77
  """
 
79
 
80
  Parameters
81
  ----------
82
+ languageCode: str
83
+ The target language code of the transcription. If not specified, the language will be inferred from the audio content.
84
  task: str
85
  The task - either translate or transcribe.
86
  prompt_strategy: AbstractPromptStrategy
src/whisper/fasterWhisperContainer.py CHANGED
@@ -4,7 +4,6 @@ from typing import List, Union
4
  from faster_whisper import WhisperModel, download_model
5
  from src.config import ModelConfig, VadInitialPromptMode
6
  from src.hooks.progressListener import ProgressListener
7
- from src.languages import get_language_from_name
8
  from src.modelCache import ModelCache
9
  from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
10
  from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
@@ -57,7 +56,7 @@ class FasterWhisperContainer(AbstractWhisperContainer):
57
  model = WhisperModel(model_url, device=device, compute_type=self.compute_type)
58
  return model
59
 
60
- def create_callback(self, language: str = None, task: str = None,
61
  prompt_strategy: AbstractPromptStrategy = None,
62
  **decodeOptions: dict) -> AbstractWhisperCallback:
63
  """
@@ -65,8 +64,8 @@ class FasterWhisperContainer(AbstractWhisperContainer):
65
 
66
  Parameters
67
  ----------
68
- language: str
69
- The target language of the transcription. If not specified, the language will be inferred from the audio content.
70
  task: str
71
  The task - either translate or transcribe.
72
  prompt_strategy: AbstractPromptStrategy
@@ -78,14 +77,14 @@ class FasterWhisperContainer(AbstractWhisperContainer):
78
  -------
79
  A WhisperCallback object.
80
  """
81
- return FasterWhisperCallback(self, language=language, task=task, prompt_strategy=prompt_strategy, **decodeOptions)
82
 
83
  class FasterWhisperCallback(AbstractWhisperCallback):
84
- def __init__(self, model_container: FasterWhisperContainer, language: str = None, task: str = None,
85
  prompt_strategy: AbstractPromptStrategy = None,
86
  **decodeOptions: dict):
87
  self.model_container = model_container
88
- self.language = language
89
  self.task = task
90
  self.prompt_strategy = prompt_strategy
91
  self.decodeOptions = decodeOptions
@@ -108,7 +107,6 @@ class FasterWhisperCallback(AbstractWhisperCallback):
108
  A callback to receive progress updates.
109
  """
110
  model: WhisperModel = self.model_container.get_model()
111
- language_code = self._lookup_language_code(self.language) if self.language else None
112
 
113
  # Copy decode options and remove options that are not supported by faster-whisper
114
  decodeOptions = self.decodeOptions.copy()
@@ -139,7 +137,7 @@ class FasterWhisperCallback(AbstractWhisperCallback):
139
  if self.prompt_strategy else prompt
140
 
141
  segments_generator, info = model.transcribe(audio, \
142
- language=language_code if language_code else detected_language, task=self.task, \
143
  initial_prompt=initial_prompt, \
144
  **decodeOptions
145
  )
@@ -197,11 +195,3 @@ class FasterWhisperCallback(AbstractWhisperCallback):
197
  return suppress_tokens
198
 
199
  return [int(token) for token in suppress_tokens.split(",")]
200
-
201
- def _lookup_language_code(self, language: str):
202
- language = get_language_from_name(language)
203
-
204
- if language is None:
205
- raise ValueError("Invalid language: " + language)
206
-
207
- return language.code
 
4
  from faster_whisper import WhisperModel, download_model
5
  from src.config import ModelConfig, VadInitialPromptMode
6
  from src.hooks.progressListener import ProgressListener
 
7
  from src.modelCache import ModelCache
8
  from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
9
  from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
 
56
  model = WhisperModel(model_url, device=device, compute_type=self.compute_type)
57
  return model
58
 
59
+ def create_callback(self, languageCode: str = None, task: str = None,
60
  prompt_strategy: AbstractPromptStrategy = None,
61
  **decodeOptions: dict) -> AbstractWhisperCallback:
62
  """
 
64
 
65
  Parameters
66
  ----------
67
+ languageCode: str
68
+ The target language code of the transcription. If not specified, the language will be inferred from the audio content.
69
  task: str
70
  The task - either translate or transcribe.
71
  prompt_strategy: AbstractPromptStrategy
 
77
  -------
78
  A WhisperCallback object.
79
  """
80
+ return FasterWhisperCallback(self, languageCode=languageCode, task=task, prompt_strategy=prompt_strategy, **decodeOptions)
81
 
82
  class FasterWhisperCallback(AbstractWhisperCallback):
83
+ def __init__(self, model_container: FasterWhisperContainer, languageCode: str = None, task: str = None,
84
  prompt_strategy: AbstractPromptStrategy = None,
85
  **decodeOptions: dict):
86
  self.model_container = model_container
87
+ self.languageCode = languageCode
88
  self.task = task
89
  self.prompt_strategy = prompt_strategy
90
  self.decodeOptions = decodeOptions
 
107
  A callback to receive progress updates.
108
  """
109
  model: WhisperModel = self.model_container.get_model()
 
110
 
111
  # Copy decode options and remove options that are not supported by faster-whisper
112
  decodeOptions = self.decodeOptions.copy()
 
137
  if self.prompt_strategy else prompt
138
 
139
  segments_generator, info = model.transcribe(audio, \
140
+ language=self.languageCode if self.languageCode else detected_language, task=self.task, \
141
  initial_prompt=initial_prompt, \
142
  **decodeOptions
143
  )
 
195
  return suppress_tokens
196
 
197
  return [int(token) for token in suppress_tokens.split(",")]
 
 
 
 
 
 
 
 
src/whisper/whisperContainer.py CHANGED
@@ -70,7 +70,7 @@ class WhisperContainer(AbstractWhisperContainer):
70
 
71
  return whisper.load_model(model_path, device=self.device, download_root=self.download_root)
72
 
73
- def create_callback(self, language: str = None, task: str = None,
74
  prompt_strategy: AbstractPromptStrategy = None,
75
  **decodeOptions: dict) -> AbstractWhisperCallback:
76
  """
@@ -78,8 +78,8 @@ class WhisperContainer(AbstractWhisperContainer):
78
 
79
  Parameters
80
  ----------
81
- language: str
82
- The target language of the transcription. If not specified, the language will be inferred from the audio content.
83
  task: str
84
  The task - either translate or transcribe.
85
  prompt_strategy: AbstractPromptStrategy
@@ -91,7 +91,7 @@ class WhisperContainer(AbstractWhisperContainer):
91
  -------
92
  A WhisperCallback object.
93
  """
94
- return WhisperCallback(self, language=language, task=task, prompt_strategy=prompt_strategy, **decodeOptions)
95
 
96
  def _get_model_path(self, model_config: ModelConfig, root_dir: str = None):
97
  from src.conversion.hf_converter import convert_hf_whisper
@@ -160,11 +160,11 @@ class WhisperContainer(AbstractWhisperContainer):
160
  return model_config.path
161
 
162
  class WhisperCallback(AbstractWhisperCallback):
163
- def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None,
164
  prompt_strategy: AbstractPromptStrategy = None,
165
  **decodeOptions: dict):
166
  self.model_container = model_container
167
- self.language = language
168
  self.task = task
169
  self.prompt_strategy = prompt_strategy
170
 
@@ -204,7 +204,7 @@ class WhisperCallback(AbstractWhisperCallback):
204
  if self.prompt_strategy else prompt
205
 
206
  result = model.transcribe(audio, \
207
- language=self.language if self.language else detected_language, task=self.task, \
208
  initial_prompt=initial_prompt, \
209
  **decodeOptions
210
  )
 
70
 
71
  return whisper.load_model(model_path, device=self.device, download_root=self.download_root)
72
 
73
+ def create_callback(self, languageCode: str = None, task: str = None,
74
  prompt_strategy: AbstractPromptStrategy = None,
75
  **decodeOptions: dict) -> AbstractWhisperCallback:
76
  """
 
78
 
79
  Parameters
80
  ----------
81
+ languageCode: str
82
+ The target language code of the transcription. If not specified, the language will be inferred from the audio content.
83
  task: str
84
  The task - either translate or transcribe.
85
  prompt_strategy: AbstractPromptStrategy
 
91
  -------
92
  A WhisperCallback object.
93
  """
94
+ return WhisperCallback(self, languageCode=languageCode, task=task, prompt_strategy=prompt_strategy, **decodeOptions)
95
 
96
  def _get_model_path(self, model_config: ModelConfig, root_dir: str = None):
97
  from src.conversion.hf_converter import convert_hf_whisper
 
160
  return model_config.path
161
 
162
  class WhisperCallback(AbstractWhisperCallback):
163
+ def __init__(self, model_container: WhisperContainer, languageCode: str = None, task: str = None,
164
  prompt_strategy: AbstractPromptStrategy = None,
165
  **decodeOptions: dict):
166
  self.model_container = model_container
167
+ self.languageCode = languageCode
168
  self.task = task
169
  self.prompt_strategy = prompt_strategy
170
 
 
204
  if self.prompt_strategy else prompt
205
 
206
  result = model.transcribe(audio, \
207
+ language=self.languageCode if self.languageCode else detected_language, task=self.task, \
208
  initial_prompt=initial_prompt, \
209
  **decodeOptions
210
  )