Spaces:
Running
Running
import sys | |
import numpy as np | |
sys.path.append('../src') | |
from srt_util.srt import SrtScript | |
from srt_util.srt import SrtSegment | |
# Helper method | |
# Align sub anchor segment pair via greedy approach | |
# Input: anchor segment, SRT segments, output array of sub, index of current sub | |
# Output: updated index of sub | |
def procedure(anchor, subsec, S_arr, subidx): | |
cache_idx = 0 | |
while subidx != cache_idx: # Terminate when alignment stablizes | |
cache_idx = subidx | |
# if sub segment runs out during the loop, terminate | |
if subidx >= len(subsec): | |
break | |
sub = subsec[subidx] | |
if anchor.end < sub.start: | |
continue | |
# If next sub has a heavier overlap compartment, add to current alignment | |
if (anchor.start <= sub.start) and (sub.end <= anchor.end) or anchor.end - sub.start > sub.end - anchor.start: | |
S_arr[-1] += sub#.source_text | |
subidx += 1 | |
return subidx - 1 # Reset last invalid update from loop | |
# Input: path1, path2 | |
# Output: aligned array of SRTsegment corresponding to path1 path2 | |
# Note: Modify comment with .source_text to get output array with string only | |
def alignment_obsolete(pred_path, gt_path): | |
empt = SrtSegment([0,'00:00:00,000 --> 00:00:00,000','','','']) | |
pred = SrtScript.parse_from_srt_file(pred_path).segments | |
gt = SrtScript.parse_from_srt_file(gt_path).segments | |
pred_arr, gt_arr = [], [] | |
idx_p, idx_t = 0, 0 # idx_p: current index of pred segment, idx_t for ground truth | |
while idx_p < len(pred) or idx_t < len(gt): | |
# Check if one srt file runs out while reading | |
ps = pred[idx_p] if idx_p < len(pred) else None | |
gs = gt[idx_t] if idx_t < len(gt) else None | |
if not ps: | |
# If ps runs out, align gs segment with filler one by one | |
gt_arr.append(gs)#.source_text | |
pred_arr.append(empt) | |
idx_t += 1 | |
continue | |
if not gs: | |
# If gs runs out, align ps segment with filler one by one | |
pred_arr.append(ps)#.source_text | |
gt_arr.append(empt) | |
idx_p += 1 | |
continue | |
ps_dur = ps.end - ps.start | |
gs_dur = gs.end - gs.start | |
# Check for duration to decide anchor and sub | |
if ps_dur <= gs_dur: | |
# Detect segment with no overlap | |
if ps.end < gs.start: | |
pred_arr.append(ps)#.source_text | |
gt_arr.append(empt) # append filler | |
idx_t -= 1 # reset ground truth index | |
else: | |
if gs.end >= ps.start: | |
gt_arr.append(gs)#.source_text | |
pred_arr.append(ps)#.source_text | |
idx_p = procedure(gs, pred, pred_arr, idx_p + 1) | |
else: | |
gt_arr[len(gt_arr) - 1] += gs#.source_text | |
#pred_arr.append(empt) | |
idx_p -= 1 | |
else: | |
# same overlap checking procedure | |
if gs.end < ps.start: | |
gt_arr.append(gs)#.source_text | |
pred_arr.append(empt) # filler | |
idx_p -= 1 # reset | |
else: | |
if ps.end >= gs.start: | |
pred_arr.append(ps)#.source_text | |
gt_arr.append(gs)#.source_text | |
idx_t = procedure(ps, gt, gt_arr, idx_t + 1) | |
else: # filler pairing | |
pred_arr[len(pred_arr) - 1] += ps | |
idx_t -= 1 | |
idx_p += 1 | |
idx_t += 1 | |
#for a in gt_arr: | |
# print(a.translation) | |
return zip(pred_arr, gt_arr) | |
# Input: path1, path2, threshold = 0.5 sec by default | |
# Output: aligned array of SRTsegment corresponding to path1 path2 | |
def alignment(pred_path, gt_path, threshold=0.5): | |
empt = SrtSegment([0, '00:00:00,000 --> 00:00:00,000', '', '', '']) | |
pred = SrtScript.parse_from_srt_file(pred_path).segments | |
gt = SrtScript.parse_from_srt_file(gt_path).segments | |
pred_arr, gt_arr = [], [] | |
idx_p, idx_t = 0, 0 | |
while idx_p < len(pred) or idx_t < len(gt): | |
ps = pred[idx_p] if idx_p < len(pred) else empt | |
gs = gt[idx_t] if idx_t < len(gt) else empt | |
# Merging sequence for pred | |
while idx_p + 1 < len(pred) and pred[idx_p + 1].end <= gs.end + threshold: | |
ps += pred[idx_p + 1] | |
idx_p += 1 | |
# Merging sequence for gt | |
while idx_t + 1 < len(gt) and gt[idx_t + 1].end <= ps.end + threshold: | |
gs += gt[idx_t + 1] | |
idx_t += 1 | |
# Append to the result arrays | |
pred_arr.append(ps) | |
gt_arr.append(gs) | |
idx_p += 1 | |
idx_t += 1 | |
#for a in pred_arr: | |
# print(a.translation) | |
#for a in gt_arr: | |
# print(a.source_text) | |
return zip(pred_arr, gt_arr) | |
# Test Case | |
#alignment('test_translation_s2.srt', 'test_translation_zh.srt') | |