Spaces:
Running
Running
Joshua Lochner
commited on
Commit
·
787a8df
1
Parent(s):
643d00a
Use new classifier for evaluation
Browse files- src/evaluate.py +65 -32
- src/model.py +5 -2
src/evaluate.py
CHANGED
@@ -38,24 +38,15 @@ def attach_predictions_to_sponsor_segments(predictions, sponsor_segments):
|
|
38 |
prediction['best_overlap'] = 0
|
39 |
prediction['best_sponsorship'] = None
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
sponsor_segment['best_overlap'] = 0
|
44 |
-
sponsor_segment['best_prediction'] = None
|
45 |
-
|
46 |
-
for prediction in predictions:
|
47 |
-
|
48 |
j = jaccard(prediction['start'], prediction['end'],
|
49 |
sponsor_segment['start'], sponsor_segment['end'])
|
50 |
-
if sponsor_segment['best_overlap'] < j:
|
51 |
-
sponsor_segment['best_overlap'] = j
|
52 |
-
sponsor_segment['best_prediction'] = prediction
|
53 |
-
|
54 |
if prediction['best_overlap'] < j:
|
55 |
prediction['best_overlap'] = j
|
56 |
prediction['best_sponsorship'] = sponsor_segment
|
57 |
|
58 |
-
return sponsor_segments
|
59 |
|
60 |
|
61 |
def calculate_metrics(labelled_words, predictions):
|
@@ -212,19 +203,55 @@ def main():
|
|
212 |
'f-score': total_fscore/len(out_metrics)
|
213 |
})
|
214 |
|
215 |
-
|
216 |
predictions, sponsor_segments)
|
217 |
|
218 |
# Identify possible issues:
|
219 |
missed_segments = [
|
220 |
prediction for prediction in predictions if prediction['best_sponsorship'] is None]
|
221 |
-
incorrect_segments = [
|
222 |
-
seg for seg in labelled_predicted_segments if seg['best_prediction'] is None]
|
223 |
|
224 |
-
#
|
225 |
-
|
226 |
-
|
227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
|
229 |
else:
|
230 |
# logger.warning(f'No labels found for {video_id}')
|
@@ -233,13 +260,15 @@ def main():
|
|
233 |
incorrect_segments = []
|
234 |
|
235 |
if missed_segments or incorrect_segments:
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
if evaluation_args.output_as_json:
|
237 |
to_print = {'video_id': video_id}
|
238 |
|
239 |
-
for z in missed_segments + incorrect_segments:
|
240 |
-
z['text'] = ' '.join(x['text']
|
241 |
-
for x in z.pop('words', []))
|
242 |
-
|
243 |
if missed_segments:
|
244 |
to_print['missed'] = missed_segments
|
245 |
|
@@ -257,8 +286,7 @@ def main():
|
|
257 |
for i, missed_segment in enumerate(missed_segments, start=1):
|
258 |
print(f'\t#{i}:', seconds_to_time(
|
259 |
missed_segment['start']), '-->', seconds_to_time(missed_segment['end']))
|
260 |
-
print('\t\tText: "', ' '
|
261 |
-
[w['text'] for w in missed_segment['words']]), '"', sep='')
|
262 |
print('\t\tCategory:',
|
263 |
missed_segment.get('category'))
|
264 |
if 'probability' in missed_segment:
|
@@ -275,24 +303,29 @@ def main():
|
|
275 |
print(
|
276 |
f'\tSubmit: https://www.youtube.com/watch?v={video_id}#segments={json_data}')
|
277 |
|
278 |
-
#
|
279 |
if incorrect_segments:
|
280 |
print(' - Incorrect segments:')
|
281 |
for i, incorrect_segment in enumerate(incorrect_segments, start=1):
|
282 |
print(f'\t#{i}:', seconds_to_time(
|
283 |
incorrect_segment['start']), '-->', seconds_to_time(incorrect_segment['end']))
|
284 |
|
285 |
-
|
286 |
-
words, incorrect_segment['start'], incorrect_segment['end'])
|
287 |
-
print('\t\tText: "', ' '.join(
|
288 |
-
[w['text'] for w in seg_words]), '"', sep='')
|
289 |
print('\t\tUUID:', incorrect_segment['uuid'])
|
290 |
-
print('\t\tCategory:',
|
291 |
-
incorrect_segment['category'])
|
292 |
print('\t\tVotes:', incorrect_segment['votes'])
|
293 |
print('\t\tViews:', incorrect_segment['views'])
|
294 |
print('\t\tLocked:',
|
295 |
incorrect_segment['locked'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
print()
|
297 |
|
298 |
except KeyboardInterrupt:
|
|
|
38 |
prediction['best_overlap'] = 0
|
39 |
prediction['best_sponsorship'] = None
|
40 |
|
41 |
+
# Assign predictions to actual (labelled) sponsored segments
|
42 |
+
for sponsor_segment in sponsor_segments:
|
|
|
|
|
|
|
|
|
|
|
43 |
j = jaccard(prediction['start'], prediction['end'],
|
44 |
sponsor_segment['start'], sponsor_segment['end'])
|
|
|
|
|
|
|
|
|
45 |
if prediction['best_overlap'] < j:
|
46 |
prediction['best_overlap'] = j
|
47 |
prediction['best_sponsorship'] = sponsor_segment
|
48 |
|
49 |
+
# return sponsor_segments
|
50 |
|
51 |
|
52 |
def calculate_metrics(labelled_words, predictions):
|
|
|
203 |
'f-score': total_fscore/len(out_metrics)
|
204 |
})
|
205 |
|
206 |
+
attach_predictions_to_sponsor_segments(
|
207 |
predictions, sponsor_segments)
|
208 |
|
209 |
# Identify possible issues:
|
210 |
missed_segments = [
|
211 |
prediction for prediction in predictions if prediction['best_sponsorship'] is None]
|
|
|
|
|
212 |
|
213 |
+
# Now, check for incorrect segments using the classifier
|
214 |
+
incorrect_segments = []
|
215 |
+
|
216 |
+
segments_to_check = []
|
217 |
+
texts = [] # Texts to send through tokenizer
|
218 |
+
for sponsor_segment in sponsor_segments:
|
219 |
+
segment_words = extract_segment(
|
220 |
+
words, sponsor_segment['start'], sponsor_segment['end'])
|
221 |
+
sponsor_segment['text'] = ' '.join(x['cleaned'] for x in segment_words)
|
222 |
+
|
223 |
+
duration = sponsor_segment['end'] - \
|
224 |
+
sponsor_segment['start']
|
225 |
+
wps = len(segment_words) / \
|
226 |
+
duration if duration > 0 else 0
|
227 |
+
if wps < 1.5:
|
228 |
+
continue
|
229 |
+
|
230 |
+
# Do not worry about those that are locked or have enough votes
|
231 |
+
# or segment['votes'] > 5:
|
232 |
+
if sponsor_segment['locked']:
|
233 |
+
continue
|
234 |
+
|
235 |
+
texts.append(sponsor_segment['text'])
|
236 |
+
segments_to_check.append(sponsor_segment)
|
237 |
+
|
238 |
+
if segments_to_check: # Segments to check
|
239 |
+
|
240 |
+
segments_scores = classifier(texts)
|
241 |
+
|
242 |
+
for segment, scores in zip(segments_to_check, segments_scores):
|
243 |
+
prediction = max(scores, key=lambda x: x['score'])
|
244 |
+
predicted_category = prediction['label'].lower()
|
245 |
+
|
246 |
+
if predicted_category == segment['category']:
|
247 |
+
continue # Ignore correct segments
|
248 |
+
|
249 |
+
segment.update({
|
250 |
+
'predicted': predicted_category,
|
251 |
+
'scores': scores
|
252 |
+
})
|
253 |
+
|
254 |
+
incorrect_segments.append(segment)
|
255 |
|
256 |
else:
|
257 |
# logger.warning(f'No labels found for {video_id}')
|
|
|
260 |
incorrect_segments = []
|
261 |
|
262 |
if missed_segments or incorrect_segments:
|
263 |
+
for z in missed_segments:
|
264 |
+
# Attach original text to missed segments
|
265 |
+
# (Already added to incorrect segments)
|
266 |
+
z['text'] = ' '.join(x['text']
|
267 |
+
for x in z.pop('words', []))
|
268 |
+
|
269 |
if evaluation_args.output_as_json:
|
270 |
to_print = {'video_id': video_id}
|
271 |
|
|
|
|
|
|
|
|
|
272 |
if missed_segments:
|
273 |
to_print['missed'] = missed_segments
|
274 |
|
|
|
286 |
for i, missed_segment in enumerate(missed_segments, start=1):
|
287 |
print(f'\t#{i}:', seconds_to_time(
|
288 |
missed_segment['start']), '-->', seconds_to_time(missed_segment['end']))
|
289 |
+
print('\t\tText: "', missed_segment['text'], '"', sep='')
|
|
|
290 |
print('\t\tCategory:',
|
291 |
missed_segment.get('category'))
|
292 |
if 'probability' in missed_segment:
|
|
|
303 |
print(
|
304 |
f'\tSubmit: https://www.youtube.com/watch?v={video_id}#segments={json_data}')
|
305 |
|
306 |
+
# Incorrect segments (in database, but incorrectly classified)
|
307 |
if incorrect_segments:
|
308 |
print(' - Incorrect segments:')
|
309 |
for i, incorrect_segment in enumerate(incorrect_segments, start=1):
|
310 |
print(f'\t#{i}:', seconds_to_time(
|
311 |
incorrect_segment['start']), '-->', seconds_to_time(incorrect_segment['end']))
|
312 |
|
313 |
+
print('\t\tText: "', incorrect_segment['text'], '"', sep='')
|
|
|
|
|
|
|
314 |
print('\t\tUUID:', incorrect_segment['uuid'])
|
|
|
|
|
315 |
print('\t\tVotes:', incorrect_segment['votes'])
|
316 |
print('\t\tViews:', incorrect_segment['views'])
|
317 |
print('\t\tLocked:',
|
318 |
incorrect_segment['locked'])
|
319 |
+
|
320 |
+
print('\t\tCurrent Category:',
|
321 |
+
incorrect_segment['category'])
|
322 |
+
print('\t\tPredicted Category:',
|
323 |
+
incorrect_segment['predicted'])
|
324 |
+
print('\t\tProbabilities:')
|
325 |
+
for item in incorrect_segment['scores']:
|
326 |
+
print(
|
327 |
+
f"\t\t\t{item['label']}: {item['score']}")
|
328 |
+
|
329 |
print()
|
330 |
|
331 |
except KeyboardInterrupt:
|
src/model.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification, TrainingArguments
|
2 |
from shared import CustomTokens, GeneralArguments
|
3 |
-
from functools import lru_cache
|
4 |
from dataclasses import dataclass, field
|
5 |
from typing import Optional, Union
|
6 |
import torch
|
@@ -72,6 +71,7 @@ class ModelArguments:
|
|
72 |
"""
|
73 |
|
74 |
model_name_or_path: str = field(
|
|
|
75 |
metadata={
|
76 |
'help': 'Path to pretrained model or model identifier from huggingface.co/models'
|
77 |
}
|
@@ -104,7 +104,7 @@ class ModelArguments:
|
|
104 |
)
|
105 |
|
106 |
import itertools
|
107 |
-
from errors import InferenceException
|
108 |
|
109 |
@dataclass
|
110 |
class InferenceArguments(ModelArguments):
|
@@ -191,6 +191,9 @@ def get_model_tokenizer_classifier(inference_args: InferenceArguments, general_a
|
|
191 |
|
192 |
|
193 |
def get_model_tokenizer(model_args: ModelArguments, general_args: Union[GeneralArguments, TrainingArguments] = None, config_args=None, model_type='seq2seq'):
|
|
|
|
|
|
|
194 |
if config_args is None:
|
195 |
config_args = {}
|
196 |
|
|
|
1 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification, TrainingArguments
|
2 |
from shared import CustomTokens, GeneralArguments
|
|
|
3 |
from dataclasses import dataclass, field
|
4 |
from typing import Optional, Union
|
5 |
import torch
|
|
|
71 |
"""
|
72 |
|
73 |
model_name_or_path: str = field(
|
74 |
+
default=None,
|
75 |
metadata={
|
76 |
'help': 'Path to pretrained model or model identifier from huggingface.co/models'
|
77 |
}
|
|
|
104 |
)
|
105 |
|
106 |
import itertools
|
107 |
+
from errors import InferenceException, ModelLoadError
|
108 |
|
109 |
@dataclass
|
110 |
class InferenceArguments(ModelArguments):
|
|
|
191 |
|
192 |
|
193 |
def get_model_tokenizer(model_args: ModelArguments, general_args: Union[GeneralArguments, TrainingArguments] = None, config_args=None, model_type='seq2seq'):
|
194 |
+
if model_args.model_name_or_path is None:
|
195 |
+
raise ModelLoadError('Must specify --model_name_or_path')
|
196 |
+
|
197 |
if config_args is None:
|
198 |
config_args = {}
|
199 |
|