Spaces:
Build error
Build error
Refactor recall calculation and add type hints
Browse files
ham.py
CHANGED
@@ -1,18 +1,18 @@
|
|
1 |
"""This module provides functions for calculating hierarchical variants of precicion, recall and F1."""
|
2 |
|
3 |
-
from typing import List,
|
4 |
|
5 |
|
6 |
-
def find_ancestors(node: str, hierarchy:
|
7 |
"""
|
8 |
Find the ancestors of a given node in a hierarchy.
|
9 |
|
10 |
Args:
|
11 |
node (str): The node for which to find ancestors.
|
12 |
-
hierarchy (
|
13 |
|
14 |
Returns:
|
15 |
-
|
16 |
"""
|
17 |
ancestors = set()
|
18 |
nodes_to_visit = [node]
|
@@ -54,45 +54,52 @@ def calculate_hierarchical_precision_recall(
|
|
54 |
Args:
|
55 |
reference_codes (List[str]): The list of reference codes.
|
56 |
predicted_codes (List[str]): The list of predicted codes.
|
57 |
-
hierarchy (Dict[str,
|
58 |
|
59 |
Returns:
|
60 |
Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values.
|
61 |
"""
|
62 |
extended_real = {}
|
|
|
63 |
|
64 |
# Extend the sets of reference codes with their ancestors
|
65 |
for code in reference_codes:
|
66 |
-
|
67 |
-
extended_real[code] = weight
|
68 |
for ancestor, ancestor_weight in hierarchy.get(code, {}).items():
|
69 |
extended_real[ancestor] = max(
|
70 |
extended_real.get(ancestor, 0), ancestor_weight
|
71 |
)
|
72 |
|
73 |
-
extended_predicted = {}
|
74 |
-
|
75 |
# Extend the sets of predicted codes with their ancestors
|
76 |
for code in predicted_codes:
|
77 |
-
|
78 |
-
extended_predicted[code] = weight
|
79 |
for ancestor, ancestor_weight in hierarchy.get(code, {}).items():
|
80 |
extended_predicted[ancestor] = max(
|
81 |
extended_predicted.get(ancestor, 0), ancestor_weight
|
82 |
)
|
83 |
|
84 |
-
# Calculate weighted correct predictions
|
85 |
-
|
86 |
for code, weight in extended_predicted.items():
|
87 |
if code in extended_real:
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
total_predicted_weights = sum(extended_predicted.values())
|
91 |
total_real_weights = sum(extended_real.values())
|
92 |
|
93 |
# Calculate hierarchical precision and recall using weighted sums
|
94 |
-
hP =
|
95 |
-
|
|
|
|
|
|
|
|
|
96 |
|
97 |
return hP, hR
|
98 |
|
|
|
1 |
"""This module provides functions for calculating hierarchical variants of precicion, recall and F1."""
|
2 |
|
3 |
+
from typing import List, Dict, Tuple, Set
|
4 |
|
5 |
|
6 |
+
def find_ancestors(node: str, hierarchy: Dict[str, Set[str]]) -> Set[str]:
|
7 |
"""
|
8 |
Find the ancestors of a given node in a hierarchy.
|
9 |
|
10 |
Args:
|
11 |
node (str): The node for which to find ancestors.
|
12 |
+
hierarchy (Dict[str, Set[str]]): A dictionary representing the hierarchy, where the keys are nodes and the values are their parents.
|
13 |
|
14 |
Returns:
|
15 |
+
Set[str]: A set of ancestors of the given node.
|
16 |
"""
|
17 |
ancestors = set()
|
18 |
nodes_to_visit = [node]
|
|
|
54 |
Args:
|
55 |
reference_codes (List[str]): The list of reference codes.
|
56 |
predicted_codes (List[str]): The list of predicted codes.
|
57 |
+
hierarchy (Dict[str, Dict[str, float]]): The hierarchy definition where keys are nodes and values are dictionaries of parent nodes with distances.
|
58 |
|
59 |
Returns:
|
60 |
Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values.
|
61 |
"""
|
62 |
extended_real = {}
|
63 |
+
extended_predicted = {}
|
64 |
|
65 |
# Extend the sets of reference codes with their ancestors
|
66 |
for code in reference_codes:
|
67 |
+
extended_real[code] = 1.0 # Full weight for exact match
|
|
|
68 |
for ancestor, ancestor_weight in hierarchy.get(code, {}).items():
|
69 |
extended_real[ancestor] = max(
|
70 |
extended_real.get(ancestor, 0), ancestor_weight
|
71 |
)
|
72 |
|
|
|
|
|
73 |
# Extend the sets of predicted codes with their ancestors
|
74 |
for code in predicted_codes:
|
75 |
+
extended_predicted[code] = 1.0
|
|
|
76 |
for ancestor, ancestor_weight in hierarchy.get(code, {}).items():
|
77 |
extended_predicted[ancestor] = max(
|
78 |
extended_predicted.get(ancestor, 0), ancestor_weight
|
79 |
)
|
80 |
|
81 |
+
# Calculate weighted correct predictions for precision
|
82 |
+
correct_weights_precision = 0
|
83 |
for code, weight in extended_predicted.items():
|
84 |
if code in extended_real:
|
85 |
+
correct_weights_precision += min(weight, extended_real[code])
|
86 |
+
|
87 |
+
# Calculate weighted correct predictions for recall
|
88 |
+
correct_weights_recall = 0
|
89 |
+
for code, weight in extended_real.items():
|
90 |
+
if code in extended_predicted:
|
91 |
+
correct_weights_recall += min(weight, extended_predicted[code])
|
92 |
|
93 |
total_predicted_weights = sum(extended_predicted.values())
|
94 |
total_real_weights = sum(extended_real.values())
|
95 |
|
96 |
# Calculate hierarchical precision and recall using weighted sums
|
97 |
+
hP = (
|
98 |
+
correct_weights_precision / total_predicted_weights
|
99 |
+
if total_predicted_weights
|
100 |
+
else 0
|
101 |
+
)
|
102 |
+
hR = correct_weights_recall / total_real_weights if total_real_weights else 0
|
103 |
|
104 |
return hP, hR
|
105 |
|