ViDove / SRT.py
Eason Lu
split seg done
7baae45
raw
history blame
14.1 kB
from datetime import timedelta
from csv import reader
from datetime import datetime
import re
class SRT_segment(object):
def __init__(self, *args) -> None:
if isinstance(args[0], dict):
segment = args[0]
self.start = segment['start']
self.end = segment['end']
self.start_ms = int((segment['start']*100)%100*10)
self.end_ms = int((segment['end']*100)%100*10)
if self.start_ms == self.end_ms and int(segment['start']) == int(segment['end']): # avoid empty time stamp
self.end_ms+=500
self.start_time = timedelta(seconds=int(segment['start']), milliseconds=self.start_ms)
self.end_time = timedelta(seconds=int(segment['end']), milliseconds=self.end_ms)
if self.start_ms == 0:
self.start_time_str = str(0)+str(self.start_time).split('.')[0]+',000'
else:
self.start_time_str = str(0)+str(self.start_time).split('.')[0]+','+str(self.start_time).split('.')[1][:3]
if self.end_ms == 0:
self.end_time_str = str(0)+str(self.end_time).split('.')[0]+',000'
else:
self.end_time_str = str(0)+str(self.end_time).split('.')[0]+','+str(self.end_time).split('.')[1][:3]
self.source_text = segment['text']
self.duration = f"{self.start_time_str} --> {self.end_time_str}"
self.translation = ""
elif isinstance(args[0], list):
self.source_text = args[0][2]
self.duration = args[0][1]
self.start_time_str = self.duration.split(" --> ")[0]
self.end_time_str = self.duration.split(" --> ")[1]
# parse the time to float
self.start_ms = int(self.start_time_str.split(',')[1])/10
self.end_ms = int(self.end_time_str.split(',')[1])/10
start_list = self.start_time_str.split(',')[0].split(':')
self.start = int(start_list[0])*3600 + int(start_list[1])*60 + int(start_list[2]) + self.start_ms/100
end_list = self.end_time_str.split(',')[0].split(':')
self.end = int(end_list[0])*3600 + int(end_list[1])*60 + int(end_list[2]) + self.end_ms/100
self.translation = ""
def merge_seg(self, seg):
self.source_text += seg.source_text
self.translation += seg.translation
self.end_time_str = seg.end_time_str
self.duration = f"{self.start_time_str} --> {self.end_time_str}"
pass
def __str__(self) -> str:
return f'{self.duration}\n{self.source_text}\n\n'
def get_trans_str(self) -> str:
return f'{self.duration}\n{self.translation}\n\n'
def get_bilingual_str(self) -> str:
return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n'
class SRT_script():
def __init__(self, segments) -> None:
self.segments = []
for seg in segments:
srt_seg = SRT_segment(seg)
self.segments.append(srt_seg)
@classmethod
def parse_from_srt_file(cls, path:str):
with open(path, 'r', encoding="utf-8") as f:
script_lines = f.read().splitlines()
segments = []
for i in range(len(script_lines)):
if i % 4 == 0:
segments.append(list(script_lines[i:i+4]))
return cls(segments)
def merge_segs(self, idx_list) -> SRT_segment:
final_seg = self.segments[idx_list[0]]
if len(idx_list) == 1:
return final_seg
for idx in range(1, len(idx_list)):
final_seg.merge_seg(self.segments[idx_list[idx]])
return final_seg
def form_whole_sentence(self):
merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
sentence = []
for i, seg in enumerate(self.segments):
if seg.source_text[-1] == '.':
sentence.append(i)
merge_list.append(sentence)
sentence = []
else:
sentence.append(i)
segments = []
for idx_list in merge_list:
segments.append(self.merge_segs(idx_list))
self.segments = segments # need memory release?
def set_translation(self, translate:str, id_range:tuple):
start_seg_id = id_range[0]
end_seg_id = id_range[1]
lines = translate.split('\n\n')
if len(lines) != (end_seg_id - start_seg_id + 1):
print(id_range)
for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
print(seg.source_text)
print(translate)
for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
# naive way to due with merge translation problem
# TODO: need a smarter solution
if i < len(lines):
if "(Note:" in lines[i]: # to avoid note
lines.remove(lines[i])
if i == len(lines) - 1:
break
try:
seg.translation = lines[i].split(":" or ": ")[1]
except:
seg.translation = lines[i]
#print(lines[i])
pass
def split_seg(self, seg, threshold):
# TODO: evenly split seg to 2 parts and add new seg into self.segments
source_text = seg.source_text
translation = seg.translation
src_commas = [m.start() for m in re.finditer(',', source_text)]
trans_commas = [m.start() for m in re.finditer(',', translation)]
if len(src_commas) != 0:
src_split_idx = src_commas[len(src_commas)//2] if len(src_commas) % 2 == 1 else src_commas[len(src_commas)//2 - 1]
else:
src_space = [m.start() for m in re.finditer(' ', source_text)]
src_split_idx = src_space[len(src_space)//2] if len(src_space) % 2 == 1 else src_space[len(src_space)//2 - 1]
if len(trans_commas) != 0:
trans_split_idx = trans_commas[len(src_commas)//2] if len(trans_commas) % 2 == 1 else trans_commas[len(trans_commas)//2 - 1]
else:
trans_split_idx = len(translation)//2
src_seg1 = source_text[:src_split_idx]
src_seg2 = source_text[src_split_idx:]
trans_seg1 = translation[:trans_split_idx]
trans_seg2 = translation[trans_split_idx:]
start_seg1 = seg.start
end_seg1 = start_seg2 = seg.start + (seg.end - seg.start)/2
end_seg2 = seg.end
seg1_dict = {}
seg1_dict['text'] = src_seg1
seg1_dict['start'] = start_seg1
seg1_dict['end'] = end_seg1
seg1 = SRT_segment(seg1_dict)
seg1.translation = trans_seg1
seg2_dict = {}
seg2_dict['text'] = src_seg2
seg2_dict['start'] = start_seg2
seg2_dict['end'] = end_seg2
seg2 = SRT_segment(seg2_dict)
seg2.translation = trans_seg2
result_list = []
if len(seg1.translation) > threshold:
result_list += self.split_seg(seg1, threshold)
else:
result_list.append(seg1)
if len(seg2.translation) > threshold:
result_list += self.split_seg(seg2, threshold)
else:
result_list.append(seg2)
return result_list
def check_len_and_split(self, threshold=30):
# TODO: if sentence length >= threshold, split this segments to two
segments = []
for seg in self.segments:
if len(seg.translation) > threshold:
seg_list = self.split_seg(seg, threshold)
segments += seg_list
else:
segments.append(seg)
self.segments = segments
pass
def get_source_only(self):
# return a string with pure source text
result = ""
for i, seg in enumerate(self.segments):
result+=f'SENTENCE {i+1}: {seg.source_text}\n\n\n'
return result
def reform_src_str(self):
result = ""
for i, seg in enumerate(self.segments):
result += f'{i+1}\n'
result += str(seg)
return result
def reform_trans_str(self):
result = ""
for i, seg in enumerate(self.segments):
result += f'{i+1}\n'
result += seg.get_trans_str()
return result
def form_bilingual_str(self):
result = ""
for i, seg in enumerate(self.segments):
result += f'{i+1}\n'
result += seg.get_bilingual_str()
return result
def write_srt_file_src(self, path:str):
# write srt file to path
with open(path, "w", encoding='utf-8') as f:
f.write(self.reform_src_str())
pass
def write_srt_file_translate(self, path:str):
with open(path, "w", encoding='utf-8') as f:
f.write(self.reform_trans_str())
pass
def write_srt_file_bilingual(self, path:str):
with open(path, "w", encoding='utf-8') as f:
f.write(self.form_bilingual_str())
pass
def correct_with_force_term(self):
## force term correction
# TODO: shortcut translation i.e. VA, ob
# TODO: variety of translation
# load term dictionary
with open("./finetune_data/dict_enzh.csv",'r', encoding='utf-8') as f:
term_enzh_dict = {rows[0]:rows[1] for rows in reader(f)}
# change term
for seg in self.segments:
ready_words = seg.source_text.split(" ")
for i in range(len(ready_words)):
word = ready_words[i]
[real_word, pos] = self.get_real_word(word)
if real_word in term_enzh_dict:
new_word = word.replace(word[:pos], term_enzh_dict.get(real_word))
else:
new_word = word
ready_words[i] = new_word
# if word[-2:] == ".\n":
# if word[:-2].lower() in term_enzh_dict:
# new_word = word.replace(word[:-2], term_enzh_dict.get(word[:-2].lower()))
# ready_words[i] = new_word
# else:
# ready_words[i] = word
# elif word.lower() in term_enzh_dict:
# new_word = word.replace(word,term_enzh_dict.get(word.lower()))
# ready_words[i] = new_word
# else:
# ready_words[i]= word
seg.source_text = " ".join(ready_words)
pass
def spell_check_term(self):
## known bug: I've will be replaced because i've is not in the dict
import enchant
dict = enchant.Dict('en_US')
term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')
for seg in self.segments:
ready_words = seg.source_text.split(" ")
for i in range(len(ready_words)):
word = ready_words[i]
[real_word, pos] = self.get_real_word(word)
if not dict.check(real_word):
suggest = term_spellDict.suggest(real_word)
if suggest: # relax spell check
new_word = word.replace(word[:pos],suggest[0])
else:
new_word = word
ready_words[i] = new_word
# if word[-2:] == ".\n":
# real_word = word[:-2]
# if not dict.check(real_word.lower()):
# new_word = word.replace(word[:-2], term_spellDict.suggest(real_word.lower())[0])
# ready_words[i] = new_word
# elif word[-1:] in [".", "\n", ","]:
# real_word = word[:-1]
# if not dict.check(real_word.lower()):
# new_word = word.replace(word[:-1], term_spellDict.suggest(real_word.lower())[0])
# ready_words[i] = new_word
# elif not dict.check(word.lower()):
# new_word = word.replace(word,term_spellDict.suggest(word.lower())[0])
# ready_words[i] = new_word
seg.source_text = " ".join(ready_words)
pass
def spell_correction(self, word:str, arg:int):
try:
arg in [0,1]
except ValueError:
print('only 0 or 1 for argument')
def uncover(word:str):
if word[-2:] == ".\n":
real_word = word[:-2].lower()
n = -2
elif word[-1:] in [".", "\n", ",", "!", "?"]:
real_word = word[:-1].lower()
n = -1
else:
real_word = word.lower()
n = 0
return real_word, len(word)+n
real_word = uncover(word)[0]
pos = uncover(word)[1]
new_word = word
if arg == 0: # term translate mode
with open("finetune_data/dict_enzh.csv",'r', encoding='utf-8') as f:
term_enzh_dict = {rows[0]:rows[1] for rows in reader(f)}
if real_word in term_enzh_dict:
new_word = word.replace(word[:pos], term_enzh_dict.get(real_word))
elif arg == 1: # term spell check mode
import enchant
dict = enchant.Dict('en_US')
term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')
if not dict.check(real_word):
if term_spellDict.suggest(real_word): # relax spell check
new_word = word.replace(word[:pos],term_spellDict.suggest(real_word)[0])
return new_word
def get_real_word(self, word:str):
if word[-2:] == ".\n":
real_word = word[:-2].lower()
n = -2
elif word[-1:] in [".", "\n", ",", "!", "?"]:
real_word = word[:-1].lower()
n = -1
else:
real_word = word.lower()
n = 0
return real_word, len(word)+n