import os import re from copy import copy, deepcopy from csv import reader from datetime import timedelta import logging import openai from tqdm import tqdm class SRT_segment(object): def __init__(self, *args) -> None: if isinstance(args[0], dict): segment = args[0] self.start = segment['start'] self.end = segment['end'] self.start_ms = int((segment['start'] * 100) % 100 * 10) self.end_ms = int((segment['end'] * 100) % 100 * 10) if self.start_ms == self.end_ms and int(segment['start']) == int(segment['end']): # avoid empty time stamp self.end_ms += 500 self.start_time = timedelta(seconds=int(segment['start']), milliseconds=self.start_ms) self.end_time = timedelta(seconds=int(segment['end']), milliseconds=self.end_ms) if self.start_ms == 0: self.start_time_str = str(0) + str(self.start_time).split('.')[0] + ',000' else: self.start_time_str = str(0) + str(self.start_time).split('.')[0] + ',' + \ str(self.start_time).split('.')[1][:3] if self.end_ms == 0: self.end_time_str = str(0) + str(self.end_time).split('.')[0] + ',000' else: self.end_time_str = str(0) + str(self.end_time).split('.')[0] + ',' + str(self.end_time).split('.')[1][ :3] self.source_text = segment['text'].lstrip() self.duration = f"{self.start_time_str} --> {self.end_time_str}" self.translation = "" elif isinstance(args[0], list): self.source_text = args[0][2] self.duration = args[0][1] self.start_time_str = self.duration.split(" --> ")[0] self.end_time_str = self.duration.split(" --> ")[1] # parse the time to float self.start_ms = int(self.start_time_str.split(',')[1]) / 10 self.end_ms = int(self.end_time_str.split(',')[1]) / 10 start_list = self.start_time_str.split(',')[0].split(':') self.start = int(start_list[0]) * 3600 + int(start_list[1]) * 60 + int(start_list[2]) + self.start_ms / 100 end_list = self.end_time_str.split(',')[0].split(':') self.end = int(end_list[0]) * 3600 + int(end_list[1]) * 60 + int(end_list[2]) + self.end_ms / 100 self.translation = "" def merge_seg(self, seg): """ Merge the segment seg with the current segment in place. :param seg: Another segment that is strictly next to current one. :return: None """ # assert seg.start_ms == self.end_ms, f"cannot merge discontinuous segments." self.source_text += f' {seg.source_text}' self.translation += f' {seg.translation}' self.end_time_str = seg.end_time_str self.end = seg.end self.end_ms = seg.end_ms self.duration = f"{self.start_time_str} --> {self.end_time_str}" pass def __add__(self, other): """ Merge the segment seg with the current segment, and return the new constructed segment. No in-place modification. :param other: Another segment that is strictly next to added segment. :return: new segment of the two sub-segments """ result = deepcopy(self) result.source_text += f' {other.source_text}' result.translation += f' {other.translation}' result.end_time_str = other.end_time_str result.end = other.end result.end_ms = other.end_ms result.duration = f"{self.start_time_str} --> {result.end_time_str}" return result def remove_trans_punc(self): """ remove punctuations in translation text :return: None """ punc_cn = ",。!?" translator = str.maketrans(punc_cn, ' ' * len(punc_cn)) self.translation = self.translation.translate(translator) def __str__(self) -> str: return f'{self.duration}\n{self.source_text}\n\n' def get_trans_str(self) -> str: return f'{self.duration}\n{self.translation}\n\n' def get_bilingual_str(self) -> str: return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n' class SRT_script(): def __init__(self, segments) -> None: self.segments = [] for seg in segments: srt_seg = SRT_segment(seg) self.segments.append(srt_seg) @classmethod def parse_from_srt_file(cls, path: str): with open(path, 'r', encoding="utf-8") as f: script_lines = [line.rstrip() for line in f.readlines()] segments = [] for i in range(len(script_lines)): if i % 4 == 0: segments.append(list(script_lines[i:i + 4])) return cls(segments) def merge_segs(self, idx_list) -> SRT_segment: """ Merge entire segment list to a single segment :param idx_list: List of index to merge :return: Merged list """ if not idx_list: raise NotImplementedError('Empty idx_list') seg_result = deepcopy(self.segments[idx_list[0]]) if len(idx_list) == 1: return seg_result for idx in range(1, len(idx_list)): seg_result += self.segments[idx_list[idx]] return seg_result def form_whole_sentence(self): """ Concatenate or Strip sentences and reconstruct segments list. This is because of improper segmentation from openai-whisper. :return: None """ logging.info("Forming whole sentences...") merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]] sentence = [] for i, seg in enumerate(self.segments): if seg.source_text[-1] in ['.', '!', '?'] and len(seg.source_text) > 10 and 'vs.' not in seg.source_text: sentence.append(i) merge_list.append(sentence) sentence = [] else: sentence.append(i) segments = [] for idx_list in merge_list: if len(idx_list) > 1: logging.info("merging segments: %s", idx_list) segments.append(self.merge_segs(idx_list)) self.segments = segments def remove_trans_punctuation(self): """ Post-process: remove all punc after translation and split :return: None """ for i, seg in enumerate(self.segments): seg.remove_trans_punc() logging.info("Removed punctuation in translation.") def set_translation(self, translate: str, id_range: tuple, model, video_name, video_link=None): start_seg_id = id_range[0] end_seg_id = id_range[1] src_text = "" for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]): src_text += seg.source_text src_text += '\n\n' def inner_func(target, input_str): response = openai.ChatCompletion.create( # model=model, model="gpt-4", messages=[ {"role": "system", "content": "你的任务是按照要求合并或拆分句子到指定行数,你需要尽可能保证句意,但必要时可以将一句话分为两行输出"}, {"role": "system", "content": "注意:你只需要输出处理过的中文句子,如果你要输出序号,请使用冒号隔开"}, {"role": "user", "content": '请将下面的句子拆分或组合为{}句:\n{}'.format(target, input_str)} ], temperature=0.15 ) return response['choices'][0]['message']['content'].strip() lines = translate.split('\n\n') if len(lines) < (end_seg_id - start_seg_id + 1): count = 0 solved = True while count < 5 and len(lines) != (end_seg_id - start_seg_id + 1): count += 1 print("Solving Unmatched Lines|iteration {}".format(count)) flag = True while flag: flag = False # print("translate:") # print(translate) try: # print("target") # print(end_seg_id - start_seg_id + 1) translate = inner_func(end_seg_id - start_seg_id + 1, translate) except Exception as e: print("An error has occurred during solving unmatched lines:", e) print("Retrying...") flag = True lines = translate.split('\n') # print("result") # print(len(lines)) if len(lines) < (end_seg_id - start_seg_id + 1): solved = False print("Failed Solving unmatched lines, Manually parse needed") if not os.path.exists("./logs"): os.mkdir("./logs") if video_link: log_file = "./logs/log_link.csv" log_exist = os.path.exists(log_file) with open(log_file, "a") as log: if not log_exist: log.write("range_of_text,iterations_solving,solved,file_length,video_link" + "\n") log.write(str(id_range) + ',' + str(count) + ',' + str(solved) + ',' + str( len(self.segments)) + ',' + video_link + "\n") else: log_file = "./logs/log_name.csv" log_exist = os.path.exists(log_file) with open(log_file, "a") as log: if not log_exist: log.write("range_of_text,iterations_solving,solved,file_length,video_name" + "\n") log.write(str(id_range) + ',' + str(count) + ',' + str(solved) + ',' + str( len(self.segments)) + ',' + video_name + "\n") print(lines) for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]): # naive way to due with merge translation problem # TODO: need a smarter solution if i < len(lines): if "Note:" in lines[i]: # to avoid note lines.remove(lines[i]) max_num -= 1 if i == len(lines) - 1: break if lines[i][0] in [' ', '\n']: lines[i] = lines[i][1:] seg.translation = lines[i] def split_seg(self, seg, text_threshold, time_threshold): # evenly split seg to 2 parts and add new seg into self.segments # ignore the initial comma to solve the recursion problem if len(seg.source_text) > 2: if seg.source_text[:2] == ', ': seg.source_text = seg.source_text[2:] if seg.translation[0] == ',': seg.translation = seg.translation[1:] source_text = seg.source_text translation = seg.translation # split the text based on commas src_commas = [m.start() for m in re.finditer(',', source_text)] trans_commas = [m.start() for m in re.finditer(',', translation)] if len(src_commas) != 0: src_split_idx = src_commas[len(src_commas) // 2] if len(src_commas) % 2 == 1 else src_commas[ len(src_commas) // 2 - 1] else: src_space = [m.start() for m in re.finditer(' ', source_text)] if len(src_space) > 0: src_split_idx = src_space[len(src_space) // 2] if len(src_space) % 2 == 1 else src_space[ len(src_space) // 2 - 1] else: src_split_idx = 0 if len(trans_commas) != 0: trans_split_idx = trans_commas[len(trans_commas) // 2] if len(trans_commas) % 2 == 1 else trans_commas[ len(trans_commas) // 2 - 1] else: trans_split_idx = len(translation) // 2 # to avoid split English word for i in range(trans_split_idx, len(translation)): if not translation[i].encode('utf-8').isalpha(): trans_split_idx = i break # split the time duration based on text length time_split_ratio = trans_split_idx / (len(seg.translation) - 1) src_seg1 = source_text[:src_split_idx] src_seg2 = source_text[src_split_idx:] trans_seg1 = translation[:trans_split_idx] trans_seg2 = translation[trans_split_idx:] start_seg1 = seg.start end_seg1 = start_seg2 = seg.start + (seg.end - seg.start) * time_split_ratio end_seg2 = seg.end seg1_dict = {} seg1_dict['text'] = src_seg1 seg1_dict['start'] = start_seg1 seg1_dict['end'] = end_seg1 seg1 = SRT_segment(seg1_dict) seg1.translation = trans_seg1 seg2_dict = {} seg2_dict['text'] = src_seg2 seg2_dict['start'] = start_seg2 seg2_dict['end'] = end_seg2 seg2 = SRT_segment(seg2_dict) seg2.translation = trans_seg2 result_list = [] if len(seg1.translation) > text_threshold and (seg1.end - seg1.start) > time_threshold: result_list += self.split_seg(seg1, text_threshold, time_threshold) else: result_list.append(seg1) if len(seg2.translation) > text_threshold and (seg2.end - seg2.start) > time_threshold: result_list += self.split_seg(seg2, text_threshold, time_threshold) else: result_list.append(seg2) return result_list def check_len_and_split(self, text_threshold=30, time_threshold=1.0): # if sentence length >= threshold and sentence duration > time_threshold, split this segments to two logging.info("performing check_len_and_split") segments = [] for i, seg in enumerate(self.segments): if len(seg.translation) > text_threshold and (seg.end - seg.start) > time_threshold: seg_list = self.split_seg(seg, text_threshold, time_threshold) logging.info("splitting segment {} in to {} parts".format(i+1, len(seg_list))) segments += seg_list else: segments.append(seg) self.segments = segments logging.info("check_len_and_split finished") pass def check_len_and_split_range(self, range, text_threshold=30, time_threshold=1.0): # DEPRECATED # if sentence length >= text_threshold, split this segments to two start_seg_id = range[0] end_seg_id = range[1] extra_len = 0 segments = [] for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]): if len(seg.translation) > text_threshold and (seg.end - seg.start) > time_threshold: seg_list = self.split_seg(seg, text_threshold, time_threshold) segments += seg_list extra_len += len(seg_list) - 1 else: segments.append(seg) self.segments[start_seg_id - 1:end_seg_id] = segments return extra_len def correct_with_force_term(self): ## force term correction logging.info("performing force term correction") # load term dictionary with open("./finetune_data/dict_enzh.csv", 'r', encoding='utf-8') as f: term_enzh_dict = {rows[0]: rows[1] for rows in reader(f)} keywords = list(term_enzh_dict.keys()) keywords.sort(key=lambda x: len(x), reverse=True) for word in keywords: for i, seg in enumerate(self.segments): if word in seg.source_text.lower(): seg.source_text = re.sub(fr"({word}es|{word}s?)\b", "{}".format(term_enzh_dict.get(word)), seg.source_text, flags=re.IGNORECASE) logging.info("replace term: " + word + " --> " + term_enzh_dict.get(word) + " in time stamp {}".format(i+1)) logging.info("source text becomes: " + seg.source_text) comp_dict = [] def fetchfunc(self,word,threshold): import enchant result = word distance = 0 threshold = threshold*len(word) if len(self.comp_dict)==0: with open("./finetune_data/dict_freq.txt", 'r', encoding='utf-8') as f: self.comp_dict = {rows[0]: 1 for rows in reader(f)} temp = "" for matched in self.comp_dict: if (" " in matched and " " in word) or (" " not in matched and " " not in word): if enchant.utils.levenshtein(word, matched)= range[1] + length: break f.write(f'{i + idx}\n') f.write(seg.get_trans_str()) pass def realtime_bilingual_write_srt(self, path, range, length, idx): # DEPRECATED start_seg_id = range[0] end_seg_id = range[1] with open(path, "a", encoding='utf-8') as f: for i, seg in enumerate(self.segments): if i < range[0] - 1: continue if i >= range[1] + length: break f.write(f'{i + idx}\n') f.write(seg.get_bilingual_str()) pass