JiaenLiu commited on
Commit
4890407
·
2 Parent(s): 575055d 9b3283d

Merge branch 'eason/refactor' of github.com:project-kxkg/project-t into eason/refactor

Browse files
Files changed (3) hide show
  1. .gitignore +2 -0
  2. SRT.py +147 -100
  3. pipeline.py +9 -36
.gitignore CHANGED
@@ -5,3 +5,5 @@
5
  test.py
6
  test.srt
7
  test.txt
 
 
 
5
  test.py
6
  test.srt
7
  test.txt
8
+ log_*.csv
9
+ log.csv
SRT.py CHANGED
@@ -3,6 +3,7 @@ from csv import reader
3
  from datetime import datetime
4
  import re
5
  import openai
 
6
  from collections import deque
7
 
8
  class SRT_segment(object):
@@ -50,9 +51,18 @@ class SRT_segment(object):
50
  self.source_text += seg.source_text
51
  self.translation += seg.translation
52
  self.end_time_str = seg.end_time_str
 
 
53
  self.duration = f"{self.start_time_str} --> {self.end_time_str}"
54
  pass
55
-
 
 
 
 
 
 
 
56
  def __str__(self) -> str:
57
  return f'{self.duration}\n{self.source_text}\n\n'
58
 
@@ -62,16 +72,6 @@ class SRT_segment(object):
62
  def get_bilingual_str(self) -> str:
63
  return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n'
64
 
65
- # def set_translation(self, trans):
66
- # if trans[0] == ',':
67
- # trans = trans[1:]
68
- # self.translation = trans
69
-
70
- # def set_src_text(self, src_text):
71
- # if src_text[0] == ',':
72
- # src_text = src_text[1:]
73
- # self.source_text = src_text
74
-
75
  class SRT_script():
76
  def __init__(self, segments) -> None:
77
  self.segments = []
@@ -105,7 +105,7 @@ class SRT_script():
105
  merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
106
  sentence = []
107
  for i, seg in enumerate(self.segments):
108
- if seg.source_text[-1] == '.':
109
  sentence.append(i)
110
  merge_list.append(sentence)
111
  sentence = []
@@ -117,56 +117,88 @@ class SRT_script():
117
  segments.append(self.merge_segs(idx_list))
118
 
119
  self.segments = segments # need memory release?
120
-
121
 
 
 
 
 
122
 
123
- def set_translation(self, translate:str, id_range:tuple, model):
124
  start_seg_id = id_range[0]
125
  end_seg_id = id_range[1]
126
 
127
- def inner_func(input_str):
128
  response = openai.ChatCompletion.create(
129
- model=model,
 
130
  messages = [
131
- {"role": "system", "content": "You are a helpful assistant that help calibrates English to Chinese subtitle translations in starcraft2."},
132
- {"role": "system", "content": "You are provided with a translated Chinese transcript; you must modify or split the Chinese sentence to match the meaning and the number of the English transcript exactly one by one. You must not merge ANY Chinese lines, you can only split them but the total Chinese lines MUST equals to number of English lines."},
133
- {"role": "system", "content": "There is no need for you to add any comments or notes, and do not modify the English transcript."},
134
- {"role": "user", "content": 'You are given the English transcript and line number, your task is to merge or split the Chinese to match the exact number of lines in English transcript, no more no less. For example, if there are more Chinese lines than English lines, merge some the Chinese lines to match the number of English lines. If Chinese lines is less than English lines, split some Chinese lines to match the english lines: "{}"'.format(input_str)}
 
 
 
135
  ],
136
- temperature=0.7
 
137
  )
138
  return response['choices'][0]['message']['content'].strip()
139
 
 
140
  lines = translate.split('\n\n')
141
  if len(lines) < (end_seg_id - start_seg_id + 1):
142
  count = 0
 
143
  while count<5 and len(lines) != (end_seg_id - start_seg_id + 1):
144
-
145
  count += 1
146
  print("Solving Unmatched Lines|iteration {}".format(count))
147
- input_str = "\n"
148
  #initialize GPT input
149
- for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
150
- input_str += 'Sentence %d: ' %(i+1)+ seg.source_text + '\n'
151
- #Append to prompt string
152
- #Adds sentence index let GPT keep track of sentence breaks
153
- input_str += translate
154
  #append translate to prompt
155
-
156
  flag = True
