CanYing0913 commited on
Commit
55c7989
1 Parent(s): fe8b7a1

SRT cleanup

Browse files

Update skeleton


Former-commit-id: 4136ffdcd5fcd7f78b4c4c6d323fceebbc1c377c

Files changed (2) hide show
  1. .gitignore +6 -5
  2. SRT.py +188 -146
.gitignore CHANGED
@@ -1,10 +1,11 @@
1
- /downloads
2
- /results
3
  .DS_Store
4
- /__pycache__
 
 
 
5
  test.py
6
  test.srt
7
  test.txt
8
  log_*.csv
9
- log.csv
10
- /test
 
1
+ __pycache__/
 
2
  .DS_Store
3
+ .idea/
4
+ downloads/
5
+ results/
6
+ test/
7
  test.py
8
  test.srt
9
  test.txt
10
  log_*.csv
11
+ log.csv
 
SRT.py CHANGED
@@ -1,10 +1,11 @@
1
- from datetime import timedelta
2
- from csv import reader
3
- from datetime import datetime
4
  import re
 
 
 
 
5
  import openai
6
- import os
7
- from collections import deque
8
 
9
  class SRT_segment(object):
10
  def __init__(self, *args) -> None:
@@ -12,22 +13,24 @@ class SRT_segment(object):
12
  segment = args[0]
13
  self.start = segment['start']
14
  self.end = segment['end']
15
- self.start_ms = int((segment['start']*100)%100*10)
16
- self.end_ms = int((segment['end']*100)%100*10)
17
 
18
- if self.start_ms == self.end_ms and int(segment['start']) == int(segment['end']): # avoid empty time stamp
19
- self.end_ms+=500
20
 
21
  self.start_time = timedelta(seconds=int(segment['start']), milliseconds=self.start_ms)
22
  self.end_time = timedelta(seconds=int(segment['end']), milliseconds=self.end_ms)
23
  if self.start_ms == 0:
24
- self.start_time_str = str(0)+str(self.start_time).split('.')[0]+',000'
25
  else:
26
- self.start_time_str = str(0)+str(self.start_time).split('.')[0]+','+str(self.start_time).split('.')[1][:3]
 
27
  if self.end_ms == 0:
28
- self.end_time_str = str(0)+str(self.end_time).split('.')[0]+',000'
29
  else:
30
- self.end_time_str = str(0)+str(self.end_time).split('.')[0]+','+str(self.end_time).split('.')[1][:3]
 
31
  self.source_text = segment['text'].lstrip()
32
  self.duration = f"{self.start_time_str} --> {self.end_time_str}"
33
  self.translation = ""
@@ -39,15 +42,21 @@ class SRT_segment(object):
39
  self.end_time_str = self.duration.split(" --> ")[1]
40
 
41
  # parse the time to float
42
- self.start_ms = int(self.start_time_str.split(',')[1])/10
43
- self.end_ms = int(self.end_time_str.split(',')[1])/10
44
  start_list = self.start_time_str.split(',')[0].split(':')
45
- self.start = int(start_list[0])*3600 + int(start_list[1])*60 + int(start_list[2]) + self.start_ms/100
46
  end_list = self.end_time_str.split(',')[0].split(':')
47
- self.end = int(end_list[0])*3600 + int(end_list[1])*60 + int(end_list[2]) + self.end_ms/100
48
  self.translation = ""
49
-
50
  def merge_seg(self, seg):
 
 
 
 
 
 
51
  self.source_text += f' {seg.source_text}'
52
  self.translation += f' {seg.translation}'
53
  self.end_time_str = seg.end_time_str
@@ -56,22 +65,42 @@ class SRT_segment(object):
56
  self.duration = f"{self.start_time_str} --> {self.end_time_str}"
57
  pass
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def remove_trans_punc(self):
60
- # remove punctuations in translation text
61
- self.translation = self.translation.replace(',', ' ')
62
- self.translation = self.translation.replace('。', ' ')
63
- self.translation = self.translation.replace('!', ' ')
64
- self.translation = self.translation.replace('?', ' ')
 
 
65
 
66
  def __str__(self) -> str:
67
- return f'{self.duration}\n{self.source_text}\n\n'
68
-
69
  def get_trans_str(self) -> str:
70
  return f'{self.duration}\n{self.translation}\n\n'
71
-
72
  def get_bilingual_str(self) -> str:
73
  return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n'
74
 
 
75
  class SRT_script():
76
  def __init__(self, segments) -> None:
77
  self.segments = []
@@ -80,29 +109,41 @@ class SRT_script():
80
  self.segments.append(srt_seg)
81
 
82
  @classmethod
83
- def parse_from_srt_file(cls, path:str):
84
  with open(path, 'r', encoding="utf-8") as f:
85
- script_lines = f.read().splitlines()
86
 
87
  segments = []
88
  for i in range(len(script_lines)):
89
  if i % 4 == 0:
90
- segments.append(list(script_lines[i:i+4]))
91
 
