Spaces:
Sleeping
Sleeping
import os | |
import re | |
from pathlib import Path | |
from copy import copy, deepcopy | |
from csv import reader | |
from datetime import timedelta | |
import logging | |
import openai | |
from tqdm import tqdm | |
import dict_util | |
# punctuation dictionary for supported languages | |
punctuation_dict = { | |
"EN": { | |
"punc_str": ". , ? ! : ; - ( ) [ ] { }", | |
"comma": ", ", | |
"sentence_end": [".", "!", "?", ";"] | |
}, | |
"ES": { | |
"punc_str": ". , ? ! : ; - ( ) [ ] { } ¡ ¿", | |
"comma": ", ", | |
"sentence_end": [".", "!", "?", ";", "¡", "¿"] | |
}, | |
"FR": { | |
"punc_str": ".,?!:;«»—", | |
"comma": ", ", | |
"sentence_end": [".", "!", "?", ";"] | |
}, | |
"DE": { | |
"punc_str": ".,?!:;„“–", | |
"comma": ", ", | |
"sentence_end": [".", "!", "?", ";"] | |
}, | |
"RU": { | |
"punc_str": ".,?!:;-«»—", | |
"comma": ", ", | |
"sentence_end": [".", "!", "?", ";"] | |
}, | |
"ZH": { | |
"punc_str": "。,?!:;()", | |
"comma": ",", | |
"sentence_end": ["。", "!", "?"] | |
}, | |
"JA": { | |
"punc_str": "。、?!:;()", | |
"comma": "、", | |
"sentence_end": ["。", "!", "?"] | |
}, | |
"AR": { | |
"punc_str": ".,?!:;-()[]،؛ ؟ «»", | |
"comma": "، ", | |
"sentence_end": [".", "!", "?", ";", "؟"] | |
}, | |
} | |
dict_path = "./domain_dict" | |
class SrtSegment(object): | |
def __init__(self, src_lang, tgt_lang, *args) -> None: | |
self.src_lang = src_lang | |
self.tgt_lang = tgt_lang | |
if isinstance(args[0], dict): | |
segment = args[0] | |
self.start = segment['start'] | |
self.end = segment['end'] | |
self.start_ms = int((segment['start'] * 100) % 100 * 10) | |
self.end_ms = int((segment['end'] * 100) % 100 * 10) | |
if self.start_ms == self.end_ms and int(segment['start']) == int(segment['end']): # avoid empty time stamp | |
self.end_ms += 500 | |
self.start_time = timedelta(seconds=int(segment['start']), milliseconds=self.start_ms) | |
self.end_time = timedelta(seconds=int(segment['end']), milliseconds=self.end_ms) | |
if self.start_ms == 0: | |
self.start_time_str = str(0) + str(self.start_time).split('.')[0] + ',000' | |
else: | |
self.start_time_str = str(0) + str(self.start_time).split('.')[0] + ',' + \ | |
str(self.start_time).split('.')[1][:3] | |
if self.end_ms == 0: | |
self.end_time_str = str(0) + str(self.end_time).split('.')[0] + ',000' | |
else: | |
self.end_time_str = str(0) + str(self.end_time).split('.')[0] + ',' + str(self.end_time).split('.')[1][ | |
:3] | |
self.source_text = segment['text'].lstrip() | |
self.duration = f"{self.start_time_str} --> {self.end_time_str}" | |
self.translation = "" | |
elif isinstance(args[0], list): | |
self.source_text = args[0][2] | |
self.duration = args[0][1] | |
self.start_time_str = self.duration.split(" --> ")[0] | |
self.end_time_str = self.duration.split(" --> ")[1] | |
# parse the time to float | |
self.start_ms = int(self.start_time_str.split(',')[1]) / 10 | |
self.end_ms = int(self.end_time_str.split(',')[1]) / 10 | |
start_list = self.start_time_str.split(',')[0].split(':') | |
self.start = int(start_list[0]) * 3600 + int(start_list[1]) * 60 + int(start_list[2]) + self.start_ms / 100 | |
end_list = self.end_time_str.split(',')[0].split(':') | |
self.end = int(end_list[0]) * 3600 + int(end_list[1]) * 60 + int(end_list[2]) + self.end_ms / 100 | |
if len(args[0]) < 5: | |
self.translation = "" | |
else: | |
self.translation = args[0][3] | |
def merge_seg(self, seg): | |
""" | |
Merge the segment seg with the current segment in place. | |
:param seg: Another segment that is strictly next to current one. | |
:return: None | |
""" | |
# assert seg.start_ms == self.end_ms, f"cannot merge discontinuous segments." | |
self.source_text += f' {seg.source_text}' | |
self.translation += f' {seg.translation}' | |
self.end_time_str = seg.end_time_str | |
self.end = seg.end | |
self.end_ms = seg.end_ms | |
self.duration = f"{self.start_time_str} --> {self.end_time_str}" | |
def __add__(self, other): | |
""" | |
Merge the segment seg with the current segment, and return the new constructed segment. | |
No in-place modification. | |
This is used for '+' operator. | |
:param other: Another segment that is strictly next to added segment. | |
:return: new segment of the two sub-segments | |
""" | |
result = deepcopy(self) | |
result.merge_seg(other) | |
return result | |
def remove_trans_punc(self) -> None: | |
""" | |
remove punctuations in translation text | |
:return: None | |
""" | |
punc_str = punctuation_dict[self.tgt_lang]["punc_str"] | |
for punc in punc_str: | |
self.translation = self.translation.replace(punc, ' ') | |
# translator = str.maketrans(punc, ' ' * len(punc)) | |
# self.translation = self.translation.translate(translator) | |
def __str__(self) -> str: | |
return f'{self.duration}\n{self.source_text}\n\n' | |
def get_trans_str(self) -> str: | |
return f'{self.duration}\n{self.translation}\n\n' | |
def get_bilingual_str(self) -> str: | |
return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n' | |
class SrtScript(object): | |
def __init__(self, src_lang, tgt_lang, segments, domain="General") -> None: | |
self.domain = domain | |
self.src_lang = src_lang | |
self.tgt_lang = tgt_lang | |
self.segments = [SrtSegment(self.src_lang, self.tgt_lang, seg) for seg in segments] | |
if self.domain != "General": | |
if os.path.exists(f"{dict_path}/{self.domain}"): | |
# TODO: load dictionary | |
self.dict = dict_util.term_dict(f"{dict_path}/{self.domain}", src_lang, tgt_lang) | |
print(self.dict["robo"]) | |
... | |
else: | |
logging.error(f"domain {self.domain} doesn't exist, fallback to general domain, this will disable correct_with_force_term and spell_check_term") | |
self.domain = "General" | |
def parse_from_srt_file(cls, src_lang, tgt_lang, path: str): | |
with open(path, 'r', encoding="utf-8") as f: | |
script_lines = [line.rstrip() for line in f.readlines()] | |
bilingual = False | |
if script_lines[2] != '' and script_lines[3] != '': | |
bilingual = True | |
segments = [] | |
if bilingual: | |
for i in range(0, len(script_lines), 5): | |
segments.append(list(script_lines[i:i + 5])) | |
else: | |
for i in range(0, len(script_lines), 4): | |
segments.append(list(script_lines[i:i + 4])) | |
return cls(src_lang, tgt_lang, segments) | |
def merge_segs(self, idx_list) -> SrtSegment: | |
""" | |
Merge entire segment list to a single segment | |
:param idx_list: List of index to merge | |
:return: Merged list | |
""" | |
if not idx_list: | |
raise NotImplementedError('Empty idx_list') | |
seg_result = deepcopy(self.segments[idx_list[0]]) | |
if len(idx_list) == 1: | |
return seg_result | |
for idx in range(1, len(idx_list)): | |
seg_result += self.segments[idx_list[idx]] | |
return seg_result | |
def form_whole_sentence(self): | |
""" | |
Concatenate or Strip sentences and reconstruct segments list. This is because of | |
improper segmentation from openai-whisper. | |
:return: None | |
""" | |
logging.info("Forming whole sentences...") | |
merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]] | |
sentence = [] | |
ending_puncs = punctuation_dict[self.src_lang]["sentence_end"] | |
# Get each entire sentence of distinct segments, fill indices to merge_list | |
for i, seg in enumerate(self.segments): | |
if seg.source_text[-1] in ending_puncs and len(seg.source_text) > 10 and 'vs.' not in seg.source_text: | |
sentence.append(i) | |
merge_list.append(sentence) | |
sentence = [] | |
else: | |
sentence.append(i) | |
# Reconstruct segments, each with an entire sentence | |
segments = [] | |
for idx_list in merge_list: | |
if len(idx_list) > 1: | |
logging.info("merging segments: %s", idx_list) | |
segments.append(self.merge_segs(idx_list)) | |
self.segments = segments | |
def remove_trans_punctuation(self): | |
""" | |
Post-process: remove all punc after translation and split | |
:return: None | |
""" | |
for i, seg in enumerate(self.segments): | |
seg.remove_trans_punc() | |
logging.info("Removed punctuation in translation.") | |
def set_translation(self, translate: str, id_range: tuple, model, video_name, video_link=None): | |
start_seg_id = id_range[0] | |
end_seg_id = id_range[1] | |
src_text = "" | |
for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]): | |
src_text += seg.source_text | |
src_text += '\n\n' | |
def inner_func(target, input_str): | |
# handling merge sentences issue. | |
response = openai.ChatCompletion.create( | |
model="gpt-4", | |
messages=[ | |
{"role": "system", | |
"content": "Your task is to merge or split sentences into a specified number of lines as required. You need to ensure the meaning of the sentences as much as possible, but when necessary, a sentence can be divided into two lines for output"}, | |
{"role": "system", "content": "Note: You only need to output the processed {} sentences. If you need to output a sequence number, please separate it with a colon.".format(self.tgt_lang)}, | |
{"role": "user", "content": 'Please split or combine the following sentences into {} sentences:\n{}'.format(target, input_str)} | |
], | |
temperature=0.15 | |
) | |
return response['choices'][0]['message']['content'].strip() | |
# handling merge sentences issue. | |
lines = translate.split('\n\n') | |
if len(lines) < (end_seg_id - start_seg_id + 1): | |
count = 0 | |
solved = True | |
while count < 5 and len(lines) != (end_seg_id - start_seg_id + 1): | |
count += 1 | |
print("Solving Unmatched Lines|iteration {}".format(count)) | |
logging.error("Solving Unmatched Lines|iteration {}".format(count)) | |
flag = True | |
while flag: | |
flag = False | |
try: | |
translate = inner_func(end_seg_id - start_seg_id + 1, translate) | |
except Exception as e: | |
print("An error has occurred during solving unmatched lines:", e) | |
print("Retrying...") | |
logging.error("An error has occurred during solving unmatched lines:", e) | |
logging.error("Retrying...") | |
flag = True | |
lines = translate.split('\n') | |
if len(lines) < (end_seg_id - start_seg_id + 1): | |
solved = False | |
print("Failed Solving unmatched lines, Manually parse needed") | |
logging.error("Failed Solving unmatched lines, Manually parse needed") | |
# FIXME: put the error log in our log file | |
if not os.path.exists("./logs"): | |
os.mkdir("./logs") | |
if video_link: | |
log_file = "./logs/log_link.csv" | |
log_exist = os.path.exists(log_file) | |
with open(log_file, "a") as log: | |
if not log_exist: | |
log.write("range_of_text,iterations_solving,solved,file_length,video_link" + "\n") | |
log.write(str(id_range) + ',' + str(count) + ',' + str(solved) + ',' + str( | |
len(self.segments)) + ',' + video_link + "\n") | |
else: | |
log_file = "./logs/log_name.csv" | |
log_exist = os.path.exists(log_file) | |
with open(log_file, "a") as log: | |
if not log_exist: | |
log.write("range_of_text,iterations_solving,solved,file_length,video_name" + "\n") | |
log.write(str(id_range) + ',' + str(count) + ',' + str(solved) + ',' + str( | |
len(self.segments)) + ',' + video_name + "\n") | |
# print(lines) | |
for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]): | |
# naive way to due with merge translation problem | |
# TODO: need a smarter solution | |
if i < len(lines): | |
if "Note:" in lines[i]: # to avoid note | |
lines.remove(lines[i]) | |
max_num -= 1 | |
if i == len(lines) - 1: | |
break | |
if lines[i][0] in [' ', '\n']: | |
lines[i] = lines[i][1:] | |
seg.translation = lines[i] | |
def split_seg(self, seg, text_threshold, time_threshold): | |
# evenly split seg to 2 parts and add new seg into self.segments | |
# ignore the initial comma to solve the recursion problem | |
src_comma_str = punctuation_dict[self.src_lang]["comma"] | |
tgt_comma_str = punctuation_dict[self.tgt_lang]["comma"] | |
if len(seg.source_text) > 2: | |
if seg.source_text[:2] == src_comma_str: | |
seg.source_text = seg.source_text[2:] | |
if seg.translation[0] == tgt_comma_str: | |
seg.translation = seg.translation[1:] | |
source_text = seg.source_text | |
translation = seg.translation | |
# split the text based on commas | |
src_commas = [m.start() for m in re.finditer(src_comma_str, source_text)] | |
trans_commas = [m.start() for m in re.finditer(tgt_comma_str, translation)] | |
if len(src_commas) != 0: | |
src_split_idx = src_commas[len(src_commas) // 2] if len(src_commas) % 2 == 1 else src_commas[ | |
len(src_commas) // 2 - 1] | |
else: | |
# split the text based on spaces | |
src_space = [m.start() for m in re.finditer(' ', source_text)] | |
if len(src_space) > 0: | |
src_split_idx = src_space[len(src_space) // 2] if len(src_space) % 2 == 1 else src_space[ | |
len(src_space) // 2 - 1] | |
else: | |
src_split_idx = 0 | |
if len(trans_commas) != 0: | |
trans_split_idx = trans_commas[len(trans_commas) // 2] if len(trans_commas) % 2 == 1 else trans_commas[ | |
len(trans_commas) // 2 - 1] | |
else: | |
trans_split_idx = len(translation) // 2 | |
# to avoid split English word | |
for i in range(trans_split_idx, len(translation)): | |
if not translation[i].encode('utf-8').isalpha(): | |
trans_split_idx = i | |
break | |
# split the time duration based on text length | |
time_split_ratio = trans_split_idx / (len(seg.translation) - 1) | |
src_seg1 = source_text[:src_split_idx] | |
src_seg2 = source_text[src_split_idx:] | |
trans_seg1 = translation[:trans_split_idx] | |
trans_seg2 = translation[trans_split_idx:] | |
start_seg1 = seg.start | |
end_seg1 = start_seg2 = seg.start + (seg.end - seg.start) * time_split_ratio | |
end_seg2 = seg.end | |
seg1_dict = {} | |
seg1_dict['text'] = src_seg1 | |
seg1_dict['start'] = start_seg1 | |
seg1_dict['end'] = end_seg1 | |
seg1 = SrtSegment(self.src_lang, self.tgt_lang, seg1_dict) | |
seg1.translation = trans_seg1 | |
seg2_dict = {} | |
seg2_dict['text'] = src_seg2 | |
seg2_dict['start'] = start_seg2 | |
seg2_dict['end'] = end_seg2 | |
seg2 = SrtSegment(self.src_lang, self.tgt_lang, seg2_dict) | |
seg2.translation = trans_seg2 | |
result_list = [] | |
if len(seg1.translation) > text_threshold and (seg1.end - seg1.start) > time_threshold: | |
result_list += self.split_seg(seg1, text_threshold, time_threshold) | |
else: | |
result_list.append(seg1) | |
if len(seg2.translation) > text_threshold and (seg2.end - seg2.start) > time_threshold: | |
result_list += self.split_seg(seg2, text_threshold, time_threshold) | |
else: | |
result_list.append(seg2) | |
return result_list | |
def check_len_and_split(self, text_threshold=30, time_threshold=1.0): | |
# if sentence length >= threshold and sentence duration > time_threshold, split this segments to two | |
logging.info("performing check_len_and_split") | |
segments = [] | |
for i, seg in enumerate(self.segments): | |
if len(seg.translation) > text_threshold and (seg.end - seg.start) > time_threshold: | |
seg_list = self.split_seg(seg, text_threshold, time_threshold) | |
logging.info("splitting segment {} in to {} parts".format(i + 1, len(seg_list))) | |
segments += seg_list | |
else: | |
segments.append(seg) | |
self.segments = segments | |
logging.info("check_len_and_split finished") | |
def check_len_and_split_range(self, range, text_threshold=30, time_threshold=1.0): | |
# DEPRECATED | |
# if sentence length >= text_threshold, split this segments to two | |
start_seg_id = range[0] | |
end_seg_id = range[1] | |
extra_len = 0 | |
segments = [] | |
for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]): | |
if len(seg.translation) > text_threshold and (seg.end - seg.start) > time_threshold: | |
seg_list = self.split_seg(seg, text_threshold, time_threshold) | |
segments += seg_list | |
extra_len += len(seg_list) - 1 | |
else: | |
segments.append(seg) | |
self.segments[start_seg_id - 1:end_seg_id] = segments | |
return extra_len | |
def correct_with_force_term(self): | |
## force term correction | |
logging.info("performing force term correction") | |
# check domain | |
if self.domain == "General": | |
logging.info("General domain could not perform correct_with_force_term. skip this step.") | |
pass | |
else: | |
keywords = list(self.dict.keys()) | |
keywords.sort(key=lambda x: len(x), reverse=True) | |
for word in keywords: | |
for i, seg in enumerate(self.segments): | |
if word in seg.source_text.lower(): | |
seg.source_text = re.sub(fr"({word}es|{word}s?)\b", "{}".format(self.dict.get(word)), | |
seg.source_text, flags=re.IGNORECASE) | |
logging.info( | |
"replace term: " + word + " --> " + self.dict.get(word) + " in time stamp {}".format( | |
i + 1)) | |
logging.info("source text becomes: " + seg.source_text) | |
comp_dict = [] | |
def fetchfunc(self, word, threshold): | |
import enchant | |
result = word | |
distance = 0 | |
threshold = threshold * len(word) | |
if len(self.comp_dict) == 0: | |
with open("./finetune_data/dict_freq.txt", 'r', encoding='utf-8') as f: | |
self.comp_dict = {rows[0]: 1 for rows in reader(f)} | |
temp = "" | |
for matched in self.comp_dict: | |
if (" " in matched and " " in word) or (" " not in matched and " " not in word): | |
if enchant.utils.levenshtein(word, matched) < enchant.utils.levenshtein(word, temp): | |
temp = matched | |
if enchant.utils.levenshtein(word, temp) < threshold: | |
distance = enchant.utils.levenshtein(word, temp) | |
result = temp | |
return distance, result | |
def extract_words(self, sentence, n): | |
# this function split the sentence to chunks by n of words | |
# e.g. sentence: "this, is a sentence", n = 2 | |
# result: ["this,", "is", "a", ["sentence"], ["this,", "is"], "is a", "a sentence"] | |
words = sentence.split() | |
res = [] | |
for j in range(n, 0, -1): | |
res += [words[i:i + j] for i in range(len(words) - j + 1)] | |
return res | |
def spell_check_term(self): | |
logging.info("performing spell check") | |
# check domain | |
if self.domain == "General": | |
logging.info("General domain could not perform spell_check_term. skip this step.") | |
pass | |
import enchant | |
dict = enchant.Dict('en_US') | |
term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt') | |
for seg in tqdm(self.segments): | |
ready_words = self.extract_words(seg.source_text, 2) | |
for i in range(len(ready_words)): | |
word_list = ready_words[i] | |
word, real_word, pos = self.get_real_word(word_list) | |
if not dict.check(real_word) and not term_spellDict.check(real_word): | |
distance, correct_term = self.fetchfunc(real_word, 0.3) | |
if distance != 0: | |
seg.source_text = re.sub(word[:pos], correct_term, seg.source_text, flags=re.IGNORECASE) | |
logging.info( | |
"replace: " + word[:pos] + " to " + correct_term + "\t distance = " + str(distance)) | |
def get_real_word(self, word_list: list): | |
word = "" | |
for w in word_list: | |
word += f"{w} " | |
word = word[:-1] # "this, is" | |
if word[-2:] == ".\n": | |
real_word = word[:-2].lower() | |
n = -2 | |
elif word[-1:] in [".", "\n", ",", "!", "?"]: | |
real_word = word[:-1].lower() | |
n = -1 | |
else: | |
real_word = word.lower() | |
n = 0 | |
return word, real_word, len(word) + n | |
## WRITE AND READ FUNCTIONS ## | |
def get_source_only(self): | |
# return a string with pure source text | |
result = "" | |
for i, seg in enumerate(self.segments): | |
result += f'{seg.source_text}\n\n\n' # f'SENTENCE {i+1}: {seg.source_text}\n\n\n' | |
return result | |
def reform_src_str(self): | |
result = "" | |
for i, seg in enumerate(self.segments): | |
result += f'{i + 1}\n' | |
result += str(seg) | |
return result | |
def reform_trans_str(self): | |
result = "" | |
for i, seg in enumerate(self.segments): | |
result += f'{i + 1}\n' | |
result += seg.get_trans_str() | |
return result | |
def form_bilingual_str(self): | |
result = "" | |
for i, seg in enumerate(self.segments): | |
result += f'{i + 1}\n' | |
result += seg.get_bilingual_str() | |
return result | |
def write_srt_file_src(self, path: str): | |
# write srt file to path | |
with open(path, "w", encoding='utf-8') as f: | |
f.write(self.reform_src_str()) | |
pass | |
def write_srt_file_translate(self, path: str): | |
logging.info("writing to " + path) | |
with open(path, "w", encoding='utf-8') as f: | |
f.write(self.reform_trans_str()) | |
pass | |
def write_srt_file_bilingual(self, path: str): | |
logging.info("writing to " + path) | |
with open(path, "w", encoding='utf-8') as f: | |
f.write(self.form_bilingual_str()) | |
pass | |
def realtime_write_srt(self, path, range, length, idx): | |
# DEPRECATED | |
start_seg_id = range[0] | |
end_seg_id = range[1] | |
with open(path, "a", encoding='utf-8') as f: | |
# for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id+length]): | |
# f.write(f'{i+idx}\n') | |
# f.write(seg.get_trans_str()) | |
for i, seg in enumerate(self.segments): | |
if i < range[0] - 1: continue | |
if i >= range[1] + length: break | |
f.write(f'{i + idx}\n') | |
f.write(seg.get_trans_str()) | |
pass | |
def realtime_bilingual_write_srt(self, path, range, length, idx): | |
# DEPRECATED | |
start_seg_id = range[0] | |
end_seg_id = range[1] | |
with open(path, "a", encoding='utf-8') as f: | |
for i, seg in enumerate(self.segments): | |
if i < range[0] - 1: continue | |
if i >= range[1] + length: break | |
f.write(f'{i + idx}\n') | |
f.write(seg.get_bilingual_str()) | |
pass | |
def split_script(script_in, chunk_size=1000): | |
script_split = script_in.split('\n\n') | |
script_arr = [] | |
range_arr = [] | |
start = 1 | |
end = 0 | |
script = "" | |
for sentence in script_split: | |
if len(script) + len(sentence) + 1 <= chunk_size: | |
script += sentence + '\n\n' | |
end += 1 | |
else: | |
range_arr.append((start, end)) | |
start = end + 1 | |
end += 1 | |
script_arr.append(script.strip()) | |
script = sentence + '\n\n' | |
if script.strip(): | |
script_arr.append(script.strip()) | |
range_arr.append((start, len(script_split) - 1)) | |
assert len(script_arr) == len(range_arr) | |
return script_arr, range_arr |