157
  while flag:
158
  flag = False
 
 
159
  try:
160
- translate = inner_func(input_str)
 
 
161
  except Exception as e:
162
  print("An error has occurred during solving unmatched lines:",e)
163
  print("Retrying...")
164
  flag = True
165
-
166
- lines = translate.split('\n\n')
 
 
167
  if len(lines) < (end_seg_id - start_seg_id + 1):
 
168
  print("Failed Solving unmatched lines, Manually parse needed")
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  print(lines)
171
  #print(id_range)
172
  #for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
@@ -182,23 +214,29 @@ class SRT_script():
182
  if i < len(lines):
183
  if "Note:" in lines[i]: # to avoid note
184
  lines.remove(lines[i])
 
185
  if i == len(lines) - 1:
186
  break
187
  try:
188
- seg.translation = lines[i].split(":" or ":")[1]
189
  except:
190
  seg.translation = lines[i]
191
- #print(lines[i])
192
- pass
193
 
194
- def split_seg(self, seg, threshold):
195
  # evenly split seg to 2 parts and add new seg into self.segments
196
- if seg.source_text[:2] == ', ':
197
- seg.source_text = seg.source_text[2:]
 
 
 
198
  if seg.translation[0] == ',':
199
  seg.translation = seg.translation[1:]
 
200
  source_text = seg.source_text
201
  translation = seg.translation
 
 
202
  src_commas = [m.start() for m in re.finditer(',', source_text)]
203
  trans_commas = [m.start() for m in re.finditer(',', translation)]
204
  if len(src_commas) != 0:
@@ -215,13 +253,18 @@ class SRT_script():
215
  else:
216
  trans_split_idx = len(translation)//2
217
 
 
 
 
218
  src_seg1 = source_text[:src_split_idx]
219
  src_seg2 = source_text[src_split_idx:]
220
  trans_seg1 = translation[:trans_split_idx]
221
  trans_seg2 = translation[trans_split_idx:]
 
222
  start_seg1 = seg.start
223
- end_seg1 = start_seg2 = seg.start + (seg.end - seg.start)/2
224
  end_seg2 = seg.end
 
225
  seg1_dict = {}
226
  seg1_dict['text'] = src_seg1
227
  seg1_dict['start'] = start_seg1
@@ -237,26 +280,26 @@ class SRT_script():
237
  seg2.translation = trans_seg2
238
 
239
  result_list = []
240
- if len(seg1.translation) > threshold:
241
- result_list += self.split_seg(seg1, threshold)
242
  else:
243
  result_list.append(seg1)
244
 
245
- if len(seg2.translation) > threshold:
246
- result_list += self.split_seg(seg2, threshold)
247
  else:
248
  result_list.append(seg2)
249
 
250
  return result_list
251
 
252
 
253
- def check_len_and_split(self, threshold=30):
254
  # DEPRECATED
255
- # if sentence length >= threshold, split this segments to two
256
  segments = []
257
  for seg in self.segments:
258
- if len(seg.translation) > threshold:
259
- seg_list = self.split_seg(seg, threshold)
260
  segments += seg_list
261
  else:
262
  segments.append(seg)
@@ -265,73 +308,25 @@ class SRT_script():
265
 
266
  pass
267
 
268
- def check_len_and_split_range(self, range, threshold=30):
269
- # if sentence length >= threshold, split this segments to two
270
  start_seg_id = range[0]
271
  end_seg_id = range[1]
272
  extra_len = 0
273
  segments = []
274
  for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
275
- if len(seg.translation) > threshold:
276
- seg_list = self.split_seg(seg, threshold)
277
  segments += seg_list
278
  extra_len += len(seg_list) - 1
279
  else:
280
  segments.append(seg)
281
 
282
  self.segments[start_seg_id-1:end_seg_id] = segments
283
-
284
  return extra_len