92
  return cls(segments)
93
 
94
  def merge_segs(self, idx_list) -> SRT_segment:
95
- final_seg = self.segments[idx_list[0]]
 
 
 
 
 
 
 
96
  if len(idx_list) == 1:
97
- return final_seg
98
-
99
  for idx in range(1, len(idx_list)):
100
- final_seg.merge_seg(self.segments[idx_list[idx]])
101
-
102
- return final_seg
103
 
104
  def form_whole_sentence(self):
105
- merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
 
 
 
 
 
106
  sentence = []
107
  for i, seg in enumerate(self.segments):
108
  if seg.source_text[-1] in ['.', '!', '?'] and len(seg.source_text) > 10:
@@ -116,114 +157,118 @@ class SRT_script():
116
  for idx_list in merge_list:
117
  segments.append(self.merge_segs(idx_list))
118
 
119
- self.segments = segments # need memory release?
120
-
121
  def remove_trans_punctuation(self):
122
- # Post-process: remove all punc after translation and split
 
 
 
123
  for i, seg in enumerate(self.segments):
124
  seg.remove_trans_punc()
125
 
126
- def set_translation(self, translate:str, id_range:tuple, model, video_name, video_link=None):
127
  start_seg_id = id_range[0]
128
  end_seg_id = id_range[1]
129
-
130
  src_text = ""
131
- for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
132
- src_text+=seg.source_text
133
- src_text+='\n\n'
134
 
135
- def inner_func(target,input_str):
136
  response = openai.ChatCompletion.create(
137
- #model=model,
138
- model = "gpt-3.5-turbo",
139
- messages = [
140
- #{"role": "system", "content": "You are a helpful assistant that help calibrates English to Chinese subtitle translations in starcraft2."},
141
- #{"role": "system", "content": "You are provided with a translated Chinese transcript; you must modify or split the Chinese sentence to match the meaning and the number of the English transcript exactly one by one. You must not merge ANY Chinese lines, you can only split them but the total Chinese lines MUST equals to number of English lines."},
142
- #{"role": "system", "content": "There is no need for you to add any comments or notes, and do not modify the English transcript."},
143
- #{"role": "user", "content": 'You are given the English transcript and line number, your task is to merge or split the Chinese to match the exact number of lines in English transcript, no more no less. For example, if there are more Chinese lines than English lines, merge some the Chinese lines to match the number of English lines. If Chinese lines is less than English lines, split some Chinese lines to match the english lines: "{}"'.format(input_str)}
144
-
145
- {"role": "system", "content": "你的任务是按照要求合并或拆分句子到指定行数,你需要尽可能保证句意,但必要时可以将一句话分为两行输出"},
146
- {"role": "system", "content": "注意:你只需要输出处理过的中文句子,如果你要输出序号,请使用冒号隔开"},
147
- {"role": "user", "content": '请将下面的句子拆分或组合为{}句:\n{}'.format(target,input_str)}
148
- # {"role": "system", "content": "请将以下中文与其英文句子一一对应并输出:"},
149
- # {"role": "system", "content": "英文:{}".format(src_text)},
150
- # {"role": "user", "content": "中文:{}\n\n".format(input_str)},
151
- ],
152
- temperature = 0.15
153
- )
 
154
  # print(src_text)
155
  # print(input_str)
156
  # print(response['choices'][0]['message']['content'].strip())
157
  # exit()
158
  return response['choices'][0]['message']['content'].strip()
159
-
160
-
161
  lines = translate.split('\n\n')
162
  if len(lines) < (end_seg_id - start_seg_id + 1):
163
  count = 0
164
  solved = True
165
- while count<5 and len(lines) != (end_seg_id - start_seg_id + 1):
166
  count += 1
167
  print("Solving Unmatched Lines|iteration {}".format(count))
168
- #input_str = "\n"
169
- #initialize GPT input
170
- #for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
171
  # input_str += 'Sentence %d: ' %(i+1)+ seg.source_text + '\n'
172
  # #Append to prompt string
173
  # #Adds sentence index let GPT keep track of sentence breaks
174
- #input_str += translate
175
- #append translate to prompt
176
  flag = True
177
  while flag:
178
  flag = False
179
- #print("translate:")
180
- #print(translate)
181
  try:
182
- #print("target")
183
- #print(end_seg_id - start_seg_id + 1)
184
- translate = inner_func(end_seg_id - start_seg_id + 1,translate)
185
  except Exception as e:
186
- print("An error has occurred during solving unmatched lines:",e)
187
  print("Retrying...")
188
  flag = True
189
  lines = translate.split('\n')
190
- #print("result")
191
- #print(len(lines))
192
-
193
  if len(lines) < (end_seg_id - start_seg_id + 1):
194
  solved = False
195
  print("Failed Solving unmatched lines, Manually parse needed")
196
-
197
  if not os.path.exists("./logs"):
198
  os.mkdir("./logs")
199
  if video_link:
