danieldux commited on
Commit
a944252
·
1 Parent(s): e4caed1

Refactor recall calculation and add type hints

Browse files
Files changed (1) hide show
  1. ham.py +23 -16
ham.py CHANGED
@@ -1,18 +1,18 @@
1
  """This module provides functions for calculating hierarchical variants of precicion, recall and F1."""
2
 
3
- from typing import List, Set, Dict, Tuple
4
 
5
 
6
- def find_ancestors(node: str, hierarchy: dict) -> set:
7
  """
8
  Find the ancestors of a given node in a hierarchy.
9
 
10
  Args:
11
  node (str): The node for which to find ancestors.
12
- hierarchy (dict): A dictionary representing the hierarchy, where the keys are nodes and the values are their parents.
13
 
14
  Returns:
15
- set: A set of ancestors of the given node.
16
  """
17
  ancestors = set()
18
  nodes_to_visit = [node]
@@ -54,45 +54,52 @@ def calculate_hierarchical_precision_recall(
54
  Args:
55
  reference_codes (List[str]): The list of reference codes.
56
  predicted_codes (List[str]): The list of predicted codes.
57
- hierarchy (Dict[str, Set[str]]): The hierarchy definition where keys are nodes and values are sets of parent nodes.
58
 
59
  Returns:
60
  Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values.
61
  """
62
  extended_real = {}
 
63
 
64
  # Extend the sets of reference codes with their ancestors
65
  for code in reference_codes:
66
- weight = 1.0 # Full weight for exact match
67
- extended_real[code] = weight
68
  for ancestor, ancestor_weight in hierarchy.get(code, {}).items():
69
  extended_real[ancestor] = max(
70
  extended_real.get(ancestor, 0), ancestor_weight
71
  )
72
 
73
- extended_predicted = {}
74
-
75
  # Extend the sets of predicted codes with their ancestors
76
  for code in predicted_codes:
77
- weight = 1.0
78
- extended_predicted[code] = weight
79
  for ancestor, ancestor_weight in hierarchy.get(code, {}).items():
80
  extended_predicted[ancestor] = max(
81
  extended_predicted.get(ancestor, 0), ancestor_weight
82
  )
83
 
84
- # Calculate weighted correct predictions
85
- correct_weights = 0
86
  for code, weight in extended_predicted.items():
87
  if code in extended_real:
88
- correct_weights += min(weight, extended_real[code])
 
 
 
 
 
 
89
 
90
  total_predicted_weights = sum(extended_predicted.values())
91
  total_real_weights = sum(extended_real.values())
92
 
93
  # Calculate hierarchical precision and recall using weighted sums
94
- hP = correct_weights / total_predicted_weights if total_predicted_weights else 0
95
- hR = correct_weights / total_real_weights if total_real_weights else 0
 
 
 
 
96
 
97
  return hP, hR
98
 
 
1
  """This module provides functions for calculating hierarchical variants of precicion, recall and F1."""
2
 
3
+ from typing import List, Dict, Tuple, Set
4
 
5
 
6
+ def find_ancestors(node: str, hierarchy: Dict[str, Set[str]]) -> Set[str]:
7
  """
8
  Find the ancestors of a given node in a hierarchy.
9
 
10
  Args:
11
  node (str): The node for which to find ancestors.
12
+ hierarchy (Dict[str, Set[str]]): A dictionary representing the hierarchy, where the keys are nodes and the values are their parents.
13
 
14
  Returns:
15
+ Set[str]: A set of ancestors of the given node.
16
  """
17
  ancestors = set()
18
  nodes_to_visit = [node]
 
54
  Args:
55
  reference_codes (List[str]): The list of reference codes.
56
  predicted_codes (List[str]): The list of predicted codes.
57
+ hierarchy (Dict[str, Dict[str, float]]): The hierarchy definition where keys are nodes and values are dictionaries of parent nodes with distances.
58
 
59
  Returns:
60
  Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values.
61
  """
62
  extended_real = {}
63
+ extended_predicted = {}
64
 
65
  # Extend the sets of reference codes with their ancestors
66
  for code in reference_codes:
67
+ extended_real[code] = 1.0 # Full weight for exact match
 
68
  for ancestor, ancestor_weight in hierarchy.get(code, {}).items():
69
  extended_real[ancestor] = max(
70
  extended_real.get(ancestor, 0), ancestor_weight
71
  )
72
 
 
 
73
  # Extend the sets of predicted codes with their ancestors
74
  for code in predicted_codes:
75
+ extended_predicted[code] = 1.0
 
76
  for ancestor, ancestor_weight in hierarchy.get(code, {}).items():
77
  extended_predicted[ancestor] = max(
78
  extended_predicted.get(ancestor, 0), ancestor_weight
79
  )
80
 
81
+ # Calculate weighted correct predictions for precision
82
+ correct_weights_precision = 0
83
  for code, weight in extended_predicted.items():
84
  if code in extended_real:
85
+ correct_weights_precision += min(weight, extended_real[code])
86
+
87
+ # Calculate weighted correct predictions for recall
88
+ correct_weights_recall = 0
89
+ for code, weight in extended_real.items():
90
+ if code in extended_predicted:
91
+ correct_weights_recall += min(weight, extended_predicted[code])
92
 
93
  total_predicted_weights = sum(extended_predicted.values())
94
  total_real_weights = sum(extended_real.values())
95
 
96
  # Calculate hierarchical precision and recall using weighted sums
97
+ hP = (
98
+ correct_weights_precision / total_predicted_weights
99
+ if total_predicted_weights
100
+ else 0
101
+ )
102
+ hR = correct_weights_recall / total_real_weights if total_real_weights else 0
103
 
104
  return hP, hR
105