285
-
286
- def get_source_only(self):
287
- # return a string with pure source text
288
- result = ""
289
- for i, seg in enumerate(self.segments):
290
- result+=f'SENTENCE {i+1}: {seg.source_text}\n\n\n'
291
-
292
- return result
293
-
294
- def reform_src_str(self):
295
- result = ""
296
- for i, seg in enumerate(self.segments):
297
- result += f'{i+1}\n'
298
- result += str(seg)
299
- return result
300
-
301
- def reform_trans_str(self):
302
- result = ""
303
- for i, seg in enumerate(self.segments):
304
- result += f'{i+1}\n'
305
- result += seg.get_trans_str()
306
- return result
307
-
308
- def form_bilingual_str(self):
309
- result = ""
310
- for i, seg in enumerate(self.segments):
311
- result += f'{i+1}\n'
312
- result += seg.get_bilingual_str()
313
- return result
314
-
315
- def write_srt_file_src(self, path:str):
316
- # write srt file to path
317
- with open(path, "w", encoding='utf-8') as f:
318
- f.write(self.reform_src_str())
319
- pass
320
-
321
- def write_srt_file_translate(self, path:str):
322
- with open(path, "w", encoding='utf-8') as f:
323
- f.write(self.reform_trans_str())
324
- pass
325
-
326
- def write_srt_file_bilingual(self, path:str):
327
- with open(path, "w", encoding='utf-8') as f:
328
- f.write(self.form_bilingual_str())
329
- pass
330
 
331
  def correct_with_force_term(self):
332
  ## force term correction
333
- # TODO: shortcut translation i.e. VA, ob
334
- # TODO: variety of translation
335
 
336
  # load term dictionary
337
  with open("./finetune_data/dict_enzh.csv",'r', encoding='utf-8') as f:
@@ -420,8 +415,57 @@ class SRT_script():
420
  real_word = word.lower()
421
  n = 0
422
  return real_word, len(word)+n
 
 
 
 
 
 
 
 
 
 
 
423
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
  def realtime_write_srt(self,path,range,length, idx):
 
425
  start_seg_id = range[0]
426
  end_seg_id = range[1]
427
  with open(path, "a", encoding='utf-8') as f:
@@ -436,6 +480,7 @@ class SRT_script():
436
  pass
437
 
438
  def realtime_bilingual_write_srt(self,path,range, length,idx):
 
439
  start_seg_id = range[0]
440
  end_seg_id = range[1]
441
  with open(path, "a", encoding='utf-8') as f:
@@ -444,4 +489,6 @@ class SRT_script():
444
  if i>=range[1] + length :break
445
  f.write(f'{i+idx}\n')
446
  f.write(seg.get_bilingual_str())
447
- pass
 
 
 
3
  from datetime import datetime
4
  import re
5
  import openai
6
+ import os
7
  from collections import deque
8
 
9
  class SRT_segment(object):
 
51
  self.source_text += seg.source_text
52
  self.translation += seg.translation
53
  self.end_time_str = seg.end_time_str
54
+ self.end = seg.end
55
+ self.end_ms = seg.end_ms
56
  self.duration = f"{self.start_time_str} --> {self.end_time_str}"
57
  pass
58
+
59
+ def remove_trans_punc(self):
60
+ # remove punctuations in translation text
61
+ self.translation = self.translation.replace(',', ' ')
62
+ self.translation = self.translation.replace('。', ' ')
63
+ self.translation = self.translation.replace('!', ' ')
64
+ self.translation = self.translation.replace('?', ' ')
65
+
66
  def __str__(self) -> str:
67
  return f'{self.duration}\n{self.source_text}\n\n'
68
 
 
72
  def get_bilingual_str(self) -> str:
73
  return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n'
74
 
 
 
 
 
 
 
 
 
 
 
75
  class SRT_script():
76
  def __init__(self, segments) -> None:
77
  self.segments = []
 
105
  merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
106
  sentence = []
107
  for i, seg in enumerate(self.segments):
108
+ if seg.source_text[-1] in ['.', '!', '?']:
109
  sentence.append(i)
110
  merge_list.append(sentence)
111
  sentence = []
 
117
  segments.append(self.merge_segs(idx_list))
118
 
119
  self.segments = segments # need memory release?
 
120
 
121
+ def remove_trans_punctuation(self):
122
+ # Post-process: remove all punc after translation and split
123
+ for i, seg in enumerate(self.segments):
124
+ seg.remove_trans_punc()
125
 
126
+ def set_translation(self, translate:str, id_range:tuple, model, video_name, video_link=None):
127
  start_seg_id = id_range[0]
128
  end_seg_id = id_range[1]
129
 