200
  log_file = "./logs/log_link.csv"
201
  log_exist = os.path.exists(log_file)
202
- with open(log_file,"a") as log:
203
  if not log_exist:
204
  log.write("range_of_text,iterations_solving,solved,file_length,video_link" + "\n")
205
- log.write(str(id_range)+','+str(count)+','+str(solved)+','+str(len(self.segments))+','+video_link + "\n")
 
206
  else:
207
  log_file = "./logs/log_name.csv"
208
  log_exist = os.path.exists(log_file)
209
- with open(log_file,"a") as log:
210
  if not log_exist:
211
  log.write("range_of_text,iterations_solving,solved,file_length,video_name" + "\n")
212
- log.write(str(id_range)+','+str(count)+','+str(solved)+','+str(len(self.segments))+','+video_name + "\n")
213
-
 
214
  print(lines)
215
- #print(id_range)
216
- #for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
217
  # print(seg.source_text)
218
- #print(translate)
219
-
220
-
221
- for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
222
  # naive way to due with merge translation problem
223
  # TODO: need a smarter solution
224
 
225
  if i < len(lines):
226
- if "Note:" in lines[i]: # to avoid note
227
  lines.remove(lines[i])
228
  max_num -= 1
229
  if i == len(lines) - 1:
@@ -233,7 +278,6 @@ class SRT_script():
233
  except:
234
  seg.translation = lines[i]
235
 
236
-
237
  def split_seg(self, seg, text_threshold, time_threshold):
238
  # evenly split seg to 2 parts and add new seg into self.segments
239
 
@@ -251,21 +295,24 @@ class SRT_script():
251
  src_commas = [m.start() for m in re.finditer(',', source_text)]
252
  trans_commas = [m.start() for m in re.finditer(',', translation)]
253
  if len(src_commas) != 0:
