Joshua Lochner commited on
Commit
d7725ec
1 Parent(s): 6a8bf30

Fix incorrect segment output format

Browse files
Files changed (1) hide show
  1. src/evaluate.py +16 -10
src/evaluate.py CHANGED
@@ -261,7 +261,7 @@ def main():
261
  # Check for incorrect segments using the classifier
262
 
263
  segments_to_check = []
264
- texts = [] # Texts to send through tokenizer
265
  for sponsor_segment in sponsor_segments:
266
  segment_words = extract_segment(
267
  words, sponsor_segment['start'], sponsor_segment['end'])
@@ -280,17 +280,22 @@ def main():
280
  if sponsor_segment['locked']:
281
  continue
282
 
283
- sponsor_segment['cleaned_text'] = clean_text(
284
- sponsor_segment['text'])
285
- texts.append(sponsor_segment['cleaned_text'])
286
  segments_to_check.append(sponsor_segment)
287
 
288
  if segments_to_check: # Some segments to check
289
 
290
- segments_scores = classifier(texts)
291
 
292
  num_correct = 0
293
  for segment, scores in zip(segments_to_check, segments_scores):
 
 
 
 
 
 
294
  all_metrics['classifier_segment_count'] += 1
295
 
296
  prediction = max(scores, key=lambda x: x['score'])
@@ -302,7 +307,7 @@ def main():
302
 
303
  segment.update({
304
  'predicted': predicted_category,
305
- 'scores': scores
306
  })
307
 
308
  incorrect_segments.append(segment)
@@ -313,8 +318,9 @@ def main():
313
 
314
  all_metrics['classifier_segment_correct'] += num_correct
315
 
316
- postfix_info['classifier_accuracy'] = all_metrics['classifier_segment_correct'] / \
317
- all_metrics['classifier_segment_count']
 
318
 
319
  out_metrics.append(current_metrics)
320
  progress.set_postfix(postfix_info)
@@ -383,9 +389,9 @@ def main():
383
  safe_print('\t\tPredicted Category:',
384
  incorrect_segment['predicted'])
385
  safe_print('\t\tProbabilities:')
386
- for item in incorrect_segment['scores']:
387
  safe_print(
388
- f"\t\t\t{item['label']}: {item['score']}")
389
 
390
  safe_print()
391
 
 
261
  # Check for incorrect segments using the classifier
262
 
263
  segments_to_check = []
264
+ cleaned_texts = [] # Texts to send through tokenizer
265
  for sponsor_segment in sponsor_segments:
266
  segment_words = extract_segment(
267
  words, sponsor_segment['start'], sponsor_segment['end'])
 
280
  if sponsor_segment['locked']:
281
  continue
282
 
283
+ cleaned_texts.append(
284
+ clean_text(sponsor_segment['text']))
 
285
  segments_to_check.append(sponsor_segment)
286
 
287
  if segments_to_check: # Some segments to check
288
 
289
+ segments_scores = classifier(cleaned_texts)
290
 
291
  num_correct = 0
292
  for segment, scores in zip(segments_to_check, segments_scores):
293
+
294
+ fixed_scores = {
295
+ score['label']: score['score']
296
+ for score in scores
297
+ }
298
+
299
  all_metrics['classifier_segment_count'] += 1
300
 
301
  prediction = max(scores, key=lambda x: x['score'])
 
307
 
308
  segment.update({
309
  'predicted': predicted_category,
310
+ 'scores': fixed_scores
311
  })
312
 
313
  incorrect_segments.append(segment)
 
318
 
319
  all_metrics['classifier_segment_correct'] += num_correct
320
 
321
+ if all_metrics['classifier_segment_count'] > 0:
322
+ postfix_info['classifier_accuracy'] = all_metrics['classifier_segment_correct'] / \
323
+ all_metrics['classifier_segment_count']
324
 
325
  out_metrics.append(current_metrics)
326
  progress.set_postfix(postfix_info)
 
389
  safe_print('\t\tPredicted Category:',
390
  incorrect_segment['predicted'])
391
  safe_print('\t\tProbabilities:')
392
+ for label, score in incorrect_segment['scores'].items():
393
  safe_print(
394
+ f"\t\t\t{label}: {score}")
395
 
396
  safe_print()
397