Joshua Lochner commited on
Commit
e77b67b
·
1 Parent(s): 79b40d9

Ignore prediction if categorized as nothing by classifier and extractor

Browse files
Files changed (1) hide show
  1. src/predict.py +7 -3
src/predict.py CHANGED
@@ -19,6 +19,7 @@ from model import get_model_tokenizer_classifier, InferenceArguments
19
  logging.basicConfig()
20
  logger = logging.getLogger(__name__)
21
 
 
22
  @dataclass
23
  class PredictArguments(InferenceArguments):
24
  video_id: str = field(
@@ -65,17 +66,20 @@ def filter_and_add_probabilities(predictions, classifier, min_probability):
65
  predicted_probabilities, key=predicted_probabilities.get)
66
  classifier_probability = predicted_probabilities[classifier_category]
67
 
68
- if classifier_category == 'none' and classifier_probability > min_probability:
69
- continue # Ignore
70
-
71
  if (prediction['category'] not in predicted_probabilities) \
72
  or (classifier_category != 'none' and classifier_probability > 0.5): # TODO make param
73
  # Unknown category or we are confident enough to overrule,
74
  # so change category to what was predicted by classifier
75
  prediction['category'] = classifier_category
76
 
 
 
 
77
  prediction['probability'] = predicted_probabilities[prediction['category']]
78
 
 
 
 
79
  # TODO add probabilities, but remove None and normalise rest
80
  prediction['probabilities'] = predicted_probabilities
81
 
 
19
  logging.basicConfig()
20
  logger = logging.getLogger(__name__)
21
 
22
+
23
  @dataclass
24
  class PredictArguments(InferenceArguments):
25
  video_id: str = field(
 
66
  predicted_probabilities, key=predicted_probabilities.get)
67
  classifier_probability = predicted_probabilities[classifier_category]
68
 
 
 
 
69
  if (prediction['category'] not in predicted_probabilities) \
70
  or (classifier_category != 'none' and classifier_probability > 0.5): # TODO make param
71
  # Unknown category or we are confident enough to overrule,
72
  # so change category to what was predicted by classifier
73
  prediction['category'] = classifier_category
74
 
75
+ if prediction['category'] == 'none':
76
+ continue # Ignore if categorised as nothing
77
+
78
  prediction['probability'] = predicted_probabilities[prediction['category']]
79
 
80
+ if min_probability is not None and prediction['probability'] < min_probability:
81
+ continue # Ignore if below threshold
82
+
83
  # TODO add probabilities, but remove None and normalise rest
84
  prediction['probabilities'] = predicted_probabilities
85