Joshua Lochner commited on
Commit
787a8df
·
1 Parent(s): 643d00a

Use new classifier for evaluation

Browse files
Files changed (2) hide show
  1. src/evaluate.py +65 -32
  2. src/model.py +5 -2
src/evaluate.py CHANGED
@@ -38,24 +38,15 @@ def attach_predictions_to_sponsor_segments(predictions, sponsor_segments):
38
  prediction['best_overlap'] = 0
39
  prediction['best_sponsorship'] = None
40
 
41
- # Assign predictions to actual (labelled) sponsored segments
42
- for sponsor_segment in sponsor_segments:
43
- sponsor_segment['best_overlap'] = 0
44
- sponsor_segment['best_prediction'] = None
45
-
46
- for prediction in predictions:
47
-
48
  j = jaccard(prediction['start'], prediction['end'],
49
  sponsor_segment['start'], sponsor_segment['end'])
50
- if sponsor_segment['best_overlap'] < j:
51
- sponsor_segment['best_overlap'] = j
52
- sponsor_segment['best_prediction'] = prediction
53
-
54
  if prediction['best_overlap'] < j:
55
  prediction['best_overlap'] = j
56
  prediction['best_sponsorship'] = sponsor_segment
57
 
58
- return sponsor_segments
59
 
60
 
61
  def calculate_metrics(labelled_words, predictions):
@@ -212,19 +203,55 @@ def main():
212
  'f-score': total_fscore/len(out_metrics)
213
  })
214
 
215
- labelled_predicted_segments = attach_predictions_to_sponsor_segments(
216
  predictions, sponsor_segments)
217
 
218
  # Identify possible issues:
219
  missed_segments = [
220
  prediction for prediction in predictions if prediction['best_sponsorship'] is None]
221
- incorrect_segments = [
222
- seg for seg in labelled_predicted_segments if seg['best_prediction'] is None]
223
 
224
- # Add words to incorrect segments
225
- for seg in incorrect_segments:
226
- seg['words'] = extract_segment(
227
- words, seg['start'], seg['end'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
  else:
230
  # logger.warning(f'No labels found for {video_id}')
@@ -233,13 +260,15 @@ def main():
233
  incorrect_segments = []
234
 
235
  if missed_segments or incorrect_segments:
 
 
 
 
 
 
236
  if evaluation_args.output_as_json:
237
  to_print = {'video_id': video_id}
238
 
239
- for z in missed_segments + incorrect_segments:
240
- z['text'] = ' '.join(x['text']
241
- for x in z.pop('words', []))
242
-
243
  if missed_segments:
244
  to_print['missed'] = missed_segments
245
 
@@ -257,8 +286,7 @@ def main():
257
  for i, missed_segment in enumerate(missed_segments, start=1):
258
  print(f'\t#{i}:', seconds_to_time(
259
  missed_segment['start']), '-->', seconds_to_time(missed_segment['end']))
260
- print('\t\tText: "', ' '.join(
261
- [w['text'] for w in missed_segment['words']]), '"', sep='')
262
  print('\t\tCategory:',
263
  missed_segment.get('category'))
264
  if 'probability' in missed_segment:
@@ -275,24 +303,29 @@ def main():
275
  print(
276
  f'\tSubmit: https://www.youtube.com/watch?v={video_id}#segments={json_data}')
277
 
278
- # Potentially incorrect segments (model didn't predict, but in database)
279
  if incorrect_segments:
280
  print(' - Incorrect segments:')
281
  for i, incorrect_segment in enumerate(incorrect_segments, start=1):
282
  print(f'\t#{i}:', seconds_to_time(
283
  incorrect_segment['start']), '-->', seconds_to_time(incorrect_segment['end']))
284
 
285
- seg_words = extract_segment(
286
- words, incorrect_segment['start'], incorrect_segment['end'])
287
- print('\t\tText: "', ' '.join(
288
- [w['text'] for w in seg_words]), '"', sep='')
289
  print('\t\tUUID:', incorrect_segment['uuid'])
290
- print('\t\tCategory:',
291
- incorrect_segment['category'])
292
  print('\t\tVotes:', incorrect_segment['votes'])
293
  print('\t\tViews:', incorrect_segment['views'])
294
  print('\t\tLocked:',
295
  incorrect_segment['locked'])
 
 
 
 
 
 
 
 
 
 
296
  print()
297
 
298
  except KeyboardInterrupt:
 
38
  prediction['best_overlap'] = 0
39
  prediction['best_sponsorship'] = None
40
 
41
+ # Assign predictions to actual (labelled) sponsored segments
42
+ for sponsor_segment in sponsor_segments:
 
 
 
 
 
43
  j = jaccard(prediction['start'], prediction['end'],
44
  sponsor_segment['start'], sponsor_segment['end'])
 
 
 
 
45
  if prediction['best_overlap'] < j:
46
  prediction['best_overlap'] = j
47
  prediction['best_sponsorship'] = sponsor_segment
48
 
49
+ # return sponsor_segments
50
 
51
 
52
  def calculate_metrics(labelled_words, predictions):
 
203
  'f-score': total_fscore/len(out_metrics)
204
  })
205
 
206
+ attach_predictions_to_sponsor_segments(
207
  predictions, sponsor_segments)
208
 
209
  # Identify possible issues:
210
  missed_segments = [
211
  prediction for prediction in predictions if prediction['best_sponsorship'] is None]
 
 
212
 
