Joshua Lochner commited on
Commit
69fe24d
·
1 Parent(s): c2ccf6d

Temporarily disable filtering of predictions using classifier

Browse files
Files changed (1) hide show
  1. src/predict.py +15 -10
src/predict.py CHANGED
@@ -23,14 +23,17 @@ from dataclasses import dataclass, field
23
  from shared import device
24
  import logging
25
 
 
26
 
27
- def seconds_to_time(seconds):
28
  fractional = round(seconds % 1, 3)
29
  fractional = '' if fractional == 0 else str(fractional)[1:]
30
  h, remainder = divmod(abs(int(seconds)), 3600)
31
  m, s = divmod(remainder, 60)
32
- return f"{'-' if seconds < 0 else ''}{h:02}:{m:02}:{s:02}{fractional}"
33
-
 
 
34
 
35
  @dataclass
36
  class TrainingOutputArguments:
@@ -136,7 +139,7 @@ def predict(video_id, model, tokenizer, segmentation_args, words=None, classifie
136
  segmentation_args
137
  )
138
 
139
- predictions = segments_to_prediction_times(segments, model, tokenizer)
140
 
141
  # Add words back to time_ranges
142
  for prediction in predictions:
@@ -144,8 +147,9 @@ def predict(video_id, model, tokenizer, segmentation_args, words=None, classifie
144
  prediction['words'] = extract_segment(
145
  words, prediction['start'], prediction['end'])
146
 
147
- if classifier_args is not None:
148
- predictions = filter_predictions(predictions, classifier_args)
 
149
 
150
  return predictions
151
 
@@ -188,7 +192,7 @@ def predict_sponsor_matches(text, model, tokenizer):
188
  return re_findall(SPONSOR_MATCH_RE, sponsorship_text)
189
 
190
 
191
- def segments_to_prediction_times(segments, model, tokenizer):
192
  predicted_time_ranges = []
193
 
194
  # TODO pass to model simultaneously, not in for loop
@@ -234,10 +238,11 @@ def segments_to_prediction_times(segments, model, tokenizer):
234
  end_time = range['end']
235
 
236
  if prev_prediction is not None and range['category'] == prev_prediction['category'] and (
237
- start_time <= prev_prediction['end'] <= end_time or start_time -
238
- prev_prediction['end'] <= MERGE_TIME_WITHIN
239
  ):
240
- # Ending time of last segment is in this segment or c, so we extend last prediction range
 
241
  final_predicted_time_ranges[-1]['end'] = end_time
242
 
243
  else: # No overlap, is a new prediction
 
23
  from shared import device
24
  import logging
25
 
26
+ import re
27
 
28
+ def seconds_to_time(seconds, remove_leading_zeroes=False):
29
  fractional = round(seconds % 1, 3)
30
  fractional = '' if fractional == 0 else str(fractional)[1:]
31
  h, remainder = divmod(abs(int(seconds)), 3600)
32
  m, s = divmod(remainder, 60)
33
+ hms = f'{h:02}:{m:02}:{s:02}'
34
+ if remove_leading_zeroes:
35
+ hms = re.sub(r'^0(?:0:0?)?', '', hms)
36
+ return f"{'-' if seconds < 0 else ''}{hms}{fractional}"
37
 
38
  @dataclass
39
  class TrainingOutputArguments:
 
139
  segmentation_args
140
  )
141
 
142
+ predictions = segments_to_predictions(segments, model, tokenizer)
143
 
144
  # Add words back to time_ranges
145
  for prediction in predictions:
 
147
  prediction['words'] = extract_segment(
148
  words, prediction['start'], prediction['end'])
149
 
150
+ # TODO add back
151
+ # if classifier_args is not None:
152
+ # predictions = filter_predictions(predictions, classifier_args)
153
 
154
  return predictions
155
 
 
192
  return re_findall(SPONSOR_MATCH_RE, sponsorship_text)
193
 
194
 
195
+ def segments_to_predictions(segments, model, tokenizer):
196
  predicted_time_ranges = []
197
 
198
  # TODO pass to model simultaneously, not in for loop
 
238
  end_time = range['end']
239
 
240
  if prev_prediction is not None and range['category'] == prev_prediction['category'] and (
241
+ start_time <= prev_prediction['end'] <= end_time or \
242
+ start_time - prev_prediction['end'] <= MERGE_TIME_WITHIN
243
  ):
244
+ # Ending time of last segment is in this segment or within the merge threshold,
245
+ # so we extend last prediction range
246
  final_predicted_time_ranges[-1]['end'] = end_time
247
 
248
  else: # No overlap, is a new prediction