Eason Lu commited on
Commit
bd8dd84
1 Parent(s): e3f9642

add unit test for remove punc

Browse files

Former-commit-id: 2a7749049106a57f8954db71ed014e287a48501a

.gitignore CHANGED
@@ -13,4 +13,5 @@ log_*.csv
13
  log.csv
14
  .chroma
15
  *.ini
16
- local_dump/
 
 
13
  log.csv
14
  .chroma
15
  *.ini
16
+ local_dump/
17
+ .pytest_cache/
configs/task_config.yaml CHANGED
@@ -30,6 +30,6 @@ post_process:
30
  output_type:
31
  subtitle: srt
32
  video: False
33
- bilingal: False
34
 
35
 
 
30
  output_type:
31
  subtitle: srt
32
  video: False
33
+ bilingal: True
34
 
35
 
src/srt_util/srt.py CHANGED
@@ -11,42 +11,42 @@ from tqdm import tqdm
11
  # punctuation dictionary for supported languages
12
  punctuation_dict = {
13
  "EN": {
14
- "punc_str": ". , ? ! : ; - ( ) [ ] { } ' \"",
15
  "comma": ", ",
16
  "sentence_end": [".", "!", "?", ";"]
17
  },
18
  "ES": {
19
- "punc_str": ". , ? ! : ; - ( ) [ ] { } ' \" ¡ ¿",
20
  "comma": ", ",
21
  "sentence_end": [".", "!", "?", ";", "¡", "¿"]
22
  },
23
  "FR": {
24
- "punc_str": ". , ? ! : ; - ( ) [ ] { } ' \" « » —",
25
  "comma": ", ",
26
  "sentence_end": [".", "!", "?", ";"]
27
  },
28
  "DE": {
29
- "punc_str": ". , ? ! : ; - ( ) [ ] { } ' \" „ “ –",
30
  "comma": ", ",
31
  "sentence_end": [".", "!", "?", ";"]
32
  },
33
  "RU": {
34
- "punc_str": ". , ? ! : ; - ( ) [ ] { } ' \" « » —",
35
  "comma": ", ",
36
  "sentence_end": [".", "!", "?", ";"]
37
  },
38
  "ZH": {
39
- "punc_str": "。 , ? ! : ; — ( ) ​``【oaicite:1】``​ 《 》 “ ”",
40
  "comma": ",",
41
  "sentence_end": ["。", "!", "?"]
42
  },
43
  "JA": {
44
- "punc_str": "。 、 ? ! : ; ー ( ) ​``【oaicite:0】``​ 「 」 『 』",
45
  "comma": "、",
46
  "sentence_end": ["。", "!", "?"]
47
  },
48
  "AR": {
49
- "punc_str": ". , ? ! : ; - ( ) [ ] { } ، ؛ ؟ « »",
50
  "comma": "، ",
51
  "sentence_end": [".", "!", "?", ";", "؟"]
52
  },
@@ -100,6 +100,7 @@ class SrtSegment(object):
100
  self.translation = ""
101
  else:
102
  self.translation = args[0][3]
 
103
 
104
  def merge_seg(self, seg):
105
  """
@@ -132,9 +133,11 @@ class SrtSegment(object):
132
  remove punctuations in translation text
133
  :return: None
134
  """
135
- punc = punctuation_dict[self.tgt_lang]["punc_str"]
136
- translator = str.maketrans(punc, ' ' * len(punc))
137
- self.translation = self.translation.translate(translator)
 
 
138
 
139
  def __str__(self) -> str:
140
  return f'{self.duration}\n{self.source_text}\n\n'
@@ -233,19 +236,20 @@ class SrtScript(object):
233
  src_text += '\n\n'
234
 
235
  def inner_func(target, input_str):
236
- # TODO: accomodate different languages
237
  response = openai.ChatCompletion.create(
238
  model="gpt-4",
239
  messages=[
240
  {"role": "system",
241
- "content": "你的任务是按照要求合并或拆分句子到指定行数,你需要尽可能保证句意,但必要时可以将一句话分为两行输出"},
242
- {"role": "system", "content": "注意:你只需要输出处理过的中文句子,如果你要输出序号,请使用冒号隔开"},
243
- {"role": "user", "content": '请将下面的句子拆分或组合为{}句:\n{}'.format(target, input_str)}
244
  ],
245
  temperature=0.15
246
  )
247
  return response['choices'][0]['message']['content'].strip()
248
 
 
249
  lines = translate.split('\n\n')
250
  if len(lines) < (end_seg_id - start_seg_id + 1):
251
  count = 0
@@ -253,6 +257,7 @@ class SrtScript(object):
253
  while count < 5 and len(lines) != (end_seg_id - start_seg_id + 1):
254
  count += 1
255
  print("Solving Unmatched Lines|iteration {}".format(count))
 
256
 
257
  flag = True
258
  while flag:
@@ -262,13 +267,17 @@ class SrtScript(object):
262
  except Exception as e:
263
  print("An error has occurred during solving unmatched lines:", e)
264
  print("Retrying...")
 
 
265
  flag = True
266
  lines = translate.split('\n')
267
 
268
  if len(lines) < (end_seg_id - start_seg_id + 1):
269
  solved = False
270
  print("Failed Solving unmatched lines, Manually parse needed")
 
271
 
 
272
  if not os.path.exists("./logs"):
273
  os.mkdir("./logs")
274
  if video_link:
@@ -287,7 +296,7 @@ class SrtScript(object):
287
  log.write("range_of_text,iterations_solving,solved,file_length,video_name" + "\n")
288
  log.write(str(id_range) + ',' + str(count) + ',' + str(solved) + ',' + str(
289
  len(self.segments)) + ',' + video_name + "\n")
290
- print(lines)
291
 
292
  for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]):
