Spaces:
Build error
Build error
def ancestors(class_label, hierarchy): | |
"""Return all ancestors of a given class label, excluding the root.""" | |
if class_label not in hierarchy or not hierarchy[class_label]: | |
return set() | |
else: | |
# Recursively get all ancestors for each parent | |
anc = set(hierarchy[class_label]) | |
for parent in hierarchy[class_label]: | |
anc.update(ancestors(parent, hierarchy)) | |
return anc | |
def extend_with_ancestors(class_labels, hierarchy): | |
"""Extend a set of class labels with their ancestors.""" | |
extended_set = set(class_labels) | |
for label in class_labels: | |
extended_set.update(ancestors(label, hierarchy)) | |
return extended_set | |
def hierarchical_precision_recall(true_labels, predicted_labels, hierarchy): | |
"""Calculate hierarchical precision and recall.""" | |
true_extended = [extend_with_ancestors(ci, hierarchy) for ci in true_labels] | |
predicted_extended = [ | |
extend_with_ancestors(c_prime_i, hierarchy) for c_prime_i in predicted_labels | |
] | |
intersect_sum = sum( | |
len(ci & c_prime_i) for ci, c_prime_i in zip(true_extended, predicted_extended) | |
) | |
predicted_sum = sum(len(c_prime_i) for c_prime_i in predicted_extended) | |
true_sum = sum(len(ci) for ci in true_extended) | |
hP = intersect_sum / predicted_sum if predicted_sum > 0 else 0 | |
hR = intersect_sum / true_sum if true_sum > 0 else 0 | |
return hP, hR | |
def hierarchical_f_measure(hP, hR, beta=1.0): | |
"""Calculate the hierarchical F-measure.""" | |
if hP + hR == 0: | |
return 0 | |
return (beta**2 + 1) * hP * hR / (beta**2 * hP + hR) | |