IRISLAB commited on
Commit
84ffa23
1 Parent(s): 17d8211

Upload 28 files

Browse files
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ venv/
2
+ ui/__pycache__/
3
+ outputs/
4
+ modules/__pycache__/
5
+ models/
6
+ modules/yt_tmp.wav
app.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import argparse
4
+
5
+ from modules.whisper_Inference import WhisperInference
6
+ from modules.faster_whisper_inference import FasterWhisperInference
7
+ from modules.nllb_inference import NLLBInference
8
+ from ui.htmls import *
9
+ from modules.youtube_manager import get_ytmetas
10
+ from modules.deepl_api import DeepLAPI
11
+ from modules.whisper_parameter import *
12
+
13
+
14
+ class App:
15
+ def __init__(self, args):
16
+ self.args = args
17
+ self.app = gr.Blocks(css=CSS, theme=self.args.theme)
18
+ self.whisper_inf = self.init_whisper()
19
+ print(f"Use \"{self.args.whisper_type}\" implementation")
20
+ print(f"Device \"{self.whisper_inf.device}\" is detected")
21
+ self.nllb_inf = NLLBInference()
22
+ self.deepl_api = DeepLAPI()
23
+
24
+ def init_whisper(self):
25
+ whisper_type = self.args.whisper_type.lower().strip()
26
+
27
+ if whisper_type in ["faster_whisper", "faster-whisper"]:
28
+ whisper_inf = FasterWhisperInference()
29
+ whisper_inf.model_dir = self.args.faster_whisper_model_dir
30
+ if whisper_type in ["whisper"]:
31
+ whisper_inf = WhisperInference()
32
+ whisper_inf.model_dir = self.args.whisper_model_dir
33
+ else:
34
+ whisper_inf = FasterWhisperInference()
35
+ whisper_inf.model_dir = self.args.faster_whisper_model_dir
36
+ return whisper_inf
37
+
38
+ @staticmethod
39
+ def open_folder(folder_path: str):
40
+ if os.path.exists(folder_path):
41
+ os.system(f"start {folder_path}")
42
+ else:
43
+ print(f"The folder {folder_path} does not exist.")
44
+
45
+ @staticmethod
46
+ def on_change_models(model_size: str):
47
+ translatable_model = ["large", "large-v1", "large-v2", "large-v3"]
48
+ if model_size not in translatable_model:
49
+ return gr.Checkbox(visible=False, value=False, interactive=False)
50
+ else:
51
+ return gr.Checkbox(visible=True, value=False, label="Translate to English?", interactive=True)
52
+
53
+ def launch(self):
54
+ with self.app:
55
+ with gr.Row():
56
+ with gr.Column():
57
+ gr.Markdown(MARKDOWN, elem_id="md_project")
58
+ with gr.Tabs():
59
+ with gr.TabItem("File"): # tab1
60
+ with gr.Row():
61
+ input_file = gr.Files(type="filepath", label="Upload File here")
62
+ with gr.Row():
63
+ dd_model = gr.Dropdown(choices=self.whisper_inf.available_models, value="large-v2",
64
+ label="Model")
65
+ dd_lang = gr.Dropdown(choices=["Automatic Detection"] + self.whisper_inf.available_langs,
66
+ value="Automatic Detection", label="Language")
67
+ dd_file_format = gr.Dropdown(["SRT", "WebVTT", "txt"], value="SRT", label="File Format")
68
+ with gr.Row():
69
+ cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
70
+ with gr.Row():
71
+ cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename", interactive=True)
72
+ with gr.Accordion("VAD Options", open=False, visible=isinstance(self.whisper_inf, FasterWhisperInference)):
73
+ cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
74
+ sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
75
+ nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
76
+ nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=9999)
77
+ nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
78
+ nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
79
+ nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
80
+ with gr.Accordion("Advanced_Parameters", open=False):
81
+ nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
82
+ nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
83
+ nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
84
+ dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types, value=self.whisper_inf.current_compute_type, interactive=True)
85
+ nb_best_of = gr.Number(label="Best Of", value=5, interactive=True)
86
+ nb_patience = gr.Number(label="Patience", value=1, interactive=True)
87
+ cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", value=True, interactive=True)
88
+ tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
89
+ sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
90
+ nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=2.4, interactive=True)
91
+ with gr.Row():
92
+ btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
93
+ with gr.Row():
94
+ tb_indicator = gr.Textbox(label="Output", scale=5)
95
+ files_subtitles = gr.Files(label="Downloadable output file", scale=3, interactive=False)
96
+ btn_openfolder = gr.Button('📂', scale=1)
97
+
98
+ params = [input_file, dd_file_format, cb_timestamp]
99
+ whisper_params = WhisperGradioComponents(model_size=dd_model,
100
+ lang=dd_lang,
101
+ is_translate=cb_translate,
102
+ beam_size=nb_beam_size,
103
+ log_prob_threshold=nb_log_prob_threshold,
104
+ no_speech_threshold=nb_no_speech_threshold,
105
+ compute_type=dd_compute_type,
106
+ best_of=nb_best_of,
107
+ patience=nb_patience,
108
+ condition_on_previous_text=cb_condition_on_previous_text,
109
+ initial_prompt=tb_initial_prompt,
110
+ temperature=sd_temperature,
111
+ compression_ratio_threshold=nb_compression_ratio_threshold,
112
+ vad_filter=cb_vad_filter,
113
+ threshold=sd_threshold,
114
+ min_speech_duration_ms=nb_min_speech_duration_ms,
115
+ max_speech_duration_s=nb_max_speech_duration_s,
116
+ min_silence_duration_ms=nb_min_silence_duration_ms,
117
+ window_size_sample=nb_window_size_sample,
118
+ speech_pad_ms=nb_speech_pad_ms)
119
+
120
+ btn_run.click(fn=self.whisper_inf.transcribe_file,
121
+ inputs=params + whisper_params.to_list(),
122
+ outputs=[tb_indicator, files_subtitles])
123
+ btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
124
+ dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
125
+
126
+ with gr.TabItem("Youtube"): # tab2
127
+ with gr.Row():
128
+ tb_youtubelink = gr.Textbox(label="Youtube Link")
129
+ with gr.Row(equal_height=True):
130
+ with gr.Column():
131
+ img_thumbnail = gr.Image(label="Youtube Thumbnail")
132
+ with gr.Column():
133
+ tb_title = gr.Label(label="Youtube Title")
134
+ tb_description = gr.Textbox(label="Youtube Description", max_lines=15)
135
+ with gr.Row():
136
+ dd_model = gr.Dropdown(choices=self.whisper_inf.available_models, value="large-v2",
137
+ label="Model")
138
+ dd_lang = gr.Dropdown(choices=["Automatic Detection"] + self.whisper_inf.available_langs,
139
+ value="Automatic Detection", label="Language")
140
+ dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt"], value="SRT", label="File Format")
141
+ with gr.Row():
142
+ cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
143
+ with gr.Row():
144
+ cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
145
+ interactive=True)
146
+ with gr.Accordion("VAD Options", open=False, visible=isinstance(self.whisper_inf, FasterWhisperInference)):
147
+ cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
148
+ sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
149
+ nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
150
+ nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=9999)
151
+ nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
152
+ nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
153
+ nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
154
+ with gr.Accordion("Advanced_Parameters", open=False):
155
+ nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
156
+ nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
157
+ nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
158
+ dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types, value=self.whisper_inf.current_compute_type, interactive=True)
159
+ nb_best_of = gr.Number(label="Best Of", value=5, interactive=True)
160
+ nb_patience = gr.Number(label="Patience", value=1, interactive=True)
161
+ cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", value=True, interactive=True)
162
+ tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
163
+ sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
164
+ nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=2.4, interactive=True)
165
+ with gr.Row():
166
+ btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
167
+ with gr.Row():
168
+ tb_indicator = gr.Textbox(label="Output", scale=5)
169
+ files_subtitles = gr.Files(label="Downloadable output file", scale=3)
170
+ btn_openfolder = gr.Button('📂', scale=1)
171
+
172
+ params = [tb_youtubelink, dd_file_format, cb_timestamp]
173
+ whisper_params = WhisperGradioComponents(model_size=dd_model,
174
+ lang=dd_lang,
175
+ is_translate=cb_translate,
176
+ beam_size=nb_beam_size,
177
+ log_prob_threshold=nb_log_prob_threshold,
178
+ no_speech_threshold=nb_no_speech_threshold,
179
+ compute_type=dd_compute_type,
180
+ best_of=nb_best_of,
181
+ patience=nb_patience,
182
+ condition_on_previous_text=cb_condition_on_previous_text,
183
+ initial_prompt=tb_initial_prompt,
184
+ temperature=sd_temperature,
185
+ compression_ratio_threshold=nb_compression_ratio_threshold,
186
+ vad_filter=cb_vad_filter,
187
+ threshold=sd_threshold,
188
+ min_speech_duration_ms=nb_min_speech_duration_ms,
189
+ max_speech_duration_s=nb_max_speech_duration_s,
190
+ min_silence_duration_ms=nb_min_silence_duration_ms,
191
+ window_size_sample=nb_window_size_sample,
192
+ speech_pad_ms=nb_speech_pad_ms)
193
+ btn_run.click(fn=self.whisper_inf.transcribe_youtube,
194
+ inputs=params + whisper_params.to_list(),
195
+ outputs=[tb_indicator, files_subtitles])
196
+ tb_youtubelink.change(get_ytmetas, inputs=[tb_youtubelink],
197
+ outputs=[img_thumbnail, tb_title, tb_description])
198
+ btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
199
+ dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
200
+
201
+ with gr.TabItem("Mic"): # tab3
202
+ with gr.Row():
203
+ mic_input = gr.Microphone(label="Record with Mic", type="filepath", interactive=True)
204
+ with gr.Row():
205
+ dd_model = gr.Dropdown(choices=self.whisper_inf.available_models, value="large-v2",
206
+ label="Model")
207
+ dd_lang = gr.Dropdown(choices=["Automatic Detection"] + self.whisper_inf.available_langs,
208
+ value="Automatic Detection", label="Language")
209
+ dd_file_format = gr.Dropdown(["SRT", "WebVTT", "txt"], value="SRT", label="File Format")
210
+ with gr.Row():
211
+ cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
212
+ with gr.Accordion("VAD Options", open=False, visible=isinstance(self.whisper_inf, FasterWhisperInference)):
213
+ cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
214
+ sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
215
+ nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
216
+ nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=9999)
217
+ nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
218
+ nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
219
+ nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
220
+ with gr.Accordion("Advanced_Parameters", open=False):
221
+ nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
222
+ nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
223
+ nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
224
+ dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types, value=self.whisper_inf.current_compute_type, interactive=True)
225
+ nb_best_of = gr.Number(label="Best Of", value=5, interactive=True)
226
+ nb_patience = gr.Number(label="Patience", value=1, interactive=True)
227
+ cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", value=True, interactive=True)
228
+ tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
229
+ sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
230
+ with gr.Row():
231
+ btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
232
+ with gr.Row():
233
+ tb_indicator = gr.Textbox(label="Output", scale=5)
234
+ files_subtitles = gr.Files(label="Downloadable output file", scale=3)
235
+ btn_openfolder = gr.Button('📂', scale=1)
236
+
237
+ params = [mic_input, dd_file_format]
238
+ whisper_params = WhisperGradioComponents(model_size=dd_model,
239
+ lang=dd_lang,
240
+ is_translate=cb_translate,
241
+ beam_size=nb_beam_size,
242
+ log_prob_threshold=nb_log_prob_threshold,
243
+ no_speech_threshold=nb_no_speech_threshold,
244
+ compute_type=dd_compute_type,
245
+ best_of=nb_best_of,
246
+ patience=nb_patience,
247
+ condition_on_previous_text=cb_condition_on_previous_text,
248
+ initial_prompt=tb_initial_prompt,
249
+ temperature=sd_temperature,
250
+ compression_ratio_threshold=nb_compression_ratio_threshold,
251
+ vad_filter=cb_vad_filter,
252
+ threshold=sd_threshold,
253
+ min_speech_duration_ms=nb_min_speech_duration_ms,
254
+ max_speech_duration_s=nb_max_speech_duration_s,
255
+ min_silence_duration_ms=nb_min_silence_duration_ms,
256
+ window_size_sample=nb_window_size_sample,
257
+ speech_pad_ms=nb_speech_pad_ms)
258
+ btn_run.click(fn=self.whisper_inf.transcribe_mic,
259
+ inputs=params + whisper_params.to_list(),
260
+ outputs=[tb_indicator, files_subtitles])
261
+ btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
262
+ dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
263
+
264
+ with gr.TabItem("T2T Translation"): # tab 4
265
+ with gr.Row():
266
+ file_subs = gr.Files(type="filepath", label="Upload Subtitle Files to translate here",
267
+ file_types=['.vtt', '.srt'])
268
+
269
+ with gr.TabItem("DeepL API"): # sub tab1
270
+ with gr.Row():
271
+ tb_authkey = gr.Textbox(label="Your Auth Key (API KEY)",
272
+ value="")
273
+ with gr.Row():
274
+ dd_deepl_sourcelang = gr.Dropdown(label="Source Language", value="Automatic Detection",
275
+ choices=list(
276
+ self.deepl_api.available_source_langs.keys()))
277
+ dd_deepl_targetlang = gr.Dropdown(label="Target Language", value="English",
278
+ choices=list(
279
+ self.deepl_api.available_target_langs.keys()))
280
+ with gr.Row():
281
+ cb_deepl_ispro = gr.Checkbox(label="Pro User?", value=False)
282
+ with gr.Row():
283
+ btn_run = gr.Button("TRANSLATE SUBTITLE FILE", variant="primary")
284
+ with gr.Row():
285
+ tb_indicator = gr.Textbox(label="Output", scale=5)
286
+ files_subtitles = gr.Files(label="Downloadable output file", scale=3)
287
+ btn_openfolder = gr.Button('📂', scale=1)
288
+
289
+ btn_run.click(fn=self.deepl_api.translate_deepl,
290
+ inputs=[tb_authkey, file_subs, dd_deepl_sourcelang, dd_deepl_targetlang,
291
+ cb_deepl_ispro],
292
+ outputs=[tb_indicator, files_subtitles])
293
+
294
+ btn_openfolder.click(fn=lambda: self.open_folder(os.path.join("outputs", "translations")),
295
+ inputs=None,
296
+ outputs=None)
297
+
298
+ with gr.TabItem("NLLB"): # sub tab2
299
+ with gr.Row():
300
+ dd_nllb_model = gr.Dropdown(label="Model", value="facebook/nllb-200-1.3B",
301
+ choices=self.nllb_inf.available_models)
302
+ dd_nllb_sourcelang = gr.Dropdown(label="Source Language",
303
+ choices=self.nllb_inf.available_source_langs)
304
+ dd_nllb_targetlang = gr.Dropdown(label="Target Language",
305
+ choices=self.nllb_inf.available_target_langs)
306
+ with gr.Row():
307
+ cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
308
+ interactive=True)
309
+ with gr.Row():
310
+ btn_run = gr.Button("TRANSLATE SUBTITLE FILE", variant="primary")
311
+ with gr.Row():
312
+ tb_indicator = gr.Textbox(label="Output", scale=5)
313
+ files_subtitles = gr.Files(label="Downloadable output file", scale=3)
314
+ btn_openfolder = gr.Button('📂', scale=1)
315
+ with gr.Column():
316
+ md_vram_table = gr.HTML(NLLB_VRAM_TABLE, elem_id="md_nllb_vram_table")
317
+
318
+ btn_run.click(fn=self.nllb_inf.translate_file,
319
+ inputs=[file_subs, dd_nllb_model, dd_nllb_sourcelang, dd_nllb_targetlang, cb_timestamp],
320
+ outputs=[tb_indicator, files_subtitles])
321
+
322
+ btn_openfolder.click(fn=lambda: self.open_folder(os.path.join("outputs", "translations")),
323
+ inputs=None,
324
+ outputs=None)
325
+
326
+ # Launch the app with optional gradio settings
327
+ launch_args = {}
328
+ if self.args.share:
329
+ launch_args['share'] = self.args.share
330
+ if self.args.server_name:
331
+ launch_args['server_name'] = self.args.server_name
332
+ if self.args.server_port:
333
+ launch_args['server_port'] = self.args.server_port
334
+ if self.args.username and self.args.password:
335
+ launch_args['auth'] = (self.args.username, self.args.password)
336
+ launch_args['inbrowser'] = True
337
+
338
+ self.app.queue(api_open=False).launch(**launch_args)
339
+
340
+
341
+ # Create the parser for command-line arguments
342
+ parser = argparse.ArgumentParser()
343
+ parser.add_argument('--whisper_type', type=str, default="faster-whisper", help='A type of the whisper implementation between: ["whisper", "faster-whisper"]')
344
+ parser.add_argument('--share', type=bool, default=False, nargs='?', const=True, help='Gradio share value')
345
+ parser.add_argument('--server_name', type=str, default=None, help='Gradio server host')
346
+ parser.add_argument('--server_port', type=int, default=None, help='Gradio server port')
347
+ parser.add_argument('--username', type=str, default=None, help='Gradio authentication username')
348
+ parser.add_argument('--password', type=str, default=None, help='Gradio authentication password')
349
+ parser.add_argument('--theme', type=str, default=None, help='Gradio Blocks theme')
350
+ parser.add_argument('--colab', type=bool, default=False, nargs='?', const=True, help='Is colab user or not')
351
+ parser.add_argument('--api_open', type=bool, default=False, nargs='?', const=True, help='enable api or not')
352
+ parser.add_argument('--whisper_model_dir', type=str, default=os.path.join("models", "Whisper"), help='Directory path of the whisper model')
353
+ parser.add_argument('--faster_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "faster-whisper"), help='Directory path of the faster-whisper model')
354
+ _args = parser.parse_args()
355
+
356
+ if __name__ == "__main__":
357
+ app = App(args=_args)
358
+ app.launch()
models/models will be saved here.txt ADDED
File without changes
modules/__init__.py ADDED
File without changes
modules/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (152 Bytes). View file
 
modules/__pycache__/deepl_api.cpython-312.pyc ADDED
Binary file (7.51 kB). View file
 
modules/__pycache__/faster_whisper_inference.cpython-312.pyc ADDED
Binary file (7.5 kB). View file
 
modules/__pycache__/nllb_inference.cpython-312.pyc ADDED
Binary file (11.9 kB). View file
 
modules/__pycache__/subtitle_manager.cpython-312.pyc ADDED
Binary file (6.03 kB). View file
 
modules/__pycache__/translation_base.cpython-312.pyc ADDED
Binary file (7.46 kB). View file
 
modules/__pycache__/whisper_Inference.cpython-312.pyc ADDED
Binary file (4.76 kB). View file
 
modules/__pycache__/whisper_base.cpython-312.pyc ADDED
Binary file (14.5 kB). View file
 
modules/__pycache__/whisper_parameter.cpython-312.pyc ADDED
Binary file (2.96 kB). View file
 
modules/__pycache__/youtube_manager.cpython-312.pyc ADDED
Binary file (1.03 kB). View file
 
modules/deepl_api.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import time
3
+ import os
4
+ from datetime import datetime
5
+ import gradio as gr
6
+
7
+ from modules.subtitle_manager import *
8
+
9
+ """
10
+ This is written with reference to the DeepL API documentation.
11
+ If you want to know the information of the DeepL API, see here: https://www.deepl.com/docs-api/documents
12
+ """
13
+
14
+ DEEPL_AVAILABLE_TARGET_LANGS = {
15
+ 'Bulgarian': 'BG',
16
+ 'Czech': 'CS',
17
+ 'Danish': 'DA',
18
+ 'German': 'DE',
19
+ 'Greek': 'EL',
20
+ 'English': 'EN',
21
+ 'English (British)': 'EN-GB',
22
+ 'English (American)': 'EN-US',
23
+ 'Spanish': 'ES',
24
+ 'Estonian': 'ET',
25
+ 'Finnish': 'FI',
26
+ 'French': 'FR',
27
+ 'Hungarian': 'HU',
28
+ 'Indonesian': 'ID',
29
+ 'Italian': 'IT',
30
+ 'Japanese': 'JA',
31
+ 'Korean': 'KO',
32
+ 'Lithuanian': 'LT',
33
+ 'Latvian': 'LV',
34
+ 'Norwegian (Bokmål)': 'NB',
35
+ 'Dutch': 'NL',
36
+ 'Polish': 'PL',
37
+ 'Portuguese': 'PT',
38
+ 'Portuguese (Brazilian)': 'PT-BR',
39
+ 'Portuguese (all Portuguese varieties excluding Brazilian Portuguese)': 'PT-PT',
40
+ 'Romanian': 'RO',
41
+ 'Russian': 'RU',
42
+ 'Slovak': 'SK',
43
+ 'Slovenian': 'SL',
44
+ 'Swedish': 'SV',
45
+ 'Turkish': 'TR',
46
+ 'Ukrainian': 'UK',
47
+ 'Chinese (simplified)': 'ZH'
48
+ }
49
+
50
+ DEEPL_AVAILABLE_SOURCE_LANGS = {
51
+ 'Automatic Detection': None,
52
+ 'Bulgarian': 'BG',
53
+ 'Czech': 'CS',
54
+ 'Danish': 'DA',
55
+ 'German': 'DE',
56
+ 'Greek': 'EL',
57
+ 'English': 'EN',
58
+ 'Spanish': 'ES',
59
+ 'Estonian': 'ET',
60
+ 'Finnish': 'FI',
61
+ 'French': 'FR',
62
+ 'Hungarian': 'HU',
63
+ 'Indonesian': 'ID',
64
+ 'Italian': 'IT',
65
+ 'Japanese': 'JA',
66
+ 'Korean': 'KO',
67
+ 'Lithuanian': 'LT',
68
+ 'Latvian': 'LV',
69
+ 'Norwegian (Bokmål)': 'NB',
70
+ 'Dutch': 'NL',
71
+ 'Polish': 'PL',
72
+ 'Portuguese (all Portuguese varieties mixed)': 'PT',
73
+ 'Romanian': 'RO',
74
+ 'Russian': 'RU',
75
+ 'Slovak': 'SK',
76
+ 'Slovenian': 'SL',
77
+ 'Swedish': 'SV',
78
+ 'Turkish': 'TR',
79
+ 'Ukrainian': 'UK',
80
+ 'Chinese': 'ZH'
81
+ }
82
+
83
+
84
+ class DeepLAPI:
85
+ def __init__(self):
86
+ self.api_interval = 1
87
+ self.max_text_batch_size = 50
88
+ self.available_target_langs = DEEPL_AVAILABLE_TARGET_LANGS
89
+ self.available_source_langs = DEEPL_AVAILABLE_SOURCE_LANGS
90
+
91
+ def translate_deepl(self,
92
+ auth_key: str,
93
+ fileobjs: list,
94
+ source_lang: str,
95
+ target_lang: str,
96
+ is_pro: bool,
97
+ progress=gr.Progress()) -> list:
98
+ """
99
+ Translate subtitle files using DeepL API
100
+ Parameters
101
+ ----------
102
+ auth_key: str
103
+ API Key for DeepL from gr.Textbox()
104
+ fileobjs: list
105
+ List of files to transcribe from gr.Files()
106
+ source_lang: str
107
+ Source language of the file to transcribe from gr.Dropdown()
108
+ target_lang: str
109
+ Target language of the file to transcribe from gr.Dropdown()
110
+ is_pro: str
111
+ Boolean value that is about pro user or not from gr.Checkbox().
112
+ progress: gr.Progress
113
+ Indicator to show progress directly in gradio.
114
+ Returns
115
+ ----------
116
+ A List of
117
+ String to return to gr.Textbox()
118
+ Files to return to gr.Files()
119
+ """
120
+
121
+ files_info = {}
122
+ for fileobj in fileobjs:
123
+ file_path = fileobj.name
124
+ file_name, file_ext = os.path.splitext(os.path.basename(fileobj.name))
125
+
126
+ if file_ext == ".srt":
127
+ parsed_dicts = parse_srt(file_path=file_path)
128
+
129
+ batch_size = self.max_text_batch_size
130
+ for batch_start in range(0, len(parsed_dicts), batch_size):
131
+ batch_end = min(batch_start + batch_size, len(parsed_dicts))
132
+ sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
133
+ translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
134
+ target_lang, is_pro)
135
+ for i, translated_text in enumerate(translated_texts):
136
+ parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
137
+ progress(batch_end / len(parsed_dicts), desc="Translating..")
138
+
139
+ subtitle = get_serialized_srt(parsed_dicts)
140
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
141
+
142
+ file_name = file_name[:-9]
143
+ output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}.srt")
144
+ write_file(subtitle, output_path)
145
+
146
+ elif file_ext == ".vtt":
147
+ parsed_dicts = parse_vtt(file_path=file_path)
148
+
149
+ batch_size = self.max_text_batch_size
150
+ for batch_start in range(0, len(parsed_dicts), batch_size):
151
+ batch_end = min(batch_start + batch_size, len(parsed_dicts))
152
+ sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
153
+ translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
154
+ target_lang, is_pro)
155
+ for i, translated_text in enumerate(translated_texts):
156
+ parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
157
+ progress(batch_end / len(parsed_dicts), desc="Translating..")
158
+
159
+ subtitle = get_serialized_vtt(parsed_dicts)
160
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
161
+
162
+ file_name = file_name[:-9]
163
+ output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}.vtt")
164
+
165
+ write_file(subtitle, output_path)
166
+
167
+ files_info[file_name] = subtitle
168
+ total_result = ''
169
+ for file_name, subtitle in files_info.items():
170
+ total_result += '------------------------------------\n'
171
+ total_result += f'{file_name}\n\n'
172
+ total_result += f'{subtitle}'
173
+
174
+ gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
175
+ return [gr_str, output_path]
176
+
177
+ def request_deepl_translate(self,
178
+ auth_key: str,
179
+ text: list,
180
+ source_lang: str,
181
+ target_lang: str,
182
+ is_pro: bool):
183
+ """Request API response to DeepL server"""
184
+
185
+ url = 'https://api.deepl.com/v2/translate' if is_pro else 'https://api-free.deepl.com/v2/translate'
186
+ headers = {
187
+ 'Authorization': f'DeepL-Auth-Key {auth_key}'
188
+ }
189
+ data = {
190
+ 'text': text,
191
+ 'source_lang': DEEPL_AVAILABLE_SOURCE_LANGS[source_lang],
192
+ 'target_lang': DEEPL_AVAILABLE_TARGET_LANGS[target_lang]
193
+ }
194
+ response = requests.post(url, headers=headers, data=data).json()
195
+ time.sleep(self.api_interval)
196
+ return response["translations"]
modules/faster_whisper_inference.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import numpy as np
4
+ from typing import BinaryIO, Union, Tuple, List
5
+
6
+ import faster_whisper
7
+ from faster_whisper.vad import VadOptions
8
+ import ctranslate2
9
+ import whisper
10
+ import gradio as gr
11
+
12
+ from modules.whisper_parameter import *
13
+ from modules.whisper_base import WhisperBase
14
+
15
+ # Temporal fix of the issue : https://github.com/jhj0517/Whisper-WebUI/issues/144
16
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
17
+
18
+
19
+ class FasterWhisperInference(WhisperBase):
20
+ def __init__(self):
21
+ super().__init__(
22
+ model_dir=os.path.join("models", "Whisper", "faster-whisper")
23
+ )
24
+ self.model_paths = self.get_model_paths()
25
+ self.available_models = self.model_paths.keys()
26
+ self.available_compute_types = ctranslate2.get_supported_compute_types(
27
+ "cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
28
+
29
+ def transcribe(self,
30
+ audio: Union[str, BinaryIO, np.ndarray],
31
+ progress: gr.Progress,
32
+ *whisper_params,
33
+ ) -> Tuple[List[dict], float]:
34
+ """
35
+ transcribe method for faster-whisper.
36
+
37
+ Parameters
38
+ ----------
39
+ audio: Union[str, BinaryIO, np.ndarray]
40
+ Audio path or file binary or Audio numpy array
41
+ progress: gr.Progress
42
+ Indicator to show progress directly in gradio.
43
+ *whisper_params: tuple
44
+ Gradio components related to Whisper. see whisper_data_class.py for details.
45
+
46
+ Returns
47
+ ----------
48
+ segments_result: List[dict]
49
+ list of dicts that includes start, end timestamps and transcribed text
50
+ elapsed_time: float
51
+ elapsed time for transcription
52
+ """
53
+ start_time = time.time()
54
+
55
+ params = WhisperValues(*whisper_params)
56
+
57
+ if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
58
+ self.update_model(params.model_size, params.compute_type, progress)
59
+
60
+ if params.lang == "Automatic Detection":
61
+ params.lang = None
62
+ else:
63
+ language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
64
+ params.lang = language_code_dict[params.lang]
65
+
66
+ vad_options = VadOptions(
67
+ threshold=params.threshold,
68
+ min_speech_duration_ms=params.min_speech_duration_ms,
69
+ max_speech_duration_s=params.max_speech_duration_s,
70
+ min_silence_duration_ms=params.min_silence_duration_ms,
71
+ window_size_samples=params.window_size_samples,
72
+ speech_pad_ms=params.speech_pad_ms
73
+ )
74
+
75
+ segments, info = self.model.transcribe(
76
+ audio=audio,
77
+ language=params.lang,
78
+ task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
79
+ beam_size=params.beam_size,
80
+ log_prob_threshold=params.log_prob_threshold,
81
+ no_speech_threshold=params.no_speech_threshold,
82
+ best_of=params.best_of,
83
+ patience=params.patience,
84
+ temperature=params.temperature,
85
+ compression_ratio_threshold=params.compression_ratio_threshold,
86
+ vad_filter=params.vad_filter,
87
+ vad_parameters=vad_options
88
+ )
89
+ progress(0, desc="Loading audio..")
90
+
91
+ segments_result = []
92
+ for segment in segments:
93
+ progress(segment.start / info.duration, desc="Transcribing..")
94
+ segments_result.append({
95
+ "start": segment.start,
96
+ "end": segment.end,
97
+ "text": segment.text
98
+ })
99
+
100
+ elapsed_time = time.time() - start_time
101
+ return segments_result, elapsed_time
102
+
103
+ def update_model(self,
104
+ model_size: str,
105
+ compute_type: str,
106
+ progress: gr.Progress
107
+ ):
108
+ """
109
+ Update current model setting
110
+
111
+ Parameters
112
+ ----------
113
+ model_size: str
114
+ Size of whisper model
115
+ compute_type: str
116
+ Compute type for transcription.
117
+ see more info : https://opennmt.net/CTranslate2/quantization.html
118
+ progress: gr.Progress
119
+ Indicator to show progress directly in gradio.
120
+ """
121
+ progress(0, desc="Initializing Model..")
122
+ self.current_model_size = self.model_paths[model_size]
123
+ self.current_compute_type = compute_type
124
+ self.model = faster_whisper.WhisperModel(
125
+ device=self.device,
126
+ model_size_or_path=self.current_model_size,
127
+ download_root=self.model_dir,
128
+ compute_type=self.current_compute_type
129
+ )
130
+
131
+ def get_model_paths(self):
132
+ """
133
+ Get available models from models path including fine-tuned model.
134
+
135
+ Returns
136
+ ----------
137
+ Name list of models
138
+ """
139
+ model_paths = {model:model for model in whisper.available_models()}
140
+ faster_whisper_prefix = "models--Systran--faster-whisper-"
141
+
142
+ existing_models = os.listdir(self.model_dir)
143
+ wrong_dirs = [".locks"]
144
+ existing_models = list(set(existing_models) - set(wrong_dirs))
145
+
146
+ webui_dir = os.getcwd()
147
+
148
+ for model_name in existing_models:
149
+ if faster_whisper_prefix in model_name:
150
+ model_name = model_name[len(faster_whisper_prefix):]
151
+
152
+ if model_name not in whisper.available_models():
153
+ model_paths[model_name] = os.path.join(webui_dir, self.model_dir, model_name)
154
+ return model_paths
modules/nllb_inference.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
2
+ import gradio as gr
3
+ import os
4
+
5
+ from modules.translation_base import TranslationBase
6
+
7
+
8
+ class NLLBInference(TranslationBase):
9
+ def __init__(self):
10
+ super().__init__(
11
+ model_dir=os.path.join("models", "NLLB")
12
+ )
13
+ self.tokenizer = None
14
+ self.available_models = ["facebook/nllb-200-3.3B", "facebook/nllb-200-1.3B", "facebook/nllb-200-distilled-600M"]
15
+ self.available_source_langs = list(NLLB_AVAILABLE_LANGS.keys())
16
+ self.available_target_langs = list(NLLB_AVAILABLE_LANGS.keys())
17
+ self.pipeline = None
18
+
19
+ def translate(self,
20
+ text: str
21
+ ):
22
+ result = self.pipeline(text)
23
+ return result[0]['translation_text']
24
+
25
+ def update_model(self,
26
+ model_size: str,
27
+ src_lang: str,
28
+ tgt_lang: str,
29
+ progress: gr.Progress
30
+ ):
31
+ if model_size != self.current_model_size or self.model is None:
32
+ print("\nInitializing NLLB Model..\n")
33
+ progress(0, desc="Initializing NLLB Model..")
34
+ self.current_model_size = model_size
35
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_size,
36
+ cache_dir=self.model_dir)
37
+ self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
38
+ cache_dir=os.path.join(self.model_dir, "tokenizers"))
39
+ src_lang = NLLB_AVAILABLE_LANGS[src_lang]
40
+ tgt_lang = NLLB_AVAILABLE_LANGS[tgt_lang]
41
+ self.pipeline = pipeline("translation",
42
+ model=self.model,
43
+ tokenizer=self.tokenizer,
44
+ src_lang=src_lang,
45
+ tgt_lang=tgt_lang,
46
+ device=self.device)
47
+
48
+ NLLB_AVAILABLE_LANGS = {
49
+ "Acehnese (Arabic script)": "ace_Arab",
50
+ "Acehnese (Latin script)": "ace_Latn",
51
+ "Mesopotamian Arabic": "acm_Arab",
52
+ "Ta’izzi-Adeni Arabic": "acq_Arab",
53
+ "Tunisian Arabic": "aeb_Arab",
54
+ "Afrikaans": "afr_Latn",
55
+ "South Levantine Arabic": "ajp_Arab",
56
+ "Akan": "aka_Latn",
57
+ "Amharic": "amh_Ethi",
58
+ "North Levantine Arabic": "apc_Arab",
59
+ "Modern Standard Arabic": "arb_Arab",
60
+ "Modern Standard Arabic (Romanized)": "arb_Latn",
61
+ "Najdi Arabic": "ars_Arab",
62
+ "Moroccan Arabic": "ary_Arab",
63
+ "Egyptian Arabic": "arz_Arab",
64
+ "Assamese": "asm_Beng",
65
+ "Asturian": "ast_Latn",
66
+ "Awadhi": "awa_Deva",
67
+ "Central Aymara": "ayr_Latn",
68
+ "South Azerbaijani": "azb_Arab",
69
+ "North Azerbaijani": "azj_Latn",
70
+ "Bashkir": "bak_Cyrl",
71
+ "Bambara": "bam_Latn",
72
+ "Balinese": "ban_Latn",
73
+ "Belarusian": "bel_Cyrl",
74
+ "Bemba": "bem_Latn",
75
+ "Bengali": "ben_Beng",
76
+ "Bhojpuri": "bho_Deva",
77
+ "Banjar (Arabic script)": "bjn_Arab",
78
+ "Banjar (Latin script)": "bjn_Latn",
79
+ "Standard Tibetan": "bod_Tibt",
80
+ "Bosnian": "bos_Latn",
81
+ "Buginese": "bug_Latn",
82
+ "Bulgarian": "bul_Cyrl",
83
+ "Catalan": "cat_Latn",
84
+ "Cebuano": "ceb_Latn",
85
+ "Czech": "ces_Latn",
86
+ "Chokwe": "cjk_Latn",
87
+ "Central Kurdish": "ckb_Arab",
88
+ "Crimean Tatar": "crh_Latn",
89
+ "Welsh": "cym_Latn",
90
+ "Danish": "dan_Latn",
91
+ "German": "deu_Latn",
92
+ "Southwestern Dinka": "dik_Latn",
93
+ "Dyula": "dyu_Latn",
94
+ "Dzongkha": "dzo_Tibt",
95
+ "Greek": "ell_Grek",
96
+ "English": "eng_Latn",
97
+ "Esperanto": "epo_Latn",
98
+ "Estonian": "est_Latn",
99
+ "Basque": "eus_Latn",
100
+ "Ewe": "ewe_Latn",
101
+ "Faroese": "fao_Latn",
102
+ "Fijian": "fij_Latn",
103
+ "Finnish": "fin_Latn",
104
+ "Fon": "fon_Latn",
105
+ "French": "fra_Latn",
106
+ "Friulian": "fur_Latn",
107
+ "Nigerian Fulfulde": "fuv_Latn",
108
+ "Scottish Gaelic": "gla_Latn",
109
+ "Irish": "gle_Latn",
110
+ "Galician": "glg_Latn",
111
+ "Guarani": "grn_Latn",
112
+ "Gujarati": "guj_Gujr",
113
+ "Haitian Creole": "hat_Latn",
114
+ "Hausa": "hau_Latn",
115
+ "Hebrew": "heb_Hebr",
116
+ "Hindi": "hin_Deva",
117
+ "Chhattisgarhi": "hne_Deva",
118
+ "Croatian": "hrv_Latn",
119
+ "Hungarian": "hun_Latn",
120
+ "Armenian": "hye_Armn",
121
+ "Igbo": "ibo_Latn",
122
+ "Ilocano": "ilo_Latn",
123
+ "Indonesian": "ind_Latn",
124
+ "Icelandic": "isl_Latn",
125
+ "Italian": "ita_Latn",
126
+ "Javanese": "jav_Latn",
127
+ "Japanese": "jpn_Jpan",
128
+ "Kabyle": "kab_Latn",
129
+ "Jingpho": "kac_Latn",
130
+ "Kamba": "kam_Latn",
131
+ "Kannada": "kan_Knda",
132
+ "Kashmiri (Arabic script)": "kas_Arab",
133
+ "Kashmiri (Devanagari script)": "kas_Deva",
134
+ "Georgian": "kat_Geor",
135
+ "Central Kanuri (Arabic script)": "knc_Arab",
136
+ "Central Kanuri (Latin script)": "knc_Latn",
137
+ "Kazakh": "kaz_Cyrl",
138
+ "Kabiyè": "kbp_Latn",
139
+ "Kabuverdianu": "kea_Latn",
140
+ "Khmer": "khm_Khmr",
141
+ "Kikuyu": "kik_Latn",
142
+ "Kinyarwanda": "kin_Latn",
143
+ "Kyrgyz": "kir_Cyrl",
144
+ "Kimbundu": "kmb_Latn",
145
+ "Northern Kurdish": "kmr_Latn",
146
+ "Kikongo": "kon_Latn",
147
+ "Korean": "kor_Hang",
148
+ "Lao": "lao_Laoo",
149
+ "Ligurian": "lij_Latn",
150
+ "Limburgish": "lim_Latn",
151
+ "Lingala": "lin_Latn",
152
+ "Lithuanian": "lit_Latn",
153
+ "Lombard": "lmo_Latn",
154
+ "Latgalian": "ltg_Latn",
155
+ "Luxembourgish": "ltz_Latn",
156
+ "Luba-Kasai": "lua_Latn",
157
+ "Ganda": "lug_Latn",
158
+ "Luo": "luo_Latn",
159
+ "Mizo": "lus_Latn",
160
+ "Standard Latvian": "lvs_Latn",
161
+ "Magahi": "mag_Deva",
162
+ "Maithili": "mai_Deva",
163
+ "Malayalam": "mal_Mlym",
164
+ "Marathi": "mar_Deva",
165
+ "Minangkabau (Arabic script)": "min_Arab",
166
+ "Minangkabau (Latin script)": "min_Latn",
167
+ "Macedonian": "mkd_Cyrl",
168
+ "Plateau Malagasy": "plt_Latn",
169
+ "Maltese": "mlt_Latn",
170
+ "Meitei (Bengali script)": "mni_Beng",
171
+ "Halh Mongolian": "khk_Cyrl",
172
+ "Mossi": "mos_Latn",
173
+ "Maori": "mri_Latn",
174
+ "Burmese": "mya_Mymr",
175
+ "Dutch": "nld_Latn",
176
+ "Norwegian Nynorsk": "nno_Latn",
177
+ "Norwegian Bokmål": "nob_Latn",
178
+ "Nepali": "npi_Deva",
179
+ "Northern Sotho": "nso_Latn",
180
+ "Nuer": "nus_Latn",
181
+ "Nyanja": "nya_Latn",
182
+ "Occitan": "oci_Latn",
183
+ "West Central Oromo": "gaz_Latn",
184
+ "Odia": "ory_Orya",
185
+ "Pangasinan": "pag_Latn",
186
+ "Eastern Panjabi": "pan_Guru",
187
+ "Papiamento": "pap_Latn",
188
+ "Western Persian": "pes_Arab",
189
+ "Polish": "pol_Latn",
190
+ "Portuguese": "por_Latn",
191
+ "Dari": "prs_Arab",
192
+ "Southern Pashto": "pbt_Arab",
193
+ "Ayacucho Quechua": "quy_Latn",
194
+ "Romanian": "ron_Latn",
195
+ "Rundi": "run_Latn",
196
+ "Russian": "rus_Cyrl",
197
+ "Sango": "sag_Latn",
198
+ "Sanskrit": "san_Deva",
199
+ "Santali": "sat_Olck",
200
+ "Sicilian": "scn_Latn",
201
+ "Shan": "shn_Mymr",
202
+ "Sinhala": "sin_Sinh",
203
+ "Slovak": "slk_Latn",
204
+ "Slovenian": "slv_Latn",
205
+ "Samoan": "smo_Latn",
206
+ "Shona": "sna_Latn",
207
+ "Sindhi": "snd_Arab",
208
+ "Somali": "som_Latn",
209
+ "Southern Sotho": "sot_Latn",
210
+ "Spanish": "spa_Latn",
211
+ "Tosk Albanian": "als_Latn",
212
+ "Sardinian": "srd_Latn",
213
+ "Serbian": "srp_Cyrl",
214
+ "Swati": "ssw_Latn",
215
+ "Sundanese": "sun_Latn",
216
+ "Swedish": "swe_Latn",
217
+ "Swahili": "swh_Latn",
218
+ "Silesian": "szl_Latn",
219
+ "Tamil": "tam_Taml",
220
+ "Tatar": "tat_Cyrl",
221
+ "Telugu": "tel_Telu",
222
+ "Tajik": "tgk_Cyrl",
223
+ "Tagalog": "tgl_Latn",
224
+ "Thai": "tha_Thai",
225
+ "Tigrinya": "tir_Ethi",
226
+ "Tamasheq (Latin script)": "taq_Latn",
227
+ "Tamasheq (Tifinagh script)": "taq_Tfng",
228
+ "Tok Pisin": "tpi_Latn",
229
+ "Tswana": "tsn_Latn",
230
+ "Tsonga": "tso_Latn",
231
+ "Turkmen": "tuk_Latn",
232
+ "Tumbuka": "tum_Latn",
233
+ "Turkish": "tur_Latn",
234
+ "Twi": "twi_Latn",
235
+ "Central Atlas Tamazight": "tzm_Tfng",
236
+ "Uyghur": "uig_Arab",
237
+ "Ukrainian": "ukr_Cyrl",
238
+ "Umbundu": "umb_Latn",
239
+ "Urdu": "urd_Arab",
240
+ "Northern Uzbek": "uzn_Latn",
241
+ "Venetian": "vec_Latn",
242
+ "Vietnamese": "vie_Latn",
243
+ "Waray": "war_Latn",
244
+ "Wolof": "wol_Latn",
245
+ "Xhosa": "xho_Latn",
246
+ "Eastern Yiddish": "ydd_Hebr",
247
+ "Yoruba": "yor_Latn",
248
+ "Yue Chinese": "yue_Hant",
249
+ "Chinese (Simplified)": "zho_Hans",
250
+ "Chinese (Traditional)": "zho_Hant",
251
+ "Standard Malay": "zsm_Latn",
252
+ "Zulu": "zul_Latn",
253
+ }
modules/subtitle_manager.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ def timeformat_srt(time):
5
+ hours = time // 3600
6
+ minutes = (time - hours * 3600) // 60
7
+ seconds = time - hours * 3600 - minutes * 60
8
+ milliseconds = (time - int(time)) * 1000
9
+ return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}"
10
+
11
+
12
+ def timeformat_vtt(time):
13
+ hours = time // 3600
14
+ minutes = (time - hours * 3600) // 60
15
+ seconds = time - hours * 3600 - minutes * 60
16
+ milliseconds = (time - int(time)) * 1000
17
+ return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}.{int(milliseconds):03d}"
18
+
19
+
20
+ def write_file(subtitle, output_file):
21
+ with open(output_file, 'w', encoding='utf-8') as f:
22
+ f.write(subtitle)
23
+
24
+
25
+ def get_srt(segments):
26
+ output = ""
27
+ for i, segment in enumerate(segments):
28
+ output += f"{i + 1}\n"
29
+ output += f"{timeformat_srt(segment['start'])} --> {timeformat_srt(segment['end'])}\n"
30
+ if segment['text'].startswith(' '):
31
+ segment['text'] = segment['text'][1:]
32
+ output += f"{segment['text']}\n\n"
33
+ return output
34
+
35
+
36
+ def get_vtt(segments):
37
+ output = "WebVTT\n\n"
38
+ for i, segment in enumerate(segments):
39
+ output += f"{i + 1}\n"
40
+ output += f"{timeformat_vtt(segment['start'])} --> {timeformat_vtt(segment['end'])}\n"
41
+ if segment['text'].startswith(' '):
42
+ segment['text'] = segment['text'][1:]
43
+ output += f"{segment['text']}\n\n"
44
+ return output
45
+
46
+
47
+ def get_txt(segments):
48
+ output = ""
49
+ for i, segment in enumerate(segments):
50
+ if segment['text'].startswith(' '):
51
+ segment['text'] = segment['text'][1:]
52
+ output += f"{segment['text']}\n"
53
+ return output
54
+
55
+
56
+ def parse_srt(file_path):
57
+ """Reads SRT file and returns as dict"""
58
+ with open(file_path, 'r', encoding='utf-8') as file:
59
+ srt_data = file.read()
60
+
61
+ data = []
62
+ blocks = srt_data.split('\n\n')
63
+
64
+ for block in blocks:
65
+ if block.strip() != '':
66
+ lines = block.strip().split('\n')
67
+ index = lines[0]
68
+ timestamp = lines[1]
69
+ sentence = ' '.join(lines[2:])
70
+
71
+ data.append({
72
+ "index": index,
73
+ "timestamp": timestamp,
74
+ "sentence": sentence
75
+ })
76
+ return data
77
+
78
+
79
+ def parse_vtt(file_path):
80
+ """Reads WebVTT file and returns as dict"""
81
+ with open(file_path, 'r', encoding='utf-8') as file:
82
+ webvtt_data = file.read()
83
+
84
+ data = []
85
+ blocks = webvtt_data.split('\n\n')
86
+
87
+ for block in blocks:
88
+ if block.strip() != '' and not block.strip().startswith("WebVTT"):
89
+ lines = block.strip().split('\n')
90
+ index = lines[0]
91
+ timestamp = lines[1]
92
+ sentence = ' '.join(lines[2:])
93
+
94
+ data.append({
95
+ "index": index,
96
+ "timestamp": timestamp,
97
+ "sentence": sentence
98
+ })
99
+
100
+ return data
101
+
102
+
103
+ def get_serialized_srt(dicts):
104
+ output = ""
105
+ for dic in dicts:
106
+ output += f'{dic["index"]}\n'
107
+ output += f'{dic["timestamp"]}\n'
108
+ output += f'{dic["sentence"]}\n\n'
109
+ return output
110
+
111
+
112
+ def get_serialized_vtt(dicts):
113
+ output = "WebVTT\n\n"
114
+ for dic in dicts:
115
+ output += f'{dic["index"]}\n'
116
+ output += f'{dic["timestamp"]}\n'
117
+ output += f'{dic["sentence"]}\n\n'
118
+ return output
119
+
120
+
121
+ def safe_filename(name):
122
+ from app import _args
123
+ INVALID_FILENAME_CHARS = r'[<>:"/\\|?*\x00-\x1f]'
124
+ safe_name = re.sub(INVALID_FILENAME_CHARS, '_', name)
125
+ if not _args.colab:
126
+ return safe_name
127
+ # Truncate the filename if it exceeds the max_length (20)
128
+ if len(safe_name) > 20:
129
+ file_extension = safe_name.split('.')[-1]
130
+ if len(file_extension) + 1 < 20:
131
+ truncated_name = safe_name[:20 - len(file_extension) - 1]
132
+ safe_name = truncated_name + '.' + file_extension
133
+ else:
134
+ safe_name = safe_name[:20]
135
+ return safe_name
modules/translation_base.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ from abc import ABC, abstractmethod
5
+ from typing import List
6
+ from datetime import datetime
7
+
8
+ from modules.whisper_parameter import *
9
+ from modules.subtitle_manager import *
10
+
11
+
12
+ class TranslationBase(ABC):
13
+ def __init__(self,
14
+ model_dir: str):
15
+ super().__init__()
16
+ self.model = None
17
+ self.model_dir = model_dir
18
+ os.makedirs(self.model_dir, exist_ok=True)
19
+ self.current_model_size = None
20
+ self.device = self.get_device()
21
+
22
+ @abstractmethod
23
+ def translate(self,
24
+ text: str
25
+ ):
26
+ pass
27
+
28
+ @abstractmethod
29
+ def update_model(self,
30
+ model_size: str,
31
+ src_lang: str,
32
+ tgt_lang: str,
33
+ progress: gr.Progress
34
+ ):
35
+ pass
36
+
37
+ def translate_file(self,
38
+ fileobjs: list,
39
+ model_size: str,
40
+ src_lang: str,
41
+ tgt_lang: str,
42
+ add_timestamp: bool,
43
+ progress=gr.Progress()) -> list:
44
+ """
45
+ Translate subtitle file from source language to target language
46
+
47
+ Parameters
48
+ ----------
49
+ fileobjs: list
50
+ List of files to transcribe from gr.Files()
51
+ model_size: str
52
+ Whisper model size from gr.Dropdown()
53
+ src_lang: str
54
+ Source language of the file to translate from gr.Dropdown()
55
+ tgt_lang: str
56
+ Target language of the file to translate from gr.Dropdown()
57
+ add_timestamp: bool
58
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
59
+ progress: gr.Progress
60
+ Indicator to show progress directly in gradio.
61
+ I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
62
+
63
+ Returns
64
+ ----------
65
+ A List of
66
+ String to return to gr.Textbox()
67
+ Files to return to gr.Files()
68
+ """
69
+ try:
70
+ self.update_model(model_size=model_size,
71
+ src_lang=src_lang,
72
+ tgt_lang=tgt_lang,
73
+ progress=progress)
74
+
75
+ files_info = {}
76
+ for fileobj in fileobjs:
77
+ file_path = fileobj.name
78
+ file_name, file_ext = os.path.splitext(os.path.basename(fileobj.name))
79
+ if file_ext == ".srt":
80
+ parsed_dicts = parse_srt(file_path=file_path)
81
+ total_progress = len(parsed_dicts)
82
+ for index, dic in enumerate(parsed_dicts):
83
+ progress(index / total_progress, desc="Translating..")
84
+ translated_text = self.translate(dic["sentence"])
85
+ dic["sentence"] = translated_text
86
+ subtitle = get_serialized_srt(parsed_dicts)
87
+
88
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
89
+ if add_timestamp:
90
+ output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}")
91
+ else:
92
+ output_path = os.path.join("outputs", "translations", f"{file_name}.srt")
93
+
94
+ elif file_ext == ".vtt":
95
+ parsed_dicts = parse_vtt(file_path=file_path)
96
+ total_progress = len(parsed_dicts)
97
+ for index, dic in enumerate(parsed_dicts):
98
+ progress(index / total_progress, desc="Translating..")
99
+ translated_text = self.translate(dic["sentence"])
100
+ dic["sentence"] = translated_text
101
+ subtitle = get_serialized_vtt(parsed_dicts)
102
+
103
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
104
+ if add_timestamp:
105
+ output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}")
106
+ else:
107
+ output_path = os.path.join("outputs", "translations", f"{file_name}.vtt")
108
+
109
+ write_file(subtitle, output_path)
110
+ files_info[file_name] = subtitle
111
+
112
+ total_result = ''
113
+ for file_name, subtitle in files_info.items():
114
+ total_result += '------------------------------------\n'
115
+ total_result += f'{file_name}\n\n'
116
+ total_result += f'{subtitle}'
117
+
118
+ gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
119
+ return [gr_str, output_path]
120
+ except Exception as e:
121
+ print(f"Error: {str(e)}")
122
+ finally:
123
+ self.release_cuda_memory()
124
+ self.remove_input_files([fileobj.name for fileobj in fileobjs])
125
+
126
+ @staticmethod
127
+ def get_device():
128
+ if torch.cuda.is_available():
129
+ return "cuda"
130
+ elif torch.backends.mps.is_available():
131
+ return "mps"
132
+ else:
133
+ return "cpu"
134
+
135
+ @staticmethod
136
+ def release_cuda_memory():
137
+ if torch.cuda.is_available():
138
+ torch.cuda.empty_cache()
139
+ torch.cuda.reset_max_memory_allocated()
140
+
141
+ @staticmethod
142
+ def remove_input_files(file_paths: List[str]):
143
+ if not file_paths:
144
+ return
145
+
146
+ for file_path in file_paths:
147
+ if file_path and os.path.exists(file_path):
148
+ os.remove(file_path)
modules/whisper_Inference.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import whisper
2
+ import gradio as gr
3
+ import time
4
+ import os
5
+ from typing import BinaryIO, Union, Tuple, List
6
+ import numpy as np
7
+ import torch
8
+
9
+ from modules.whisper_base import WhisperBase
10
+ from modules.whisper_parameter import *
11
+
12
+
13
+ class WhisperInference(WhisperBase):
14
+ def __init__(self):
15
+ super().__init__(
16
+ model_dir=os.path.join("models", "Whisper")
17
+ )
18
+
19
+ def transcribe(self,
20
+ audio: Union[str, np.ndarray, torch.Tensor],
21
+ progress: gr.Progress,
22
+ *whisper_params,
23
+ ) -> Tuple[List[dict], float]:
24
+ """
25
+ transcribe method for faster-whisper.
26
+
27
+ Parameters
28
+ ----------
29
+ audio: Union[str, BinaryIO, np.ndarray]
30
+ Audio path or file binary or Audio numpy array
31
+ progress: gr.Progress
32
+ Indicator to show progress directly in gradio.
33
+ *whisper_params: tuple
34
+ Gradio components related to Whisper. see whisper_data_class.py for details.
35
+
36
+ Returns
37
+ ----------
38
+ segments_result: List[dict]
39
+ list of dicts that includes start, end timestamps and transcribed text
40
+ elapsed_time: float
41
+ elapsed time for transcription
42
+ """
43
+ start_time = time.time()
44
+ params = WhisperValues(*whisper_params)
45
+
46
+ if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
47
+ self.update_model(params.model_size, params.compute_type, progress)
48
+
49
+ if params.lang == "Automatic Detection":
50
+ params.lang = None
51
+
52
+ def progress_callback(progress_value):
53
+ progress(progress_value, desc="Transcribing..")
54
+
55
+ segments_result = self.model.transcribe(audio=audio,
56
+ language=params.lang,
57
+ verbose=False,
58
+ beam_size=params.beam_size,
59
+ logprob_threshold=params.log_prob_threshold,
60
+ no_speech_threshold=params.no_speech_threshold,
61
+ task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
62
+ fp16=True if params.compute_type == "float16" else False,
63
+ best_of=params.best_of,
64
+ patience=params.patience,
65
+ temperature=params.temperature,
66
+ compression_ratio_threshold=params.compression_ratio_threshold,
67
+ progress_callback=progress_callback,)["segments"]
68
+ elapsed_time = time.time() - start_time
69
+
70
+ return segments_result, elapsed_time
71
+
72
+ def update_model(self,
73
+ model_size: str,
74
+ compute_type: str,
75
+ progress: gr.Progress,
76
+ ):
77
+ """
78
+ Update current model setting
79
+
80
+ Parameters
81
+ ----------
82
+ model_size: str
83
+ Size of whisper model
84
+ compute_type: str
85
+ Compute type for transcription.
86
+ see more info : https://opennmt.net/CTranslate2/quantization.html
87
+ progress: gr.Progress
88
+ Indicator to show progress directly in gradio.
89
+ """
90
+ progress(0, desc="Initializing Model..")
91
+ self.current_compute_type = compute_type
92
+ self.current_model_size = model_size
93
+ self.model = whisper.load_model(
94
+ name=model_size,
95
+ device=self.device,
96
+ download_root=self.model_dir
97
+ )
modules/whisper_base.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from typing import List
4
+ import whisper
5
+ import gradio as gr
6
+ from abc import ABC, abstractmethod
7
+ from typing import BinaryIO, Union, Tuple, List
8
+ import numpy as np
9
+ from datetime import datetime
10
+
11
+ from modules.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
12
+ from modules.youtube_manager import get_ytdata, get_ytaudio
13
+ from modules.whisper_parameter import *
14
+
15
+
16
+ class WhisperBase(ABC):
17
+ def __init__(self,
18
+ model_dir: str):
19
+ self.model = None
20
+ self.current_model_size = None
21
+ self.model_dir = model_dir
22
+ os.makedirs(self.model_dir, exist_ok=True)
23
+ self.available_models = whisper.available_models()
24
+ self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
25
+ self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
26
+ self.device = self.get_device()
27
+ self.available_compute_types = ["float16", "float32"]
28
+ self.current_compute_type = "float16" if self.device == "cuda" else "float32"
29
+
30
+ @abstractmethod
31
+ def transcribe(self,
32
+ audio: Union[str, BinaryIO, np.ndarray],
33
+ progress: gr.Progress,
34
+ *whisper_params,
35
+ ):
36
+ pass
37
+
38
+ @abstractmethod
39
+ def update_model(self,
40
+ model_size: str,
41
+ compute_type: str,
42
+ progress: gr.Progress
43
+ ):
44
+ pass
45
+
46
+ def transcribe_file(self,
47
+ files: list,
48
+ file_format: str,
49
+ add_timestamp: bool,
50
+ progress=gr.Progress(),
51
+ *whisper_params,
52
+ ) -> list:
53
+ """
54
+ Write subtitle file from Files
55
+
56
+ Parameters
57
+ ----------
58
+ files: list
59
+ List of files to transcribe from gr.Files()
60
+ file_format: str
61
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
62
+ add_timestamp: bool
63
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
64
+ progress: gr.Progress
65
+ Indicator to show progress directly in gradio.
66
+ *whisper_params: tuple
67
+ Gradio components related to Whisper. see whisper_data_class.py for details.
68
+
69
+ Returns
70
+ ----------
71
+ result_str:
72
+ Result of transcription to return to gr.Textbox()
73
+ result_file_path:
74
+ Output file path to return to gr.Files()
75
+ """
76
+ try:
77
+ files_info = {}
78
+ for file in files:
79
+ transcribed_segments, time_for_task = self.transcribe(
80
+ file.name,
81
+ progress,
82
+ *whisper_params,
83
+ )
84
+
85
+ file_name, file_ext = os.path.splitext(os.path.basename(file.name))
86
+ file_name = safe_filename(file_name)
87
+ subtitle, file_path = self.generate_and_write_file(
88
+ file_name=file_name,
89
+ transcribed_segments=transcribed_segments,
90
+ add_timestamp=add_timestamp,
91
+ file_format=file_format
92
+ )
93
+ files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
94
+
95
+ total_result = ''
96
+ total_time = 0
97
+ for file_name, info in files_info.items():
98
+ total_result += '------------------------------------\n'
99
+ total_result += f'{file_name}\n\n'
100
+ total_result += f'{info["subtitle"]}'
101
+ total_time += info["time_for_task"]
102
+
103
+ result_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
104
+ result_file_path = [info['path'] for info in files_info.values()]
105
+
106
+ return [result_str, result_file_path]
107
+
108
+ except Exception as e:
109
+ print(f"Error transcribing file: {e}")
110
+ finally:
111
+ self.release_cuda_memory()
112
+ if not files:
113
+ self.remove_input_files([file.name for file in files])
114
+
115
+ def transcribe_mic(self,
116
+ mic_audio: str,
117
+ file_format: str,
118
+ progress=gr.Progress(),
119
+ *whisper_params,
120
+ ) -> list:
121
+ """
122
+ Write subtitle file from microphone
123
+
124
+ Parameters
125
+ ----------
126
+ mic_audio: str
127
+ Audio file path from gr.Microphone()
128
+ file_format: str
129
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
130
+ progress: gr.Progress
131
+ Indicator to show progress directly in gradio.
132
+ *whisper_params: tuple
133
+ Gradio components related to Whisper. see whisper_data_class.py for details.
134
+
135
+ Returns
136
+ ----------
137
+ result_str:
138
+ Result of transcription to return to gr.Textbox()
139
+ result_file_path:
140
+ Output file path to return to gr.Files()
141
+ """
142
+ try:
143
+ progress(0, desc="Loading Audio..")
144
+ transcribed_segments, time_for_task = self.transcribe(
145
+ mic_audio,
146
+ progress,
147
+ *whisper_params,
148
+ )
149
+ progress(1, desc="Completed!")
150
+
151
+ subtitle, result_file_path = self.generate_and_write_file(
152
+ file_name="Mic",
153
+ transcribed_segments=transcribed_segments,
154
+ add_timestamp=True,
155
+ file_format=file_format
156
+ )
157
+
158
+ result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
159
+ return [result_str, result_file_path]
160
+ except Exception as e:
161
+ print(f"Error transcribing file: {e}")
162
+ finally:
163
+ self.release_cuda_memory()
164
+ self.remove_input_files([mic_audio])
165
+
166
+ def transcribe_youtube(self,
167
+ youtube_link: str,
168
+ file_format: str,
169
+ add_timestamp: bool,
170
+ progress=gr.Progress(),
171
+ *whisper_params,
172
+ ) -> list:
173
+ """
174
+ Write subtitle file from Youtube
175
+
176
+ Parameters
177
+ ----------
178
+ youtube_link: str
179
+ URL of the Youtube video to transcribe from gr.Textbox()
180
+ file_format: str
181
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
182
+ add_timestamp: bool
183
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
184
+ progress: gr.Progress
185
+ Indicator to show progress directly in gradio.
186
+ *whisper_params: tuple
187
+ Gradio components related to Whisper. see whisper_data_class.py for details.
188
+
189
+ Returns
190
+ ----------
191
+ result_str:
192
+ Result of transcription to return to gr.Textbox()
193
+ result_file_path:
194
+ Output file path to return to gr.Files()
195
+ """
196
+ try:
197
+ progress(0, desc="Loading Audio from Youtube..")
198
+ yt = get_ytdata(youtube_link)
199
+ audio = get_ytaudio(yt)
200
+
201
+ transcribed_segments, time_for_task = self.transcribe(
202
+ audio,
203
+ progress,
204
+ *whisper_params,
205
+ )
206
+
207
+ progress(1, desc="Completed!")
208
+
209
+ file_name = safe_filename(yt.title)
210
+ subtitle, result_file_path = self.generate_and_write_file(
211
+ file_name=file_name,
212
+ transcribed_segments=transcribed_segments,
213
+ add_timestamp=add_timestamp,
214
+ file_format=file_format
215
+ )
216
+ result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
217
+
218
+ return [result_str, result_file_path]
219
+
220
+ except Exception as e:
221
+ print(f"Error transcribing file: {e}")
222
+ finally:
223
+ try:
224
+ if 'yt' not in locals():
225
+ yt = get_ytdata(youtube_link)
226
+ file_path = get_ytaudio(yt)
227
+ else:
228
+ file_path = get_ytaudio(yt)
229
+
230
+ self.release_cuda_memory()
231
+ self.remove_input_files([file_path])
232
+ except Exception as cleanup_error:
233
+ pass
234
+
235
+ @staticmethod
236
+ def generate_and_write_file(file_name: str,
237
+ transcribed_segments: list,
238
+ add_timestamp: bool,
239
+ file_format: str,
240
+ ) -> str:
241
+ """
242
+ Writes subtitle file
243
+
244
+ Parameters
245
+ ----------
246
+ file_name: str
247
+ Output file name
248
+ transcribed_segments: list
249
+ Text segments transcribed from audio
250
+ add_timestamp: bool
251
+ Determines whether to add a timestamp to the end of the filename.
252
+ file_format: str
253
+ File format to write. Supported formats: [SRT, WebVTT, txt]
254
+
255
+ Returns
256
+ ----------
257
+ content: str
258
+ Result of the transcription
259
+ output_path: str
260
+ output file path
261
+ """
262
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
263
+ if add_timestamp:
264
+ output_path = os.path.join("outputs", f"{file_name}-{timestamp}")
265
+ else:
266
+ output_path = os.path.join("outputs", f"{file_name}")
267
+
268
+ if file_format == "SRT":
269
+ content = get_srt(transcribed_segments)
270
+ output_path += '.srt'
271
+ write_file(content, output_path)
272
+
273
+ elif file_format == "WebVTT":
274
+ content = get_vtt(transcribed_segments)
275
+ output_path += '.vtt'
276
+ write_file(content, output_path)
277
+
278
+ elif file_format == "txt":
279
+ content = get_txt(transcribed_segments)
280
+ output_path += '.txt'
281
+ write_file(content, output_path)
282
+ return content, output_path
283
+
284
+ @staticmethod
285
+ def format_time(elapsed_time: float) -> str:
286
+ """
287
+ Get {hours} {minutes} {seconds} time format string
288
+
289
+ Parameters
290
+ ----------
291
+ elapsed_time: str
292
+ Elapsed time for transcription
293
+
294
+ Returns
295
+ ----------
296
+ Time format string
297
+ """
298
+ hours, rem = divmod(elapsed_time, 3600)
299
+ minutes, seconds = divmod(rem, 60)
300
+
301
+ time_str = ""
302
+ if hours:
303
+ time_str += f"{hours} hours "
304
+ if minutes:
305
+ time_str += f"{minutes} minutes "
306
+ seconds = round(seconds)
307
+ time_str += f"{seconds} seconds"
308
+
309
+ return time_str.strip()
310
+
311
+ @staticmethod
312
+ def get_device():
313
+ if torch.cuda.is_available():
314
+ return "cuda"
315
+ elif torch.backends.mps.is_available():
316
+ return "mps"
317
+ else:
318
+ return "cpu"
319
+
320
+ @staticmethod
321
+ def release_cuda_memory():
322
+ if torch.cuda.is_available():
323
+ torch.cuda.empty_cache()
324
+ torch.cuda.reset_max_memory_allocated()
325
+
326
+ @staticmethod
327
+ def remove_input_files(file_paths: List[str]):
328
+ if not file_paths:
329
+ return
330
+
331
+ for file_path in file_paths:
332
+ if file_path and os.path.exists(file_path):
333
+ os.remove(file_path)
modules/whisper_parameter.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, fields
2
+ import gradio as gr
3
+ from typing import Optional
4
+
5
+
6
+ @dataclass
7
+ class WhisperGradioComponents:
8
+ model_size: gr.Dropdown
9
+ lang: gr.Dropdown
10
+ is_translate: gr.Checkbox
11
+ beam_size: gr.Number
12
+ log_prob_threshold: gr.Number
13
+ no_speech_threshold: gr.Number
14
+ compute_type: gr.Dropdown
15
+ best_of: gr.Number
16
+ patience: gr.Number
17
+ condition_on_previous_text: gr.Checkbox
18
+ initial_prompt: gr.Textbox
19
+ temperature: gr.Slider
20
+ compression_ratio_threshold: gr.Number
21
+ vad_filter: gr.Checkbox
22
+ threshold: gr.Slider
23
+ min_speech_duration_ms: gr.Number
24
+ max_speech_duration_s: gr.Number
25
+ min_silence_duration_ms: gr.Number
26
+ window_size_sample: gr.Number
27
+ speech_pad_ms: gr.Number
28
+ """
29
+ A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
30
+ See more about Gradio pre-processing: https://www.gradio.app/docs/components
31
+
32
+ Attributes
33
+ ----------
34
+ model_size: gr.Dropdown
35
+ Whisper model size.
36
+
37
+ lang: gr.Dropdown
38
+ Source language of the file to transcribe.
39
+
40
+ is_translate: gr.Checkbox
41
+ Boolean value that determines whether to translate to English.
42
+ It's Whisper's feature to translate speech from another language directly into English end-to-end.
43
+
44
+ beam_size: gr.Number
45
+ Int value that is used for decoding option.
46
+
47
+ log_prob_threshold: gr.Number
48
+ If the average log probability over sampled tokens is below this value, treat as failed.
49
+
50
+ no_speech_threshold: gr.Number
51
+ If the no_speech probability is higher than this value AND
52
+ the average log probability over sampled tokens is below `log_prob_threshold`,
53
+ consider the segment as silent.
54
+
55
+ compute_type: gr.Dropdown
56
+ compute type for transcription.
57
+ see more info : https://opennmt.net/CTranslate2/quantization.html
58
+
59
+ best_of: gr.Number
60
+ Number of candidates when sampling with non-zero temperature.
61
+
62
+ patience: gr.Number
63
+ Beam search patience factor.
64
+
65
+ condition_on_previous_text: gr.Checkbox
66
+ if True, the previous output of the model is provided as a prompt for the next window;
67
+ disabling may make the text inconsistent across windows, but the model becomes less prone to
68
+ getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
69
+
70
+ initial_prompt: gr.Textbox
71
+ Optional text to provide as a prompt for the first window. This can be used to provide, or
72
+ "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
73
+ to make it more likely to predict those word correctly.
74
+
75
+ temperature: gr.Slider
76
+ Temperature for sampling. It can be a tuple of temperatures,
77
+ which will be successively used upon failures according to either
78
+ `compression_ratio_threshold` or `log_prob_threshold`.
79
+
80
+ compression_ratio_threshold: gr.Number
81
+ If the gzip compression ratio is above this value, treat as failed
82
+
83
+ vad_filter: gr.Checkbox
84
+ Enable the voice activity detection (VAD) to filter out parts of the audio
85
+ without speech. This step is using the Silero VAD model
86
+ https://github.com/snakers4/silero-vad.
87
+
88
+ threshold: gr.Slider
89
+ This parameter is related with Silero VAD. Speech threshold.
90
+ Silero VAD outputs speech probabilities for each audio chunk,
91
+ probabilities ABOVE this value are considered as SPEECH. It is better to tune this
92
+ parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
93
+
94
+ min_speech_duration_ms: gr.Number
95
+ This parameter is related with Silero VAD. Final speech chunks shorter min_speech_duration_ms are thrown out.
96
+
97
+ max_speech_duration_s: gr.Number
98
+ This parameter is related with Silero VAD. Maximum duration of speech chunks in seconds. Chunks longer
99
+ than max_speech_duration_s will be split at the timestamp of the last silence that
100
+ lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
101
+ split aggressively just before max_speech_duration_s.
102
+
103
+ min_silence_duration_ms: gr.Number
104
+ This parameter is related with Silero VAD. In the end of each speech chunk wait for min_silence_duration_ms
105
+ before separating it
106
+
107
+ window_size_samples: gr.Number
108
+ This parameter is related with Silero VAD. Audio chunks of window_size_samples size are fed to the silero VAD model.
109
+ WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
110
+ Values other than these may affect model performance!!
111
+
112
+ speech_pad_ms: gr.Number
113
+ This parameter is related with Silero VAD. Final speech chunks are padded by speech_pad_ms each side
114
+ """
115
+
116
+ def to_list(self) -> list:
117
+ """
118
+ Converts the data class attributes into a list. Use "before" Gradio pre-processing.
119
+ See more about Gradio pre-processing: : https://www.gradio.app/docs/components
120
+
121
+ Returns
122
+ ----------
123
+ A list of Gradio components
124
+ """
125
+ return [getattr(self, f.name) for f in fields(self)]
126
+
127
+
128
+ @dataclass
129
+ class WhisperValues:
130
+ model_size: str
131
+ lang: str
132
+ is_translate: bool
133
+ beam_size: int
134
+ log_prob_threshold: float
135
+ no_speech_threshold: float
136
+ compute_type: str
137
+ best_of: int
138
+ patience: float
139
+ condition_on_previous_text: bool
140
+ initial_prompt: Optional[str]
141
+ temperature: float
142
+ compression_ratio_threshold: float
143
+ vad_filter: bool
144
+ threshold: float
145
+ min_speech_duration_ms: int
146
+ max_speech_duration_s: float
147
+ min_silence_duration_ms: int
148
+ window_size_samples: int
149
+ speech_pad_ms: int
150
+ """
151
+ A data class to use Whisper parameters. Use "after" Gradio pre-processing.
152
+ See more about Gradio pre-processing: : https://www.gradio.app/docs/components
153
+ """
modules/youtube_manager.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pytube import YouTube
2
+ import os
3
+
4
+
5
+ def get_ytdata(link):
6
+ return YouTube(link)
7
+
8
+
9
+ def get_ytmetas(link):
10
+ yt = YouTube(link)
11
+ return yt.thumbnail_url, yt.title, yt.description
12
+
13
+
14
+ def get_ytaudio(ytdata: YouTube):
15
+ return ytdata.streams.get_audio_only().download(filename=os.path.join("modules", "yt_tmp.wav"))
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu121
2
+ torch
3
+ git+https://github.com/jhj0517/jhj0517-whisper.git
4
+ faster-whisper==1.0.2
5
+ transformers
6
+ gradio==4.29.0
7
+ pytube
ui/__init__.py ADDED
File without changes
ui/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (147 Bytes). View file
 
