Spaces:
Runtime error
Runtime error
Joshua Lochner
commited on
Commit
·
69fe24d
1
Parent(s):
c2ccf6d
Temporarily disable filtering of predictions using classifier
Browse files- src/predict.py +15 -10
src/predict.py
CHANGED
@@ -23,14 +23,17 @@ from dataclasses import dataclass, field
|
|
23 |
from shared import device
|
24 |
import logging
|
25 |
|
|
|
26 |
|
27 |
-
def seconds_to_time(seconds):
|
28 |
fractional = round(seconds % 1, 3)
|
29 |
fractional = '' if fractional == 0 else str(fractional)[1:]
|
30 |
h, remainder = divmod(abs(int(seconds)), 3600)
|
31 |
m, s = divmod(remainder, 60)
|
32 |
-
|
33 |
-
|
|
|
|
|
34 |
|
35 |
@dataclass
|
36 |
class TrainingOutputArguments:
|
@@ -136,7 +139,7 @@ def predict(video_id, model, tokenizer, segmentation_args, words=None, classifie
|
|
136 |
segmentation_args
|
137 |
)
|
138 |
|
139 |
-
predictions =
|
140 |
|
141 |
# Add words back to time_ranges
|
142 |
for prediction in predictions:
|
@@ -144,8 +147,9 @@ def predict(video_id, model, tokenizer, segmentation_args, words=None, classifie
|
|
144 |
prediction['words'] = extract_segment(
|
145 |
words, prediction['start'], prediction['end'])
|
146 |
|
147 |
-
|
148 |
-
|
|
|
149 |
|
150 |
return predictions
|
151 |
|
@@ -188,7 +192,7 @@ def predict_sponsor_matches(text, model, tokenizer):
|
|
188 |
return re_findall(SPONSOR_MATCH_RE, sponsorship_text)
|
189 |
|
190 |
|
191 |
-
def
|
192 |
predicted_time_ranges = []
|
193 |
|
194 |
# TODO pass to model simultaneously, not in for loop
|
@@ -234,10 +238,11 @@ def segments_to_prediction_times(segments, model, tokenizer):
|
|
234 |
end_time = range['end']
|
235 |
|
236 |
if prev_prediction is not None and range['category'] == prev_prediction['category'] and (
|
237 |
-
start_time <= prev_prediction['end'] <= end_time or
|
238 |
-
prev_prediction['end'] <= MERGE_TIME_WITHIN
|
239 |
):
|
240 |
-
# Ending time of last segment is in this segment or
|
|
|
241 |
final_predicted_time_ranges[-1]['end'] = end_time
|
242 |
|
243 |
else: # No overlap, is a new prediction
|
|
|
23 |
from shared import device
|
24 |
import logging
|
25 |
|
26 |
+
import re
|
27 |
|
28 |
+
def seconds_to_time(seconds, remove_leading_zeroes=False):
|
29 |
fractional = round(seconds % 1, 3)
|
30 |
fractional = '' if fractional == 0 else str(fractional)[1:]
|
31 |
h, remainder = divmod(abs(int(seconds)), 3600)
|
32 |
m, s = divmod(remainder, 60)
|
33 |
+
hms = f'{h:02}:{m:02}:{s:02}'
|
34 |
+
if remove_leading_zeroes:
|
35 |
+
hms = re.sub(r'^0(?:0:0?)?', '', hms)
|
36 |
+
return f"{'-' if seconds < 0 else ''}{hms}{fractional}"
|
37 |
|
38 |
@dataclass
|
39 |
class TrainingOutputArguments:
|
|
|
139 |
segmentation_args
|
140 |
)
|
141 |
|
142 |
+
predictions = segments_to_predictions(segments, model, tokenizer)
|
143 |
|
144 |
# Add words back to time_ranges
|
145 |
for prediction in predictions:
|
|
|
147 |
prediction['words'] = extract_segment(
|
148 |
words, prediction['start'], prediction['end'])
|
149 |
|
150 |
+
# TODO add back
|
151 |
+
# if classifier_args is not None:
|
152 |
+
# predictions = filter_predictions(predictions, classifier_args)
|
153 |
|
154 |
return predictions
|
155 |
|
|
|
192 |
return re_findall(SPONSOR_MATCH_RE, sponsorship_text)
|
193 |
|
194 |
|
195 |
+
def segments_to_predictions(segments, model, tokenizer):
|
196 |
predicted_time_ranges = []
|
197 |
|
198 |
# TODO pass to model simultaneously, not in for loop
|
|
|
238 |
end_time = range['end']
|
239 |
|
240 |
if prev_prediction is not None and range['category'] == prev_prediction['category'] and (
|
241 |
+
start_time <= prev_prediction['end'] <= end_time or \
|
242 |
+
start_time - prev_prediction['end'] <= MERGE_TIME_WITHIN
|
243 |
):
|
244 |
+
# Ending time of last segment is in this segment or within the merge threshold,
|
245 |
+
# so we extend last prediction range
|
246 |
final_predicted_time_ranges[-1]['end'] = end_time
|
247 |
|
248 |
else: # No overlap, is a new prediction
|