130
+ def inner_func(target,input_str):
131
  response = openai.ChatCompletion.create(
132
+ #model=model,
133
+ model = "gpt-3.5-turbo",
134
  messages = [
135
+ #{"role": "system", "content": "You are a helpful assistant that help calibrates English to Chinese subtitle translations in starcraft2."},
136
+ #{"role": "system", "content": "You are provided with a translated Chinese transcript; you must modify or split the Chinese sentence to match the meaning and the number of the English transcript exactly one by one. You must not merge ANY Chinese lines, you can only split them but the total Chinese lines MUST equals to number of English lines."},
137
+ #{"role": "system", "content": "There is no need for you to add any comments or notes, and do not modify the English transcript."},
138
+ #{"role": "user", "content": 'You are given the English transcript and line number, your task is to merge or split the Chinese to match the exact number of lines in English transcript, no more no less. For example, if there are more Chinese lines than English lines, merge some the Chinese lines to match the number of English lines. If Chinese lines is less than English lines, split some Chinese lines to match the english lines: "{}"'.format(input_str)}
139
+ {"role": "system", "content": "你的任务是按照要求合并或拆分句子到指定行数,你需要尽可能保证句意,但必要时可以将一句话分为两行输出"},
140
+ {"role": "system", "content": "注意:你只需要输出处理过的中文句子,如果你要输出序号,请使用冒号隔开"},
141
+ {"role": "user", "content": '请将下面的句子拆分或组合为{}句:\n{}'.format(target,input_str)}
142
  ],
143
+ #temperature=0.7
144
+ temperature = 0.15
145
  )
146
  return response['choices'][0]['message']['content'].strip()
147
 
148
+
149
  lines = translate.split('\n\n')
150
  if len(lines) < (end_seg_id - start_seg_id + 1):
151
  count = 0
152
+ solved = True
153
  while count<5 and len(lines) != (end_seg_id - start_seg_id + 1):
 
154
  count += 1
155
  print("Solving Unmatched Lines|iteration {}".format(count))
156
+ #input_str = "\n"
157
  #initialize GPT input
158
+ #for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
159
+ # input_str += 'Sentence %d: ' %(i+1)+ seg.source_text + '\n'
160
+ # #Append to prompt string
161
+ # #Adds sentence index let GPT keep track of sentence breaks
162
+ #input_str += translate
163
  #append translate to prompt
 
164
  flag = True
165
  while flag:
166
  flag = False
167
+ #print("translate:")
168
+ #print(translate)
169
  try:
170
+ #print("target")
171
+ #print(end_seg_id - start_seg_id + 1)
172
+ translate = inner_func(end_seg_id - start_seg_id + 1,translate)
173
  except Exception as e:
174
  print("An error has occurred during solving unmatched lines:",e)
175
  print("Retrying...")
176
  flag = True
177
+ lines = translate.split('\n')
178
+ #print("result")
179
+ #print(len(lines))
180
+
181
  if len(lines) < (end_seg_id - start_seg_id + 1):
182
+ solved = False
183
  print("Failed Solving unmatched lines, Manually parse needed")
184
 
185
+ if not os.path.exists("./logs"):
186
+ os.mkdir("./logs")
187
+ if video_link:
188
+ log_file = "./logs/log_link.csv"
189
+ log_exist = os.path.exists(log_file)
190
+ with open(log_file,"a") as log:
191
+ if not log_exist:
192
+ log.write("range_of_text,iterations_solving,solved,file_length,video_link" + "\n")
193
+ log.write(str(id_range)+','+str(count)+','+str(solved)+','+str(len(self.segments))+','+video_link + "\n")
194
+ else:
195
+ log_file = "./logs/log_name.csv"
196
+ log_exist = os.path.exists(log_file)
197
+ with open(log_file,"a") as log:
198
+ if not log_exist:
199
+ log.write("range_of_text,iterations_solving,solved,file_length,video_name" + "\n")
200
+ log.write(str(id_range)+','+str(count)+','+str(solved)+','+str(len(self.segments))+','+video_name + "\n")
201
+
202
  print(lines)
203
  #print(id_range)
204
  #for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
 
214
  if i < len(lines):
215
  if "Note:" in lines[i]: # to avoid note
216
  lines.remove(lines[i])
217
+ max_num -= 1
218
  if i == len(lines) - 1:
219
  break
220
  try:
221
+ seg.translation = lines[i].split(":" or ":" or ".")[1]
222
  except:
223
  seg.translation = lines[i]
224
+
 
225
 
