Spaces:
Sleeping
Sleeping
Eason Lu
commited on
Commit
·
1a902ed
1
Parent(s):
b37d0d4
adapt different languages for srt.py
Browse filesFormer-commit-id: 197fea8adb5bd35d60d2fc0d09ddc1af21cac117
- src/srt_util/srt.py +25 -10
- src/task.py +3 -3
- src/translators/translation.py +3 -5
src/srt_util/srt.py
CHANGED
@@ -8,9 +8,22 @@ import logging
|
|
8 |
import openai
|
9 |
from tqdm import tqdm
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
class SrtSegment(object):
|
13 |
-
def __init__(self, *args) -> None:
|
|
|
|
|
|
|
14 |
if isinstance(args[0], dict):
|
15 |
segment = args[0]
|
16 |
self.start = segment['start']
|
@@ -83,11 +96,11 @@ class SrtSegment(object):
|
|
83 |
|
84 |
def remove_trans_punc(self) -> None:
|
85 |
"""
|
86 |
-
remove
|
87 |
:return: None
|
88 |
"""
|
89 |
-
|
90 |
-
translator = str.maketrans(
|
91 |
self.translation = self.translation.translate(translator)
|
92 |
|
93 |
def __str__(self) -> str:
|
@@ -101,11 +114,13 @@ class SrtSegment(object):
|
|
101 |
|
102 |
|
103 |
class SrtScript(object):
|
104 |
-
def __init__(self, segments) -> None:
|
105 |
-
self.
|
|
|
|
|
106 |
|
107 |
@classmethod
|
108 |
-
def parse_from_srt_file(cls, path: str):
|
109 |
with open(path, 'r', encoding="utf-8") as f:
|
110 |
script_lines = [line.rstrip() for line in f.readlines()]
|
111 |
bilingual = False
|
@@ -119,7 +134,7 @@ class SrtScript(object):
|
|
119 |
for i in range(0, len(script_lines), 4):
|
120 |
segments.append(list(script_lines[i:i + 4]))
|
121 |
|
122 |
-
return cls(segments)
|
123 |
|
124 |
def merge_segs(self, idx_list) -> SrtSegment:
|
125 |
"""
|
@@ -309,14 +324,14 @@ class SrtScript(object):
|
|
309 |
seg1_dict['text'] = src_seg1
|
310 |
seg1_dict['start'] = start_seg1
|
311 |
seg1_dict['end'] = end_seg1
|
312 |
-
seg1 = SrtSegment(seg1_dict)
|
313 |
seg1.translation = trans_seg1
|
314 |
|
315 |
seg2_dict = {}
|
316 |
seg2_dict['text'] = src_seg2
|
317 |
seg2_dict['start'] = start_seg2
|
318 |
seg2_dict['end'] = end_seg2
|
319 |
-
seg2 = SrtSegment(seg2_dict)
|
320 |
seg2.translation = trans_seg2
|
321 |
|
322 |
result_list = []
|
|
|
8 |
import openai
|
9 |
from tqdm import tqdm
|
10 |
|
11 |
+
punctuation_dict = {
|
12 |
+
"EN": ". , ? ! : ; - ( ) [ ] { } ' \"",
|
13 |
+
"ES": ". , ? ! : ; - ( ) [ ] { } ' \" ¡ ¿",
|
14 |
+
"FR": ". , ? ! : ; - ( ) [ ] { } ' \" « » —",
|
15 |
+
"DE": ". , ? ! : ; - ( ) [ ] { } ' \" „ “ –",
|
16 |
+
"RU": ". , ? ! : ; - ( ) [ ] { } ' \" « » —",
|
17 |
+
"ZH": "。 , ? ! : ; — ( ) ​``【oaicite:1】``​ 《 》 “ ”",
|
18 |
+
"JA": "。 、 ? ! : ; ー ( ) ​``【oaicite:0】``​ 「 」 『 』",
|
19 |
+
"AR": ". , ? ! : ; - ( ) [ ] { } ، ؛ ؟ « »",
|
20 |
+
}
|
21 |
|
22 |
class SrtSegment(object):
|
23 |
+
def __init__(self, src_lang, tgt_lang, *args) -> None:
|
24 |
+
self.src_lang = src_lang
|
25 |
+
self.tgt_lang = tgt_lang
|
26 |
+
|
27 |
if isinstance(args[0], dict):
|
28 |
segment = args[0]
|
29 |
self.start = segment['start']
|
|
|
96 |
|
97 |
def remove_trans_punc(self) -> None:
|
98 |
"""
|
99 |
+
remove punctuations in translation text
|
100 |
:return: None
|
101 |
"""
|
102 |
+
punc = punctuation_dict[self.tgt_lang]
|
103 |
+
translator = str.maketrans(punc, ' ' * len(punc))
|
104 |
self.translation = self.translation.translate(translator)
|
105 |
|
106 |
def __str__(self) -> str:
|
|
|
114 |
|
115 |
|
116 |
class SrtScript(object):
|
117 |
+
def __init__(self, src_lang, tgt_lang, segments) -> None:
|
118 |
+
self.src_lang = src_lang
|
119 |
+
self.tgt_lang = tgt_lang
|
120 |
+
self.segments = [SrtSegment(self.src_lang, self.tgt_lang, seg) for seg in segments]
|
121 |
|
122 |
@classmethod
|
123 |
+
def parse_from_srt_file(cls, src_lang, tgt_lang, path: str):
|
124 |
with open(path, 'r', encoding="utf-8") as f:
|
125 |
script_lines = [line.rstrip() for line in f.readlines()]
|
126 |
bilingual = False
|
|
|
134 |
for i in range(0, len(script_lines), 4):
|
135 |
segments.append(list(script_lines[i:i + 4]))
|
136 |
|
137 |
+
return cls(src_lang, tgt_lang, segments)
|
138 |
|
139 |
def merge_segs(self, idx_list) -> SrtSegment:
|
140 |
"""
|
|
|
324 |
seg1_dict['text'] = src_seg1
|
325 |
seg1_dict['start'] = start_seg1
|
326 |
seg1_dict['end'] = end_seg1
|
327 |
+
seg1 = SrtSegment(self.src_lang, self.tgt_lang, seg1_dict)
|
328 |
seg1.translation = trans_seg1
|
329 |
|
330 |
seg2_dict = {}
|
331 |
seg2_dict['text'] = src_seg2
|
332 |
seg2_dict['start'] = start_seg2
|
333 |
seg2_dict['end'] = end_seg2
|
334 |
+
seg2 = SrtSegment(self.src_lang, self.tgt_lang, seg2_dict)
|
335 |
seg2.translation = trans_seg2
|
336 |
|
337 |
result_list = []
|
src/task.py
CHANGED
@@ -129,10 +129,10 @@ class Task:
|
|
129 |
# TODO: setup ASR module like translator
|
130 |
self.status = TaskStatus.INITIALIZING_ASR
|
131 |
self.t_s = time()
|
132 |
-
|
133 |
method = self.ASR_setting["whisper_config"]["method"]
|
134 |
whisper_model = self.ASR_setting["whisper_config"]["whisper_model"]
|
135 |
-
src_srt_path = self.task_local_dir.joinpath(f"task_{self.task_id}
|
136 |
if not Path.exists(src_srt_path):
|
137 |
# extract script from audio
|
138 |
logging.info("extract script from audio")
|
@@ -157,7 +157,7 @@ class Task:
|
|
157 |
# after get the transcript, release the gpu resource
|
158 |
torch.cuda.empty_cache()
|
159 |
|
160 |
-
self.SRT_Script = SrtScript(transcript['segments'])
|
161 |
# save the srt script to local
|
162 |
self.SRT_Script.write_srt_file_src(src_srt_path)
|
163 |
|
|
|
129 |
# TODO: setup ASR module like translator
|
130 |
self.status = TaskStatus.INITIALIZING_ASR
|
131 |
self.t_s = time()
|
132 |
+
|
133 |
method = self.ASR_setting["whisper_config"]["method"]
|
134 |
whisper_model = self.ASR_setting["whisper_config"]["whisper_model"]
|
135 |
+
src_srt_path = self.task_local_dir.joinpath(f"task_{self.task_id}_{self.source_lang}.srt")
|
136 |
if not Path.exists(src_srt_path):
|
137 |
# extract script from audio
|
138 |
logging.info("extract script from audio")
|
|
|
157 |
# after get the transcript, release the gpu resource
|
158 |
torch.cuda.empty_cache()
|
159 |
|
160 |
+
self.SRT_Script = SrtScript(self.source_lang, self.target_lang, transcript['segments'])
|
161 |
# save the srt script to local
|
162 |
self.SRT_Script.write_srt_file_src(src_srt_path)
|
163 |
|
src/translators/translation.py
CHANGED
@@ -18,8 +18,6 @@ def check_translation(sentence, translation):
|
|
18 |
translation_count = translation.count('\n\n') + 1
|
19 |
|
20 |
if sentence_count != translation_count:
|
21 |
-
# print("sentence length: ", len(sentence), sentence_count)
|
22 |
-
# print("translation length: ", len(translation), translation_count)
|
23 |
return False
|
24 |
else:
|
25 |
return True
|
@@ -34,7 +32,7 @@ def prompt_selector(src_lang, tgt_lang, domain):
|
|
34 |
tgt_lang = language_map[tgt_lang]
|
35 |
prompt = f"""
|
36 |
you are a translation assistant, your job is to translate a video in domain of {domain} from {src_lang} to {tgt_lang},
|
37 |
-
you will be provided with a segement in {
|
38 |
meaning and the number of lines.
|
39 |
"""
|
40 |
return prompt
|
@@ -56,11 +54,11 @@ def translate(srt, script_arr, range_arr, model_name, video_name=None, attempts_
|
|
56 |
:param task: Prompt.
|
57 |
:param temp: Model temperature.
|
58 |
"""
|
59 |
-
|
60 |
if input is None:
|
61 |
raise Exception("Warning! No Input have passed to LLM!")
|
62 |
if task is None:
|
63 |
-
task = "
|
64 |
logging.info(f"translation prompt: {task}")
|
65 |
previous_length = 0
|
66 |
for sentence, range_ in tqdm(zip(script_arr, range_arr)):
|
|
|
18 |
translation_count = translation.count('\n\n') + 1
|
19 |
|
20 |
if sentence_count != translation_count:
|
|
|
|
|
21 |
return False
|
22 |
else:
|
23 |
return True
|
|
|
32 |
tgt_lang = language_map[tgt_lang]
|
33 |
prompt = f"""
|
34 |
you are a translation assistant, your job is to translate a video in domain of {domain} from {src_lang} to {tgt_lang},
|
35 |
+
you will be provided with a segement in {src_lang} parsed by line, where your translation text should keep the original
|
36 |
meaning and the number of lines.
|
37 |
"""
|
38 |
return prompt
|
|
|
54 |
:param task: Prompt.
|
55 |
:param temp: Model temperature.
|
56 |
"""
|
57 |
+
|
58 |
if input is None:
|
59 |
raise Exception("Warning! No Input have passed to LLM!")
|
60 |
if task is None:
|
61 |
+
task = "你是一个翻译助理,你的任务是翻译视频,你会被提供一个按行分割的英文段落,你需要在保证句意和行数的情况下输出翻译后的文本。"
|
62 |
logging.info(f"translation prompt: {task}")
|
63 |
previous_length = 0
|
64 |
for sentence, range_ in tqdm(zip(script_arr, range_arr)):
|