Macrodove commited on
Commit
5543434
·
1 Parent(s): 105baa8

fixed alignment_obsolete, rewrite alignment with new approach

Browse files
Files changed (1) hide show
  1. evaluation/alignment.py +42 -8
evaluation/alignment.py CHANGED
@@ -30,7 +30,7 @@ def procedure(anchor, subsec, S_arr, subidx):
30
  # Input: path1, path2
31
  # Output: aligned array of SRTsegment corresponding to path1 path2
32
  # Note: Modify comment with .source_text to get output array with string only
33
- def alignment(pred_path, gt_path):
34
  empt = SrtSegment([0,'00:00:00,000 --> 00:00:00,000','','',''])
35
  pred = SrtScript.parse_from_srt_file(pred_path).segments
36
  gt = SrtScript.parse_from_srt_file(gt_path).segments
@@ -67,12 +67,14 @@ def alignment(pred_path, gt_path):
67
  gt_arr.append(empt) # append filler
68
  idx_t -= 1 # reset ground truth index
69
  else:
70
- gt_arr.append(gs)#.source_text
71
  if gs.end >= ps.start:
 
72
  pred_arr.append(ps)#.source_text
73
  idx_p = procedure(gs, pred, pred_arr, idx_p + 1)
74
- else: # filler pairing
75
- pred_arr.append(empt)
 
76
  idx_p -= 1
77
  else:
78
  # same overlap checking procedure
@@ -81,19 +83,51 @@ def alignment(pred_path, gt_path):
81
  pred_arr.append(empt) # filler
82
  idx_p -= 1 # reset
83
  else:
84
- pred_arr.append(ps)#.source_text
85
  if ps.end >= gs.start:
 
86
  gt_arr.append(gs)#.source_text
87
  idx_t = procedure(ps, gt, gt_arr, idx_t + 1)
88
  else: # filler pairing
89
- gt_arr.append(empt)
90
  idx_t -= 1
91
 
92
  idx_p += 1
93
  idx_t += 1
94
- #for a in pred_arr:
95
- # print(a.source_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  return zip(pred_arr, gt_arr)
97
 
 
98
  # Test Case
99
  #alignment('test_translation_zh.srt', 'test_translation_bi.srt')
 
30
  # Input: path1, path2
31
  # Output: aligned array of SRTsegment corresponding to path1 path2
32
  # Note: Modify comment with .source_text to get output array with string only
33
+ def alignment_obsolete(pred_path, gt_path):
34
  empt = SrtSegment([0,'00:00:00,000 --> 00:00:00,000','','',''])
35
  pred = SrtScript.parse_from_srt_file(pred_path).segments
36
  gt = SrtScript.parse_from_srt_file(gt_path).segments
 
67
  gt_arr.append(empt) # append filler
68
  idx_t -= 1 # reset ground truth index
69
  else:
70
+
71
  if gs.end >= ps.start:
72
+ gt_arr.append(gs)#.source_text
73
  pred_arr.append(ps)#.source_text
74
  idx_p = procedure(gs, pred, pred_arr, idx_p + 1)
75
+ else:
76
+ gt_arr[len(gt_arr) - 1] += gs#.source_text
77
+ #pred_arr.append(empt)
78
  idx_p -= 1
79
  else:
80
  # same overlap checking procedure
 
83
  pred_arr.append(empt) # filler
84
  idx_p -= 1 # reset
85
  else:
 
86
  if ps.end >= gs.start:
87
+ pred_arr.append(ps)#.source_text
88
  gt_arr.append(gs)#.source_text
89
  idx_t = procedure(ps, gt, gt_arr, idx_t + 1)
90
  else: # filler pairing
91
+ pred_arr[len(pred_arr) - 1] += ps
92
  idx_t -= 1
93
 
94
  idx_p += 1
95
  idx_t += 1
96
+ #for a in gt_arr:
97
+ # print(a.translation)
98
+ return zip(pred_arr, gt_arr)
99
+
100
+ # Input: path1, path2, threshold = 0.5 sec by default
101
+ # Output: aligned array of SRTsegment corresponding to path1 path2
102
+ def alignment(pred_path, gt_path, threshold=0.5):
103
+ empt = SrtSegment([0, '00:00:00,000 --> 00:00:00,000', '', '', ''])
104
+ pred = SrtScript.parse_from_srt_file(pred_path).segments
105
+ gt = SrtScript.parse_from_srt_file(gt_path).segments
106
+ pred_arr, gt_arr = [], []
107
+ idx_p, idx_t = 0, 0
108
+
109
+ while idx_p < len(pred) or idx_t < len(gt):
110
+ ps = pred[idx_p] if idx_p < len(pred) else empt
111
+ gs = gt[idx_t] if idx_t < len(gt) else empt
112
+
113
+ # Merging sequence for pred
114
+ while idx_p + 1 < len(pred) and abs(pred[idx_p + 1].end - gs.end) <= threshold:
115
+ ps += pred[idx_p + 1]
116
+ idx_p += 1
117
+
118
+ # Merging sequence for gt
119
+ while idx_t + 1 < len(gt) and abs(gt[idx_t + 1].end - ps.end) <= threshold:
120
+ gs += gt[idx_t + 1]
121
+ idx_t += 1
122
+
123
+ # Append to the result arrays
124
+ pred_arr.append(ps)
125
+ gt_arr.append(gs)
126
+ idx_p += 1
127
+ idx_t += 1
128
+
129
  return zip(pred_arr, gt_arr)
130
 
131
+
132
  # Test Case
133
  #alignment('test_translation_zh.srt', 'test_translation_bi.srt')