ui/__pycache__/htmls.cpython-312.pyc ADDED
Binary file (2.01 kB). View file
 
ui/htmls.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CSS = """
2
+ .bmc-button {
3
+ padding: 2px 5px;
4
+ border-radius: 5px;
5
+ background-color: #FF813F;
6
+ color: white;
7
+ box-shadow: 0px 1px 2px rgba(0, 0, 0, 0.3);
8
+ text-decoration: none;
9
+ display: inline-block;
10
+ font-size: 20px;
11
+ margin: 2px;
12
+ cursor: pointer;
13
+ -webkit-transition: background-color 0.3s ease;
14
+ -ms-transition: background-color 0.3s ease;
15
+ transition: background-color 0.3s ease;
16
+ }
17
+ .bmc-button:hover,
18
+ .bmc-button:active,
19
+ .bmc-button:focus {
20
+ background-color: #FF5633;
21
+ }
22
+ .markdown {
23
+ margin-bottom: 0;
24
+ padding-bottom: 0;
25
+ }
26
+ .tabs {
27
+ margin-top: 0;
28
+ padding-top: 0;
29
+ }
30
+
31
+ #md_project a {
32
+ color: black;
33
+ text-decoration: none;
34
+ }
35
+ #md_project a:hover {
36
+ text-decoration: underline;
37
+ }
38
+ """
39
+
40
+ MARKDOWN = """
41
+ ### [Whisper Web-UI](https://github.com/jhj0517/Whsiper-WebUI)
42
+ """
43
+
44
+
45
+ NLLB_VRAM_TABLE = """
46
+ <!DOCTYPE html>
47
+ <html lang="en">
48
+ <head>
49
+ <meta charset="UTF-8">
50
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
51
+ <style>
52
+ table {
53
+ border-collapse: collapse;
54
+ width: 100%;
55
+ }
56
+ th, td {
57
+ border: 1px solid #dddddd;
58
+ text-align: left;
59
+ padding: 8px;
60
+ }
61
+ th {
62
+ background-color: #f2f2f2;
63
+ }
64
+ </style>
65
+ </head>
66
+ <body>
67
+
68
+ <details>
69
+ <summary>VRAM usage for each model</summary>
70
+ <table>
71
+ <thead>
72
+ <tr>
73
+ <th>Model name</th>
74
+ <th>Required VRAM</th>
75
+ </tr>
76
+ </thead>
77
+ <tbody>
78
+ <tr>
79
+ <td>nllb-200-3.3B</td>
80
+ <td>~16GB</td>
81
+ </tr>
82
+ <tr>
83
+ <td>nllb-200-1.3B</td>
84
+ <td>~8GB</td>
85
+ </tr>
86
+ <tr>
87
+ <td>nllb-200-distilled-600M</td>
88
+ <td>~4GB</td>
89
+ </tr>
90
+ </tbody>
91
+ </table>
92
+ <p><strong>Note:</strong> Be mindful of your VRAM! The table above provides an approximate VRAM usage for each model.</p>
93
+ </details>
94
+
95
+ </body>
96
+ </html>
97
+ """