Macrodove commited on
Commit
eb2a612
·
1 Parent(s): 1cbd56e

Bug fixed, code restructed

Browse files

Former-commit-id: b6060a23f96851752feb0ea174ff87ed297b35c6

Files changed (1) hide show
  1. evaluation/alignment.py +60 -71
evaluation/alignment.py CHANGED
@@ -3,87 +3,76 @@ import numpy as np
3
  sys.path.append('../src')
4
  from srt_util.srt import SrtScript
5
 
6
-
7
- def procedure(anchor,subsec,S_arr,subidx):
8
-
9
- temp = subsec[subidx - 1]
10
- print('------------------------------')
11
- print(anchor)
12
- print(temp)
13
-
14
  cache_idx = 0
15
- while subidx != cache_idx: # loop until alignment stablized
16
- cache_idx = subidx # reinitialize cache
17
- # Inside interval
18
- if subidx >= len(subsec): continue
19
  sub = subsec[subidx]
20
- if (anchor.end < sub.start): continue
21
- if (anchor.start < sub.start) & (sub.end < anchor.end):
22
- S_arr[len(S_arr) - 1] += sub.source_text
 
23
  subidx += 1
24
- elif anchor.end - sub.start > sub.end - anchor.start:
25
- S_arr[len(S_arr) - 1] += sub.source_text
26
- subidx += 1
27
-
28
-
29
- print(sub)
30
- print(S_arr[len(S_arr) - 1])
31
- print('------------------------------')
32
-
33
- subidx -= 1 # reset subidx to last segment
34
-
35
 
36
- def alignment(pred_path,gt_path,threshold = 0.3):
37
  pred = SrtScript.parse_from_srt_file(pred_path).segments
38
  gt = SrtScript.parse_from_srt_file(gt_path).segments
39
- pred_arr = []
40
- gt_arr = []
41
- duration = 0
42
- #count = 0
43
- #for ps,gs in zip(pred,gt):
44
- # duration += ps.end + gs.end - ps.start - gs.start
45
- # count += len(ps.source_text) + len(gs.source_text)
46
- #density = count / duration #word density
47
- idx_p, idx_t = -1, -1
48
  while idx_p < len(pred) or idx_t < len(gt):
49
- idx_p += 1
50
- idx_t += 1
51
- try:
52
- ps = pred[idx_p]
53
- gs = gt[idx_t]
54
- except IndexError:
55
- if idx_t >= len(gt):
56
- pred_arr.append(ps.source_text)
57
- continue
58
- if idx_p >= len(pred):
59
- gs = gt[idx_t]
60
- gt_arr.append(gs.source_text)
61
- continue
62
- #print('init' + str(idx_t) + str(idx_p))
63
- #duration
64
- ps_dur = ps.end - ps.start
65
- gs_dur = gs.end - gs.start
66
- #forward/backward
67
- if ps_dur <= gs_dur:
68
  gt_arr.append(gs.source_text)
69
- if gs.end < ps.start:
70
- idx_p -= 1 # reset idx if no match
71
- continue
 
 
72
  pred_arr.append(ps.source_text)
 
73
  idx_p += 1
74
- procedure(gs,pred,pred_arr,idx_p)
75
- else:
76
- pred_arr.append(ps.source_text)
 
 
 
77
  if ps.end < gs.start:
78
- idx_t -= 1 # reset idx if no match
79
- continue
80
- gt_arr.append(gs.source_text)
81
- idx_t += 1
82
- procedure(ps,gt,gt_arr,idx_t)
83
- #print(pred_arr)
84
- #print(gt_arr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- return zip(pred_arr,gt_arr)
 
 
 
87
 
88
-
89
- alignment('../results/OVB/OVB_en.srt','../results/OVM/OVM_en.srt')
 
3
  sys.path.append('../src')
4
  from srt_util.srt import SrtScript
5
 
6
+ def procedure(anchor, subsec, S_arr, subidx):
 
 
 
 
 
 
 
7
  cache_idx = 0
8
+ while subidx != cache_idx:
9
+ cache_idx = subidx
10
+ if subidx >= len(subsec):
11
+ break
12
  sub = subsec[subidx]
13
+ if anchor.end < sub.start:
14
+ continue
15
+ if (anchor.start <= sub.start) and (sub.end <= anchor.end) or anchor.end - sub.start > sub.end - anchor.start:
16
+ S_arr[-1] += sub.source_text
17
  subidx += 1
18
+ return subidx - 1
 
 
 
 
 
 
 
 
 
 
19
 
20
+ def alignment(pred_path, gt_path):
21
  pred = SrtScript.parse_from_srt_file(pred_path).segments
22
  gt = SrtScript.parse_from_srt_file(gt_path).segments
23
+ pred_arr, gt_arr = [], []
24
+ idx_p, idx_t = 0, 0
25
+
 
 
 
 
 
 
26
  while idx_p < len(pred) or idx_t < len(gt):
27
+ ps = pred[idx_p] if idx_p < len(pred) else None
28
+ gs = gt[idx_t] if idx_t < len(gt) else None
29
+
30
+ if not ps:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  gt_arr.append(gs.source_text)
32
+ pred_arr.append('')
33
+ idx_t += 1
34
+ continue
35
+
36
+ if not gs:
37
  pred_arr.append(ps.source_text)
38
+ gt_arr.append('')
39
  idx_p += 1
40
+ continue
41
+
42
+ ps_dur = ps.end - ps.start
43
+ gs_dur = gs.end - gs.start
44
+
45
+ if ps_dur <= gs_dur:
46
  if ps.end < gs.start:
47
+ pred_arr.append(ps.source_text)
48
+ gt_arr.append('')
49
+ idx_t -= 1
50
+ else:
51
+ gt_arr.append(gs.source_text)
52
+ if gs.end >= ps.start:
53
+ pred_arr.append(ps.source_text)
54
+ idx_p = procedure(gs, pred, pred_arr, idx_p + 1)
55
+ else:
56
+ pred_arr.append('')
57
+ idx_p -= 1
58
+ else:
59
+ if gs.end < ps.start:
60
+ gt_arr.append(gs.source_text)
61
+ pred_arr.append('')
62
+ idx_p -= 1
63
+ else:
64
+ pred_arr.append(ps.source_text)
65
+ if ps.end >= gs.start:
66
+ gt_arr.append(gs.source_text)
67
+ idx_t = procedure(ps, gt, gt_arr, idx_t + 1)
68
+ else:
69
+ gt_arr.append('')
70
+ idx_t -= 1
71
 
72
+ idx_p += 1
73
+ idx_t += 1
74
+ #print(gt_arr)
75
+ return zip(pred_arr, gt_arr)
76
 
77
+ # Test Case
78
+ #alignment('../results/...PATH1.../FILE.srt', '../results/PATH2/FILE.srt')