Spaces:
Runtime error
Runtime error
Upload 20 files
Browse files- modules/__init__.py +0 -0
- modules/__pycache__/__init__.cpython-312.pyc +0 -0
- modules/__pycache__/deepl_api.cpython-312.pyc +0 -0
- modules/__pycache__/faster_whisper_inference.cpython-312.pyc +0 -0
- modules/__pycache__/nllb_inference.cpython-312.pyc +0 -0
- modules/__pycache__/subtitle_manager.cpython-312.pyc +0 -0
- modules/__pycache__/translation_base.cpython-312.pyc +0 -0
- modules/__pycache__/whisper_Inference.cpython-312.pyc +0 -0
- modules/__pycache__/whisper_base.cpython-312.pyc +0 -0
- modules/__pycache__/whisper_parameter.cpython-312.pyc +0 -0
- modules/__pycache__/youtube_manager.cpython-312.pyc +0 -0
- modules/deepl_api.py +196 -0
- modules/faster_whisper_inference.py +154 -0
- modules/nllb_inference.py +253 -0
- modules/subtitle_manager.py +135 -0
- modules/translation_base.py +148 -0
- modules/whisper_Inference.py +97 -0
- modules/whisper_base.py +333 -0
- modules/whisper_parameter.py +153 -0
- modules/youtube_manager.py +15 -0
modules/__init__.py
ADDED
File without changes
|
modules/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (152 Bytes). View file
|
|
modules/__pycache__/deepl_api.cpython-312.pyc
ADDED
Binary file (7.51 kB). View file
|
|
modules/__pycache__/faster_whisper_inference.cpython-312.pyc
ADDED
Binary file (7.5 kB). View file
|
|
modules/__pycache__/nllb_inference.cpython-312.pyc
ADDED
Binary file (11.9 kB). View file
|
|
modules/__pycache__/subtitle_manager.cpython-312.pyc
ADDED
Binary file (6.03 kB). View file
|
|
modules/__pycache__/translation_base.cpython-312.pyc
ADDED
Binary file (7.46 kB). View file
|
|
modules/__pycache__/whisper_Inference.cpython-312.pyc
ADDED
Binary file (4.76 kB). View file
|
|
modules/__pycache__/whisper_base.cpython-312.pyc
ADDED
Binary file (14.5 kB). View file
|
|
modules/__pycache__/whisper_parameter.cpython-312.pyc
ADDED
Binary file (2.96 kB). View file
|
|
modules/__pycache__/youtube_manager.cpython-312.pyc
ADDED
Binary file (1.03 kB). View file
|
|
modules/deepl_api.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import time
|
3 |
+
import os
|
4 |
+
from datetime import datetime
|
5 |
+
import gradio as gr
|
6 |
+
|
7 |
+
from modules.subtitle_manager import *
|
8 |
+
|
9 |
+
"""
|
10 |
+
This is written with reference to the DeepL API documentation.
|
11 |
+
If you want to know the information of the DeepL API, see here: https://www.deepl.com/docs-api/documents
|
12 |
+
"""
|
13 |
+
|
14 |
+
DEEPL_AVAILABLE_TARGET_LANGS = {
|
15 |
+
'Bulgarian': 'BG',
|
16 |
+
'Czech': 'CS',
|
17 |
+
'Danish': 'DA',
|
18 |
+
'German': 'DE',
|
19 |
+
'Greek': 'EL',
|
20 |
+
'English': 'EN',
|
21 |
+
'English (British)': 'EN-GB',
|
22 |
+
'English (American)': 'EN-US',
|
23 |
+
'Spanish': 'ES',
|
24 |
+
'Estonian': 'ET',
|
25 |
+
'Finnish': 'FI',
|
26 |
+
'French': 'FR',
|
27 |
+
'Hungarian': 'HU',
|
28 |
+
'Indonesian': 'ID',
|
29 |
+
'Italian': 'IT',
|
30 |
+
'Japanese': 'JA',
|
31 |
+
'Korean': 'KO',
|
32 |
+
'Lithuanian': 'LT',
|
33 |
+
'Latvian': 'LV',
|
34 |
+
'Norwegian (Bokmål)': 'NB',
|
35 |
+
'Dutch': 'NL',
|
36 |
+
'Polish': 'PL',
|
37 |
+
'Portuguese': 'PT',
|
38 |
+
'Portuguese (Brazilian)': 'PT-BR',
|
39 |
+
'Portuguese (all Portuguese varieties excluding Brazilian Portuguese)': 'PT-PT',
|
40 |
+
'Romanian': 'RO',
|
41 |
+
'Russian': 'RU',
|
42 |
+
'Slovak': 'SK',
|
43 |
+
'Slovenian': 'SL',
|
44 |
+
'Swedish': 'SV',
|
45 |
+
'Turkish': 'TR',
|
46 |
+
'Ukrainian': 'UK',
|
47 |
+
'Chinese (simplified)': 'ZH'
|
48 |
+
}
|
49 |
+
|
50 |
+
DEEPL_AVAILABLE_SOURCE_LANGS = {
|
51 |
+
'Automatic Detection': None,
|
52 |
+
'Bulgarian': 'BG',
|
53 |
+
'Czech': 'CS',
|
54 |
+
'Danish': 'DA',
|
55 |
+
'German': 'DE',
|
56 |
+
'Greek': 'EL',
|
57 |
+
'English': 'EN',
|
58 |
+
'Spanish': 'ES',
|
59 |
+
'Estonian': 'ET',
|
60 |
+
'Finnish': 'FI',
|
61 |
+
'French': 'FR',
|
62 |
+
'Hungarian': 'HU',
|
63 |
+
'Indonesian': 'ID',
|
64 |
+
'Italian': 'IT',
|
65 |
+
'Japanese': 'JA',
|
66 |
+
'Korean': 'KO',
|
67 |
+
'Lithuanian': 'LT',
|
68 |
+
'Latvian': 'LV',
|
69 |
+
'Norwegian (Bokmål)': 'NB',
|
70 |
+
'Dutch': 'NL',
|
71 |
+
'Polish': 'PL',
|
72 |
+
'Portuguese (all Portuguese varieties mixed)': 'PT',
|
73 |
+
'Romanian': 'RO',
|
74 |
+
'Russian': 'RU',
|
75 |
+
'Slovak': 'SK',
|
76 |
+
'Slovenian': 'SL',
|
77 |
+
'Swedish': 'SV',
|
78 |
+
'Turkish': 'TR',
|
79 |
+
'Ukrainian': 'UK',
|
80 |
+
'Chinese': 'ZH'
|
81 |
+
}
|
82 |
+
|
83 |
+
|
84 |
+
class DeepLAPI:
|
85 |
+
def __init__(self):
|
86 |
+
self.api_interval = 1
|
87 |
+
self.max_text_batch_size = 50
|
88 |
+
self.available_target_langs = DEEPL_AVAILABLE_TARGET_LANGS
|
89 |
+
self.available_source_langs = DEEPL_AVAILABLE_SOURCE_LANGS
|
90 |
+
|
91 |
+
def translate_deepl(self,
|
92 |
+
auth_key: str,
|
93 |
+
fileobjs: list,
|
94 |
+
source_lang: str,
|
95 |
+
target_lang: str,
|
96 |
+
is_pro: bool,
|
97 |
+
progress=gr.Progress()) -> list:
|
98 |
+
"""
|
99 |
+
Translate subtitle files using DeepL API
|
100 |
+
Parameters
|
101 |
+
----------
|
102 |
+
auth_key: str
|
103 |
+
API Key for DeepL from gr.Textbox()
|
104 |
+
fileobjs: list
|
105 |
+
List of files to transcribe from gr.Files()
|
106 |
+
source_lang: str
|
107 |
+
Source language of the file to transcribe from gr.Dropdown()
|
108 |
+
target_lang: str
|
109 |
+
Target language of the file to transcribe from gr.Dropdown()
|
110 |
+
is_pro: str
|
111 |
+
Boolean value that is about pro user or not from gr.Checkbox().
|
112 |
+
progress: gr.Progress
|
113 |
+
Indicator to show progress directly in gradio.
|
114 |
+
Returns
|
115 |
+
----------
|
116 |
+
A List of
|
117 |
+
String to return to gr.Textbox()
|
118 |
+
Files to return to gr.Files()
|
119 |
+
"""
|
120 |
+
|
121 |
+
files_info = {}
|
122 |
+
for fileobj in fileobjs:
|
123 |
+
file_path = fileobj.name
|
124 |
+
file_name, file_ext = os.path.splitext(os.path.basename(fileobj.name))
|
125 |
+
|
126 |
+
if file_ext == ".srt":
|
127 |
+
parsed_dicts = parse_srt(file_path=file_path)
|
128 |
+
|
129 |
+
batch_size = self.max_text_batch_size
|
130 |
+
for batch_start in range(0, len(parsed_dicts), batch_size):
|
131 |
+
batch_end = min(batch_start + batch_size, len(parsed_dicts))
|
132 |
+
sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
|
133 |
+
translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
|
134 |
+
target_lang, is_pro)
|
135 |
+
for i, translated_text in enumerate(translated_texts):
|
136 |
+
parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
|
137 |
+
progress(batch_end / len(parsed_dicts), desc="Translating..")
|
138 |
+
|
139 |
+
subtitle = get_serialized_srt(parsed_dicts)
|
140 |
+
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
141 |
+
|
142 |
+
file_name = file_name[:-9]
|
143 |
+
output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}.srt")
|
144 |
+
write_file(subtitle, output_path)
|
145 |
+
|
146 |
+
elif file_ext == ".vtt":
|
147 |
+
parsed_dicts = parse_vtt(file_path=file_path)
|
148 |
+
|
149 |
+
batch_size = self.max_text_batch_size
|
150 |
+
for batch_start in range(0, len(parsed_dicts), batch_size):
|
151 |
+
batch_end = min(batch_start + batch_size, len(parsed_dicts))
|
152 |
+
sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
|
153 |
+
translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
|
154 |
+
target_lang, is_pro)
|
155 |
+
for i, translated_text in enumerate(translated_texts):
|
156 |
+
parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
|
157 |
+
progress(batch_end / len(parsed_dicts), desc="Translating..")
|
158 |
+
|
159 |
+
subtitle = get_serialized_vtt(parsed_dicts)
|
160 |
+
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
161 |
+
|
162 |
+
file_name = file_name[:-9]
|
163 |
+
output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}.vtt")
|
164 |
+
|
165 |
+
write_file(subtitle, output_path)
|
166 |
+
|
167 |
+
files_info[file_name] = subtitle
|
168 |
+
total_result = ''
|
169 |
+
for file_name, subtitle in files_info.items():
|
170 |
+
total_result += '------------------------------------\n'
|
171 |
+
total_result += f'{file_name}\n\n'
|
172 |
+
total_result += f'{subtitle}'
|
173 |
+
|
174 |
+
gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
|
175 |
+
return [gr_str, output_path]
|
176 |
+
|
177 |
+
def request_deepl_translate(self,
|
178 |
+
auth_key: str,
|
179 |
+
text: list,
|
180 |
+
source_lang: str,
|
181 |
+
target_lang: str,
|
182 |
+
is_pro: bool):
|
183 |
+
"""Request API response to DeepL server"""
|
184 |
+
|
185 |
+
url = 'https://api.deepl.com/v2/translate' if is_pro else 'https://api-free.deepl.com/v2/translate'
|
186 |
+
headers = {
|
187 |
+
'Authorization': f'DeepL-Auth-Key {auth_key}'
|
188 |
+
}
|
189 |
+
data = {
|
190 |
+
'text': text,
|
191 |
+
'source_lang': DEEPL_AVAILABLE_SOURCE_LANGS[source_lang],
|
192 |
+
'target_lang': DEEPL_AVAILABLE_TARGET_LANGS[target_lang]
|
193 |
+
}
|
194 |
+
response = requests.post(url, headers=headers, data=data).json()
|
195 |
+
time.sleep(self.api_interval)
|
196 |
+
return response["translations"]
|
modules/faster_whisper_inference.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import numpy as np
|
4 |
+
from typing import BinaryIO, Union, Tuple, List
|
5 |
+
|
6 |
+
import faster_whisper
|
7 |
+
from faster_whisper.vad import VadOptions
|
8 |
+
import ctranslate2
|
9 |
+
import whisper
|
10 |
+
import gradio as gr
|
11 |
+
|
12 |
+
from modules.whisper_parameter import *
|
13 |
+
from modules.whisper_base import WhisperBase
|
14 |
+
|
15 |
+
# Temporal fix of the issue : https://github.com/jhj0517/Whisper-WebUI/issues/144
|
16 |
+
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
17 |
+
|
18 |
+
|
19 |
+
class FasterWhisperInference(WhisperBase):
|
20 |
+
def __init__(self):
|
21 |
+
super().__init__(
|
22 |
+
model_dir=os.path.join("models", "Whisper", "faster-whisper")
|
23 |
+
)
|
24 |
+
self.model_paths = self.get_model_paths()
|
25 |
+
self.available_models = self.model_paths.keys()
|
26 |
+
self.available_compute_types = ctranslate2.get_supported_compute_types(
|
27 |
+
"cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
|
28 |
+
|
29 |
+
def transcribe(self,
|
30 |
+
audio: Union[str, BinaryIO, np.ndarray],
|
31 |
+
progress: gr.Progress,
|
32 |
+
*whisper_params,
|
33 |
+
) -> Tuple[List[dict], float]:
|
34 |
+
"""
|
35 |
+
transcribe method for faster-whisper.
|
36 |
+
|
37 |
+
Parameters
|
38 |
+
----------
|
39 |
+
audio: Union[str, BinaryIO, np.ndarray]
|
40 |
+
Audio path or file binary or Audio numpy array
|
41 |
+
progress: gr.Progress
|
42 |
+
Indicator to show progress directly in gradio.
|
43 |
+
*whisper_params: tuple
|
44 |
+
Gradio components related to Whisper. see whisper_data_class.py for details.
|
45 |
+
|
46 |
+
Returns
|
47 |
+
----------
|
48 |
+
segments_result: List[dict]
|
49 |
+
list of dicts that includes start, end timestamps and transcribed text
|
50 |
+
elapsed_time: float
|
51 |
+
elapsed time for transcription
|
52 |
+
"""
|
53 |
+
start_time = time.time()
|
54 |
+
|
55 |
+
params = WhisperValues(*whisper_params)
|
56 |
+
|
57 |
+
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
58 |
+
self.update_model(params.model_size, params.compute_type, progress)
|
59 |
+
|
60 |
+
if params.lang == "Automatic Detection":
|
61 |
+
params.lang = None
|
62 |
+
else:
|
63 |
+
language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
|
64 |
+
params.lang = language_code_dict[params.lang]
|
65 |
+
|
66 |
+
vad_options = VadOptions(
|
67 |
+
threshold=params.threshold,
|
68 |
+
min_speech_duration_ms=params.min_speech_duration_ms,
|
69 |
+
max_speech_duration_s=params.max_speech_duration_s,
|
70 |
+
min_silence_duration_ms=params.min_silence_duration_ms,
|
71 |
+
window_size_samples=params.window_size_samples,
|
72 |
+
speech_pad_ms=params.speech_pad_ms
|
73 |
+
)
|
74 |
+
|
75 |
+
segments, info = self.model.transcribe(
|
76 |
+
audio=audio,
|
77 |
+
language=params.lang,
|
78 |
+
task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
|
79 |
+
beam_size=params.beam_size,
|
80 |
+
log_prob_threshold=params.log_prob_threshold,
|
81 |
+
no_speech_threshold=params.no_speech_threshold,
|
82 |
+
best_of=params.best_of,
|
83 |
+
patience=params.patience,
|
84 |
+
temperature=params.temperature,
|
85 |
+
compression_ratio_threshold=params.compression_ratio_threshold,
|
86 |
+
vad_filter=params.vad_filter,
|
87 |
+
vad_parameters=vad_options
|
88 |
+
)
|
89 |
+
progress(0, desc="Loading audio..")
|
90 |
+
|
91 |
+
segments_result = []
|
92 |
+
for segment in segments:
|
93 |
+
progress(segment.start / info.duration, desc="Transcribing..")
|
94 |
+
segments_result.append({
|
95 |
+
"start": segment.start,
|
96 |
+
"end": segment.end,
|
97 |
+
"text": segment.text
|
98 |
+
})
|
99 |
+
|
100 |
+
elapsed_time = time.time() - start_time
|
101 |
+
return segments_result, elapsed_time
|
102 |
+
|
103 |
+
def update_model(self,
|
104 |
+
model_size: str,
|
105 |
+
compute_type: str,
|
106 |
+
progress: gr.Progress
|
107 |
+
):
|
108 |
+
"""
|
109 |
+
Update current model setting
|
110 |
+
|
111 |
+
Parameters
|
112 |
+
----------
|
113 |
+
model_size: str
|
114 |
+
Size of whisper model
|
115 |
+
compute_type: str
|
116 |
+
Compute type for transcription.
|
117 |
+
see more info : https://opennmt.net/CTranslate2/quantization.html
|
118 |
+
progress: gr.Progress
|
119 |
+
Indicator to show progress directly in gradio.
|
120 |
+
"""
|
121 |
+
progress(0, desc="Initializing Model..")
|
122 |
+
self.current_model_size = self.model_paths[model_size]
|
123 |
+
self.current_compute_type = compute_type
|
124 |
+
self.model = faster_whisper.WhisperModel(
|
125 |
+
device=self.device,
|
126 |
+
model_size_or_path=self.current_model_size,
|
127 |
+
download_root=self.model_dir,
|
128 |
+
compute_type=self.current_compute_type
|
129 |
+
)
|
130 |
+
|
131 |
+
def get_model_paths(self):
|
132 |
+
"""
|
133 |
+
Get available models from models path including fine-tuned model.
|
134 |
+
|
135 |
+
Returns
|
136 |
+
----------
|
137 |
+
Name list of models
|
138 |
+
"""
|
139 |
+
model_paths = {model:model for model in whisper.available_models()}
|
140 |
+
faster_whisper_prefix = "models--Systran--faster-whisper-"
|
141 |
+
|
142 |
+
existing_models = os.listdir(self.model_dir)
|
143 |
+
wrong_dirs = [".locks"]
|
144 |
+
existing_models = list(set(existing_models) - set(wrong_dirs))
|
145 |
+
|
146 |
+
webui_dir = os.getcwd()
|
147 |
+
|
148 |
+
for model_name in existing_models:
|
149 |
+
if faster_whisper_prefix in model_name:
|
150 |
+
model_name = model_name[len(faster_whisper_prefix):]
|
151 |
+
|
152 |
+
if model_name not in whisper.available_models():
|
153 |
+
model_paths[model_name] = os.path.join(webui_dir, self.model_dir, model_name)
|
154 |
+
return model_paths
|
modules/nllb_inference.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
2 |
+
import gradio as gr
|
3 |
+
import os
|
4 |
+
|
5 |
+
from modules.translation_base import TranslationBase
|
6 |
+
|
7 |
+
|
8 |
+
class NLLBInference(TranslationBase):
|
9 |
+
def __init__(self):
|
10 |
+
super().__init__(
|
11 |
+
model_dir=os.path.join("models", "NLLB")
|
12 |
+
)
|
13 |
+
self.tokenizer = None
|
14 |
+
self.available_models = ["facebook/nllb-200-3.3B", "facebook/nllb-200-1.3B", "facebook/nllb-200-distilled-600M"]
|
15 |
+
self.available_source_langs = list(NLLB_AVAILABLE_LANGS.keys())
|
16 |
+
self.available_target_langs = list(NLLB_AVAILABLE_LANGS.keys())
|
17 |
+
self.pipeline = None
|
18 |
+
|
19 |
+
def translate(self,
|
20 |
+
text: str
|
21 |
+
):
|
22 |
+
result = self.pipeline(text)
|
23 |
+
return result[0]['translation_text']
|
24 |
+
|
25 |
+
def update_model(self,
|
26 |
+
model_size: str,
|
27 |
+
src_lang: str,
|
28 |
+
tgt_lang: str,
|
29 |
+
progress: gr.Progress
|
30 |
+
):
|
31 |
+
if model_size != self.current_model_size or self.model is None:
|
32 |
+
print("\nInitializing NLLB Model..\n")
|
33 |
+
progress(0, desc="Initializing NLLB Model..")
|
34 |
+
self.current_model_size = model_size
|
35 |
+
self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_size,
|
36 |
+
cache_dir=self.model_dir)
|
37 |
+
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
|
38 |
+
cache_dir=os.path.join(self.model_dir, "tokenizers"))
|
39 |
+
src_lang = NLLB_AVAILABLE_LANGS[src_lang]
|
40 |
+
tgt_lang = NLLB_AVAILABLE_LANGS[tgt_lang]
|
41 |
+
self.pipeline = pipeline("translation",
|
42 |
+
model=self.model,
|
43 |
+
tokenizer=self.tokenizer,
|
44 |
+
src_lang=src_lang,
|
45 |
+
tgt_lang=tgt_lang,
|
46 |
+
device=self.device)
|
47 |
+
|
48 |
+
NLLB_AVAILABLE_LANGS = {
|
49 |
+
"Acehnese (Arabic script)": "ace_Arab",
|
50 |
+
"Acehnese (Latin script)": "ace_Latn",
|
51 |
+
"Mesopotamian Arabic": "acm_Arab",
|
52 |
+
"Ta’izzi-Adeni Arabic": "acq_Arab",
|
53 |
+
"Tunisian Arabic": "aeb_Arab",
|
54 |
+
"Afrikaans": "afr_Latn",
|
55 |
+
"South Levantine Arabic": "ajp_Arab",
|
56 |
+
"Akan": "aka_Latn",
|
57 |
+
"Amharic": "amh_Ethi",
|
58 |
+
"North Levantine Arabic": "apc_Arab",
|
59 |
+
"Modern Standard Arabic": "arb_Arab",
|
60 |
+
"Modern Standard Arabic (Romanized)": "arb_Latn",
|
61 |
+
"Najdi Arabic": "ars_Arab",
|
62 |
+
"Moroccan Arabic": "ary_Arab",
|
63 |
+
"Egyptian Arabic": "arz_Arab",
|
64 |
+
"Assamese": "asm_Beng",
|
65 |
+
"Asturian": "ast_Latn",
|
66 |
+
"Awadhi": "awa_Deva",
|
67 |
+
"Central Aymara": "ayr_Latn",
|
68 |
+
"South Azerbaijani": "azb_Arab",
|
69 |
+
"North Azerbaijani": "azj_Latn",
|
70 |
+
"Bashkir": "bak_Cyrl",
|
71 |
+
"Bambara": "bam_Latn",
|
72 |
+
"Balinese": "ban_Latn",
|
73 |
+
"Belarusian": "bel_Cyrl",
|
74 |
+
"Bemba": "bem_Latn",
|
75 |
+
"Bengali": "ben_Beng",
|
76 |
+
"Bhojpuri": "bho_Deva",
|
77 |
+
"Banjar (Arabic script)": "bjn_Arab",
|
78 |
+
"Banjar (Latin script)": "bjn_Latn",
|
79 |
+
"Standard Tibetan": "bod_Tibt",
|
80 |
+
"Bosnian": "bos_Latn",
|
81 |
+
"Buginese": "bug_Latn",
|
82 |
+
"Bulgarian": "bul_Cyrl",
|
83 |
+
"Catalan": "cat_Latn",
|
84 |
+
"Cebuano": "ceb_Latn",
|
85 |
+
"Czech": "ces_Latn",
|
86 |
+
"Chokwe": "cjk_Latn",
|
87 |
+
"Central Kurdish": "ckb_Arab",
|
88 |
+
"Crimean Tatar": "crh_Latn",
|
89 |
+
"Welsh": "cym_Latn",
|
90 |
+
"Danish": "dan_Latn",
|
91 |
+
"German": "deu_Latn",
|
92 |
+
"Southwestern Dinka": "dik_Latn",
|
93 |
+
"Dyula": "dyu_Latn",
|
94 |
+
"Dzongkha": "dzo_Tibt",
|
95 |
+
"Greek": "ell_Grek",
|
96 |
+
"English": "eng_Latn",
|
97 |
+
"Esperanto": "epo_Latn",
|
98 |
+
"Estonian": "est_Latn",
|
99 |
+
"Basque": "eus_Latn",
|
100 |
+
"Ewe": "ewe_Latn",
|
101 |
+
"Faroese": "fao_Latn",
|
102 |
+
"Fijian": "fij_Latn",
|
103 |
+
"Finnish": "fin_Latn",
|
104 |
+
"Fon": "fon_Latn",
|
105 |
+
"French": "fra_Latn",
|
106 |
+
"Friulian": "fur_Latn",
|
107 |
+
"Nigerian Fulfulde": "fuv_Latn",
|
108 |
+
"Scottish Gaelic": "gla_Latn",
|
109 |
+
"Irish": "gle_Latn",
|
110 |
+
"Galician": "glg_Latn",
|
111 |
+
"Guarani": "grn_Latn",
|
112 |
+
"Gujarati": "guj_Gujr",
|
113 |
+
"Haitian Creole": "hat_Latn",
|
114 |
+
"Hausa": "hau_Latn",
|
115 |
+
"Hebrew": "heb_Hebr",
|
116 |
+
"Hindi": "hin_Deva",
|
117 |
+
"Chhattisgarhi": "hne_Deva",
|
118 |
+
"Croatian": "hrv_Latn",
|
119 |
+
"Hungarian": "hun_Latn",
|
120 |
+
"Armenian": "hye_Armn",
|
121 |
+
"Igbo": "ibo_Latn",
|
122 |
+
"Ilocano": "ilo_Latn",
|
123 |
+
"Indonesian": "ind_Latn",
|
124 |
+
"Icelandic": "isl_Latn",
|
125 |
+
"Italian": "ita_Latn",
|
126 |
+
"Javanese": "jav_Latn",
|
127 |
+
"Japanese": "jpn_Jpan",
|
128 |
+
"Kabyle": "kab_Latn",
|
129 |
+
"Jingpho": "kac_Latn",
|
130 |
+
"Kamba": "kam_Latn",
|
131 |
+
"Kannada": "kan_Knda",
|
132 |
+
"Kashmiri (Arabic script)": "kas_Arab",
|
133 |
+
"Kashmiri (Devanagari script)": "kas_Deva",
|
134 |
+
"Georgian": "kat_Geor",
|
135 |
+
"Central Kanuri (Arabic script)": "knc_Arab",
|
136 |
+
"Central Kanuri (Latin script)": "knc_Latn",
|
137 |
+
"Kazakh": "kaz_Cyrl",
|
138 |
+
"Kabiyè": "kbp_Latn",
|
139 |
+
"Kabuverdianu": "kea_Latn",
|
140 |
+
"Khmer": "khm_Khmr",
|
141 |
+
"Kikuyu": "kik_Latn",
|
142 |
+
"Kinyarwanda": "kin_Latn",
|
143 |
+
"Kyrgyz": "kir_Cyrl",
|
144 |
+
"Kimbundu": "kmb_Latn",
|
145 |
+
"Northern Kurdish": "kmr_Latn",
|
146 |
+
"Kikongo": "kon_Latn",
|
147 |
+
"Korean": "kor_Hang",
|
148 |
+
"Lao": "lao_Laoo",
|
149 |
+
"Ligurian": "lij_Latn",
|
150 |
+
"Limburgish": "lim_Latn",
|
151 |
+
"Lingala": "lin_Latn",
|
152 |
+
"Lithuanian": "lit_Latn",
|
153 |
+
"Lombard": "lmo_Latn",
|
154 |
+
"Latgalian": "ltg_Latn",
|
155 |
+
"Luxembourgish": "ltz_Latn",
|
156 |
+
"Luba-Kasai": "lua_Latn",
|
157 |
+
"Ganda": "lug_Latn",
|
158 |
+
"Luo": "luo_Latn",
|
159 |
+
"Mizo": "lus_Latn",
|
160 |
+
"Standard Latvian": "lvs_Latn",
|
161 |
+
"Magahi": "mag_Deva",
|
162 |
+
"Maithili": "mai_Deva",
|
163 |
+
"Malayalam": "mal_Mlym",
|
164 |
+
"Marathi": "mar_Deva",
|
165 |
+
"Minangkabau (Arabic script)": "min_Arab",
|
166 |
+
"Minangkabau (Latin script)": "min_Latn",
|
167 |
+
"Macedonian": "mkd_Cyrl",
|
168 |
+
"Plateau Malagasy": "plt_Latn",
|
169 |
+
"Maltese": "mlt_Latn",
|
170 |
+
"Meitei (Bengali script)": "mni_Beng",
|
171 |
+
"Halh Mongolian": "khk_Cyrl",
|
172 |
+
"Mossi": "mos_Latn",
|
173 |
+
"Maori": "mri_Latn",
|
174 |
+
"Burmese": "mya_Mymr",
|
175 |
+
"Dutch": "nld_Latn",
|
176 |
+
"Norwegian Nynorsk": "nno_Latn",
|
177 |
+
"Norwegian Bokmål": "nob_Latn",
|
178 |
+
"Nepali": "npi_Deva",
|
179 |
+
"Northern Sotho": "nso_Latn",
|
180 |
+
"Nuer": "nus_Latn",
|
181 |
+
"Nyanja": "nya_Latn",
|
182 |
+
"Occitan": "oci_Latn",
|
183 |
+
"West Central Oromo": "gaz_Latn",
|
184 |
+
"Odia": "ory_Orya",
|
185 |
+
"Pangasinan": "pag_Latn",
|
186 |
+
"Eastern Panjabi": "pan_Guru",
|
187 |
+
"Papiamento": "pap_Latn",
|
188 |
+
"Western Persian": "pes_Arab",
|
189 |
+
"Polish": "pol_Latn",
|
190 |
+
"Portuguese": "por_Latn",
|
191 |
+
"Dari": "prs_Arab",
|
192 |
+
"Southern Pashto": "pbt_Arab",
|
193 |
+
"Ayacucho Quechua": "quy_Latn",
|
194 |
+
"Romanian": "ron_Latn",
|
195 |
+
"Rundi": "run_Latn",
|
196 |
+
"Russian": "rus_Cyrl",
|
197 |
+
"Sango": "sag_Latn",
|
198 |
+
"Sanskrit": "san_Deva",
|
199 |
+
"Santali": "sat_Olck",
|
200 |
+
"Sicilian": "scn_Latn",
|
201 |
+
"Shan": "shn_Mymr",
|
202 |
+
"Sinhala": "sin_Sinh",
|
203 |
+
"Slovak": "slk_Latn",
|
204 |
+
"Slovenian": "slv_Latn",
|
205 |
+
"Samoan": "smo_Latn",
|
206 |
+
"Shona": "sna_Latn",
|
207 |
+
"Sindhi": "snd_Arab",
|
208 |
+
"Somali": "som_Latn",
|
209 |
+
"Southern Sotho": "sot_Latn",
|
210 |
+
"Spanish": "spa_Latn",
|
211 |
+
"Tosk Albanian": "als_Latn",
|
212 |
+
"Sardinian": "srd_Latn",
|
213 |
+
"Serbian": "srp_Cyrl",
|
214 |
+
"Swati": "ssw_Latn",
|
215 |
+
"Sundanese": "sun_Latn",
|
216 |
+
"Swedish": "swe_Latn",
|
217 |
+
"Swahili": "swh_Latn",
|
218 |
+
"Silesian": "szl_Latn",
|
219 |
+
"Tamil": "tam_Taml",
|
220 |
+
"Tatar": "tat_Cyrl",
|
221 |
+
"Telugu": "tel_Telu",
|
222 |
+
"Tajik": "tgk_Cyrl",
|
223 |
+
"Tagalog": "tgl_Latn",
|
224 |
+
"Thai": "tha_Thai",
|
225 |
+
"Tigrinya": "tir_Ethi",
|
226 |
+
"Tamasheq (Latin script)": "taq_Latn",
|
227 |
+
"Tamasheq (Tifinagh script)": "taq_Tfng",
|
228 |
+
"Tok Pisin": "tpi_Latn",
|
229 |
+
"Tswana": "tsn_Latn",
|
230 |
+
"Tsonga": "tso_Latn",
|
231 |
+
"Turkmen": "tuk_Latn",
|
232 |
+
"Tumbuka": "tum_Latn",
|
233 |
+
"Turkish": "tur_Latn",
|
234 |
+
"Twi": "twi_Latn",
|
235 |
+
"Central Atlas Tamazight": "tzm_Tfng",
|
236 |
+
"Uyghur": "uig_Arab",
|
237 |
+
"Ukrainian": "ukr_Cyrl",
|
238 |
+
"Umbundu": "umb_Latn",
|
239 |
+
"Urdu": "urd_Arab",
|
240 |
+
"Northern Uzbek": "uzn_Latn",
|
241 |
+
"Venetian": "vec_Latn",
|
242 |
+
"Vietnamese": "vie_Latn",
|
243 |
+
"Waray": "war_Latn",
|
244 |
+
"Wolof": "wol_Latn",
|
245 |
+
"Xhosa": "xho_Latn",
|
246 |
+
"Eastern Yiddish": "ydd_Hebr",
|
247 |
+
"Yoruba": "yor_Latn",
|
248 |
+
"Yue Chinese": "yue_Hant",
|
249 |
+
"Chinese (Simplified)": "zho_Hans",
|
250 |
+
"Chinese (Traditional)": "zho_Hant",
|
251 |
+
"Standard Malay": "zsm_Latn",
|
252 |
+
"Zulu": "zul_Latn",
|
253 |
+
}
|
modules/subtitle_manager.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
|
4 |
+
def timeformat_srt(time):
|
5 |
+
hours = time // 3600
|
6 |
+
minutes = (time - hours * 3600) // 60
|
7 |
+
seconds = time - hours * 3600 - minutes * 60
|
8 |
+
milliseconds = (time - int(time)) * 1000
|
9 |
+
return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}"
|
10 |
+
|
11 |
+
|
12 |
+
def timeformat_vtt(time):
|
13 |
+
hours = time // 3600
|
14 |
+
minutes = (time - hours * 3600) // 60
|
15 |
+
seconds = time - hours * 3600 - minutes * 60
|
16 |
+
milliseconds = (time - int(time)) * 1000
|
17 |
+
return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}.{int(milliseconds):03d}"
|
18 |
+
|
19 |
+
|
20 |
+
def write_file(subtitle, output_file):
|
21 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
22 |
+
f.write(subtitle)
|
23 |
+
|
24 |
+
|
25 |
+
def get_srt(segments):
|
26 |
+
output = ""
|
27 |
+
for i, segment in enumerate(segments):
|
28 |
+
output += f"{i + 1}\n"
|
29 |
+
output += f"{timeformat_srt(segment['start'])} --> {timeformat_srt(segment['end'])}\n"
|
30 |
+
if segment['text'].startswith(' '):
|
31 |
+
segment['text'] = segment['text'][1:]
|
32 |
+
output += f"{segment['text']}\n\n"
|
33 |
+
return output
|
34 |
+
|
35 |
+
|
36 |
+
def get_vtt(segments):
|
37 |
+
output = "WebVTT\n\n"
|
38 |
+
for i, segment in enumerate(segments):
|
39 |
+
output += f"{i + 1}\n"
|
40 |
+
output += f"{timeformat_vtt(segment['start'])} --> {timeformat_vtt(segment['end'])}\n"
|
41 |
+
if segment['text'].startswith(' '):
|
42 |
+
segment['text'] = segment['text'][1:]
|
43 |
+
output += f"{segment['text']}\n\n"
|
44 |
+
return output
|
45 |
+
|
46 |
+
|
47 |
+
def get_txt(segments):
|
48 |
+
output = ""
|
49 |
+
for i, segment in enumerate(segments):
|
50 |
+
if segment['text'].startswith(' '):
|
51 |
+
segment['text'] = segment['text'][1:]
|
52 |
+
output += f"{segment['text']}\n"
|
53 |
+
return output
|
54 |
+
|
55 |
+
|
56 |
+
def parse_srt(file_path):
|
57 |
+
"""Reads SRT file and returns as dict"""
|
58 |
+
with open(file_path, 'r', encoding='utf-8') as file:
|
59 |
+
srt_data = file.read()
|
60 |
+
|
61 |
+
data = []
|
62 |
+
blocks = srt_data.split('\n\n')
|
63 |
+
|
64 |
+
for block in blocks:
|
65 |
+
if block.strip() != '':
|
66 |
+
lines = block.strip().split('\n')
|
67 |
+
index = lines[0]
|
68 |
+
timestamp = lines[1]
|
69 |
+
sentence = ' '.join(lines[2:])
|
70 |
+
|
71 |
+
data.append({
|
72 |
+
"index": index,
|
73 |
+
"timestamp": timestamp,
|
74 |
+
"sentence": sentence
|
75 |
+
})
|
76 |
+
return data
|
77 |
+
|
78 |
+
|
79 |
+
def parse_vtt(file_path):
|
80 |
+
"""Reads WebVTT file and returns as dict"""
|
81 |
+
with open(file_path, 'r', encoding='utf-8') as file:
|
82 |
+
webvtt_data = file.read()
|
83 |
+
|
84 |
+
data = []
|
85 |
+
blocks = webvtt_data.split('\n\n')
|
86 |
+
|
87 |
+
for block in blocks:
|
88 |
+
if block.strip() != '' and not block.strip().startswith("WebVTT"):
|
89 |
+
lines = block.strip().split('\n')
|
90 |
+
index = lines[0]
|
91 |
+
timestamp = lines[1]
|
92 |
+
sentence = ' '.join(lines[2:])
|
93 |
+
|
94 |
+
data.append({
|
95 |
+
"index": index,
|
96 |
+
"timestamp": timestamp,
|
97 |
+
"sentence": sentence
|
98 |
+
})
|
99 |
+
|
100 |
+
return data
|
101 |
+
|
102 |
+
|
103 |
+
def get_serialized_srt(dicts):
|
104 |
+
output = ""
|
105 |
+
for dic in dicts:
|
106 |
+
output += f'{dic["index"]}\n'
|
107 |
+
output += f'{dic["timestamp"]}\n'
|
108 |
+
output += f'{dic["sentence"]}\n\n'
|
109 |
+
return output
|
110 |
+
|
111 |
+
|
112 |
+
def get_serialized_vtt(dicts):
|
113 |
+
output = "WebVTT\n\n"
|
114 |
+
for dic in dicts:
|
115 |
+
output += f'{dic["index"]}\n'
|
116 |
+
output += f'{dic["timestamp"]}\n'
|
117 |
+
output += f'{dic["sentence"]}\n\n'
|
118 |
+
return output
|
119 |
+
|
120 |
+
|
121 |
+
def safe_filename(name):
|
122 |
+
from app import _args
|
123 |
+
INVALID_FILENAME_CHARS = r'[<>:"/\\|?*\x00-\x1f]'
|
124 |
+
safe_name = re.sub(INVALID_FILENAME_CHARS, '_', name)
|
125 |
+
if not _args.colab:
|
126 |
+
return safe_name
|
127 |
+
# Truncate the filename if it exceeds the max_length (20)
|
128 |
+
if len(safe_name) > 20:
|
129 |
+
file_extension = safe_name.split('.')[-1]
|
130 |
+
if len(file_extension) + 1 < 20:
|
131 |
+
truncated_name = safe_name[:20 - len(file_extension) - 1]
|
132 |
+
safe_name = truncated_name + '.' + file_extension
|
133 |
+
else:
|
134 |
+
safe_name = safe_name[:20]
|
135 |
+
return safe_name
|
modules/translation_base.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import gradio as gr
|
4 |
+
from abc import ABC, abstractmethod
|
5 |
+
from typing import List
|
6 |
+
from datetime import datetime
|
7 |
+
|
8 |
+
from modules.whisper_parameter import *
|
9 |
+
from modules.subtitle_manager import *
|
10 |
+
|
11 |
+
|
12 |
+
class TranslationBase(ABC):
|
13 |
+
def __init__(self,
|
14 |
+
model_dir: str):
|
15 |
+
super().__init__()
|
16 |
+
self.model = None
|
17 |
+
self.model_dir = model_dir
|
18 |
+
os.makedirs(self.model_dir, exist_ok=True)
|
19 |
+
self.current_model_size = None
|
20 |
+
self.device = self.get_device()
|
21 |
+
|
22 |
+
@abstractmethod
|
23 |
+
def translate(self,
|
24 |
+
text: str
|
25 |
+
):
|
26 |
+
pass
|
27 |
+
|
28 |
+
@abstractmethod
|
29 |
+
def update_model(self,
|
30 |
+
model_size: str,
|
31 |
+
src_lang: str,
|
32 |
+
tgt_lang: str,
|
33 |
+
progress: gr.Progress
|
34 |
+
):
|
35 |
+
pass
|
36 |
+
|
37 |
+
def translate_file(self,
|
38 |
+
fileobjs: list,
|
39 |
+
model_size: str,
|
40 |
+
src_lang: str,
|
41 |
+
tgt_lang: str,
|
42 |
+
add_timestamp: bool,
|
43 |
+
progress=gr.Progress()) -> list:
|
44 |
+
"""
|
45 |
+
Translate subtitle file from source language to target language
|
46 |
+
|
47 |
+
Parameters
|
48 |
+
----------
|
49 |
+
fileobjs: list
|
50 |
+
List of files to transcribe from gr.Files()
|
51 |
+
model_size: str
|
52 |
+
Whisper model size from gr.Dropdown()
|
53 |
+
src_lang: str
|
54 |
+
Source language of the file to translate from gr.Dropdown()
|
55 |
+
tgt_lang: str
|
56 |
+
Target language of the file to translate from gr.Dropdown()
|
57 |
+
add_timestamp: bool
|
58 |
+
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
|
59 |
+
progress: gr.Progress
|
60 |
+
Indicator to show progress directly in gradio.
|
61 |
+
I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
|
62 |
+
|
63 |
+
Returns
|
64 |
+
----------
|
65 |
+
A List of
|
66 |
+
String to return to gr.Textbox()
|
67 |
+
Files to return to gr.Files()
|
68 |
+
"""
|
69 |
+
try:
|
70 |
+
self.update_model(model_size=model_size,
|
71 |
+
src_lang=src_lang,
|
72 |
+
tgt_lang=tgt_lang,
|
73 |
+
progress=progress)
|
74 |
+
|
75 |
+
files_info = {}
|
76 |
+
for fileobj in fileobjs:
|
77 |
+
file_path = fileobj.name
|
78 |
+
file_name, file_ext = os.path.splitext(os.path.basename(fileobj.name))
|
79 |
+
if file_ext == ".srt":
|
80 |
+
parsed_dicts = parse_srt(file_path=file_path)
|
81 |
+
total_progress = len(parsed_dicts)
|
82 |
+
for index, dic in enumerate(parsed_dicts):
|
83 |
+
progress(index / total_progress, desc="Translating..")
|
84 |
+
translated_text = self.translate(dic["sentence"])
|
85 |
+
dic["sentence"] = translated_text
|
86 |
+
subtitle = get_serialized_srt(parsed_dicts)
|
87 |
+
|
88 |
+
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
89 |
+
if add_timestamp:
|
90 |
+
output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}")
|
91 |
+
else:
|
92 |
+
output_path = os.path.join("outputs", "translations", f"{file_name}.srt")
|
93 |
+
|
94 |
+
elif file_ext == ".vtt":
|
95 |
+
parsed_dicts = parse_vtt(file_path=file_path)
|
96 |
+
total_progress = len(parsed_dicts)
|
97 |
+
for index, dic in enumerate(parsed_dicts):
|
98 |
+
progress(index / total_progress, desc="Translating..")
|
99 |
+
translated_text = self.translate(dic["sentence"])
|
100 |
+
dic["sentence"] = translated_text
|
101 |
+
subtitle = get_serialized_vtt(parsed_dicts)
|
102 |
+
|
103 |
+
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
104 |
+
if add_timestamp:
|
105 |
+
output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}")
|
106 |
+
else:
|
107 |
+
output_path = os.path.join("outputs", "translations", f"{file_name}.vtt")
|
108 |
+
|
109 |
+
write_file(subtitle, output_path)
|
110 |
+
files_info[file_name] = subtitle
|
111 |
+
|
112 |
+
total_result = ''
|
113 |
+
for file_name, subtitle in files_info.items():
|
114 |
+
total_result += '------------------------------------\n'
|
115 |
+
total_result += f'{file_name}\n\n'
|
116 |
+
total_result += f'{subtitle}'
|
117 |
+
|
118 |
+
gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
|
119 |
+
return [gr_str, output_path]
|
120 |
+
except Exception as e:
|
121 |
+
print(f"Error: {str(e)}")
|
122 |
+
finally:
|
123 |
+
self.release_cuda_memory()
|
124 |
+
self.remove_input_files([fileobj.name for fileobj in fileobjs])
|
125 |
+
|
126 |
+
@staticmethod
|
127 |
+
def get_device():
|
128 |
+
if torch.cuda.is_available():
|
129 |
+
return "cuda"
|
130 |
+
elif torch.backends.mps.is_available():
|
131 |
+
return "mps"
|
132 |
+
else:
|
133 |
+
return "cpu"
|
134 |
+
|
135 |
+
@staticmethod
|
136 |
+
def release_cuda_memory():
|
137 |
+
if torch.cuda.is_available():
|
138 |
+
torch.cuda.empty_cache()
|
139 |
+
torch.cuda.reset_max_memory_allocated()
|
140 |
+
|
141 |
+
@staticmethod
|
142 |
+
def remove_input_files(file_paths: List[str]):
|
143 |
+
if not file_paths:
|
144 |
+
return
|
145 |
+
|
146 |
+
for file_path in file_paths:
|
147 |
+
if file_path and os.path.exists(file_path):
|
148 |
+
os.remove(file_path)
|
modules/whisper_Inference.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import whisper
|
2 |
+
import gradio as gr
|
3 |
+
import time
|
4 |
+
import os
|
5 |
+
from typing import BinaryIO, Union, Tuple, List
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from modules.whisper_base import WhisperBase
|
10 |
+
from modules.whisper_parameter import *
|
11 |
+
|
12 |
+
|
13 |
+
class WhisperInference(WhisperBase):
|
14 |
+
def __init__(self):
|
15 |
+
super().__init__(
|
16 |
+
model_dir=os.path.join("models", "Whisper")
|
17 |
+
)
|
18 |
+
|
19 |
+
def transcribe(self,
|
20 |
+
audio: Union[str, np.ndarray, torch.Tensor],
|
21 |
+
progress: gr.Progress,
|
22 |
+
*whisper_params,
|
23 |
+
) -> Tuple[List[dict], float]:
|
24 |
+
"""
|
25 |
+
transcribe method for faster-whisper.
|
26 |
+
|
27 |
+
Parameters
|
28 |
+
----------
|
29 |
+
audio: Union[str, BinaryIO, np.ndarray]
|
30 |
+
Audio path or file binary or Audio numpy array
|
31 |
+
progress: gr.Progress
|
32 |
+
Indicator to show progress directly in gradio.
|
33 |
+
*whisper_params: tuple
|
34 |
+
Gradio components related to Whisper. see whisper_data_class.py for details.
|
35 |
+
|
36 |
+
Returns
|
37 |
+
----------
|
38 |
+
segments_result: List[dict]
|
39 |
+
list of dicts that includes start, end timestamps and transcribed text
|
40 |
+
elapsed_time: float
|
41 |
+
elapsed time for transcription
|
42 |
+
"""
|
43 |
+
start_time = time.time()
|
44 |
+
params = WhisperValues(*whisper_params)
|
45 |
+
|
46 |
+
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
47 |
+
self.update_model(params.model_size, params.compute_type, progress)
|
48 |
+
|
49 |
+
if params.lang == "Automatic Detection":
|
50 |
+
params.lang = None
|
51 |
+
|
52 |
+
def progress_callback(progress_value):
|
53 |
+
progress(progress_value, desc="Transcribing..")
|
54 |
+
|
55 |
+
segments_result = self.model.transcribe(audio=audio,
|
56 |
+
language=params.lang,
|
57 |
+
verbose=False,
|
58 |
+
beam_size=params.beam_size,
|
59 |
+
logprob_threshold=params.log_prob_threshold,
|
60 |
+
no_speech_threshold=params.no_speech_threshold,
|
61 |
+
task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
|
62 |
+
fp16=True if params.compute_type == "float16" else False,
|
63 |
+
best_of=params.best_of,
|
64 |
+
patience=params.patience,
|
65 |
+
temperature=params.temperature,
|
66 |
+
compression_ratio_threshold=params.compression_ratio_threshold,
|
67 |
+
progress_callback=progress_callback,)["segments"]
|
68 |
+
elapsed_time = time.time() - start_time
|
69 |
+
|
70 |
+
return segments_result, elapsed_time
|
71 |
+
|
72 |
+
def update_model(self,
|
73 |
+
model_size: str,
|
74 |
+
compute_type: str,
|
75 |
+
progress: gr.Progress,
|
76 |
+
):
|
77 |
+
"""
|
78 |
+
Update current model setting
|
79 |
+
|
80 |
+
Parameters
|
81 |
+
----------
|
82 |
+
model_size: str
|
83 |
+
Size of whisper model
|
84 |
+
compute_type: str
|
85 |
+
Compute type for transcription.
|
86 |
+
see more info : https://opennmt.net/CTranslate2/quantization.html
|
87 |
+
progress: gr.Progress
|
88 |
+
Indicator to show progress directly in gradio.
|
89 |
+
"""
|
90 |
+
progress(0, desc="Initializing Model..")
|
91 |
+
self.current_compute_type = compute_type
|
92 |
+
self.current_model_size = model_size
|
93 |
+
self.model = whisper.load_model(
|
94 |
+
name=model_size,
|
95 |
+
device=self.device,
|
96 |
+
download_root=self.model_dir
|
97 |
+
)
|
modules/whisper_base.py
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from typing import List
|
4 |
+
import whisper
|
5 |
+
import gradio as gr
|
6 |
+
from abc import ABC, abstractmethod
|
7 |
+
from typing import BinaryIO, Union, Tuple, List
|
8 |
+
import numpy as np
|
9 |
+
from datetime import datetime
|
10 |
+
|
11 |
+
from modules.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
|
12 |
+
from modules.youtube_manager import get_ytdata, get_ytaudio
|
13 |
+
from modules.whisper_parameter import *
|
14 |
+
|
15 |
+
|
16 |
+
class WhisperBase(ABC):
|
17 |
+
def __init__(self,
|
18 |
+
model_dir: str):
|
19 |
+
self.model = None
|
20 |
+
self.current_model_size = None
|
21 |
+
self.model_dir = model_dir
|
22 |
+
os.makedirs(self.model_dir, exist_ok=True)
|
23 |
+
self.available_models = whisper.available_models()
|
24 |
+
self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
|
25 |
+
self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
|
26 |
+
self.device = self.get_device()
|
27 |
+
self.available_compute_types = ["float16", "float32"]
|
28 |
+
self.current_compute_type = "float16" if self.device == "cuda" else "float32"
|
29 |
+
|
30 |
+
@abstractmethod
|
31 |
+
def transcribe(self,
|
32 |
+
audio: Union[str, BinaryIO, np.ndarray],
|
33 |
+
progress: gr.Progress,
|
34 |
+
*whisper_params,
|
35 |
+
):
|
36 |
+
pass
|
37 |
+
|
38 |
+
@abstractmethod
|
39 |
+
def update_model(self,
|
40 |
+
model_size: str,
|
41 |
+
compute_type: str,
|
42 |
+
progress: gr.Progress
|
43 |
+
):
|
44 |
+
pass
|
45 |
+
|
46 |
+
def transcribe_file(self,
|
47 |
+
files: list,
|
48 |
+
file_format: str,
|
49 |
+
add_timestamp: bool,
|
50 |
+
progress=gr.Progress(),
|
51 |
+
*whisper_params,
|
52 |
+
) -> list:
|
53 |
+
"""
|
54 |
+
Write subtitle file from Files
|
55 |
+
|
56 |
+
Parameters
|
57 |
+
----------
|
58 |
+
files: list
|
59 |
+
List of files to transcribe from gr.Files()
|
60 |
+
file_format: str
|
61 |
+
Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
|
62 |
+
add_timestamp: bool
|
63 |
+
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
|
64 |
+
progress: gr.Progress
|
65 |
+
Indicator to show progress directly in gradio.
|
66 |
+
*whisper_params: tuple
|
67 |
+
Gradio components related to Whisper. see whisper_data_class.py for details.
|
68 |
+
|
69 |
+
Returns
|
70 |
+
----------
|
71 |
+
result_str:
|
72 |
+
Result of transcription to return to gr.Textbox()
|
73 |
+
result_file_path:
|
74 |
+
Output file path to return to gr.Files()
|
75 |
+
"""
|
76 |
+
try:
|
77 |
+
files_info = {}
|
78 |
+
for file in files:
|
79 |
+
transcribed_segments, time_for_task = self.transcribe(
|
80 |
+
file.name,
|
81 |
+
progress,
|
82 |
+
*whisper_params,
|
83 |
+
)
|
84 |
+
|
85 |
+
file_name, file_ext = os.path.splitext(os.path.basename(file.name))
|
86 |
+
file_name = safe_filename(file_name)
|
87 |
+
subtitle, file_path = self.generate_and_write_file(
|
88 |
+
file_name=file_name,
|
89 |
+
transcribed_segments=transcribed_segments,
|
90 |
+
add_timestamp=add_timestamp,
|
91 |
+
file_format=file_format
|
92 |
+
)
|
93 |
+
files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
|
94 |
+
|
95 |
+
total_result = ''
|
96 |
+
total_time = 0
|
97 |
+
for file_name, info in files_info.items():
|
98 |
+
total_result += '------------------------------------\n'
|
99 |
+
total_result += f'{file_name}\n\n'
|
100 |
+
total_result += f'{info["subtitle"]}'
|
101 |
+
total_time += info["time_for_task"]
|
102 |
+
|
103 |
+
result_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
|
104 |
+
result_file_path = [info['path'] for info in files_info.values()]
|
105 |
+
|
106 |
+
return [result_str, result_file_path]
|
107 |
+
|
108 |
+
except Exception as e:
|
109 |
+
print(f"Error transcribing file: {e}")
|
110 |
+
finally:
|
111 |
+
self.release_cuda_memory()
|
112 |
+
if not files:
|
113 |
+
self.remove_input_files([file.name for file in files])
|
114 |
+
|
115 |
+
def transcribe_mic(self,
|
116 |
+
mic_audio: str,
|
117 |
+
file_format: str,
|
118 |
+
progress=gr.Progress(),
|
119 |
+
*whisper_params,
|
120 |
+
) -> list:
|
121 |
+
"""
|
122 |
+
Write subtitle file from microphone
|
123 |
+
|
124 |
+
Parameters
|
125 |
+
----------
|
126 |
+
mic_audio: str
|
127 |
+
Audio file path from gr.Microphone()
|
128 |
+
file_format: str
|
129 |
+
Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
|
130 |
+
progress: gr.Progress
|
131 |
+
Indicator to show progress directly in gradio.
|
132 |
+
*whisper_params: tuple
|
133 |
+
Gradio components related to Whisper. see whisper_data_class.py for details.
|
134 |
+
|
135 |
+
Returns
|
136 |
+
----------
|
137 |
+
result_str:
|
138 |
+
Result of transcription to return to gr.Textbox()
|
139 |
+
result_file_path:
|
140 |
+
Output file path to return to gr.Files()
|
141 |
+
"""
|
142 |
+
try:
|
143 |
+
progress(0, desc="Loading Audio..")
|
144 |
+
transcribed_segments, time_for_task = self.transcribe(
|
145 |
+
mic_audio,
|
146 |
+
progress,
|
147 |
+
*whisper_params,
|
148 |
+
)
|
149 |
+
progress(1, desc="Completed!")
|
150 |
+
|
151 |
+
subtitle, result_file_path = self.generate_and_write_file(
|
152 |
+
file_name="Mic",
|
153 |
+
transcribed_segments=transcribed_segments,
|
154 |
+
add_timestamp=True,
|
155 |
+
file_format=file_format
|
156 |
+
)
|
157 |
+
|
158 |
+
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
|
159 |
+
return [result_str, result_file_path]
|
160 |
+
except Exception as e:
|
161 |
+
print(f"Error transcribing file: {e}")
|
162 |
+
finally:
|
163 |
+
self.release_cuda_memory()
|
164 |
+
self.remove_input_files([mic_audio])
|
165 |
+
|
166 |
+
def transcribe_youtube(self,
|
167 |
+
youtube_link: str,
|
168 |
+
file_format: str,
|
169 |
+
add_timestamp: bool,
|
170 |
+
progress=gr.Progress(),
|
171 |
+
*whisper_params,
|
172 |
+
) -> list:
|
173 |
+
"""
|
174 |
+
Write subtitle file from Youtube
|
175 |
+
|
176 |
+
Parameters
|
177 |
+
----------
|
178 |
+
youtube_link: str
|
179 |
+
URL of the Youtube video to transcribe from gr.Textbox()
|
180 |
+
file_format: str
|
181 |
+
Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
|
182 |
+
add_timestamp: bool
|
183 |
+
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
|
184 |
+
progress: gr.Progress
|
185 |
+
Indicator to show progress directly in gradio.
|
186 |
+
*whisper_params: tuple
|
187 |
+
Gradio components related to Whisper. see whisper_data_class.py for details.
|
188 |
+
|
189 |
+
Returns
|
190 |
+
----------
|
191 |
+
result_str:
|
192 |
+
Result of transcription to return to gr.Textbox()
|
193 |
+
result_file_path:
|
194 |
+
Output file path to return to gr.Files()
|
195 |
+
"""
|
196 |
+
try:
|
197 |
+
progress(0, desc="Loading Audio from Youtube..")
|
198 |
+
yt = get_ytdata(youtube_link)
|
199 |
+
audio = get_ytaudio(yt)
|
200 |
+
|
201 |
+
transcribed_segments, time_for_task = self.transcribe(
|
202 |
+
audio,
|
203 |
+
progress,
|
204 |
+
*whisper_params,
|
205 |
+
)
|
206 |
+
|
207 |
+
progress(1, desc="Completed!")
|
208 |
+
|
209 |
+
file_name = safe_filename(yt.title)
|
210 |
+
subtitle, result_file_path = self.generate_and_write_file(
|
211 |
+
file_name=file_name,
|
212 |
+
transcribed_segments=transcribed_segments,
|
213 |
+
add_timestamp=add_timestamp,
|
214 |
+
file_format=file_format
|
215 |
+
)
|
216 |
+
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
|
217 |
+
|
218 |
+
return [result_str, result_file_path]
|
219 |
+
|
220 |
+
except Exception as e:
|
221 |
+
print(f"Error transcribing file: {e}")
|
222 |
+
finally:
|
223 |
+
try:
|
224 |
+
if 'yt' not in locals():
|
225 |
+
yt = get_ytdata(youtube_link)
|
226 |
+
file_path = get_ytaudio(yt)
|
227 |
+
else:
|
228 |
+
file_path = get_ytaudio(yt)
|
229 |
+
|
230 |
+
self.release_cuda_memory()
|
231 |
+
self.remove_input_files([file_path])
|
232 |
+
except Exception as cleanup_error:
|
233 |
+
pass
|
234 |
+
|
235 |
+
@staticmethod
|
236 |
+
def generate_and_write_file(file_name: str,
|
237 |
+
transcribed_segments: list,
|
238 |
+
add_timestamp: bool,
|
239 |
+
file_format: str,
|
240 |
+
) -> str:
|
241 |
+
"""
|
242 |
+
Writes subtitle file
|
243 |
+
|
244 |
+
Parameters
|
245 |
+
----------
|
246 |
+
file_name: str
|
247 |
+
Output file name
|
248 |
+
transcribed_segments: list
|
249 |
+
Text segments transcribed from audio
|
250 |
+
add_timestamp: bool
|
251 |
+
Determines whether to add a timestamp to the end of the filename.
|
252 |
+
file_format: str
|
253 |
+
File format to write. Supported formats: [SRT, WebVTT, txt]
|
254 |
+
|
255 |
+
Returns
|
256 |
+
----------
|
257 |
+
content: str
|
258 |
+
Result of the transcription
|
259 |
+
output_path: str
|
260 |
+
output file path
|
261 |
+
"""
|
262 |
+
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
263 |
+
if add_timestamp:
|
264 |
+
output_path = os.path.join("outputs", f"{file_name}-{timestamp}")
|
265 |
+
else:
|
266 |
+
output_path = os.path.join("outputs", f"{file_name}")
|
267 |
+
|
268 |
+
if file_format == "SRT":
|
269 |
+
content = get_srt(transcribed_segments)
|
270 |
+
output_path += '.srt'
|
271 |
+
write_file(content, output_path)
|
272 |
+
|
273 |
+
elif file_format == "WebVTT":
|
274 |
+
content = get_vtt(transcribed_segments)
|
275 |
+
output_path += '.vtt'
|
276 |
+
write_file(content, output_path)
|
277 |
+
|
278 |
+
elif file_format == "txt":
|
279 |
+
content = get_txt(transcribed_segments)
|
280 |
+
output_path += '.txt'
|
281 |
+
write_file(content, output_path)
|
282 |
+
return content, output_path
|
283 |
+
|
284 |
+
@staticmethod
|
285 |
+
def format_time(elapsed_time: float) -> str:
|
286 |
+
"""
|
287 |
+
Get {hours} {minutes} {seconds} time format string
|
288 |
+
|
289 |
+
Parameters
|
290 |
+
----------
|
291 |
+
elapsed_time: str
|
292 |
+
Elapsed time for transcription
|
293 |
+
|
294 |
+
Returns
|
295 |
+
----------
|
296 |
+
Time format string
|
297 |
+
"""
|
298 |
+
hours, rem = divmod(elapsed_time, 3600)
|
299 |
+
minutes, seconds = divmod(rem, 60)
|
300 |
+
|
301 |
+
time_str = ""
|
302 |
+
if hours:
|
303 |
+
time_str += f"{hours} hours "
|
304 |
+
if minutes:
|
305 |
+
time_str += f"{minutes} minutes "
|
306 |
+
seconds = round(seconds)
|
307 |
+
time_str += f"{seconds} seconds"
|
308 |
+
|
309 |
+
return time_str.strip()
|
310 |
+
|
311 |
+
@staticmethod
|
312 |
+
def get_device():
|
313 |
+
if torch.cuda.is_available():
|
314 |
+
return "cuda"
|
315 |
+
elif torch.backends.mps.is_available():
|
316 |
+
return "mps"
|
317 |
+
else:
|
318 |
+
return "cpu"
|
319 |
+
|
320 |
+
@staticmethod
|
321 |
+
def release_cuda_memory():
|
322 |
+
if torch.cuda.is_available():
|
323 |
+
torch.cuda.empty_cache()
|
324 |
+
torch.cuda.reset_max_memory_allocated()
|
325 |
+
|
326 |
+
@staticmethod
|
327 |
+
def remove_input_files(file_paths: List[str]):
|
328 |
+
if not file_paths:
|
329 |
+
return
|
330 |
+
|
331 |
+
for file_path in file_paths:
|
332 |
+
if file_path and os.path.exists(file_path):
|
333 |
+
os.remove(file_path)
|
modules/whisper_parameter.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, fields
|
2 |
+
import gradio as gr
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class WhisperGradioComponents:
|
8 |
+
model_size: gr.Dropdown
|
9 |
+
lang: gr.Dropdown
|
10 |
+
is_translate: gr.Checkbox
|
11 |
+
beam_size: gr.Number
|
12 |
+
log_prob_threshold: gr.Number
|
13 |
+
no_speech_threshold: gr.Number
|
14 |
+
compute_type: gr.Dropdown
|
15 |
+
best_of: gr.Number
|
16 |
+
patience: gr.Number
|
17 |
+
condition_on_previous_text: gr.Checkbox
|
18 |
+
initial_prompt: gr.Textbox
|
19 |
+
temperature: gr.Slider
|
20 |
+
compression_ratio_threshold: gr.Number
|
21 |
+
vad_filter: gr.Checkbox
|
22 |
+
threshold: gr.Slider
|
23 |
+
min_speech_duration_ms: gr.Number
|
24 |
+
max_speech_duration_s: gr.Number
|
25 |
+
min_silence_duration_ms: gr.Number
|
26 |
+
window_size_sample: gr.Number
|
27 |
+
speech_pad_ms: gr.Number
|
28 |
+
"""
|
29 |
+
A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
|
30 |
+
See more about Gradio pre-processing: https://www.gradio.app/docs/components
|
31 |
+
|
32 |
+
Attributes
|
33 |
+
----------
|
34 |
+
model_size: gr.Dropdown
|
35 |
+
Whisper model size.
|
36 |
+
|
37 |
+
lang: gr.Dropdown
|
38 |
+
Source language of the file to transcribe.
|
39 |
+
|
40 |
+
is_translate: gr.Checkbox
|
41 |
+
Boolean value that determines whether to translate to English.
|
42 |
+
It's Whisper's feature to translate speech from another language directly into English end-to-end.
|
43 |
+
|
44 |
+
beam_size: gr.Number
|
45 |
+
Int value that is used for decoding option.
|
46 |
+
|
47 |
+
log_prob_threshold: gr.Number
|
48 |
+
If the average log probability over sampled tokens is below this value, treat as failed.
|
49 |
+
|
50 |
+
no_speech_threshold: gr.Number
|
51 |
+
If the no_speech probability is higher than this value AND
|
52 |
+
the average log probability over sampled tokens is below `log_prob_threshold`,
|
53 |
+
consider the segment as silent.
|
54 |
+
|
55 |
+
compute_type: gr.Dropdown
|
56 |
+
compute type for transcription.
|
57 |
+
see more info : https://opennmt.net/CTranslate2/quantization.html
|
58 |
+
|
59 |
+
best_of: gr.Number
|
60 |
+
Number of candidates when sampling with non-zero temperature.
|
61 |
+
|
62 |
+
patience: gr.Number
|
63 |
+
Beam search patience factor.
|
64 |
+
|
65 |
+
condition_on_previous_text: gr.Checkbox
|
66 |
+
if True, the previous output of the model is provided as a prompt for the next window;
|
67 |
+
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
68 |
+
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
69 |
+
|
70 |
+
initial_prompt: gr.Textbox
|
71 |
+
Optional text to provide as a prompt for the first window. This can be used to provide, or
|
72 |
+
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
73 |
+
to make it more likely to predict those word correctly.
|
74 |
+
|
75 |
+
temperature: gr.Slider
|
76 |
+
Temperature for sampling. It can be a tuple of temperatures,
|
77 |
+
which will be successively used upon failures according to either
|
78 |
+
`compression_ratio_threshold` or `log_prob_threshold`.
|
79 |
+
|
80 |
+
compression_ratio_threshold: gr.Number
|
81 |
+
If the gzip compression ratio is above this value, treat as failed
|
82 |
+
|
83 |
+
vad_filter: gr.Checkbox
|
84 |
+
Enable the voice activity detection (VAD) to filter out parts of the audio
|
85 |
+
without speech. This step is using the Silero VAD model
|
86 |
+
https://github.com/snakers4/silero-vad.
|
87 |
+
|
88 |
+
threshold: gr.Slider
|
89 |
+
This parameter is related with Silero VAD. Speech threshold.
|
90 |
+
Silero VAD outputs speech probabilities for each audio chunk,
|
91 |
+
probabilities ABOVE this value are considered as SPEECH. It is better to tune this
|
92 |
+
parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
|
93 |
+
|
94 |
+
min_speech_duration_ms: gr.Number
|
95 |
+
This parameter is related with Silero VAD. Final speech chunks shorter min_speech_duration_ms are thrown out.
|
96 |
+
|
97 |
+
max_speech_duration_s: gr.Number
|
98 |
+
This parameter is related with Silero VAD. Maximum duration of speech chunks in seconds. Chunks longer
|
99 |
+
than max_speech_duration_s will be split at the timestamp of the last silence that
|
100 |
+
lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
|
101 |
+
split aggressively just before max_speech_duration_s.
|
102 |
+
|
103 |
+
min_silence_duration_ms: gr.Number
|
104 |
+
This parameter is related with Silero VAD. In the end of each speech chunk wait for min_silence_duration_ms
|
105 |
+
before separating it
|
106 |
+
|
107 |
+
window_size_samples: gr.Number
|
108 |
+
This parameter is related with Silero VAD. Audio chunks of window_size_samples size are fed to the silero VAD model.
|
109 |
+
WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
|
110 |
+
Values other than these may affect model performance!!
|
111 |
+
|
112 |
+
speech_pad_ms: gr.Number
|
113 |
+
This parameter is related with Silero VAD. Final speech chunks are padded by speech_pad_ms each side
|
114 |
+
"""
|
115 |
+
|
116 |
+
def to_list(self) -> list:
|
117 |
+
"""
|
118 |
+
Converts the data class attributes into a list. Use "before" Gradio pre-processing.
|
119 |
+
See more about Gradio pre-processing: : https://www.gradio.app/docs/components
|
120 |
+
|
121 |
+
Returns
|
122 |
+
----------
|
123 |
+
A list of Gradio components
|
124 |
+
"""
|
125 |
+
return [getattr(self, f.name) for f in fields(self)]
|
126 |
+
|
127 |
+
|
128 |
+
@dataclass
|
129 |
+
class WhisperValues:
|
130 |
+
model_size: str
|
131 |
+
lang: str
|
132 |
+
is_translate: bool
|
133 |
+
beam_size: int
|
134 |
+
log_prob_threshold: float
|
135 |
+
no_speech_threshold: float
|
136 |
+
compute_type: str
|
137 |
+
best_of: int
|
138 |
+
patience: float
|
139 |
+
condition_on_previous_text: bool
|
140 |
+
initial_prompt: Optional[str]
|
141 |
+
temperature: float
|
142 |
+
compression_ratio_threshold: float
|
143 |
+
vad_filter: bool
|
144 |
+
threshold: float
|
145 |
+
min_speech_duration_ms: int
|
146 |
+
max_speech_duration_s: float
|
147 |
+
min_silence_duration_ms: int
|
148 |
+
window_size_samples: int
|
149 |
+
speech_pad_ms: int
|
150 |
+
"""
|
151 |
+
A data class to use Whisper parameters. Use "after" Gradio pre-processing.
|
152 |
+
See more about Gradio pre-processing: : https://www.gradio.app/docs/components
|
153 |
+
"""
|
modules/youtube_manager.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pytube import YouTube
|
2 |
+
import os
|
3 |
+
|
4 |
+
|
5 |
+
def get_ytdata(link):
|
6 |
+
return YouTube(link)
|
7 |
+
|
8 |
+
|
9 |
+
def get_ytmetas(link):
|
10 |
+
yt = YouTube(link)
|
11 |
+
return yt.thumbnail_url, yt.title, yt.description
|
12 |
+
|
13 |
+
|
14 |
+
def get_ytaudio(ytdata: YouTube):
|
15 |
+
return ytdata.streams.get_audio_only().download(filename=os.path.join("modules", "yt_tmp.wav"))
|