226
+ def split_seg(self, seg, text_threshold, time_threshold):
227
  # evenly split seg to 2 parts and add new seg into self.segments
228
+
229
+ # ignore the initial comma to solve the recursion problem
230
+ if len(seg.source_text) > 2:
231
+ if seg.source_text[:2] == ', ':
232
+ seg.source_text = seg.source_text[2:]
233
  if seg.translation[0] == ',':
234
  seg.translation = seg.translation[1:]
235
+
236
  source_text = seg.source_text
237
  translation = seg.translation
238
+
239
+ # split the text based on commas
240
  src_commas = [m.start() for m in re.finditer(',', source_text)]
241
  trans_commas = [m.start() for m in re.finditer(',', translation)]
242
  if len(src_commas) != 0:
 
253
  else:
254
  trans_split_idx = len(translation)//2
255
 
256
+ # split the time duration based on text length
257
+ time_split_ratio = trans_split_idx/(len(seg.translation) - 1)
258
+
259
  src_seg1 = source_text[:src_split_idx]
260
  src_seg2 = source_text[src_split_idx:]
261
  trans_seg1 = translation[:trans_split_idx]
262
  trans_seg2 = translation[trans_split_idx:]
263
+
264
  start_seg1 = seg.start
265
+ end_seg1 = start_seg2 = seg.start + (seg.end - seg.start)*time_split_ratio
266
  end_seg2 = seg.end
267
+
268
  seg1_dict = {}
269
  seg1_dict['text'] = src_seg1
270
  seg1_dict['start'] = start_seg1
 
280
  seg2.translation = trans_seg2
281
 
282
  result_list = []
283
+ if len(seg1.translation) > text_threshold and (seg1.end - seg1.start) > time_threshold:
284
+ result_list += self.split_seg(seg1, text_threshold, time_threshold)
285
  else:
286
  result_list.append(seg1)
287
 
288
+ if len(seg2.translation) > text_threshold and (seg2.end - seg2.start) > time_threshold:
289
+ result_list += self.split_seg(seg2, text_threshold, time_threshold)
290
  else:
291
  result_list.append(seg2)
292
 
293
  return result_list
294
 
295
 
296
+ def check_len_and_split(self, text_threshold=30, time_threshold=1.0):
297
  # DEPRECATED
298
+ # if sentence length >= threshold and sentence duration > time_threshold, split this segments to two
299
  segments = []
300
  for seg in self.segments:
301
+ if len(seg.translation) > text_threshold and (seg.end - seg.start) > time_threshold:
302
+ seg_list = self.split_seg(seg, text_threshold, time_threshold)
303
  segments += seg_list
304
  else:
305
  segments.append(seg)
 
308
 
309
  pass
310
 
311
+ def check_len_and_split_range(self, range, text_threshold=30, time_threshold=1.0):
312
+ # if sentence length >= text_threshold, split this segments to two
313
  start_seg_id = range[0]
314
  end_seg_id = range[1]
315
  extra_len = 0
316
  segments = []
317
  for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
318
+ if len(seg.translation) > text_threshold and (seg.end - seg.start) > time_threshold:
319
+ seg_list = self.split_seg(seg, text_threshold, time_threshold)
320
  segments += seg_list
321
  extra_len += len(seg_list) - 1
322
  else:
323
  segments.append(seg)
324
 
325
  self.segments[start_seg_id-1:end_seg_id] = segments
 
326
  return extra_len
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
 
328
  def correct_with_force_term(self):
329
  ## force term correction
 
 
330
 
331
  # load term dictionary
332
  with open("./finetune_data/dict_enzh.csv",'r', encoding='utf-8') as f:
 
415
  real_word = word.lower()
416
  n = 0
417
  return real_word, len(word)+n
418
+
419
+
420
+ ## WRITE AND READ FUNCTIONS ##
421
+
422
+ def get_source_only(self):
423
+ # return a string with pure source text
424
+ result = ""
425
+ for i, seg in enumerate(self.segments):
426
+ result+=f'SENTENCE {i+1}: {seg.source_text}\n\n\n'
427
+
428
+ return result
429
 
