Eason Lu commited on
Commit
1a902ed
·
1 Parent(s): b37d0d4

adapt different languages for srt.py

Browse files

Former-commit-id: 197fea8adb5bd35d60d2fc0d09ddc1af21cac117

src/srt_util/srt.py CHANGED
@@ -8,9 +8,22 @@ import logging
8
  import openai
9
  from tqdm import tqdm
10
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  class SrtSegment(object):
13
- def __init__(self, *args) -> None:
 
 
 
14
  if isinstance(args[0], dict):
15
  segment = args[0]
16
  self.start = segment['start']
@@ -83,11 +96,11 @@ class SrtSegment(object):
83
 
84
  def remove_trans_punc(self) -> None:
85
  """
86
- remove CN punctuations in translation text
87
  :return: None
88
  """
89
- punc_cn = ",。!?"
90
- translator = str.maketrans(punc_cn, ' ' * len(punc_cn))
91
  self.translation = self.translation.translate(translator)
92
 
93
  def __str__(self) -> str:
@@ -101,11 +114,13 @@ class SrtSegment(object):
101
 
102
 
103
  class SrtScript(object):
104
- def __init__(self, segments) -> None:
105
- self.segments = [SrtSegment(seg) for seg in segments]
 
 
106
 
107
  @classmethod
108
- def parse_from_srt_file(cls, path: str):
109
  with open(path, 'r', encoding="utf-8") as f:
110
  script_lines = [line.rstrip() for line in f.readlines()]
111
  bilingual = False
@@ -119,7 +134,7 @@ class SrtScript(object):
119
  for i in range(0, len(script_lines), 4):
120
  segments.append(list(script_lines[i:i + 4]))
121
 
122
- return cls(segments)
123
 
124
  def merge_segs(self, idx_list) -> SrtSegment:
125
  """
@@ -309,14 +324,14 @@ class SrtScript(object):
309
  seg1_dict['text'] = src_seg1
310
  seg1_dict['start'] = start_seg1
311
  seg1_dict['end'] = end_seg1
312
- seg1 = SrtSegment(seg1_dict)
313
  seg1.translation = trans_seg1
314
 
315
  seg2_dict = {}
316
  seg2_dict['text'] = src_seg2
317
  seg2_dict['start'] = start_seg2
318
  seg2_dict['end'] = end_seg2
319
- seg2 = SrtSegment(seg2_dict)
320
  seg2.translation = trans_seg2
321
 
322
  result_list = []
 
8
  import openai
9
  from tqdm import tqdm
10
 
11
+ punctuation_dict = {
12
+ "EN": ". , ? ! : ; - ( ) [ ] { } ' \"",
13
+ "ES": ". , ? ! : ; - ( ) [ ] { } ' \" ¡ ¿",
14
+ "FR": ". , ? ! : ; - ( ) [ ] { } ' \" « » —",
15
+ "DE": ". , ? ! : ; - ( ) [ ] { } ' \" „ “ –",
16
+ "RU": ". , ? ! : ; - ( ) [ ] { } ' \" « » —",
17
+ "ZH": "。 , ? ! : ; — ( ) ​``【oaicite:1】``​ 《 》 “ ”",
18
+ "JA": "。 、 ? ! : ; ー ( ) ​``【oaicite:0】``​ 「 」 『 』",
19
+ "AR": ". , ? ! : ; - ( ) [ ] { } ، ؛ ؟ « »",
20
+ }
21
 
22
  class SrtSegment(object):
23
+ def __init__(self, src_lang, tgt_lang, *args) -> None:
24
+ self.src_lang = src_lang
25
+ self.tgt_lang = tgt_lang
26
+
27
  if isinstance(args[0], dict):
28
  segment = args[0]
29
  self.start = segment['start']
 
96
 
97
  def remove_trans_punc(self) -> None:
98
  """
99
+ remove punctuations in translation text
100
  :return: None
101
  """
102
+ punc = punctuation_dict[self.tgt_lang]
103
+ translator = str.maketrans(punc, ' ' * len(punc))
104
  self.translation = self.translation.translate(translator)
105
 
106
  def __str__(self) -> str:
 
114
 
115
 
116
  class SrtScript(object):
117
+ def __init__(self, src_lang, tgt_lang, segments) -> None:
118
+ self.src_lang = src_lang
119
+ self.tgt_lang = tgt_lang
120
+ self.segments = [SrtSegment(self.src_lang, self.tgt_lang, seg) for seg in segments]
121
 
122
  @classmethod
123
+ def parse_from_srt_file(cls, src_lang, tgt_lang, path: str):
124
  with open(path, 'r', encoding="utf-8") as f:
125
  script_lines = [line.rstrip() for line in f.readlines()]
126
  bilingual = False
 
134
  for i in range(0, len(script_lines), 4):
135
  segments.append(list(script_lines[i:i + 4]))
136
 
137
+ return cls(src_lang, tgt_lang, segments)
138
 
139
  def merge_segs(self, idx_list) -> SrtSegment:
140
  """
 
324
  seg1_dict['text'] = src_seg1
