Spaces:
Sleeping
Sleeping
Eason Lu
commited on
Commit
•
bd8dd84
1
Parent(s):
e3f9642
add unit test for remove punc
Browse filesFormer-commit-id: 2a7749049106a57f8954db71ed014e287a48501a
- .gitignore +2 -1
- configs/task_config.yaml +1 -1
- src/srt_util/srt.py +31 -30
- tests/test_remove_punc.py +21 -0
.gitignore
CHANGED
@@ -13,4 +13,5 @@ log_*.csv
|
|
13 |
log.csv
|
14 |
.chroma
|
15 |
*.ini
|
16 |
-
local_dump/
|
|
|
|
13 |
log.csv
|
14 |
.chroma
|
15 |
*.ini
|
16 |
+
local_dump/
|
17 |
+
.pytest_cache/
|
configs/task_config.yaml
CHANGED
@@ -30,6 +30,6 @@ post_process:
|
|
30 |
output_type:
|
31 |
subtitle: srt
|
32 |
video: False
|
33 |
-
bilingal:
|
34 |
|
35 |
|
|
|
30 |
output_type:
|
31 |
subtitle: srt
|
32 |
video: False
|
33 |
+
bilingal: True
|
34 |
|
35 |
|
src/srt_util/srt.py
CHANGED
@@ -11,42 +11,42 @@ from tqdm import tqdm
|
|
11 |
# punctuation dictionary for supported languages
|
12 |
punctuation_dict = {
|
13 |
"EN": {
|
14 |
-
"punc_str": ". , ? ! : ; - ( ) [ ] { }
|
15 |
"comma": ", ",
|
16 |
"sentence_end": [".", "!", "?", ";"]
|
17 |
},
|
18 |
"ES": {
|
19 |
-
"punc_str": ". , ? ! : ; - ( ) [ ] { }
|
20 |
"comma": ", ",
|
21 |
"sentence_end": [".", "!", "?", ";", "¡", "¿"]
|
22 |
},
|
23 |
"FR": {
|
24 |
-
"punc_str": "
|
25 |
"comma": ", ",
|
26 |
"sentence_end": [".", "!", "?", ";"]
|
27 |
},
|
28 |
"DE": {
|
29 |
-
"punc_str": "
|
30 |
"comma": ", ",
|
31 |
"sentence_end": [".", "!", "?", ";"]
|
32 |
},
|
33 |
"RU": {
|
34 |
-
"punc_str": "
|
35 |
"comma": ", ",
|
36 |
"sentence_end": [".", "!", "?", ";"]
|
37 |
},
|
38 |
"ZH": {
|
39 |
-
"punc_str": "
|
40 |
"comma": ",",
|
41 |
"sentence_end": ["。", "!", "?"]
|
42 |
},
|
43 |
"JA": {
|
44 |
-
"punc_str": "
|
45 |
"comma": "、",
|
46 |
"sentence_end": ["。", "!", "?"]
|
47 |
},
|
48 |
"AR": {
|
49 |
-
"punc_str": "
|
50 |
"comma": "، ",
|
51 |
"sentence_end": [".", "!", "?", ";", "؟"]
|
52 |
},
|
@@ -100,6 +100,7 @@ class SrtSegment(object):
|
|
100 |
self.translation = ""
|
101 |
else:
|
102 |
self.translation = args[0][3]
|
|
|
103 |
|
104 |
def merge_seg(self, seg):
|
105 |
"""
|
@@ -132,9 +133,11 @@ class SrtSegment(object):
|
|
132 |
remove punctuations in translation text
|
133 |
:return: None
|
134 |
"""
|
135 |
-
|
136 |
-
|
137 |
-
|
|
|
|
|
138 |
|
139 |
def __str__(self) -> str:
|
140 |
return f'{self.duration}\n{self.source_text}\n\n'
|
@@ -233,19 +236,20 @@ class SrtScript(object):
|
|
233 |
src_text += '\n\n'
|
234 |
|
235 |
def inner_func(target, input_str):
|
236 |
-
#
|
237 |
response = openai.ChatCompletion.create(
|
238 |
model="gpt-4",
|
239 |
messages=[
|
240 |
{"role": "system",
|
241 |
-
"content": "
|
242 |
-
{"role": "system", "content": "
|
243 |
-
{"role": "user", "content": '
|
244 |
],
|
245 |
temperature=0.15
|
246 |
)
|
247 |
return response['choices'][0]['message']['content'].strip()
|
248 |
|
|
|
249 |
lines = translate.split('\n\n')
|
250 |
if len(lines) < (end_seg_id - start_seg_id + 1):
|
251 |
count = 0
|
@@ -253,6 +257,7 @@ class SrtScript(object):
|
|
253 |
while count < 5 and len(lines) != (end_seg_id - start_seg_id + 1):
|
254 |
count += 1
|
255 |
print("Solving Unmatched Lines|iteration {}".format(count))
|
|
|
256 |
|
257 |
flag = True
|
258 |
while flag:
|
@@ -262,13 +267,17 @@ class SrtScript(object):
|
|
262 |
except Exception as e:
|
263 |
print("An error has occurred during solving unmatched lines:", e)
|
264 |
print("Retrying...")
|
|
|
|
|
265 |
flag = True
|
266 |
lines = translate.split('\n')
|
267 |
|
268 |
if len(lines) < (end_seg_id - start_seg_id + 1):
|
269 |
solved = False
|
270 |
print("Failed Solving unmatched lines, Manually parse needed")
|
|
|
271 |
|
|
|
272 |
if not os.path.exists("./logs"):
|
273 |
os.mkdir("./logs")
|
274 |
if video_link:
|
@@ -287,7 +296,7 @@ class SrtScript(object):
|
|
287 |
log.write("range_of_text,iterations_solving,solved,file_length,video_name" + "\n")
|
288 |
log.write(str(id_range) + ',' + str(count) + ',' + str(solved) + ',' + str(
|
289 |
len(self.segments)) + ',' + video_name + "\n")
|
290 |
-
print(lines)
|
291 |
|
292 |
for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]):
|
293 |
# naive way to due with merge translation problem
|
@@ -337,19 +346,13 @@ class SrtScript(object):
|
|
337 |
trans_split_idx = trans_commas[len(trans_commas) // 2] if len(trans_commas) % 2 == 1 else trans_commas[
|
338 |
len(trans_commas) // 2 - 1]
|
339 |
else:
|
340 |
-
|
341 |
-
trans_space = [m.start() for m in re.finditer(' ', translation)]
|
342 |
-
if len(trans_space) > 0:
|
343 |
-
trans_split_idx = trans_space[len(trans_space) // 2] if len(trans_space) % 2 == 1 else trans_space[
|
344 |
-
len(trans_space) // 2 - 1]
|
345 |
-
else:
|
346 |
-
trans_split_idx = len(translation) // 2
|
347 |
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
|
354 |
# split the time duration based on text length
|
355 |
time_split_ratio = trans_split_idx / (len(seg.translation) - 1)
|
@@ -405,8 +408,6 @@ class SrtScript(object):
|
|
405 |
self.segments = segments
|
406 |
logging.info("check_len_and_split finished")
|
407 |
|
408 |
-
pass
|
409 |
-
|
410 |
def check_len_and_split_range(self, range, text_threshold=30, time_threshold=1.0):
|
411 |
# DEPRECATED
|
412 |
# if sentence length >= text_threshold, split this segments to two
|
|
|
11 |
# punctuation dictionary for supported languages
|
12 |
punctuation_dict = {
|
13 |
"EN": {
|
14 |
+
"punc_str": ". , ? ! : ; - ( ) [ ] { }",
|
15 |
"comma": ", ",
|
16 |
"sentence_end": [".", "!", "?", ";"]
|
17 |
},
|
18 |
"ES": {
|
19 |
+
"punc_str": ". , ? ! : ; - ( ) [ ] { } ¡ ¿",
|
20 |
"comma": ", ",
|
21 |
"sentence_end": [".", "!", "?", ";", "¡", "¿"]
|
22 |
},
|
23 |
"FR": {
|
24 |
+
"punc_str": ".,?!:;«»—",
|
25 |
"comma": ", ",
|
26 |
"sentence_end": [".", "!", "?", ";"]
|
27 |
},
|
28 |
"DE": {
|
29 |
+
"punc_str": ".,?!:;„“–",
|
30 |
"comma": ", ",
|
31 |
"sentence_end": [".", "!", "?", ";"]
|
32 |
},
|
33 |
"RU": {
|
34 |
+
"punc_str": ".,?!:;-«»—",
|
35 |
"comma": ", ",
|
36 |
"sentence_end": [".", "!", "?", ";"]
|
37 |
},
|
38 |
"ZH": {
|
39 |
+
"punc_str": "。,?!:;()",
|
40 |
"comma": ",",
|
41 |
"sentence_end": ["。", "!", "?"]
|
42 |
},
|
43 |
"JA": {
|
44 |
+
"punc_str": "。、?!:;()",
|
45 |
"comma": "、",
|
46 |
"sentence_end": ["。", "!", "?"]
|
47 |
},
|
48 |
"AR": {
|
49 |
+
"punc_str": ".,?!:;-()[]،؛ ؟ «»",
|
50 |
"comma": "، ",
|
51 |
"sentence_end": [".", "!", "?", ";", "؟"]
|
52 |
},
|
|
|
100 |
self.translation = ""
|
101 |
else:
|
102 |
self.translation = args[0][3]
|
103 |
+
|
104 |
|
105 |
def merge_seg(self, seg):
|
106 |
"""
|
|
|
133 |
remove punctuations in translation text
|
134 |
:return: None
|
135 |
"""
|
136 |
+
punc_str = punctuation_dict[self.tgt_lang]["punc_str"]
|
137 |
+
for punc in punc_str:
|
138 |
+
self.translation = self.translation.replace(punc, ' ')
|
139 |
+
# translator = str.maketrans(punc, ' ' * len(punc))
|
140 |
+
# self.translation = self.translation.translate(translator)
|
141 |
|
142 |
def __str__(self) -> str:
|
143 |
return f'{self.duration}\n{self.source_text}\n\n'
|
|
|
236 |
src_text += '\n\n'
|
237 |
|
238 |
def inner_func(target, input_str):
|
239 |
+
# handling merge sentences issue.
|
240 |
response = openai.ChatCompletion.create(
|
241 |
model="gpt-4",
|
242 |
messages=[
|
243 |
{"role": "system",
|
244 |
+
"content": "Your task is to merge or split sentences into a specified number of lines as required. You need to ensure the meaning of the sentences as much as possible, but when necessary, a sentence can be divided into two lines for output"},
|
245 |
+
{"role": "system", "content": "Note: You only need to output the processed {} sentences. If you need to output a sequence number, please separate it with a colon.".format(self.tgt_lang)},
|
246 |
+
{"role": "user", "content": 'Please split or combine the following sentences into {} sentences:\n{}'.format(target, input_str)}
|
247 |
],
|
248 |
temperature=0.15
|
249 |
)
|
250 |
return response['choices'][0]['message']['content'].strip()
|
251 |
|
252 |
+
# handling merge sentences issue.
|
253 |
lines = translate.split('\n\n')
|
254 |
if len(lines) < (end_seg_id - start_seg_id + 1):
|
255 |
count = 0
|
|
|
257 |
while count < 5 and len(lines) != (end_seg_id - start_seg_id + 1):
|
258 |
count += 1
|
259 |
print("Solving Unmatched Lines|iteration {}".format(count))
|
260 |
+
logging.error("Solving Unmatched Lines|iteration {}".format(count))
|
261 |
|
262 |
flag = True
|
263 |
while flag:
|
|
|
267 |
except Exception as e:
|
268 |
print("An error has occurred during solving unmatched lines:", e)
|
269 |
print("Retrying...")
|
270 |
+
logging.error("An error has occurred during solving unmatched lines:", e)
|
271 |
+
logging.error("Retrying...")
|
272 |
flag = True
|
273 |
lines = translate.split('\n')
|
274 |
|
275 |
if len(lines) < (end_seg_id - start_seg_id + 1):
|
276 |
solved = False
|
277 |
print("Failed Solving unmatched lines, Manually parse needed")
|
278 |
+
logging.error("Failed Solving unmatched lines, Manually parse needed")
|
279 |
|
280 |
+
# FIXME: put the error log in our log file
|
281 |
if not os.path.exists("./logs"):
|
282 |
os.mkdir("./logs")
|
283 |
if video_link:
|
|
|
296 |
log.write("range_of_text,iterations_solving,solved,file_length,video_name" + "\n")
|
297 |
log.write(str(id_range) + ',' + str(count) + ',' + str(solved) + ',' + str(
|
298 |
len(self.segments)) + ',' + video_name + "\n")
|
299 |
+
# print(lines)
|
300 |
|
301 |
for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]):
|
302 |
# naive way to due with merge translation problem
|
|
|
346 |
trans_split_idx = trans_commas[len(trans_commas) // 2] if len(trans_commas) % 2 == 1 else trans_commas[
|
347 |
len(trans_commas) // 2 - 1]
|
348 |
else:
|
349 |
+
trans_split_idx = len(translation) // 2
|
|
|
|
|
|
|
|
|
|
|
|
|
350 |
|
351 |
+
# to avoid split English word
|
352 |
+
for i in range(trans_split_idx, len(translation)):
|
353 |
+
if not translation[i].encode('utf-8').isalpha():
|
354 |
+
trans_split_idx = i
|
355 |
+
break
|
356 |
|
357 |
# split the time duration based on text length
|
358 |
time_split_ratio = trans_split_idx / (len(seg.translation) - 1)
|
|
|
408 |
self.segments = segments
|
409 |
logging.info("check_len_and_split finished")
|
410 |
|
|
|
|
|
411 |
def check_len_and_split_range(self, range, text_threshold=30, time_threshold=1.0):
|
412 |
# DEPRECATED
|
413 |
# if sentence length >= text_threshold, split this segments to two
|
tests/test_remove_punc.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append('./src')
|
3 |
+
from srt_util.srt import SrtScript, SrtSegment
|
4 |
+
|
5 |
+
zh_test1 = "再次,如果你对一些福利感兴趣,你也可以。"
|
6 |
+
zh_en_test1 = "GG。Classic在我今年解说的最奇葩的系列赛中获得了胜利。"
|
7 |
+
|
8 |
+
def form_srt_class(src_lang, tgt_lang, source_text="", translation="", duration="00:00:00,740 --> 00:00:08,779"):
|
9 |
+
segment = [0, duration, source_text, translation, ""]
|
10 |
+
return SrtScript(src_lang, tgt_lang, [segment])
|
11 |
+
|
12 |
+
def test_zh():
|
13 |
+
srt = form_srt_class(src_lang="EN", tgt_lang="ZH", translation=zh_test1)
|
14 |
+
srt.remove_trans_punctuation()
|
15 |
+
assert srt.segments[0].translation == "再次 如果你对一些福利感兴趣 你也可以 "
|
16 |
+
|
17 |
+
def test_zh_en():
|
18 |
+
srt = form_srt_class(src_lang="EN", tgt_lang="ZH", translation=zh_en_test1)
|
19 |
+
srt.remove_trans_punctuation()
|
20 |
+
assert srt.segments[0].translation == "GG Classic在我今年解说的最奇葩的系列赛中获得了胜利 "
|
21 |
+
|