Peter Rupnik commited on
Commit
8941911
·
1 Parent(s): 1cdf067

Add frames_to_intervals function with filtering

Browse files
Files changed (1) hide show
  1. README.md +59 -9
README.md CHANGED
@@ -22,13 +22,13 @@ te test split of the same dataset.
22
 
23
  # Evaluation
24
 
25
- Although the output of the model is a series 0 or 1, describing their 20ms frames, the evaluation was done on
26
- event level; spans of consecutive outputs 1 were bundled together into one event. When the true and predicted
27
  events partially overlap, this is counted as a true positive.
28
 
29
  ## Evaluation on ROG corpus
30
 
31
- In evaluation, we only evaluate positive events, i.e.
32
  ```
33
  precision recall f1-score support
34
 
@@ -41,18 +41,18 @@ Evaluation on 800 human-annotated instances ParlaSpeech-HR and ParlaSpeech-RS p
41
 
42
  ```
43
  Performance on RS:
44
- Classification report for human vs model on event level:
45
  precision recall f1-score support
46
 
47
  1 0.95 0.99 0.97 542
48
  Performance on HR:
49
- Classification report for human vs model on event level:
50
  precision recall f1-score support
51
 
52
  1 0.93 0.98 0.95 531
53
  ```
54
- The metrics reported are on event level, which means that if true and
55
- predicted filled pauses at least partially overlap, we count them as a
56
  True Positive event.
57
 
58
 
@@ -80,6 +80,51 @@ ds = Dataset.from_dict(
80
  ).cast_column("audio", Audio(sampling_rate=16_000, mono=True))
81
 
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  def evaluator(chunks):
84
  sampling_rate = chunks["audio"][0]["sampling_rate"]
85
  with torch.no_grad():
@@ -90,13 +135,18 @@ def evaluator(chunks):
90
  ).to(device)
91
  logits = model(**inputs).logits
92
  y_pred = np.array(logits.cpu()).argmax(axis=-1)
93
- return {"y_pred": y_pred.tolist()}
 
94
 
95
 
96
  ds = ds.map(evaluator, batched=True)
97
  print(ds["y_pred"][0])
98
- # Returns a list of 20ms frames: [0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,....]
99
  # with 0 indicating no filled pause detected in that frame
 
 
 
 
100
  ```
101
 
102
 
 
22
 
23
  # Evaluation
24
 
25
+ Although the output of the model is a series 0 or 1, describing their 20ms frames, the evaluation was done on
26
+ event level; spans of consecutive outputs 1 were bundled together into one event. When the true and predicted
27
  events partially overlap, this is counted as a true positive.
28
 
29
  ## Evaluation on ROG corpus
30
 
31
+ In evaluation, we only evaluate positive events, i.e.
32
  ```
33
  precision recall f1-score support
34
 
 
41
 
42
  ```
43
  Performance on RS:
44
+ Classification report for human vs model on event level:
45
  precision recall f1-score support
46
 
47
  1 0.95 0.99 0.97 542
48
  Performance on HR:
49
+ Classification report for human vs model on event level:
50
  precision recall f1-score support
51
 
52
  1 0.93 0.98 0.95 531
53
  ```
54
+ The metrics reported are on event level, which means that if true and
55
+ predicted filled pauses at least partially overlap, we count them as a
56
  True Positive event.
57
 
58
 
 
80
  ).cast_column("audio", Audio(sampling_rate=16_000, mono=True))
81
 
82
 
83
+ def frames_to_intervals(
84
+ frames: list[int], drop_short=True, drop_initial=True, short_cutoff_s=0.08
85
+ ) -> list[tuple[float]]:
86
+ """Transforms a list of ones or zeros, corresponding to annotations on frame
87
+ levels, to a list of intervals ([start second, end second]).
88
+
89
+ Allows for additional filtering on duration (false positives are often short)
90
+ and start times (false positives starting at 0.0 are often an artifact of
91
+ poor segmentation).
92
+
93
+ :param list[int] frames: Input frame labels
94
+ :param bool drop_short: Drop everything shorter than short_cutoff_s, defaults to True
95
+ :param bool drop_initial: Drop predictions starting at 0.0, defaults to True
96
+ :param float short_cutoff_s: Duration in seconds of shortest allowable prediction, defaults to 0.08
97
+ :return list[tuple[float]]: List of intervals [start_s, end_s]
98
+ """
99
+ from itertools import pairwise
100
+ import pandas as pd
101
+
102
+ results = []
103
+ ndf = pd.DataFrame(
104
+ data={
105
+ "time_s": [0.020 * i for i in range(len(frames))],
106
+ "frames": frames,
107
+ }
108
+ )
109
+ ndf = ndf.dropna()
110
+ indices_of_change = ndf.frames.diff()[ndf.frames.diff() != 0].index.values
111
+ for si, ei in pairwise(indices_of_change):
112
+ if ndf.loc[si : ei - 1, "frames"].mode()[0] == 0:
113
+ pass
114
+ else:
115
+ results.append(
116
+ (
117
+ round(ndf.loc[si, "time_s"], 3),
118
+ round(ndf.loc[ei - 1, "time_s"], 3),
119
+ )
120
+ )
121
+ if drop_short and (len(results) > 0):
122
+ results = [i for i in results if (i[1] - i[0] >= short_cutoff_s)]
123
+ if drop_initial and (len(results) > 0):
124
+ results = [i for i in results if i[0] != 0.0]
125
+ return results
126
+
127
+
128
  def evaluator(chunks):
129
  sampling_rate = chunks["audio"][0]["sampling_rate"]
130
  with torch.no_grad():
 
135
  ).to(device)
136
  logits = model(**inputs).logits
137
  y_pred = np.array(logits.cpu()).argmax(axis=-1)
138
+ intervals = [frames_to_intervals(i) for i in y_pred]
139
+ return {"y_pred": y_pred.tolist(), "intervals": intervals}
140
 
141
 
142
  ds = ds.map(evaluator, batched=True)
143
  print(ds["y_pred"][0])
144
+ # Prints a list of 20ms frames: [0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,0....]
145
  # with 0 indicating no filled pause detected in that frame
146
+
147
+ print(ds["intervals"][0])
148
+ # Prints the identified intervals as a list of [start_s, ends_s]:
149
+ # [[0.08, 0.28 ], ...]
150
  ```
151
 
152