325
  seg1_dict['start'] = start_seg1
326
  seg1_dict['end'] = end_seg1
327
+ seg1 = SrtSegment(self.src_lang, self.tgt_lang, seg1_dict)
328
  seg1.translation = trans_seg1
329
 
330
  seg2_dict = {}
331
  seg2_dict['text'] = src_seg2
332
  seg2_dict['start'] = start_seg2
333
  seg2_dict['end'] = end_seg2
334
+ seg2 = SrtSegment(self.src_lang, self.tgt_lang, seg2_dict)
335
  seg2.translation = trans_seg2
336
 
337
  result_list = []
src/task.py CHANGED
@@ -129,10 +129,10 @@ class Task:
129
  # TODO: setup ASR module like translator
130
  self.status = TaskStatus.INITIALIZING_ASR
131
  self.t_s = time()
132
- # self.SRT_Script = SrtScript
133
  method = self.ASR_setting["whisper_config"]["method"]
134
  whisper_model = self.ASR_setting["whisper_config"]["whisper_model"]
135
- src_srt_path = self.task_local_dir.joinpath(f"task_{self.task_id})_{self.source_lang}.srt")
136
  if not Path.exists(src_srt_path):
137
  # extract script from audio
138
  logging.info("extract script from audio")
@@ -157,7 +157,7 @@ class Task:
157
  # after get the transcript, release the gpu resource
158
  torch.cuda.empty_cache()
159
 
160
- self.SRT_Script = SrtScript(transcript['segments'])
161
  # save the srt script to local
162
  self.SRT_Script.write_srt_file_src(src_srt_path)
163
 
 
129
  # TODO: setup ASR module like translator
130
  self.status = TaskStatus.INITIALIZING_ASR
131
  self.t_s = time()
132
+
133
  method = self.ASR_setting["whisper_config"]["method"]
134
  whisper_model = self.ASR_setting["whisper_config"]["whisper_model"]
135
+ src_srt_path = self.task_local_dir.joinpath(f"task_{self.task_id}_{self.source_lang}.srt")
136
  if not Path.exists(src_srt_path):
137
  # extract script from audio
138
  logging.info("extract script from audio")
 
157
  # after get the transcript, release the gpu resource
158
  torch.cuda.empty_cache()
159
 
160
+ self.SRT_Script = SrtScript(self.source_lang, self.target_lang, transcript['segments'])
161
  # save the srt script to local
162
  self.SRT_Script.write_srt_file_src(src_srt_path)
163
 
src/translators/translation.py CHANGED
@@ -18,8 +18,6 @@ def check_translation(sentence, translation):
18
  translation_count = translation.count('\n\n') + 1
19
 
20
  if sentence_count != translation_count:
21
- # print("sentence length: ", len(sentence), sentence_count)
22
- # print("translation length: ", len(translation), translation_count)
23
  return False
24
  else:
25
  return True
@@ -34,7 +32,7 @@ def prompt_selector(src_lang, tgt_lang, domain):
34
  tgt_lang = language_map[tgt_lang]
35
  prompt = f"""
36
  you are a translation assistant, your job is to translate a video in domain of {domain} from {src_lang} to {tgt_lang},
37
- you will be provided with a segement in {[src_lang]} parsed by line, where your translation text should keep the original
38
  meaning and the number of lines.
39
  """
40
  return prompt
@@ -56,11 +54,11 @@ def translate(srt, script_arr, range_arr, model_name, video_name=None, attempts_
56
  :param task: Prompt.
57
  :param temp: Model temperature.
58
  """
59
- #logging.info("Start translating...")
60
  if input is None:
61
  raise Exception("Warning! No Input have passed to LLM!")
62
  if task is None:
63
- task = "你是一个翻译助理,你的任务是翻译星际争霸视频,你会被提供一个按行分割的英文段落,你需要在保证句意和行数的情况下输出翻译后的文本。"
64
  logging.info(f"translation prompt: {task}")
65
  previous_length = 0
66
  for sentence, range_ in tqdm(zip(script_arr, range_arr)):
 
18
  translation_count = translation.count('\n\n') + 1
19
 
20
  if sentence_count != translation_count:
 
 
21
  return False
22
  else:
23
  return True
 
32
  tgt_lang = language_map[tgt_lang]
33
  prompt = f"""
34
  you are a translation assistant, your job is to translate a video in domain of {domain} from {src_lang} to {tgt_lang},
35
+ you will be provided with a segement in {src_lang} parsed by line, where your translation text should keep the original
36
  meaning and the number of lines.
37
  """
38
  return prompt
 
54
  :param task: Prompt.
55
  :param temp: Model temperature.
56
  """
57
+
58
  if input is None:
59
  raise Exception("Warning! No Input have passed to LLM!")
60
  if task is None:
61
+ task = "你是一个翻译助理,你的任务是翻译视频,你会被提供一个按行分割的英文段落,你需要在保证句意和行数的情况下输出翻译后的文本。"
62
  logging.info(f"translation prompt: {task}")
63
  previous_length = 0
64
  for sentence, range_ in tqdm(zip(script_arr, range_arr)):