254
- src_split_idx = src_commas[len(src_commas)//2] if len(src_commas) % 2 == 1 else src_commas[len(src_commas)//2 - 1]
 
255
  else:
256
  src_space = [m.start() for m in re.finditer(' ', source_text)]
257
- if len(src_space) > 0:
258
- src_split_idx = src_space[len(src_space)//2] if len(src_space) % 2 == 1 else src_space[len(src_space)//2 - 1]
 
259
  else:
260
  src_split_idx = 0
261
 
262
  if len(trans_commas) != 0:
263
- trans_split_idx = trans_commas[len(trans_commas)//2] if len(trans_commas) % 2 == 1 else trans_commas[len(trans_commas)//2 - 1]
 
264
  else:
265
- trans_split_idx = len(translation)//2
266
-
267
  # split the time duration based on text length
268
- time_split_ratio = trans_split_idx/(len(seg.translation) - 1)
269
 
270
  src_seg1 = source_text[:src_split_idx]
271
  src_seg2 = source_text[src_split_idx:]
@@ -273,7 +320,7 @@ class SRT_script():
273
  trans_seg2 = translation[trans_split_idx:]
274
 
275
  start_seg1 = seg.start
276
- end_seg1 = start_seg2 = seg.start + (seg.end - seg.start)*time_split_ratio
277
  end_seg2 = seg.end
278
 
279
  seg1_dict = {}
@@ -295,7 +342,7 @@ class SRT_script():
295
  result_list += self.split_seg(seg1, text_threshold, time_threshold)
296
  else:
297
  result_list.append(seg1)
298
-
299
  if len(seg2.translation) > text_threshold and (seg2.end - seg2.start) > time_threshold:
300
  result_list += self.split_seg(seg2, text_threshold, time_threshold)
301
  else:
@@ -303,7 +350,6 @@ class SRT_script():
303
 
304
  return result_list
305
 
306
-
307
  def check_len_and_split(self, text_threshold=30, time_threshold=1.0):
308
  # DEPRECATED
309
  # if sentence length >= threshold and sentence duration > time_threshold, split this segments to two
@@ -314,7 +360,7 @@ class SRT_script():
314
  segments += seg_list
315
  else:
316
  segments.append(seg)
317
-
318
  self.segments = segments
319
 
320
  pass
@@ -325,23 +371,23 @@ class SRT_script():
325
  end_seg_id = range[1]
326
  extra_len = 0
327
  segments = []
328
- for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
329
  if len(seg.translation) > text_threshold and (seg.end - seg.start) > time_threshold:
330
  seg_list = self.split_seg(seg, text_threshold, time_threshold)
331
  segments += seg_list
332
- extra_len += len(seg_list) - 1
333
  else:
334
  segments.append(seg)
335
-
336
- self.segments[start_seg_id-1:end_seg_id] = segments
337
  return extra_len
338
 
339
  def correct_with_force_term(self):
340
  ## force term correction
341
 
342
  # load term dictionary
343
- with open("./finetune_data/dict_enzh.csv",'r', encoding='utf-8') as f:
344
- term_enzh_dict = {rows[0]:rows[1] for rows in reader(f)}
345
 
346
  # change term
347
  for seg in self.segments:
@@ -359,7 +405,7 @@ class SRT_script():
359
 
360
  def spell_check_term(self):
361
  ## known bug: I've will be replaced because i've is not in the dict
362
-
363
  import enchant
364
  dict = enchant.Dict('en_US')
365
  term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')
@@ -372,21 +418,20 @@ class SRT_script():
372
  if not dict.check(word[:pos]):
373
  suggest = term_spellDict.suggest(real_word)
374
  if suggest: # relax spell check
375
- new_word = word.replace(word[:pos],suggest[0])
376
- else:
377
  new_word = word
378
  ready_words[i] = new_word
379
  seg.source_text = " ".join(ready_words)
380
  pass
381
 
382
- def spell_correction(self, word:str, arg:int):
383
  try:
384
- arg in [0,1]
385
  except ValueError:
386
  print('only 0 or 1 for argument')
387
 
388
-
389
- def uncover(word:str):
390
  if word[-2:] == ".\n":
391
  real_word = word[:-2].lower()
392
  n = -2
@@ -396,14 +441,14 @@ class SRT_script():
396
  else:
397
  real_word = word.lower()
398
  n = 0
399
- return real_word, len(word)+n
400
-
401
  real_word = uncover(word)[0]
402
  pos = uncover(word)[1]
403
  new_word = word
404
  if arg == 0: # term translate mode
405
- with open("finetune_data/dict_enzh.csv",'r', encoding='utf-8') as f:
406
- term_enzh_dict = {rows[0]:rows[1] for rows in reader(f)}
407
  if real_word in term_enzh_dict:
408
  new_word = word.replace(word[:pos], term_enzh_dict.get(real_word))
409
  elif arg == 1: # term spell check mode
@@ -412,10 +457,10 @@ class SRT_script():
412
  term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')
413
  if not dict.check(real_word):
414
  if term_spellDict.suggest(real_word): # relax spell check
415
- new_word = word.replace(word[:pos],term_spellDict.suggest(real_word)[0])
416
  return new_word
417
-
418
- def get_real_word(self, word:str):
419
  if word[-2:] == ".\n":
420
  real_word = word[:-2].lower()
421
  n = -2
@@ -425,8 +470,7 @@ class SRT_script():
425
  else:
426
  real_word = word.lower()
427
  n = 0
428
- return real_word, len(word)+n
429
-
430
 
431
  ## WRITE AND READ FUNCTIONS ##
432
 
@@ -434,48 +478,48 @@ class SRT_script():
434
  # return a string with pure source text
435
  result = ""
436
  for i, seg in enumerate(self.segments):
437
- result+=f'SENTENCE {i+1}: {seg.source_text}\n\n\n'
438
-
439
  return result
440
-
441
  def reform_src_str(self):
442
  result = ""
443
  for i, seg in enumerate(self.segments):
444
- result += f'{i+1}\n'
445
  result += str(seg)
446
  return result
447
 
448
  def reform_trans_str(self):
449
  result = ""
450
  for i, seg in enumerate(self.segments):
451
- result += f'{i+1}\n'
452
  result += seg.get_trans_str()
453
  return result
454
-
455
  def form_bilingual_str(self):
456
  result = ""
457
  for i, seg in enumerate(self.segments):
458
- result += f'{i+1}\n'
459
  result += seg.get_bilingual_str()
460
  return result
461
 
462
- def write_srt_file_src(self, path:str):
463
  # write srt file to path
464
  with open(path, "w", encoding='utf-8') as f:
465
  f.write(self.reform_src_str())
466
  pass
467
 
468
- def write_srt_file_translate(self, path:str):
469
  with open(path, "w", encoding='utf-8') as f:
470
  f.write(self.reform_trans_str())
471
  pass
472
 
473
- def write_srt_file_bilingual(self, path:str):
474
  with open(path, "w", encoding='utf-8') as f:
475
  f.write(self.form_bilingual_str())
476
  pass
477
 
478
- def realtime_write_srt(self,path,range,length, idx):
479
  # DEPRECATED
480
  start_seg_id = range[0]
481
  end_seg_id = range[1]
@@ -484,22 +528,20 @@ class SRT_script():
484
  # f.write(f'{i+idx}\n')
485
  # f.write(seg.get_trans_str())
486
  for i, seg in enumerate(self.segments):
487
- if i<range[0]-1 : continue
488
- if i>=range[1] + length :break
489
- f.write(f'{i+idx}\n')
490
  f.write(seg.get_trans_str())
491
  pass
492
 
493
- def realtime_bilingual_write_srt(self,path,range, length,idx):
494
  # DEPRECATED
495
  start_seg_id = range[0]
496
  end_seg_id = range[1]
497
  with open(path, "a", encoding='utf-8') as f:
498
  for i, seg in enumerate(self.segments):
499
- if i<range[0]-1 : continue
500
- if i>=range[1] + length :break
501
- f.write(f'{i+idx}\n')
502
  f.write(seg.get_bilingual_str())
503
  pass
504
-
505
-
 
1
+ import os
 
 
2
  import re
3
+ from copy import copy, deepcopy
4
+ from csv import reader
5
+ from datetime import timedelta
6
+
7
  import openai
8
+
 
9
 
10
  class SRT_segment(object):
11
  def __init__(self, *args) -> None:
 
13
  segment = args[0]
14
  self.start = segment['start']
15
  self.end = segment['end']
16
+ self.start_ms = int((segment['start'] * 100) % 100 * 10)
17
+ self.end_ms = int((segment['end'] * 100) % 100 * 10)
18
 
19
+ if self.start_ms == self.end_ms and int(segment['start']) == int(segment['end']): # avoid empty time stamp
20
+ self.end_ms += 500
21
 
22
  self.start_time = timedelta(seconds=int(segment['start']), milliseconds=self.start_ms)
23
  self.end_time = timedelta(seconds=int(segment['end']), milliseconds=self.end_ms)
24
  if self.start_ms == 0:
25
+ self.start_time_str = str(0) + str(self.start_time).split('.')[0] + ',000'
26
  else:
27
+ self.start_time_str = str(0) + str(self.start_time).split('.')[0] + ',' + \
28
+ str(self.start_time).split('.')[1][:3]
29
  if self.end_ms == 0:
30
+ self.end_time_str = str(0) + str(self.end_time).split('.')[0] + ',000'
31
  else:
32
+ self.end_time_str = str(0) + str(self.end_time).split('.')[0] + ',' + str(self.end_time).split('.')[1][
33
+ :3]
34
  self.source_text = segment['text'].lstrip()
35
  self.duration = f"{self.start_time_str} --> {self.end_time_str}"
36
  self.translation = ""
 
42
  self.end_time_str = self.duration.split(" --> ")[1]
43
 
44
  # parse the time to float
45
+ self.start_ms = int(self.start_time_str.split(',')[1]) / 10
46
+ self.end_ms = int(self.end_time_str.split(',')[1]) / 10
47
  start_list = self.start_time_str.split(',')[0].split(':')
48
+ self.start = int(start_list[0]) * 3600 + int(start_list[1]) * 60 + int(start_list[2]) + self.start_ms / 100
49
  end_list = self.end_time_str.split(',')[0].split(':')
50
+ self.end = int(end_list[0]) * 3600 + int(end_list[1]) * 60 + int(end_list[2]) + self.end_ms / 100
51
  self.translation = ""
52
+
53
  def merge_seg(self, seg):
54
+ """
55
+ Merge the segment seg with the current segment in place.
56
+ :param seg: Another segment that is strictly next to current one.
57
+ :return: None
58
+ """
59
+ # assert seg.start_ms == self.end_ms, f"cannot merge discontinuous segments."
60
  self.source_text += f' {seg.source_text}'
61
  self.translation += f' {seg.translation}'
62
  self.end_time_str = seg.end_time_str
 
65
  self.duration = f"{self.start_time_str} --> {self.end_time_str}"
66
  pass
67
 
68
+ def __add__(self, other):
69
+ """
70
+ Merge the segment seg with the current segment, and return the new constructed segment.
71
+ No in-place modification.
72
+ :param other: Another segment that is strictly next to added segment.
73
+ :return: new segment of the two sub-segments
74
+ """
75
+ # assert other.start_ms == self.end_ms, f"cannot merge discontinuous segments."
76
+ result = deepcopy(self)
77
+ result.source_text += f' {other.source_text}'
78
+ result.translation += f' {other.translation}'
79
+ result.end_time_str = other.end_time_str
80
+ result.end = other.end
81
+ result.end_ms = other.end_ms
82
+ result.duration = f"{self.start_time_str} --> {self.end_time_str}"
83
+ return result
84
+
85
  def remove_trans_punc(self):
86
+ """
87
+ remove punctuations in translation text
88
+ :return: None
89
+ """
90
+ punc_cn = ",。!?"
91
+ translator = str.maketrans(punc_cn, ' ' * len(punc_cn))
92
+ self.translation = self.translation.translate(translator)
93
 
94
  def __str__(self) -> str:
95
+ return f'{self.duration}\n{self.source_text}\n\n'
96
+
97
  def get_trans_str(self) -> str:
98
  return f'{self.duration}\n{self.translation}\n\n'
99
+
100
  def get_bilingual_str(self) -> str:
101
  return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n'
102
 
103
+
104
  class SRT_script():
105
  def __init__(self, segments) -> None:
106
  self.segments = []
 
109
  self.segments.append(srt_seg)
110
 
111
  @classmethod
112
+ def parse_from_srt_file(cls, path: str):
113
  with open(path, 'r', encoding="utf-8") as f:
114
+ script_lines = f.read().splitlines()
115
 
116
  segments = []
117
  for i in range(len(script_lines)):
118
  if i % 4 == 0:
119
+ segments.append(list(script_lines[i:i + 4]))
120
 
121
  return cls(segments)
122
 
123
  def merge_segs(self, idx_list) -> SRT_segment:
124
+ """
125
+ Merge entire segment list to a single segment
126
+ :param idx_list: List of index to merge
127
+ :return: Merged list
128
+ """
129
+ if not idx_list:
130
+ raise NotImplementedError('Empty idx_list')
131
+ seg_result = deepcopy(self.segments[idx_list[0]])
132
  if len(idx_list) == 1:
133
+ return seg_result
134
+
135
  for idx in range(1, len(idx_list)):
136
+ seg_result += self.segments[idx_list[idx]]
137
+
138
+ return seg_result
139
 
140
  def form_whole_sentence(self):
141
+ """
142
+ Concatenate or Strip sentences and reconstruct segments list. This is because of
143
+ improper segmentation from openai-whisper.
144
+ :return: None
145
+ """
146
+ merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
147
  sentence = []
148
  for i, seg in enumerate(self.segments):
149
  if seg.source_text[-1] in ['.', '!', '?'] and len(seg.source_text) > 10:
 
157
  for idx_list in merge_list:
158
  segments.append(self.merge_segs(idx_list))
159
 
160
+ self.segments = segments # need memory release?
161
+
162
  def remove_trans_punctuation(self):
163
+ """
164
+ Post-process: remove all punc after translation and split
165
+ :return: None
166
+ """
167
  for i, seg in enumerate(self.segments):
168
  seg.remove_trans_punc()
169
 
170
+ def set_translation(self, translate: str, id_range: tuple, model, video_name, video_link=None):
171
  start_seg_id = id_range[0]
172
  end_seg_id = id_range[1]
173
+
174
  src_text = ""
175
+ for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]):
176
+ src_text += seg.source_text
177
+ src_text += '\n\n'
178
 
179
+ def inner_func(target, input_str):
180
  response = openai.ChatCompletion.create(
181
+ # model=model,
182
+ model="gpt-3.5-turbo",
183
+ messages=[
184
+ # {"role": "system", "content": "You are a helpful assistant that help calibrates English to Chinese subtitle translations in starcraft2."},
185
+ # {"role": "system", "content": "You are provided with a translated Chinese transcript; you must modify or split the Chinese sentence to match the meaning and the number of the English transcript exactly one by one. You must not merge ANY Chinese lines, you can only split them but the total Chinese lines MUST equals to number of English lines."},
186
+ # {"role": "system", "content": "There is no need for you to add any comments or notes, and do not modify the English transcript."},
187
+ # {"role": "user", "content": 'You are given the English transcript and line number, your task is to merge or split the Chinese to match the exact number of lines in English transcript, no more no less. For example, if there are more Chinese lines than English lines, merge some the Chinese lines to match the number of English lines. If Chinese lines is less than English lines, split some Chinese lines to match the english lines: "{}"'.format(input_str)}
188
+
189
+ {"role": "system",
190
+ "content": "你的任务是按照要求合并或拆分句子到指定行数,你需要尽可能保证句意,但必要时可以将一句话分为两行输出"},
191
+ {"role": "system", "content": "注意:你只需要输出处理过的中文句子,如果你要输出序号,请使用冒号隔开"},
192
+ {"role": "user", "content": '请将下面的句子拆分或组合为{}句:\n{}'.format(target, input_str)}
193
+ # {"role": "system", "content": "请将以下中文与其英文句子一一对应并输出:"},
194
+ # {"role": "system", "content": "英文:{}".format(src_text)},
195
+ # {"role": "user", "content": "中文:{}\n\n".format(input_str)},
196
+ ],
197
+ temperature=0.15
198
+ )
199
  # print(src_text)
200
  # print(input_str)
201
  # print(response['choices'][0]['message']['content'].strip())
202
  # exit()
203
  return response['choices'][0]['message']['content'].strip()
204
+
 
205
  lines = translate.split('\n\n')
206
  if len(lines) < (end_seg_id - start_seg_id + 1):
207
  count = 0
208
  solved = True
209
+ while count < 5 and len(lines) != (end_seg_id - start_seg_id + 1):
210
  count += 1
211
  print("Solving Unmatched Lines|iteration {}".format(count))
212
+ # input_str = "\n"
213
+ # initialize GPT input
214
+ # for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
215
  # input_str += 'Sentence %d: ' %(i+1)+ seg.source_text + '\n'
216
  # #Append to prompt string
217
  # #Adds sentence index let GPT keep track of sentence breaks
218
+ # input_str += translate
219
+ # append translate to prompt
220
  flag = True
221
  while flag:
222
  flag = False
223
+ # print("translate:")
224
+ # print(translate)
225
  try:
226
+ # print("target")
227
+ # print(end_seg_id - start_seg_id + 1)
228
+ translate = inner_func(end_seg_id - start_seg_id + 1, translate)
229
  except Exception as e:
230
+ print("An error has occurred during solving unmatched lines:", e)
231
  print("Retrying...")
232
  flag = True
233
  lines = translate.split('\n')
234
+ # print("result")
235
+ # print(len(lines))
236
+
237
  if len(lines) < (end_seg_id - start_seg_id + 1):
238
  solved = False
239
  print("Failed Solving unmatched lines, Manually parse needed")
240
+
241
  if not os.path.exists("./logs"):
242
  os.mkdir("./logs")
243
  if video_link:
244
  log_file = "./logs/log_link.csv"
245
  log_exist = os.path.exists(log_file)
246
+ with open(log_file, "a") as log:
247
  if not log_exist:
248
  log.write("range_of_text,iterations_solving,solved,file_length,video_link" + "\n")
249
+ log.write(str(id_range) + ',' + str(count) + ',' + str(solved) + ',' + str(
250
+ len(self.segments)) + ',' + video_link + "\n")
251
  else:
252
  log_file = "./logs/log_name.csv"
253
  log_exist = os.path.exists(log_file)
254
+ with open(log_file, "a") as log:
255
  if not log_exist:
256
  log.write("range_of_text,iterations_solving,solved,file_length,video_name" + "\n")
257
+ log.write(str(id_range) + ',' + str(count) + ',' + str(solved) + ',' + str(
258
+ len(self.segments)) + ',' + video_name + "\n")
259
+
260
  print(lines)
261
+ # print(id_range)
262
+ # for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
263
  # print(seg.source_text)
264
+ # print(translate)
265
+
266
+ for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]):
 
