IRISLAB commited on
Commit
785fa6f
1 Parent(s): 7271105

Upload 20 files

Browse files
modules/__init__.py ADDED
File without changes
modules/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (152 Bytes). View file
 
modules/__pycache__/deepl_api.cpython-312.pyc ADDED
Binary file (7.51 kB). View file
 
modules/__pycache__/faster_whisper_inference.cpython-312.pyc ADDED
Binary file (7.5 kB). View file
 
modules/__pycache__/nllb_inference.cpython-312.pyc ADDED
Binary file (11.9 kB). View file
 
modules/__pycache__/subtitle_manager.cpython-312.pyc ADDED
Binary file (6.03 kB). View file
 
modules/__pycache__/translation_base.cpython-312.pyc ADDED
Binary file (7.46 kB). View file
 
modules/__pycache__/whisper_Inference.cpython-312.pyc ADDED
Binary file (4.76 kB). View file
 
modules/__pycache__/whisper_base.cpython-312.pyc ADDED
Binary file (14.5 kB). View file
 
modules/__pycache__/whisper_parameter.cpython-312.pyc ADDED
Binary file (2.96 kB). View file
 
modules/__pycache__/youtube_manager.cpython-312.pyc ADDED
Binary file (1.03 kB). View file
 
modules/deepl_api.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import time
3
+ import os
4
+ from datetime import datetime
5
+ import gradio as gr
6
+
7
+ from modules.subtitle_manager import *
8
+
9
+ """
10
+ This is written with reference to the DeepL API documentation.
11
+ If you want to know the information of the DeepL API, see here: https://www.deepl.com/docs-api/documents
12
+ """
13
+
14
+ DEEPL_AVAILABLE_TARGET_LANGS = {
15
+ 'Bulgarian': 'BG',
16
+ 'Czech': 'CS',
17
+ 'Danish': 'DA',
18
+ 'German': 'DE',
19
+ 'Greek': 'EL',
20
+ 'English': 'EN',
21
+ 'English (British)': 'EN-GB',
22
+ 'English (American)': 'EN-US',
23
+ 'Spanish': 'ES',
24
+ 'Estonian': 'ET',
25
+ 'Finnish': 'FI',
26
+ 'French': 'FR',
27
+ 'Hungarian': 'HU',
28
+ 'Indonesian': 'ID',
29
+ 'Italian': 'IT',
30
+ 'Japanese': 'JA',
31
+ 'Korean': 'KO',
32
+ 'Lithuanian': 'LT',
33
+ 'Latvian': 'LV',
34
+ 'Norwegian (Bokmål)': 'NB',
35
+ 'Dutch': 'NL',
36
+ 'Polish': 'PL',
37
+ 'Portuguese': 'PT',
38
+ 'Portuguese (Brazilian)': 'PT-BR',
39
+ 'Portuguese (all Portuguese varieties excluding Brazilian Portuguese)': 'PT-PT',
40
+ 'Romanian': 'RO',
41
+ 'Russian': 'RU',
42
+ 'Slovak': 'SK',
43
+ 'Slovenian': 'SL',
44
+ 'Swedish': 'SV',
45
+ 'Turkish': 'TR',
46
+ 'Ukrainian': 'UK',
47
+ 'Chinese (simplified)': 'ZH'
48
+ }
49
+
50
+ DEEPL_AVAILABLE_SOURCE_LANGS = {
51
+ 'Automatic Detection': None,
52
+ 'Bulgarian': 'BG',
53
+ 'Czech': 'CS',
54
+ 'Danish': 'DA',
55
+ 'German': 'DE',
56
+ 'Greek': 'EL',
57
+ 'English': 'EN',
58
+ 'Spanish': 'ES',
59
+ 'Estonian': 'ET',
60
+ 'Finnish': 'FI',
61
+ 'French': 'FR',
62
+ 'Hungarian': 'HU',
63
+ 'Indonesian': 'ID',
64
+ 'Italian': 'IT',
65
+ 'Japanese': 'JA',
66
+ 'Korean': 'KO',
67
+ 'Lithuanian': 'LT',
68
+ 'Latvian': 'LV',
69
+ 'Norwegian (Bokmål)': 'NB',
70
+ 'Dutch': 'NL',
71
+ 'Polish': 'PL',
72
+ 'Portuguese (all Portuguese varieties mixed)': 'PT',
73
+ 'Romanian': 'RO',
74
+ 'Russian': 'RU',
75
+ 'Slovak': 'SK',
76
+ 'Slovenian': 'SL',
77
+ 'Swedish': 'SV',
78
+ 'Turkish': 'TR',
79
+ 'Ukrainian': 'UK',
80
+ 'Chinese': 'ZH'
81
+ }
82
+
83
+
84
+ class DeepLAPI:
85
+ def __init__(self):
86
+ self.api_interval = 1
87
+ self.max_text_batch_size = 50
88
+ self.available_target_langs = DEEPL_AVAILABLE_TARGET_LANGS
89
+ self.available_source_langs = DEEPL_AVAILABLE_SOURCE_LANGS
90
+
91
+ def translate_deepl(self,
92
+ auth_key: str,
93
+ fileobjs: list,
94
+ source_lang: str,
95
+ target_lang: str,
96
+ is_pro: bool,
97
+ progress=gr.Progress()) -> list:
98
+ """
99
+ Translate subtitle files using DeepL API
100
+ Parameters
101
+ ----------
102
+ auth_key: str
103
+ API Key for DeepL from gr.Textbox()
104
+ fileobjs: list
105
+ List of files to transcribe from gr.Files()
106
+ source_lang: str
107
+ Source language of the file to transcribe from gr.Dropdown()
108
+ target_lang: str
109
+ Target language of the file to transcribe from gr.Dropdown()
110
+ is_pro: str
111
+ Boolean value that is about pro user or not from gr.Checkbox().
112
+ progress: gr.Progress
113
+ Indicator to show progress directly in gradio.
114
+ Returns
115
+ ----------
116
+ A List of
117
+ String to return to gr.Textbox()
118
+ Files to return to gr.Files()
119
+ """
120
+
121
+ files_info = {}
122
+ for fileobj in fileobjs:
123
+ file_path = fileobj.name
124
+ file_name, file_ext = os.path.splitext(os.path.basename(fileobj.name))
125
+
126
+ if file_ext == ".srt":
127
+ parsed_dicts = parse_srt(file_path=file_path)
128
+
129
+ batch_size = self.max_text_batch_size
130
+ for batch_start in range(0, len(parsed_dicts), batch_size):
131
+ batch_end = min(batch_start + batch_size, len(parsed_dicts))
132
+ sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
133
+ translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
134
+ target_lang, is_pro)
135
+ for i, translated_text in enumerate(translated_texts):
136
+ parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
137
+ progress(batch_end / len(parsed_dicts), desc="Translating..")
138
+
139
+ subtitle = get_serialized_srt(parsed_dicts)
140
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
141
+
142
+ file_name = file_name[:-9]
143
+ output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}.srt")
144
+ write_file(subtitle, output_path)
145
+
146
+ elif file_ext == ".vtt":
147
+ parsed_dicts = parse_vtt(file_path=file_path)
148
+
149
+ batch_size = self.max_text_batch_size
150
+ for batch_start in range(0, len(parsed_dicts), batch_size):
151
+ batch_end = min(batch_start + batch_size, len(parsed_dicts))
152
+ sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
153
+ translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
154
+ target_lang, is_pro)
155
+ for i, translated_text in enumerate(translated_texts):
156
+ parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
157
+ progress(batch_end / len(parsed_dicts), desc="Translating..")
158
+
159
+ subtitle = get_serialized_vtt(parsed_dicts)
160
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
161
+
162
+ file_name = file_name[:-9]
163
+ output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}.vtt")
164
+
165
+ write_file(subtitle, output_path)
166
+
167
+ files_info[file_name] = subtitle
168
+ total_result = ''
169
+ for file_name, subtitle in files_info.items():
170
+ total_result += '------------------------------------\n'
171
+ total_result += f'{file_name}\n\n'
172
+ total_result += f'{subtitle}'
173
+
174
+ gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
175
+ return [gr_str, output_path]
176
+
177
+ def request_deepl_translate(self,
178
+ auth_key: str,
179
+ text: list,
180
+ source_lang: str,
181
+ target_lang: str,
182
+ is_pro: bool):
183
+ """Request API response to DeepL server"""
184
+
185
+ url = 'https://api.deepl.com/v2/translate' if is_pro else 'https://api-free.deepl.com/v2/translate'
186
+ headers = {
187
+ 'Authorization': f'DeepL-Auth-Key {auth_key}'
188
+ }
189
+ data = {
190
+ 'text': text,
191
+ 'source_lang': DEEPL_AVAILABLE_SOURCE_LANGS[source_lang],
192
+ 'target_lang': DEEPL_AVAILABLE_TARGET_LANGS[target_lang]
193
+ }
194
+ response = requests.post(url, headers=headers, data=data).json()
195
+ time.sleep(self.api_interval)
196
+ return response["translations"]
modules/faster_whisper_inference.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import numpy as np
4
+ from typing import BinaryIO, Union, Tuple, List
5
+
6
+ import faster_whisper
7
+ from faster_whisper.vad import VadOptions
8
+ import ctranslate2
9
+ import whisper
10
+ import gradio as gr
11
+
12
+ from modules.whisper_parameter import *
13
+ from modules.whisper_base import WhisperBase
14
+
15
+ # Temporal fix of the issue : https://github.com/jhj0517/Whisper-WebUI/issues/144
16
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
17
+
18
+
19
+ class FasterWhisperInference(WhisperBase):
20
+ def __init__(self):
21
+ super().__init__(
22
+ model_dir=os.path.join("models", "Whisper", "faster-whisper")
23
+ )
24
+ self.model_paths = self.get_model_paths()
25
+ self.available_models = self.model_paths.keys()
26
+ self.available_compute_types = ctranslate2.get_supported_compute_types(
27
+ "cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
28
+
29
+ def transcribe(self,
30
+ audio: Union[str, BinaryIO, np.ndarray],
31
+ progress: gr.Progress,
32
+ *whisper_params,
33
+ ) -> Tuple[List[dict], float]:
34
+ """
35
+ transcribe method for faster-whisper.
36
+
37
+ Parameters
38
+ ----------
39
+ audio: Union[str, BinaryIO, np.ndarray]
40
+ Audio path or file binary or Audio numpy array
41
+ progress: gr.Progress
42
+ Indicator to show progress directly in gradio.
43
+ *whisper_params: tuple
44
+ Gradio components related to Whisper. see whisper_data_class.py for details.
45
+
46
+ Returns
47
+ ----------
48
+ segments_result: List[dict]
49
+ list of dicts that includes start, end timestamps and transcribed text
50
+ elapsed_time: float
51
+ elapsed time for transcription
52
+ """
53
+ start_time = time.time()
54
+
55
+ params = WhisperValues(*whisper_params)
56
+
57
+ if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
58
+ self.update_model(params.model_size, params.compute_type, progress)
59
+
60
+ if params.lang == "Automatic Detection":
61
+ params.lang = None
62
+ else:
63
+ language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
64
+ params.lang = language_code_dict[params.lang]
65
+
66
+ vad_options = VadOptions(
67
+ threshold=params.threshold,
68
+ min_speech_duration_ms=params.min_speech_duration_ms,
69
+ max_speech_duration_s=params.max_speech_duration_s,
70
+ min_silence_duration_ms=params.min_silence_duration_ms,
71
+ window_size_samples=params.window_size_samples,
72
+ speech_pad_ms=params.speech_pad_ms
73
+ )
74
+
75
+ segments, info = self.model.transcribe(
76
+ audio=audio,
77
+ language=params.lang,
78
+ task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
79
+ beam_size=params.beam_size,
80
+ log_prob_threshold=params.log_prob_threshold,
81
+ no_speech_threshold=params.no_speech_threshold,
82
+ best_of=params.best_of,
83
+ patience=params.patience,
84
+ temperature=params.temperature,
85
+ compression_ratio_threshold=params.compression_ratio_threshold,
86
+ vad_filter=params.vad_filter,
87
+ vad_parameters=vad_options
88
+ )
89
+ progress(0, desc="Loading audio..")
90
+
91
+ segments_result = []
92
+ for segment in segments:
93
+ progress(segment.start / info.duration, desc="Transcribing..")
94
+ segments_result.append({
95
+ "start": segment.start,
96
+ "end": segment.end,
97
+ "text": segment.text
98
+ })
99
+
100
+ elapsed_time = time.time() - start_time
101
+ return segments_result, elapsed_time
102
+
103
+ def update_model(self,
104
+ model_size: str,
105
+ compute_type: str,
106
+ progress: gr.Progress
107
+ ):
108
+ """
109
+ Update current model setting
110
+
111
+ Parameters
112
+ ----------
113
+ model_size: str
114
+ Size of whisper model
115
+ compute_type: str
116
+ Compute type for transcription.
117
+ see more info : https://opennmt.net/CTranslate2/quantization.html
118
+ progress: gr.Progress
119
+ Indicator to show progress directly in gradio.
120
+ """
121
+ progress(0, desc="Initializing Model..")
122
+ self.current_model_size = self.model_paths[model_size]
123
+ self.current_compute_type = compute_type
124
+ self.model = faster_whisper.WhisperModel(
125
+ device=self.device,
126
+ model_size_or_path=self.current_model_size,
127
+ download_root=self.model_dir,
128
+ compute_type=self.current_compute_type
129
+ )
130
+
131
+ def get_model_paths(self):
132
+ """
133
+ Get available models from models path including fine-tuned model.
134
+
135
+ Returns
136
+ ----------
137
+ Name list of models
138
+ """
139
+ model_paths = {model:model for model in whisper.available_models()}
140
+ faster_whisper_prefix = "models--Systran--faster-whisper-"
141
+
142
+ existing_models = os.listdir(self.model_dir)
143
+ wrong_dirs = [".locks"]
144
+ existing_models = list(set(existing_models) - set(wrong_dirs))
145
+
146
+ webui_dir = os.getcwd()
147
+
148
+ for model_name in existing_models:
149
+ if faster_whisper_prefix in model_name:
150
+ model_name = model_name[len(faster_whisper_prefix):]
151
+
152
+ if model_name not in whisper.available_models():
153
+ model_paths[model_name] = os.path.join(webui_dir, self.model_dir, model_name)
154
+ return model_paths
modules/nllb_inference.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
2
+ import gradio as gr
3
+ import os
4
+
5
+ from modules.translation_base import TranslationBase
6
+
7
+
8
+ class NLLBInference(TranslationBase):
9
+ def __init__(self):
10
+ super().__init__(
11
+ model_dir=os.path.join("models", "NLLB")
12
+ )
13
+ self.tokenizer = None
14
+ self.available_models = ["facebook/nllb-200-3.3B", "facebook/nllb-200-1.3B", "facebook/nllb-200-distilled-600M"]
15
+ self.available_source_langs = list(NLLB_AVAILABLE_LANGS.keys())
16
+ self.available_target_langs = list(NLLB_AVAILABLE_LANGS.keys())
17
+ self.pipeline = None
18
+
19
+ def translate(self,
20
+ text: str
21
+ ):
22
+ result = self.pipeline(text)
23
+ return result[0]['translation_text']
24
+
25
+ def update_model(self,
26
+ model_size: str,
27
+ src_lang: str,
28
+ tgt_lang: str,
29
+ progress: gr.Progress
30
+ ):
31
+ if model_size != self.current_model_size or self.model is None:
32
+ print("\nInitializing NLLB Model..\n")
33
+ progress(0, desc="Initializing NLLB Model..")
34
+ self.current_model_size = model_size
35
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_size,
36
+ cache_dir=self.model_dir)
37
+ self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
38
+ cache_dir=os.path.join(self.model_dir, "tokenizers"))
39
+ src_lang = NLLB_AVAILABLE_LANGS[src_lang]
40
+ tgt_lang = NLLB_AVAILABLE_LANGS[tgt_lang]
41
+ self.pipeline = pipeline("translation",
42
+ model=self.model,
43
+ tokenizer=self.tokenizer,
44
+ src_lang=src_lang,
45
+ tgt_lang=tgt_lang,
46
+ device=self.device)
47
+
48
+ NLLB_AVAILABLE_LANGS = {
49
+ "Acehnese (Arabic script)": "ace_Arab",
50
+ "Acehnese (Latin script)": "ace_Latn",
51
+ "Mesopotamian Arabic": "acm_Arab",
52
+ "Ta’izzi-Adeni Arabic": "acq_Arab",
53
+ "Tunisian Arabic": "aeb_Arab",
54
+ "Afrikaans": "afr_Latn",
55
+ "South Levantine Arabic": "ajp_Arab",
56
+ "Akan": "aka_Latn",
57
+ "Amharic": "amh_Ethi",
58
+ "North Levantine Arabic": "apc_Arab",
59
+ "Modern Standard Arabic": "arb_Arab",
60
+ "Modern Standard Arabic (Romanized)": "arb_Latn",
61
+ "Najdi Arabic": "ars_Arab",
62
+ "Moroccan Arabic": "ary_Arab",
63
+ "Egyptian Arabic": "arz_Arab",
64
+ "Assamese": "asm_Beng",
65
+ "Asturian": "ast_Latn",
66
+ "Awadhi": "awa_Deva",
67
+ "Central Aymara": "ayr_Latn",
68
+ "South Azerbaijani": "azb_Arab",
69
+ "North Azerbaijani": "azj_Latn",
70
+ "Bashkir": "bak_Cyrl",
71
+ "Bambara": "bam_Latn",
72
+ "Balinese": "ban_Latn",
73
+ "Belarusian": "bel_Cyrl",
74
+ "Bemba": "bem_Latn",
75
+ "Bengali": "ben_Beng",
76
+ "Bhojpuri": "bho_Deva",
77
+ "Banjar (Arabic script)": "bjn_Arab",
78
+ "Banjar (Latin script)": "bjn_Latn",
79
+ "Standard Tibetan": "bod_Tibt",
80
+ "Bosnian": "bos_Latn",
81
+ "Buginese": "bug_Latn",
82
+ "Bulgarian": "bul_Cyrl",
83
+ "Catalan": "cat_Latn",
84
+ "Cebuano": "ceb_Latn",
85
+ "Czech": "ces_Latn",
86
+ "Chokwe": "cjk_Latn",
87
+ "Central Kurdish": "ckb_Arab",
88
+ "Crimean Tatar": "crh_Latn",
89
+ "Welsh": "cym_Latn",
90
+ "Danish": "dan_Latn",
91
+ "German": "deu_Latn",
92
+ "Southwestern Dinka": "dik_Latn",
93
+ "Dyula": "dyu_Latn",
94
+ "Dzongkha": "dzo_Tibt",
95
+ "Greek": "ell_Grek",
96
+ "English": "eng_Latn",
97
+ "Esperanto": "epo_Latn",
98
+ "Estonian": "est_Latn",
99
+ "Basque": "eus_Latn",
100
+ "Ewe": "ewe_Latn",
101
+ "Faroese": "fao_Latn",
102
+ "Fijian": "fij_Latn",
103
+ "Finnish": "fin_Latn",
104
+ "Fon": "fon_Latn",
105
+ "French": "fra_Latn",
106
+ "Friulian": "fur_Latn",
107
+ "Nigerian Fulfulde": "fuv_Latn",
108
+ "Scottish Gaelic": "gla_Latn",
109
+ "Irish": "gle_Latn",
110
+ "Galician": "glg_Latn",
111
+ "Guarani": "grn_Latn",
112
+ "Gujarati": "guj_Gujr",
113
+ "Haitian Creole": "hat_Latn",
114
+ "Hausa": "hau_Latn",
115
+ "Hebrew": "heb_Hebr",
116
+ "Hindi": "hin_Deva",
117
+ "Chhattisgarhi": "hne_Deva",
118
+ "Croatian": "hrv_Latn",
119
+ "Hungarian": "hun_Latn",
120
+ "Armenian": "hye_Armn",
121
+ "Igbo": "ibo_Latn",
122
+ "Ilocano": "ilo_Latn",
123
+ "Indonesian": "ind_Latn",
124
+ "Icelandic": "isl_Latn",
125
+ "Italian": "ita_Latn",
126
+ "Javanese": "jav_Latn",
127
+ "Japanese": "jpn_Jpan",
128
+ "Kabyle": "kab_Latn",
129
+ "Jingpho": "kac_Latn",
130
+ "Kamba": "kam_Latn",
131
+ "Kannada": "kan_Knda",
132
+ "Kashmiri (Arabic script)": "kas_Arab",
133
+ "Kashmiri (Devanagari script)": "kas_Deva",
134
+ "Georgian": "kat_Geor",
135
+ "Central Kanuri (Arabic script)": "knc_Arab",
136
+ "Central Kanuri (Latin script)": "knc_Latn",
137
+ "Kazakh": "kaz_Cyrl",
138
+ "Kabiyè": "kbp_Latn",
139
+ "Kabuverdianu": "kea_Latn",
140
+ "Khmer": "khm_Khmr",
141
+ "Kikuyu": "kik_Latn",
142
+ "Kinyarwanda": "kin_Latn",
143
+ "Kyrgyz": "kir_Cyrl",
144
+ "Kimbundu": "kmb_Latn",
145
+ "Northern Kurdish": "kmr_Latn",
146
+ "Kikongo": "kon_Latn",
147
+ "Korean": "kor_Hang",
148
+ "Lao": "lao_Laoo",
149
+ "Ligurian": "lij_Latn",
150
+ "Limburgish": "lim_Latn",
151
+ "Lingala": "lin_Latn",
152
+ "Lithuanian": "lit_Latn",
153
+ "Lombard": "lmo_Latn",
154
+ "Latgalian": "ltg_Latn",
155
+ "Luxembourgish": "ltz_Latn",
156
+ "Luba-Kasai": "lua_Latn",
157
+ "Ganda": "lug_Latn",
158
+ "Luo": "luo_Latn",
159
+ "Mizo": "lus_Latn",
160
+ "Standard Latvian": "lvs_Latn",
161
+ "Magahi": "mag_Deva",
162
+ "Maithili": "mai_Deva",
163
+ "Malayalam": "mal_Mlym",
164
+ "Marathi": "mar_Deva",
165
+ "Minangkabau (Arabic script)": "min_Arab",
166
+ "Minangkabau (Latin script)": "min_Latn",
167
+ "Macedonian": "mkd_Cyrl",
168
+ "Plateau Malagasy": "plt_Latn",
169
+ "Maltese": "mlt_Latn",
170
+ "Meitei (Bengali script)": "mni_Beng",
171
+ "Halh Mongolian": "khk_Cyrl",
172
+ "Mossi": "mos_Latn",
173
+ "Maori": "mri_Latn",
174
+ "Burmese": "mya_Mymr",
175
+ "Dutch": "nld_Latn",
176
+ "Norwegian Nynorsk": "nno_Latn",
177
+ "Norwegian Bokmål": "nob_Latn",
178
+ "Nepali": "npi_Deva",
179
+ "Northern Sotho": "nso_Latn",
180
+ "Nuer": "nus_Latn",
181
+ "Nyanja": "nya_Latn",
182
+ "Occitan": "oci_Latn",
183
+ "West Central Oromo": "gaz_Latn",
184
+ "Odia": "ory_Orya",
185
+ "Pangasinan": "pag_Latn",
186
+ "Eastern Panjabi": "pan_Guru",
187
+ "Papiamento": "pap_Latn",
188
+ "Western Persian": "pes_Arab",
189
+ "Polish": "pol_Latn",
190
+ "Portuguese": "por_Latn",
191
+ "Dari": "prs_Arab",
192
+ "Southern Pashto": "pbt_Arab",
193
+ "Ayacucho Quechua": "quy_Latn",
194
+ "Romanian": "ron_Latn",
195
+ "Rundi": "run_Latn",
196
+ "Russian": "rus_Cyrl",
197
+ "Sango": "sag_Latn",
198
+ "Sanskrit": "san_Deva",
199
+ "Santali": "sat_Olck",
200
+ "Sicilian": "scn_Latn",
201
+ "Shan": "shn_Mymr",
202
+ "Sinhala": "sin_Sinh",
203
+ "Slovak": "slk_Latn",
204
+ "Slovenian": "slv_Latn",
205
+ "Samoan": "smo_Latn",
206
+ "Shona": "sna_Latn",
207
+ "Sindhi": "snd_Arab",
208
+ "Somali": "som_Latn",
209
+ "Southern Sotho": "sot_Latn",
210
+ "Spanish": "spa_Latn",
211
+ "Tosk Albanian": "als_Latn",
212
+ "Sardinian": "srd_Latn",
213
+ "Serbian": "srp_Cyrl",
214
+ "Swati": "ssw_Latn",
215
+ "Sundanese": "sun_Latn",
216
+ "Swedish": "swe_Latn",
217
+ "Swahili": "swh_Latn",
218
+ "Silesian": "szl_Latn",
219
+ "Tamil": "tam_Taml",
220
+ "Tatar": "tat_Cyrl",
221
+ "Telugu": "tel_Telu",
222
+ "Tajik": "tgk_Cyrl",
223
+ "Tagalog": "tgl_Latn",
224
+ "Thai": "tha_Thai",
225
+ "Tigrinya": "tir_Ethi",
226
+ "Tamasheq (Latin script)": "taq_Latn",
227
+ "Tamasheq (Tifinagh script)": "taq_Tfng",
228
+ "Tok Pisin": "tpi_Latn",
229
+ "Tswana": "tsn_Latn",
230
+ "Tsonga": "tso_Latn",
231
+ "Turkmen": "tuk_Latn",
232
+ "Tumbuka": "tum_Latn",
233
+ "Turkish": "tur_Latn",
234
+ "Twi": "twi_Latn",
235
+ "Central Atlas Tamazight": "tzm_Tfng",
236
+ "Uyghur": "uig_Arab",
237
+ "Ukrainian": "ukr_Cyrl",
238
+ "Umbundu": "umb_Latn",
239
+ "Urdu": "urd_Arab",
240
+ "Northern Uzbek": "uzn_Latn",
241
+ "Venetian": "vec_Latn",
242
+ "Vietnamese": "vie_Latn",
243
+ "Waray": "war_Latn",
244
+ "Wolof": "wol_Latn",
245
+ "Xhosa": "xho_Latn",
246
+ "Eastern Yiddish": "ydd_Hebr",
247
+ "Yoruba": "yor_Latn",
248
+ "Yue Chinese": "yue_Hant",
249
+ "Chinese (Simplified)": "zho_Hans",
250
+ "Chinese (Traditional)": "zho_Hant",
251
+ "Standard Malay": "zsm_Latn",
252
+ "Zulu": "zul_Latn",
253
+ }
modules/subtitle_manager.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ def timeformat_srt(time):
5
+ hours = time // 3600
6
+ minutes = (time - hours * 3600) // 60
7
+ seconds = time - hours * 3600 - minutes * 60
8
+ milliseconds = (time - int(time)) * 1000
9
+ return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}"
10
+
11
+
12
+ def timeformat_vtt(time):
13
+ hours = time // 3600
14
+ minutes = (time - hours * 3600) // 60
15
+ seconds = time - hours * 3600 - minutes * 60
16
+ milliseconds = (time - int(time)) * 1000
17
+ return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}.{int(milliseconds):03d}"
18
+
19
+
20
+ def write_file(subtitle, output_file):
21
+ with open(output_file, 'w', encoding='utf-8') as f:
22
+ f.write(subtitle)
23
+
24
+
25
+ def get_srt(segments):
26
+ output = ""
27
+ for i, segment in enumerate(segments):
28
+ output += f"{i + 1}\n"
29
+ output += f"{timeformat_srt(segment['start'])} --> {timeformat_srt(segment['end'])}\n"
30
+ if segment['text'].startswith(' '):
31
+ segment['text'] = segment['text'][1:]
32
+ output += f"{segment['text']}\n\n"
33
+ return output
34
+
35
+
36
+ def get_vtt(segments):
37
+ output = "WebVTT\n\n"
38
+ for i, segment in enumerate(segments):
39
+ output += f"{i + 1}\n"
40
+ output += f"{timeformat_vtt(segment['start'])} --> {timeformat_vtt(segment['end'])}\n"
41
+ if segment['text'].startswith(' '):
42
+ segment['text'] = segment['text'][1:]
43
+ output += f"{segment['text']}\n\n"
44
+ return output
45
+
46
+
47
+ def get_txt(segments):
48
+ output = ""
49
+ for i, segment in enumerate(segments):
50
+ if segment['text'].startswith(' '):
51
+ segment['text'] = segment['text'][1:]
52
+ output += f"{segment['text']}\n"
53
+ return output
54
+
55
+
56
+ def parse_srt(file_path):
57
+ """Reads SRT file and returns as dict"""
58
+ with open(file_path, 'r', encoding='utf-8') as file:
59
+ srt_data = file.read()
60
+
61
+ data = []
62
+ blocks = srt_data.split('\n\n')
63
+
64
+ for block in blocks:
65
+ if block.strip() != '':
66
+ lines = block.strip().split('\n')
67
+ index = lines[0]
68
+ timestamp = lines[1]
69
+ sentence = ' '.join(lines[2:])
70
+
71
+ data.append({
72
+ "index": index,
73
+ "timestamp": timestamp,
74
+ "sentence": sentence
75
+ })
76
+ return data
77
+
78
+
79
+ def parse_vtt(file_path):
80
+ """Reads WebVTT file and returns as dict"""
81
+ with open(file_path, 'r', encoding='utf-8') as file:
82
+ webvtt_data = file.read()
83
+
84
+ data = []
85
+ blocks = webvtt_data.split('\n\n')
86
+
87
+ for block in blocks:
88
+ if block.strip() != '' and not block.strip().startswith("WebVTT"):
89
+ lines = block.strip().split('\n')
90
+ index = lines[0]
91
+ timestamp = lines[1]
92
+ sentence = ' '.join(lines[2:])
93
+
94
+ data.append({
95
+ "index": index,
96
+ "timestamp": timestamp,
97
+ "sentence": sentence
98
+ })
99
+
100
+ return data
101
+
102
+
103
+ def get_serialized_srt(dicts):
104
+ output = ""
105
+ for dic in dicts:
106
+ output += f'{dic["index"]}\n'
107
+ output += f'{dic["timestamp"]}\n'
108
+ output += f'{dic["sentence"]}\n\n'
109
+ return output
110
+
111
+
112
+ def get_serialized_vtt(dicts):
113
+ output = "WebVTT\n\n"
114
+ for dic in dicts:
115
+ output += f'{dic["index"]}\n'
116
+ output += f'{dic["timestamp"]}\n'
117
+ output += f'{dic["sentence"]}\n\n'
118
+ return output
119
+
120
+
121
+ def safe_filename(name):
122
+ from app import _args
123
+ INVALID_FILENAME_CHARS = r'[<>:"/\\|?*\x00-\x1f]'
124
+ safe_name = re.sub(INVALID_FILENAME_CHARS, '_', name)
125
+ if not _args.colab:
126
+ return safe_name
127
+ # Truncate the filename if it exceeds the max_length (20)
128
+ if len(safe_name) > 20:
129
+ file_extension = safe_name.split('.')[-1]
130
+ if len(file_extension) + 1 < 20:
131
+ truncated_name = safe_name[:20 - len(file_extension) - 1]
132
+ safe_name = truncated_name + '.' + file_extension
133
+ else:
134
+ safe_name = safe_name[:20]
135
+ return safe_name
modules/translation_base.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ from abc import ABC, abstractmethod
5
+ from typing import List
6
+ from datetime import datetime
7
+
8
+ from modules.whisper_parameter import *
9
+ from modules.subtitle_manager import *
10
+
11
+
12
+ class TranslationBase(ABC):
13
+ def __init__(self,
14
+ model_dir: str):
15
+ super().__init__()
16
+ self.model = None
17
+ self.model_dir = model_dir
18
+ os.makedirs(self.model_dir, exist_ok=True)
19
+ self.current_model_size = None
20
+ self.device = self.get_device()
21
+
22
+ @abstractmethod
23
+ def translate(self,
24
+ text: str
25
+ ):
26
+ pass
27
+
28
+ @abstractmethod
29
+ def update_model(self,
30
+ model_size: str,
31
+ src_lang: str,
32
+ tgt_lang: str,
33
+ progress: gr.Progress
34
+ ):
35
+ pass
36
+
37
+ def translate_file(self,
38
+ fileobjs: list,
39
+ model_size: str,
40
+ src_lang: str,
41
+ tgt_lang: str,
42
+ add_timestamp: bool,
43
+ progress=gr.Progress()) -> list:
44
+ """
45
+ Translate subtitle file from source language to target language
46
+
47
+ Parameters
48
+ ----------
49
+ fileobjs: list
50
+ List of files to transcribe from gr.Files()
51
+ model_size: str
52
+ Whisper model size from gr.Dropdown()
53
+ src_lang: str
54
+ Source language of the file to translate from gr.Dropdown()
55
+ tgt_lang: str
56
+ Target language of the file to translate from gr.Dropdown()
57
+ add_timestamp: bool
58
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
59
+ progress: gr.Progress
60
+ Indicator to show progress directly in gradio.
61
+ I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
62
+
63
+ Returns
64
+ ----------
65
+ A List of
66
+ String to return to gr.Textbox()
67
+ Files to return to gr.Files()
68
+ """
69
+ try:
70
+ self.update_model(model_size=model_size,
71
+ src_lang=src_lang,
72
+ tgt_lang=tgt_lang,
73
+ progress=progress)
74
+
75
+ files_info = {}
76
+ for fileobj in fileobjs:
77
+ file_path = fileobj.name
78
+ file_name, file_ext = os.path.splitext(os.path.basename(fileobj.name))
79
+ if file_ext == ".srt":
80
+ parsed_dicts = parse_srt(file_path=file_path)
81
+ total_progress = len(parsed_dicts)
82
+ for index, dic in enumerate(parsed_dicts):
83
+ progress(index / total_progress, desc="Translating..")
84
+ translated_text = self.translate(dic["sentence"])
85
+ dic["sentence"] = translated_text
86
+ subtitle = get_serialized_srt(parsed_dicts)
87
+
88
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
89
+ if add_timestamp:
90
+ output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}")
91
+ else:
92
+ output_path = os.path.join("outputs", "translations", f"{file_name}.srt")
93
+
94
+ elif file_ext == ".vtt":
95
+ parsed_dicts = parse_vtt(file_path=file_path)
96
+ total_progress = len(parsed_dicts)
97
+ for index, dic in enumerate(parsed_dicts):
98
+ progress(index / total_progress, desc="Translating..")
99
+ translated_text = self.translate(dic["sentence"])
100
+ dic["sentence"] = translated_text
101
+ subtitle = get_serialized_vtt(parsed_dicts)
102
+
103
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
104
+ if add_timestamp:
105
+ output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}")
106
+ else:
107
+ output_path = os.path.join("outputs", "translations", f"{file_name}.vtt")
108
+
109
+ write_file(subtitle, output_path)
110
+ files_info[file_name] = subtitle
111
+
112
+ total_result = ''
113
+ for file_name, subtitle in files_info.items():
114
+ total_result += '------------------------------------\n'
115
+ total_result += f'{file_name}\n\n'
116
+ total_result += f'{subtitle}'
117
+
118
+ gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
119
+ return [gr_str, output_path]
120
+ except Exception as e:
121
+ print(f"Error: {str(e)}")
122
+ finally:
123
+ self.release_cuda_memory()
124
+ self.remove_input_files([fileobj.name for fileobj in fileobjs])
125
+
126
+ @staticmethod
127
+ def get_device():
128
+ if torch.cuda.is_available():
129
+ return "cuda"
130
+ elif torch.backends.mps.is_available():
131
+ return "mps"
132
+ else:
133
+ return "cpu"
134
+
135
+ @staticmethod
136
+ def release_cuda_memory():
137
+ if torch.cuda.is_available():
138
+ torch.cuda.empty_cache()
139
+ torch.cuda.reset_max_memory_allocated()
140
+
141
+ @staticmethod
142
+ def remove_input_files(file_paths: List[str]):
143
+ if not file_paths:
144
+ return
145
+
146
+ for file_path in file_paths:
147
+ if file_path and os.path.exists(file_path):
148
+ os.remove(file_path)
modules/whisper_Inference.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import whisper
2
+ import gradio as gr
3
+ import time
4
+ import os
5
+ from typing import BinaryIO, Union, Tuple, List
6
+ import numpy as np
7
+ import torch
8
+
9
+ from modules.whisper_base import WhisperBase
10
+ from modules.whisper_parameter import *
11
+
12
+
13
+ class WhisperInference(WhisperBase):
14
+ def __init__(self):
15
+ super().__init__(
16
+ model_dir=os.path.join("models", "Whisper")
17
+ )
18
+
19
+ def transcribe(self,
20
+ audio: Union[str, np.ndarray, torch.Tensor],
21
+ progress: gr.Progress,
22
+ *whisper_params,
23
+ ) -> Tuple[List[dict], float]:
24
+ """
25
+ transcribe method for faster-whisper.
26
+
27
+ Parameters
28
+ ----------
29
+ audio: Union[str, BinaryIO, np.ndarray]
30
+ Audio path or file binary or Audio numpy array
31
+ progress: gr.Progress
32
+ Indicator to show progress directly in gradio.
33
+ *whisper_params: tuple
34
+ Gradio components related to Whisper. see whisper_data_class.py for details.
35
+
36
+ Returns
37
+ ----------
38
+ segments_result: List[dict]
39
+ list of dicts that includes start, end timestamps and transcribed text
40
+ elapsed_time: float
41
+ elapsed time for transcription
42
+ """
43
+ start_time = time.time()
44
+ params = WhisperValues(*whisper_params)
45
+
46
+ if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
47
+ self.update_model(params.model_size, params.compute_type, progress)
48
+
49
+ if params.lang == "Automatic Detection":
50
+ params.lang = None
51
+
52
+ def progress_callback(progress_value):
53
+ progress(progress_value, desc="Transcribing..")
54
+
55
+ segments_result = self.model.transcribe(audio=audio,
56
+ language=params.lang,
57
+ verbose=False,
58
+ beam_size=params.beam_size,
59
+ logprob_threshold=params.log_prob_threshold,
60
+ no_speech_threshold=params.no_speech_threshold,
61
+ task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
62
+ fp16=True if params.compute_type == "float16" else False,
63
+ best_of=params.best_of,
64
+ patience=params.patience,
65
+ temperature=params.temperature,
66
+ compression_ratio_threshold=params.compression_ratio_threshold,
67
+ progress_callback=progress_callback,)["segments"]
68
+ elapsed_time = time.time() - start_time
69
+
70
+ return segments_result, elapsed_time
71
+
72
+ def update_model(self,
73
+ model_size: str,
74
+ compute_type: str,
75
+ progress: gr.Progress,
76
+ ):
77
+ """
78
+ Update current model setting
79
+
80
+ Parameters
81
+ ----------
82
+ model_size: str
83
+ Size of whisper model
84
+ compute_type: str
85
+ Compute type for transcription.
86
+ see more info : https://opennmt.net/CTranslate2/quantization.html
87
+ progress: gr.Progress
88
+ Indicator to show progress directly in gradio.
89
+ """
90
+ progress(0, desc="Initializing Model..")
91
+ self.current_compute_type = compute_type
92
+ self.current_model_size = model_size
93
+ self.model = whisper.load_model(
94
+ name=model_size,
95
+ device=self.device,
96
+ download_root=self.model_dir
97
+ )
modules/whisper_base.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from typing import List
4
+ import whisper
5
+ import gradio as gr
6
+ from abc import ABC, abstractmethod
7
+ from typing import BinaryIO, Union, Tuple, List
8
+ import numpy as np
9
+ from datetime import datetime
10
+
11
+ from modules.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
12
+ from modules.youtube_manager import get_ytdata, get_ytaudio
13
+ from modules.whisper_parameter import *
14
+
15
+
16
+ class WhisperBase(ABC):
17
+ def __init__(self,
18
+ model_dir: str):
19
+ self.model = None
20
+ self.current_model_size = None
21
+ self.model_dir = model_dir
22
+ os.makedirs(self.model_dir, exist_ok=True)
23
+ self.available_models = whisper.available_models()
24
+ self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
25
+ self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
26
+ self.device = self.get_device()
27
+ self.available_compute_types = ["float16", "float32"]
28
+ self.current_compute_type = "float16" if self.device == "cuda" else "float32"
29
+
30
+ @abstractmethod
31
+ def transcribe(self,
32
+ audio: Union[str, BinaryIO, np.ndarray],
33
+ progress: gr.Progress,
34
+ *whisper_params,
35
+ ):
36
+ pass
37
+
38
+ @abstractmethod
39
+ def update_model(self,
40
+ model_size: str,
41
+ compute_type: str,
42
+ progress: gr.Progress
43
+ ):
44
+ pass
45
+
46
+ def transcribe_file(self,
47
+ files: list,
48
+ file_format: str,
49
+ add_timestamp: bool,
50
+ progress=gr.Progress(),
51
+ *whisper_params,
52
+ ) -> list:
53
+ """
54
+ Write subtitle file from Files
55
+
56
+ Parameters
57
+ ----------
58
+ files: list
59
+ List of files to transcribe from gr.Files()
60
+ file_format: str
61
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
62
+ add_timestamp: bool
63
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
64
+ progress: gr.Progress
65
+ Indicator to show progress directly in gradio.
66
+ *whisper_params: tuple
67
+ Gradio components related to Whisper. see whisper_data_class.py for details.
68
+
69
+ Returns
70
+ ----------
71
+ result_str:
72
+ Result of transcription to return to gr.Textbox()
73
+ result_file_path:
74
+ Output file path to return to gr.Files()
75
+ """
76
+ try:
77
+ files_info = {}
78
+ for file in files:
79
+ transcribed_segments, time_for_task = self.transcribe(
80
+ file.name,
81
+ progress,
82
+ *whisper_params,
83
+ )
84
+
85
+ file_name, file_ext = os.path.splitext(os.path.basename(file.name))
86
+ file_name = safe_filename(file_name)
87
+ subtitle, file_path = self.generate_and_write_file(
88
+ file_name=file_name,
89
+ transcribed_segments=transcribed_segments,
90
+ add_timestamp=add_timestamp,
91
+ file_format=file_format
92
+ )
93
+ files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
94
+
95
+ total_result = ''
96
+ total_time = 0
97
+ for file_name, info in files_info.items():
98
+ total_result += '------------------------------------\n'
99
+ total_result += f'{file_name}\n\n'
100
+ total_result += f'{info["subtitle"]}'
101
+ total_time += info["time_for_task"]
102
+
103
+ result_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
104
+ result_file_path = [info['path'] for info in files_info.values()]
105
+
106
+ return [result_str, result_file_path]
107
+
108
+ except Exception as e:
109
+ print(f"Error transcribing file: {e}")
110
+ finally:
111
+ self.release_cuda_memory()
112
+ if not files:
113
+ self.remove_input_files([file.name for file in files])
114
+
115
+ def transcribe_mic(self,
116
+ mic_audio: str,
117
+ file_format: str,
118
+ progress=gr.Progress(),
119
+ *whisper_params,
120
+ ) -> list:
121
+ """
122
+ Write subtitle file from microphone
123
+
124
+ Parameters
125
+ ----------
126
+ mic_audio: str
127
+ Audio file path from gr.Microphone()
128
+ file_format: str
129
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
130
+ progress: gr.Progress
131
+ Indicator to show progress directly in gradio.
132
+ *whisper_params: tuple
133
+ Gradio components related to Whisper. see whisper_data_class.py for details.
134
+
135
+ Returns
136
+ ----------
137
+ result_str:
138
+ Result of transcription to return to gr.Textbox()
139
+ result_file_path:
140
+ Output file path to return to gr.Files()
141
+ """
142
+ try:
143
+ progress(0, desc="Loading Audio..")
144
+ transcribed_segments, time_for_task = self.transcribe(
145
+ mic_audio,
146
+ progress,
147
+ *whisper_params,
148
+ )
149
+ progress(1, desc="Completed!")
150
+
151
+ subtitle, result_file_path = self.generate_and_write_file(
152
+ file_name="Mic",
153
+ transcribed_segments=transcribed_segments,
154
+ add_timestamp=True,
155
+ file_format=file_format
156
+ )
157
+
158
+ result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
159
+ return [result_str, result_file_path]
160
+ except Exception as e:
161
+ print(f"Error transcribing file: {e}")
162
+ finally:
163
+ self.release_cuda_memory()
164
+ self.remove_input_files([mic_audio])
165
+
166
+ def transcribe_youtube(self,
167
+ youtube_link: str,
168
+ file_format: str,
169
+ add_timestamp: bool,
170
+ progress=gr.Progress(),
171
+ *whisper_params,
172
+ ) -> list:
173
+ """
174
+ Write subtitle file from Youtube
175
+
176
+ Parameters
177
+ ----------
178
+ youtube_link: str
179
+ URL of the Youtube video to transcribe from gr.Textbox()
180
+ file_format: str
181
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
182
+ add_timestamp: bool
183
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
184
+ progress: gr.Progress
185
+ Indicator to show progress directly in gradio.
186
+ *whisper_params: tuple
187
+ Gradio components related to Whisper. see whisper_data_class.py for details.
188
+
189
+ Returns
190
+ ----------
191
+ result_str:
192
+ Result of transcription to return to gr.Textbox()
193
+ result_file_path:
194
+ Output file path to return to gr.Files()
195
+ """
196
+ try:
197
+ progress(0, desc="Loading Audio from Youtube..")
198
+ yt = get_ytdata(youtube_link)
199
+ audio = get_ytaudio(yt)
200
+
201
+ transcribed_segments, time_for_task = self.transcribe(
202
+ audio,
203
+ progress,
204
+ *whisper_params,
205
+ )
206
+
207
+ progress(1, desc="Completed!")
208
+
209
+ file_name = safe_filename(yt.title)
210
+ subtitle, result_file_path = self.generate_and_write_file(
211
+ file_name=file_name,
212
+ transcribed_segments=transcribed_segments,
213
+ add_timestamp=add_timestamp,
214
+ file_format=file_format
215
+ )
216
+ result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
217
+
218
+ return [result_str, result_file_path]
219
+
220
+ except Exception as e:
221
+ print(f"Error transcribing file: {e}")
222
+ finally:
223
+ try:
224
+ if 'yt' not in locals():
225
+ yt = get_ytdata(youtube_link)
226
+ file_path = get_ytaudio(yt)
227
+ else:
228
+ file_path = get_ytaudio(yt)
229
+
230
+ self.release_cuda_memory()
231
+ self.remove_input_files([file_path])
232
+ except Exception as cleanup_error:
233
+ pass
234
+
235
+ @staticmethod
236
+ def generate_and_write_file(file_name: str,
237
+ transcribed_segments: list,
238
+ add_timestamp: bool,
239
+ file_format: str,
240
+ ) -> str:
241
+ """
242
+ Writes subtitle file
243
+
244
+ Parameters
245
+ ----------
246
+ file_name: str
247
+ Output file name
248
+ transcribed_segments: list
249
+ Text segments transcribed from audio
250
+ add_timestamp: bool
251
+ Determines whether to add a timestamp to the end of the filename.
252
+ file_format: str
253
+ File format to write. Supported formats: [SRT, WebVTT, txt]
254
+
255
+ Returns
256
+ ----------
257
+ content: str
258
+ Result of the transcription
259
+ output_path: str
260
+ output file path
261
+ """
262
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
263
+ if add_timestamp:
264
+ output_path = os.path.join("outputs", f"{file_name}-{timestamp}")
265
+ else:
266
+ output_path = os.path.join("outputs", f"{file_name}")
267
+
268
+ if file_format == "SRT":
269
+ content = get_srt(transcribed_segments)
270
+ output_path += '.srt'
271
+ write_file(content, output_path)
272
+
273
+ elif file_format == "WebVTT":
274
+ content = get_vtt(transcribed_segments)
275
+ output_path += '.vtt'
276
+ write_file(content, output_path)
277
+
278
+ elif file_format == "txt":
279
+ content = get_txt(transcribed_segments)
280
+ output_path += '.txt'
281
+ write_file(content, output_path)
282
+ return content, output_path
283
+
284
+ @staticmethod
285
+ def format_time(elapsed_time: float) -> str:
286
+ """
287
+ Get {hours} {minutes} {seconds} time format string
288
+
289
+ Parameters
290
+ ----------
291
+ elapsed_time: str
292
+ Elapsed time for transcription
293
+
294
+ Returns
295
+ ----------
296
+ Time format string
297
+ """
298
+ hours, rem = divmod(elapsed_time, 3600)
299
+ minutes, seconds = divmod(rem, 60)
300
+
301
+ time_str = ""
302
+ if hours:
303
+ time_str += f"{hours} hours "
304
+ if minutes:
305
+ time_str += f"{minutes} minutes "
306
+ seconds = round(seconds)
307
+ time_str += f"{seconds} seconds"
308
+
309
+ return time_str.strip()
310
+
311
+ @staticmethod
312
+ def get_device():
313
+ if torch.cuda.is_available():
314
+ return "cuda"
315
+ elif torch.backends.mps.is_available():
316
+ return "mps"
317
+ else:
318
+ return "cpu"
319
+
320
+ @staticmethod
321
+ def release_cuda_memory():
322
+ if torch.cuda.is_available():
323
+ torch.cuda.empty_cache()
324
+ torch.cuda.reset_max_memory_allocated()
325
+
326
+ @staticmethod
327
+ def remove_input_files(file_paths: List[str]):
328
+ if not file_paths:
329
+ return
330
+
331
+ for file_path in file_paths:
332
+ if file_path and os.path.exists(file_path):
333
+ os.remove(file_path)
modules/whisper_parameter.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, fields
2
+ import gradio as gr
3
+ from typing import Optional
4
+
5
+
6
+ @dataclass
7
+ class WhisperGradioComponents:
8
+ model_size: gr.Dropdown
9
+ lang: gr.Dropdown
10
+ is_translate: gr.Checkbox
11
+ beam_size: gr.Number
12
+ log_prob_threshold: gr.Number
13
+ no_speech_threshold: gr.Number
14
+ compute_type: gr.Dropdown
15
+ best_of: gr.Number
16
+ patience: gr.Number
17
+ condition_on_previous_text: gr.Checkbox
18
+ initial_prompt: gr.Textbox
19
+ temperature: gr.Slider
20
+ compression_ratio_threshold: gr.Number
21
+ vad_filter: gr.Checkbox
22
+ threshold: gr.Slider
23
+ min_speech_duration_ms: gr.Number
24
+ max_speech_duration_s: gr.Number
25
+ min_silence_duration_ms: gr.Number
26
+ window_size_sample: gr.Number
27
+ speech_pad_ms: gr.Number
28
+ """
29
+ A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
30
+ See more about Gradio pre-processing: https://www.gradio.app/docs/components
31
+
32
+ Attributes
33
+ ----------
34
+ model_size: gr.Dropdown
35
+ Whisper model size.
36
+
37
+ lang: gr.Dropdown
38
+ Source language of the file to transcribe.
39
+
40
+ is_translate: gr.Checkbox
41
+ Boolean value that determines whether to translate to English.
42
+ It's Whisper's feature to translate speech from another language directly into English end-to-end.
43
+
44
+ beam_size: gr.Number
45
+ Int value that is used for decoding option.
46
+
47
+ log_prob_threshold: gr.Number
48
+ If the average log probability over sampled tokens is below this value, treat as failed.
49
+
50
+ no_speech_threshold: gr.Number
51
+ If the no_speech probability is higher than this value AND
52
+ the average log probability over sampled tokens is below `log_prob_threshold`,
53
+ consider the segment as silent.
54
+
55
+ compute_type: gr.Dropdown
56
+ compute type for transcription.
57
+ see more info : https://opennmt.net/CTranslate2/quantization.html
58
+
59
+ best_of: gr.Number
60
+ Number of candidates when sampling with non-zero temperature.
61
+
62
+ patience: gr.Number
63
+ Beam search patience factor.
64
+
65
+ condition_on_previous_text: gr.Checkbox
66
+ if True, the previous output of the model is provided as a prompt for the next window;
67
+ disabling may make the text inconsistent across windows, but the model becomes less prone to
68
+ getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
69
+
70
+ initial_prompt: gr.Textbox
71
+ Optional text to provide as a prompt for the first window. This can be used to provide, or
72
+ "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
73
+ to make it more likely to predict those word correctly.
74
+
75
+ temperature: gr.Slider
76
+ Temperature for sampling. It can be a tuple of temperatures,
77
+ which will be successively used upon failures according to either
78
+ `compression_ratio_threshold` or `log_prob_threshold`.
79
+
80
+ compression_ratio_threshold: gr.Number
81
+ If the gzip compression ratio is above this value, treat as failed
82
+
83
+ vad_filter: gr.Checkbox
84
+ Enable the voice activity detection (VAD) to filter out parts of the audio
85
+ without speech. This step is using the Silero VAD model
86
+ https://github.com/snakers4/silero-vad.
87
+
88
+ threshold: gr.Slider
89
+ This parameter is related with Silero VAD. Speech threshold.
90
+ Silero VAD outputs speech probabilities for each audio chunk,
91
+ probabilities ABOVE this value are considered as SPEECH. It is better to tune this
92
+ parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
93
+
94
+ min_speech_duration_ms: gr.Number
95
+ This parameter is related with Silero VAD. Final speech chunks shorter min_speech_duration_ms are thrown out.
96
+
97
+ max_speech_duration_s: gr.Number
98
+ This parameter is related with Silero VAD. Maximum duration of speech chunks in seconds. Chunks longer
99
+ than max_speech_duration_s will be split at the timestamp of the last silence that
100
+ lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
101
+ split aggressively just before max_speech_duration_s.
102
+
103
+ min_silence_duration_ms: gr.Number
104
+ This parameter is related with Silero VAD. In the end of each speech chunk wait for min_silence_duration_ms
105
+ before separating it
106
+
107
+ window_size_samples: gr.Number
108
+ This parameter is related with Silero VAD. Audio chunks of window_size_samples size are fed to the silero VAD model.
109
+ WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
110
+ Values other than these may affect model performance!!
111
+
112
+ speech_pad_ms: gr.Number
113
+ This parameter is related with Silero VAD. Final speech chunks are padded by speech_pad_ms each side
114
+ """
115
+
116
+ def to_list(self) -> list:
117
+ """
118
+ Converts the data class attributes into a list. Use "before" Gradio pre-processing.
119
+ See more about Gradio pre-processing: : https://www.gradio.app/docs/components
120
+
121
+ Returns
122
+ ----------
123
+ A list of Gradio components
124
+ """
125
+ return [getattr(self, f.name) for f in fields(self)]
126
+
127
+
128
+ @dataclass
129
+ class WhisperValues:
130
+ model_size: str
131
+ lang: str
132
+ is_translate: bool
133
+ beam_size: int
134
+ log_prob_threshold: float
135
+ no_speech_threshold: float
136
+ compute_type: str
137
+ best_of: int
138
+ patience: float
139
+ condition_on_previous_text: bool
140
+ initial_prompt: Optional[str]
141
+ temperature: float
142
+ compression_ratio_threshold: float
143
+ vad_filter: bool
144
+ threshold: float
145
+ min_speech_duration_ms: int
146
+ max_speech_duration_s: float
147
+ min_silence_duration_ms: int
148
+ window_size_samples: int
149
+ speech_pad_ms: int
150
+ """
151
+ A data class to use Whisper parameters. Use "after" Gradio pre-processing.
152
+ See more about Gradio pre-processing: : https://www.gradio.app/docs/components
153
+ """
modules/youtube_manager.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pytube import YouTube
2
+ import os
3
+
4
+
5
+ def get_ytdata(link):
6
+ return YouTube(link)
7
+
8
+
9
+ def get_ytmetas(link):
10
+ yt = YouTube(link)
11
+ return yt.thumbnail_url, yt.title, yt.description
12
+
13
+
14
+ def get_ytaudio(ytdata: YouTube):
15
+ return ytdata.streams.get_audio_only().download(filename=os.path.join("modules", "yt_tmp.wav"))