213
+ # Now, check for incorrect segments using the classifier
214
+ incorrect_segments = []
215
+
216
+ segments_to_check = []
217
+ texts = [] # Texts to send through tokenizer
218
+ for sponsor_segment in sponsor_segments:
219
+ segment_words = extract_segment(
220
+ words, sponsor_segment['start'], sponsor_segment['end'])
221
+ sponsor_segment['text'] = ' '.join(x['cleaned'] for x in segment_words)
222
+
223
+ duration = sponsor_segment['end'] - \
224
+ sponsor_segment['start']
225
+ wps = len(segment_words) / \
226
+ duration if duration > 0 else 0
227
+ if wps < 1.5:
228
+ continue
229
+
230
+ # Do not worry about those that are locked or have enough votes
231
+ # or segment['votes'] > 5:
232
+ if sponsor_segment['locked']:
233
+ continue
234
+
235
+ texts.append(sponsor_segment['text'])
236
+ segments_to_check.append(sponsor_segment)
237
+
238
+ if segments_to_check: # Segments to check
239
+
240
+ segments_scores = classifier(texts)
241
+
242
+ for segment, scores in zip(segments_to_check, segments_scores):
243
+ prediction = max(scores, key=lambda x: x['score'])
244
+ predicted_category = prediction['label'].lower()
245
+
246
+ if predicted_category == segment['category']:
247
+ continue # Ignore correct segments
248
+
249
+ segment.update({
250
+ 'predicted': predicted_category,
251
+ 'scores': scores
252
+ })
253
+
254
+ incorrect_segments.append(segment)
255
 
256
  else:
257
  # logger.warning(f'No labels found for {video_id}')
 
260
  incorrect_segments = []
261
 
262
  if missed_segments or incorrect_segments:
263
+ for z in missed_segments:
264
+ # Attach original text to missed segments
265
+ # (Already added to incorrect segments)
266
+ z['text'] = ' '.join(x['text']
267
+ for x in z.pop('words', []))
268
+
269
  if evaluation_args.output_as_json:
270
  to_print = {'video_id': video_id}
271
 
 
 
 
 
272
  if missed_segments:
273
  to_print['missed'] = missed_segments
274
 
 
286
  for i, missed_segment in enumerate(missed_segments, start=1):
287
  print(f'\t#{i}:', seconds_to_time(
288
  missed_segment['start']), '-->', seconds_to_time(missed_segment['end']))
289
+ print('\t\tText: "', missed_segment['text'], '"', sep='')
 
290
  print('\t\tCategory:',
291
  missed_segment.get('category'))
292
  if 'probability' in missed_segment:
 
303
  print(
304
  f'\tSubmit: https://www.youtube.com/watch?v={video_id}#segments={json_data}')
305
 
306
+ # Incorrect segments (in database, but incorrectly classified)
307
  if incorrect_segments:
308
  print(' - Incorrect segments:')
309
  for i, incorrect_segment in enumerate(incorrect_segments, start=1):
310
  print(f'\t#{i}:', seconds_to_time(
311
  incorrect_segment['start']), '-->', seconds_to_time(incorrect_segment['end']))
312
 
313
+ print('\t\tText: "', incorrect_segment['text'], '"', sep='')
 
 
 
314
  print('\t\tUUID:', incorrect_segment['uuid'])
 
 
315
  print('\t\tVotes:', incorrect_segment['votes'])
316
  print('\t\tViews:', incorrect_segment['views'])
317
  print('\t\tLocked:',
318
  incorrect_segment['locked'])
319
+
320
+ print('\t\tCurrent Category:',
321
+ incorrect_segment['category'])
322
+ print('\t\tPredicted Category:',
323
+ incorrect_segment['predicted'])
324
+ print('\t\tProbabilities:')
325
+ for item in incorrect_segment['scores']:
326
+ print(
327
+ f"\t\t\t{item['label']}: {item['score']}")
328
+
329
  print()
330
 
331
  except KeyboardInterrupt:
src/model.py CHANGED
@@ -1,6 +1,5 @@
1
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification, TrainingArguments
2
  from shared import CustomTokens, GeneralArguments
3
- from functools import lru_cache
4
  from dataclasses import dataclass, field
5
  from typing import Optional, Union
6
  import torch
@@ -72,6 +71,7 @@ class ModelArguments:
72
  """
73
 
74
  model_name_or_path: str = field(
 
75
  metadata={
76
  'help': 'Path to pretrained model or model identifier from huggingface.co/models'
77
  }
@@ -104,7 +104,7 @@ class ModelArguments:
104
  )
105
 
106
  import itertools
107
- from errors import InferenceException
108
 
109
  @dataclass
110
  class InferenceArguments(ModelArguments):
@@ -191,6 +191,9 @@ def get_model_tokenizer_classifier(inference_args: InferenceArguments, general_a
191
 
192
 
193
  def get_model_tokenizer(model_args: ModelArguments, general_args: Union[GeneralArguments, TrainingArguments] = None, config_args=None, model_type='seq2seq'):
 
 
 
194
  if config_args is None:
195
  config_args = {}
196
 
 
1
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification, TrainingArguments
2
  from shared import CustomTokens, GeneralArguments
 
3
  from dataclasses import dataclass, field
4
  from typing import Optional, Union
5
  import torch
 
71
  """
72
 
73
  model_name_or_path: str = field(
74
+ default=None,
75
  metadata={
76
  'help': 'Path to pretrained model or model identifier from huggingface.co/models'
77
  }
 
104
  )
105
 
106
  import itertools
107
+ from errors import InferenceException, ModelLoadError
108
 
109
  @dataclass
110
  class InferenceArguments(ModelArguments):
 
191
 
192
 
193
  def get_model_tokenizer(model_args: ModelArguments, general_args: Union[GeneralArguments, TrainingArguments] = None, config_args=None, model_type='seq2seq'):
194
+ if model_args.model_name_or_path is None:
195
+ raise ModelLoadError('Must specify --model_name_or_path')
196
+
197
  if config_args is None:
198
  config_args = {}
199