267
  # naive way to due with merge translation problem
268
  # TODO: need a smarter solution
269
 
270
  if i < len(lines):
271
+ if "Note:" in lines[i]: # to avoid note
272
  lines.remove(lines[i])
273
  max_num -= 1
274
  if i == len(lines) - 1:
 
278
  except:
279
  seg.translation = lines[i]
280
 
 
281
  def split_seg(self, seg, text_threshold, time_threshold):
282
  # evenly split seg to 2 parts and add new seg into self.segments
283
 
 
295
  src_commas = [m.start() for m in re.finditer(',', source_text)]
296
  trans_commas = [m.start() for m in re.finditer(',', translation)]
297
  if len(src_commas) != 0:
298
+ src_split_idx = src_commas[len(src_commas) // 2] if len(src_commas) % 2 == 1 else src_commas[
299
+ len(src_commas) // 2 - 1]
300
  else:
301
  src_space = [m.start() for m in re.finditer(' ', source_text)]
302
+ if len(src_space) > 0:
303
+ src_split_idx = src_space[len(src_space) // 2] if len(src_space) % 2 == 1 else src_space[
304
+ len(src_space) // 2 - 1]
305
  else:
306
  src_split_idx = 0
307
 
308
  if len(trans_commas) != 0:
309
+ trans_split_idx = trans_commas[len(trans_commas) // 2] if len(trans_commas) % 2 == 1 else trans_commas[
310
+ len(trans_commas) // 2 - 1]
311
  else:
312
+ trans_split_idx = len(translation) // 2
313
+
314
  # split the time duration based on text length
315
+ time_split_ratio = trans_split_idx / (len(seg.translation) - 1)
316
 
317
  src_seg1 = source_text[:src_split_idx]
318
  src_seg2 = source_text[src_split_idx:]
 
320
  trans_seg2 = translation[trans_split_idx:]
321
 
322
  start_seg1 = seg.start
323
+ end_seg1 = start_seg2 = seg.start + (seg.end - seg.start) * time_split_ratio
324
  end_seg2 = seg.end
325
 
326
  seg1_dict = {}
 
342
  result_list += self.split_seg(seg1, text_threshold, time_threshold)
343
  else:
344
  result_list.append(seg1)
345
+
346
  if len(seg2.translation) > text_threshold and (seg2.end - seg2.start) > time_threshold:
347
  result_list += self.split_seg(seg2, text_threshold, time_threshold)
348
  else:
 
350
 
351
  return result_list
352
 
 
353
  def check_len_and_split(self, text_threshold=30, time_threshold=1.0):
354
  # DEPRECATED
355
  # if sentence length >= threshold and sentence duration > time_threshold, split this segments to two
 
360
  segments += seg_list
361
  else:
362
  segments.append(seg)
363
+
364
  self.segments = segments
365
 
366
  pass
 
371
  end_seg_id = range[1]
372
  extra_len = 0
373
  segments = []
374
+ for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]):
375
  if len(seg.translation) > text_threshold and (seg.end - seg.start) > time_threshold:
376
  seg_list = self.split_seg(seg, text_threshold, time_threshold)
377
  segments += seg_list
378
+ extra_len += len(seg_list) - 1
379
  else:
380
  segments.append(seg)
381
+
382
+ self.segments[start_seg_id - 1:end_seg_id] = segments
383
  return extra_len
384
 
385
  def correct_with_force_term(self):
386
  ## force term correction
387
 
388
  # load term dictionary
389
+ with open("./finetune_data/dict_enzh.csv", 'r', encoding='utf-8') as f:
390
+ term_enzh_dict = {rows[0]: rows[1] for rows in reader(f)}
391
 
392
  # change term
393
  for seg in self.segments:
 
405
 
406
  def spell_check_term(self):
407
  ## known bug: I've will be replaced because i've is not in the dict
408
+
409
  import enchant
410
  dict = enchant.Dict('en_US')
411
  term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')
 
418
  if not dict.check(word[:pos]):
419
  suggest = term_spellDict.suggest(real_word)
420
  if suggest: # relax spell check
421
+ new_word = word.replace(word[:pos], suggest[0])
422
+ else:
423
  new_word = word
424
  ready_words[i] = new_word
425
  seg.source_text = " ".join(ready_words)
426
  pass
427
 
428
+ def spell_correction(self, word: str, arg: int):
429
  try:
430
+ arg in [0, 1]
431
  except ValueError:
432
  print('only 0 or 1 for argument')
433
 
434
+ def uncover(word: str):
 
435
  if word[-2:] == ".\n":
436
  real_word = word[:-2].lower()
437
  n = -2
 
441
  else:
442
  real_word = word.lower()
443
  n = 0
444
+ return real_word, len(word) + n
445
+
446
  real_word = uncover(word)[0]
447
  pos = uncover(word)[1]
448
  new_word = word
449
  if arg == 0: # term translate mode
450
+ with open("finetune_data/dict_enzh.csv", 'r', encoding='utf-8') as f:
451
+ term_enzh_dict = {rows[0]: rows[1] for rows in reader(f)}
452
  if real_word in term_enzh_dict:
453
  new_word = word.replace(word[:pos], term_enzh_dict.get(real_word))
454
  elif arg == 1: # term spell check mode
 
457
  term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')
458
  if not dict.check(real_word):
459
  if term_spellDict.suggest(real_word): # relax spell check
460
+ new_word = word.replace(word[:pos], term_spellDict.suggest(real_word)[0])
461
  return new_word
462
+
463
+ def get_real_word(self, word: str):
464
  if word[-2:] == ".\n":
465
  real_word = word[:-2].lower()
466
  n = -2
 
470
  else:
471
  real_word = word.lower()
472
  n = 0
473
+ return real_word, len(word) + n
 
474
 
475
  ## WRITE AND READ FUNCTIONS ##
476
 
 
478
  # return a string with pure source text
479
  result = ""
480
  for i, seg in enumerate(self.segments):
481
+ result += f'SENTENCE {i + 1}: {seg.source_text}\n\n\n'
482
+
483
  return result
484
+
485
  def reform_src_str(self):
486
  result = ""
487
  for i, seg in enumerate(self.segments):
488
+ result += f'{i + 1}\n'
489
  result += str(seg)
490
  return result
491
 
492
  def reform_trans_str(self):
493
  result = ""
494
  for i, seg in enumerate(self.segments):
495
+ result += f'{i + 1}\n'
496
  result += seg.get_trans_str()
497
  return result
498
+
499
  def form_bilingual_str(self):
500
  result = ""
501
  for i, seg in enumerate(self.segments):
502
+ result += f'{i + 1}\n'
503
  result += seg.get_bilingual_str()
504
  return result
505
 
506
+ def write_srt_file_src(self, path: str):
507
  # write srt file to path
508
  with open(path, "w", encoding='utf-8') as f:
509
  f.write(self.reform_src_str())
510
  pass
511
 
512
+ def write_srt_file_translate(self, path: str):
513
  with open(path, "w", encoding='utf-8') as f:
514
  f.write(self.reform_trans_str())
515
  pass
516
 
517
+ def write_srt_file_bilingual(self, path: str):
518
  with open(path, "w", encoding='utf-8') as f:
519
  f.write(self.form_bilingual_str())
520
  pass
521
 
522
+ def realtime_write_srt(self, path, range, length, idx):
523
  # DEPRECATED
524
  start_seg_id = range[0]
525
  end_seg_id = range[1]
 
528
  # f.write(f'{i+idx}\n')
529
  # f.write(seg.get_trans_str())
530
  for i, seg in enumerate(self.segments):
531
+ if i < range[0] - 1: continue
532
+ if i >= range[1] + length: break
533
+ f.write(f'{i + idx}\n')
534
  f.write(seg.get_trans_str())
535
  pass
536
 
537
+ def realtime_bilingual_write_srt(self, path, range, length, idx):
538
  # DEPRECATED
539
  start_seg_id = range[0]
540
  end_seg_id = range[1]
541
  with open(path, "a", encoding='utf-8') as f:
542
  for i, seg in enumerate(self.segments):
543
+ if i < range[0] - 1: continue
544
+ if i >= range[1] + length: break
545
+ f.write(f'{i + idx}\n')
546
  f.write(seg.get_bilingual_str())
547
  pass