293
  # naive way to due with merge translation problem
@@ -337,19 +346,13 @@ class SrtScript(object):
337
  trans_split_idx = trans_commas[len(trans_commas) // 2] if len(trans_commas) % 2 == 1 else trans_commas[
338
  len(trans_commas) // 2 - 1]
339
  else:
340
- # split the text based on spaces
341
- trans_space = [m.start() for m in re.finditer(' ', translation)]
342
- if len(trans_space) > 0:
343
- trans_split_idx = trans_space[len(trans_space) // 2] if len(trans_space) % 2 == 1 else trans_space[
344
- len(trans_space) // 2 - 1]
345
- else:
346
- trans_split_idx = len(translation) // 2
347
 
348
- # to avoid split English word
349
- for i in range(trans_split_idx, len(translation)):
350
- if not translation[i].encode('utf-8').isalpha():
351
- trans_split_idx = i
352
- break
353
 
354
  # split the time duration based on text length
355
  time_split_ratio = trans_split_idx / (len(seg.translation) - 1)
@@ -405,8 +408,6 @@ class SrtScript(object):
405
  self.segments = segments
406
  logging.info("check_len_and_split finished")
407
 
408
- pass
409
-
410
  def check_len_and_split_range(self, range, text_threshold=30, time_threshold=1.0):
411
  # DEPRECATED
412
  # if sentence length >= text_threshold, split this segments to two
 
11
  # punctuation dictionary for supported languages
12
  punctuation_dict = {
13
  "EN": {
14
+ "punc_str": ". , ? ! : ; - ( ) [ ] { }",
15
  "comma": ", ",
16
  "sentence_end": [".", "!", "?", ";"]
17
  },
18
  "ES": {
19
+ "punc_str": ". , ? ! : ; - ( ) [ ] { } ¡ ¿",
20
  "comma": ", ",
21
  "sentence_end": [".", "!", "?", ";", "¡", "¿"]
22
  },
23
  "FR": {
24
+ "punc_str": ".,?!:;«»—",
25
  "comma": ", ",
26
  "sentence_end": [".", "!", "?", ";"]
27
  },
28
  "DE": {
29
+ "punc_str": ".,?!:;„“–",
30
  "comma": ", ",
31
  "sentence_end": [".", "!", "?", ";"]
32
  },
33
  "RU": {
34
+ "punc_str": ".,?!:;-«»—",
35
  "comma": ", ",
36
  "sentence_end": [".", "!", "?", ";"]
37
  },
38
  "ZH": {
39
+ "punc_str": "。,?!:;()",
40
  "comma": ",",
41
  "sentence_end": ["。", "!", "?"]
42
  },
43
  "JA": {
44
+ "punc_str": "。、?!:;()",
45
  "comma": "、",
46
  "sentence_end": ["。", "!", "?"]
47
  },
48
  "AR": {
49
+ "punc_str": ".,?!:;-()[]،؛ ؟ «»",
50
  "comma": "، ",
51
  "sentence_end": [".", "!", "?", ";", "؟"]
52
  },
 
100
  self.translation = ""
101
  else:
102
  self.translation = args[0][3]
103
+
104
 
105
  def merge_seg(self, seg):
106
  """
 
133
  remove punctuations in translation text
134
  :return: None
135
  """
136
+ punc_str = punctuation_dict[self.tgt_lang]["punc_str"]
137
+ for punc in punc_str:
138
+ self.translation = self.translation.replace(punc, ' ')
139
+ # translator = str.maketrans(punc, ' ' * len(punc))
140
+ # self.translation = self.translation.translate(translator)
141
 
142
  def __str__(self) -> str:
143
  return f'{self.duration}\n{self.source_text}\n\n'
 
236
  src_text += '\n\n'
237
 
238
  def inner_func(target, input_str):
239
+ # handling merge sentences issue.
240
  response = openai.ChatCompletion.create(
241
  model="gpt-4",
242
  messages=[
243
  {"role": "system",
244
+ "content": "Your task is to merge or split sentences into a specified number of lines as required. You need to ensure the meaning of the sentences as much as possible, but when necessary, a sentence can be divided into two lines for output"},
245
+ {"role": "system", "content": "Note: You only need to output the processed {} sentences. If you need to output a sequence number, please separate it with a colon.".format(self.tgt_lang)},
246
+ {"role": "user", "content": 'Please split or combine the following sentences into {} sentences:\n{}'.format(target, input_str)}
247
  ],
248
  temperature=0.15
249
  )
250
  return response['choices'][0]['message']['content'].strip()
251
 
252
+ # handling merge sentences issue.
253
  lines = translate.split('\n\n')
254
  if len(lines) < (end_seg_id - start_seg_id + 1):
255
  count = 0
 
257
  while count < 5 and len(lines) != (end_seg_id - start_seg_id + 1):
258
  count += 1
259
  print("Solving Unmatched Lines|iteration {}".format(count))
260
+ logging.error("Solving Unmatched Lines|iteration {}".format(count))
261
 
262
  flag = True
263
  while flag:
 
267
  except Exception as e:
268
  print("An error has occurred during solving unmatched lines:", e)
269
  print("Retrying...")
270
+ logging.error("An error has occurred during solving unmatched lines:", e)
271
+ logging.error("Retrying...")
272
  flag = True
273
  lines = translate.split('\n')
274
 
275
  if len(lines) < (end_seg_id - start_seg_id + 1):
276
  solved = False
277
  print("Failed Solving unmatched lines, Manually parse needed")
278
+ logging.error("Failed Solving unmatched lines, Manually parse needed")
279
 
280
+ # FIXME: put the error log in our log file
281
  if not os.path.exists("./logs"):
282
  os.mkdir("./logs")
283
  if video_link:
 
296
  log.write("range_of_text,iterations_solving,solved,file_length,video_name" + "\n")
297
  log.write(str(id_range) + ',' + str(count) + ',' + str(solved) + ',' + str(
298
  len(self.segments)) + ',' + video_name + "\n")
299
+ # print(lines)
300
 
301
  for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]):
302
  # naive way to due with merge translation problem
 
346
  trans_split_idx = trans_commas[len(trans_commas) // 2] if len(trans_commas) % 2 == 1 else trans_commas[
347
  len(trans_commas) // 2 - 1]
348
  else:
349
+ trans_split_idx = len(translation) // 2
 
 
 
 
 
 
350
 
351
+ # to avoid split English word
352
+ for i in range(trans_split_idx, len(translation)):
353
+ if not translation[i].encode('utf-8').isalpha():
354
+ trans_split_idx = i
355
+ break
356
 
357
  # split the time duration based on text length
358
  time_split_ratio = trans_split_idx / (len(seg.translation) - 1)
 
408
  self.segments = segments
409
  logging.info("check_len_and_split finished")
410
 
 
 
411
  def check_len_and_split_range(self, range, text_threshold=30, time_threshold=1.0):
412
  # DEPRECATED
413
  # if sentence length >= text_threshold, split this segments to two
tests/test_remove_punc.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('./src')
3
+ from srt_util.srt import SrtScript, SrtSegment
4
+
5
+ zh_test1 = "再次,如果你对一些福利感兴趣,你也可以。"
6
+ zh_en_test1 = "GG。Classic在我今年解说的最奇葩的系列赛中获得了胜利。"
7
+
8
+ def form_srt_class(src_lang, tgt_lang, source_text="", translation="", duration="00:00:00,740 --> 00:00:08,779"):
9
+ segment = [0, duration, source_text, translation, ""]
10
+ return SrtScript(src_lang, tgt_lang, [segment])
11
+
12
+ def test_zh():
13
+ srt = form_srt_class(src_lang="EN", tgt_lang="ZH", translation=zh_test1)
14
+ srt.remove_trans_punctuation()
15
+ assert srt.segments[0].translation == "再次 如果你对一些福利感兴趣 你也可以 "
16
+
17
+ def test_zh_en():
18
+ srt = form_srt_class(src_lang="EN", tgt_lang="ZH", translation=zh_en_test1)
19
+ srt.remove_trans_punctuation()
20
+ assert srt.segments[0].translation == "GG Classic在我今年解说的最奇葩的系列赛中获得了胜利 "
21
+