430
+ def reform_src_str(self):
431
+ result = ""
432
+ for i, seg in enumerate(self.segments):
433
+ result += f'{i+1}\n'
434
+ result += str(seg)
435
+ return result
436
+
437
+ def reform_trans_str(self):
438
+ result = ""
439
+ for i, seg in enumerate(self.segments):
440
+ result += f'{i+1}\n'
441
+ result += seg.get_trans_str()
442
+ return result
443
+
444
+ def form_bilingual_str(self):
445
+ result = ""
446
+ for i, seg in enumerate(self.segments):
447
+ result += f'{i+1}\n'
448
+ result += seg.get_bilingual_str()
449
+ return result
450
+
451
+ def write_srt_file_src(self, path:str):
452
+ # write srt file to path
453
+ with open(path, "w", encoding='utf-8') as f:
454
+ f.write(self.reform_src_str())
455
+ pass
456
+
457
+ def write_srt_file_translate(self, path:str):
458
+ with open(path, "w", encoding='utf-8') as f:
459
+ f.write(self.reform_trans_str())
460
+ pass
461
+
462
+ def write_srt_file_bilingual(self, path:str):
463
+ with open(path, "w", encoding='utf-8') as f:
464
+ f.write(self.form_bilingual_str())
465
+ pass
466
+
467
  def realtime_write_srt(self,path,range,length, idx):
468
+ # DEPRECATED
469
  start_seg_id = range[0]
470
  end_seg_id = range[1]
471
  with open(path, "a", encoding='utf-8') as f:
 
480
  pass
481
 
482
  def realtime_bilingual_write_srt(self,path,range, length,idx):
483
+ # DEPRECATED
484
  start_seg_id = range[0]
485
  end_seg_id = range[1]
486
  with open(path, "a", encoding='utf-8') as f:
 
489
  if i>=range[1] + length :break
490
  f.write(f'{i+idx}\n')
491
  f.write(seg.get_bilingual_str())
492
+ pass
493
+
494
+
pipeline.py CHANGED
@@ -5,6 +5,8 @@ import os
5
  from tqdm import tqdm
6
  from SRT import SRT_script
7
  import stable_whisper
 
 
8
  import subprocess
9
 
10
  import time
@@ -47,7 +49,7 @@ if args.video_name == 'placeholder' :
47
  elif args.audio_file is not None:
48
  VIDEO_NAME = args.audio_file.split('/')[-1].split('.')[0]
49
  elif args.srt_file is not None:
50
- VIDEO_NAME = args.srt_file.split('/')[-1].split('.')[0]
51
  else:
52
  VIDEO_NAME = args.video_name
53
  else:
@@ -95,14 +97,7 @@ elif args.video_file is not None:
95
  audio_file= open(args.audio_file, "rb")
96
  audio_path = args.audio_file
97
  else:
98
- # escaped_video_path = args.video_file.replace('(', '\(').replace(')', '\)').replace(' ', '\ ')
99
- # print(escaped_video_path)
100
- # os.system(f'ffmpeg -i {escaped_video_path} -f mp3 -ab 192000 -vn {DOWNLOAD_PATH}/audio/{VIDEO_NAME}.mp3')
101
- # audio_file= open(f'{DOWNLOAD_PATH}/audio/{VIDEO_NAME}.mp3', "rb")
102
- # audio_path = f'{DOWNLOAD_PATH}/audio/{VIDEO_NAME}.mp3'
103
  output_audio_path = f'{DOWNLOAD_PATH}/audio/{VIDEO_NAME}.mp3'
104
- # print(video_path)
105
- # print(output_audio_path)
106
  subprocess.run(['ffmpeg', '-i', video_path, '-f', 'mp3', '-ab', '192000', '-vn', output_audio_path])
107
  audio_file = open(output_audio_path, "rb")
108
  audio_path = output_audio_path
@@ -133,7 +128,7 @@ else:
133
 
134
  # use stable-whisper
135
  model = stable_whisper.load_model('base')
136
- transcript = model.transcribe(audio_path, regroup = False)
137
  (
138
  transcript
139
  .split_by_punctuation(['.', '。', '?'])
@@ -143,14 +138,9 @@ else:
143
  )
144
  # transcript.to_srt_vtt(srt_file_en)
145
  transcript = transcript.to_dict()
 
146
  srt = SRT_script(transcript['segments']) # read segments to SRT class
147
 
148
- #Write SRT file
149
-
150
- # from whisper.utils import WriteSRT
151
- # with open(srt_file_en, 'w', encoding="utf-8") as f:
152
- # writer = WriteSRT(RESULT_PATH)
153
- # writer.write_result(transcript, f)
154
  else:
155
  srt = SRT_script.parse_from_srt_file(srt_file_en)
156
 
@@ -241,21 +231,6 @@ def get_response(model_name, sentence):
241
  )
242
 
243
  return response['choices'][0]['message']['content'].strip()
244
-
245
- # if model_name == "text-davinci-003":
246
- # prompt = f"Please help me translate this into Chinese:\n\n{s}\n\n"
247
- # # print(prompt)
248
- # response = openai.Completion.create(
249
- # model=model_name,
250
- # prompt=prompt,
251
- # temperature=0.1,
252
- # max_tokens=2000,
253
- # top_p=1.0,
254
- # frequency_penalty=0.0,
255
- # presence_penalty=0.0
256
- # )
257
- # return response['choices'][0]['text'].strip()
258
- pass
259
 
260
 
261
  # Translate and save
@@ -277,16 +252,14 @@ for sentence, range in tqdm(zip(script_arr, range_arr)):
277
  time.sleep(30)
278
  flag = True
279
  # add read-time output back and modify the post-processing by using one batch as an unit.
280
- srt.set_translation(translate, range, model_name)
281
- # add_length = srt.check_len_and_split_range(range, threshold)
282
- # srt.realtime_write_srt(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.srt",range, add_length ,segidx)
283
- # # save current length as previous length
284
- # previous_length = add_length
285
  # srt.realtime_bilingual_write_srt(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_bi.srt",range, add_length,segidx)
286
 
287
  srt.check_len_and_split()
 
288
  srt.write_srt_file_translate(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.srt")
289
- # srt.write_srt_file_bilingual(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_bi.srt")
290
 
291
  if not args.only_srt:
292
  assSub_zh = srt2ass(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.srt", "default", "No", "Modest")
 
5
  from tqdm import tqdm
6
  from SRT import SRT_script
7
  import stable_whisper
8
+ import whisper
9
+
10
  import subprocess
11
 
12
  import time
 
49
  elif args.audio_file is not None:
50
  VIDEO_NAME = args.audio_file.split('/')[-1].split('.')[0]
51
  elif args.srt_file is not None:
52
+ VIDEO_NAME = args.srt_file.split('/')[-1].split('.')[0].split("_")[0]
53
  else:
54
  VIDEO_NAME = args.video_name
55
  else:
 
97
  audio_file= open(args.audio_file, "rb")
98
  audio_path = args.audio_file
99
  else:
 
 
 
 
 
100
  output_audio_path = f'{DOWNLOAD_PATH}/audio/{VIDEO_NAME}.mp3'
 
 
101
  subprocess.run(['ffmpeg', '-i', video_path, '-f', 'mp3', '-ab', '192000', '-vn', output_audio_path])
102
  audio_file = open(output_audio_path, "rb")
103
  audio_path = output_audio_path
 
128
 
129
  # use stable-whisper
130
  model = stable_whisper.load_model('base')
131
+ transcript = model.transcribe(audio_path, regroup = False, initial_prompt="Hello, welcome to my lecture. Are you good my friend?")
132
  (
133
  transcript
134
  .split_by_punctuation(['.', '。', '?'])
 
138
  )
139
  # transcript.to_srt_vtt(srt_file_en)
140
  transcript = transcript.to_dict()
141
+ # print(transcript)
142
  srt = SRT_script(transcript['segments']) # read segments to SRT class
143
 
 
 
 
 
 
 
144
  else:
145
  srt = SRT_script.parse_from_srt_file(srt_file_en)
146
 
 
231
  )
232
 
233
  return response['choices'][0]['message']['content'].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
 
236
  # Translate and save
 
252
  time.sleep(30)
253
  flag = True
254
  # add read-time output back and modify the post-processing by using one batch as an unit.
255
+ srt.set_translation(translate, range, model_name, VIDEO_NAME, args.link)
256
+
 
 
 
257
  # srt.realtime_bilingual_write_srt(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_bi.srt",range, add_length,segidx)
258
 
259
  srt.check_len_and_split()
260
+ srt.remove_trans_punctuation()
261
  srt.write_srt_file_translate(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.srt")
262
+ srt.write_srt_file_bilingual(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_bi.srt")
263
 
264
  if not args.only_srt:
265
  assSub_zh = srt2ass(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.srt", "default", "No", "Modest")