Spaces:
Build error
Build error
"""This module provides functions for calculating hierarchical variants of precicion, recall and F1.""" | |
from typing import List, Dict, Tuple, Set | |
def find_ancestors(node: str, hierarchy: Dict[str, Set[str]]) -> Set[str]: | |
""" | |
Find the ancestors of a given node in a hierarchy. | |
Args: | |
node (str): The node for which to find ancestors. | |
hierarchy (Dict[str, Set[str]]): A dictionary representing the hierarchy, where the keys are nodes and the values are their parents. | |
Returns: | |
Set[str]: A set of ancestors of the given node. | |
""" | |
ancestors = set() | |
nodes_to_visit = [node] | |
while nodes_to_visit: | |
current_node = nodes_to_visit.pop() | |
if current_node in hierarchy: | |
parents = hierarchy[current_node] | |
ancestors.update(parents) | |
nodes_to_visit.extend(parents) | |
return ancestors | |
def extend_with_ancestors(classes: set, hierarchy: dict) -> set: | |
""" | |
Extend the given set of classes with their ancestors from the hierarchy. | |
Args: | |
classes (set): The set of classes to extend. | |
hierarchy (dict): The hierarchy of classes. | |
Returns: | |
set: The extended set of classes including their ancestors. | |
""" | |
extended_classes = set(classes) | |
for cls in classes: | |
ancestors = find_ancestors(cls, hierarchy) | |
extended_classes.update(ancestors) | |
return extended_classes | |
def calculate_hierarchical_precision_recall( | |
reference_codes: List[str], | |
predicted_codes: List[str], | |
hierarchy: Dict[str, Dict[str, float]], | |
) -> Tuple[float, float]: | |
""" | |
Calculates the hierarchical precision and recall given the reference codes, predicted codes, and hierarchy definition. | |
Args: | |
reference_codes (List[str]): The list of reference codes. | |
predicted_codes (List[str]): The list of predicted codes. | |
hierarchy (Dict[str, Dict[str, float]]): The hierarchy definition where keys are nodes and values are dictionaries of parent nodes with distances. | |
Returns: | |
Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values. | |
""" | |
extended_real = {} | |
extended_predicted = {} | |
# Extend the sets of reference codes with their ancestors | |
for code in reference_codes: | |
extended_real[code] = 1.0 # Full weight for exact match | |
for ancestor, ancestor_weight in hierarchy.get(code, {}).items(): | |
extended_real[ancestor] = max( | |
extended_real.get(ancestor, 0), ancestor_weight | |
) | |
# Extend the sets of predicted codes with their ancestors | |
for code in predicted_codes: | |
extended_predicted[code] = 1.0 | |
for ancestor, ancestor_weight in hierarchy.get(code, {}).items(): | |
extended_predicted[ancestor] = max( | |
extended_predicted.get(ancestor, 0), ancestor_weight | |
) | |
# Calculate weighted correct predictions for precision | |
correct_weights_precision = 0 | |
for code, weight in extended_predicted.items(): | |
if code in extended_real: | |
correct_weights_precision += min(weight, extended_real[code]) | |
# Calculate weighted correct predictions for recall | |
correct_weights_recall = 0 | |
for code, weight in extended_real.items(): | |
if code in extended_predicted: | |
correct_weights_recall += min(weight, extended_predicted[code]) | |
total_predicted_weights = sum(extended_predicted.values()) | |
total_real_weights = sum(extended_real.values()) | |
# Calculate hierarchical precision and recall using weighted sums | |
hP = ( | |
correct_weights_precision / total_predicted_weights | |
if total_predicted_weights | |
else 0 | |
) | |
hR = correct_weights_recall / total_real_weights if total_real_weights else 0 | |
return hP, hR | |
def hierarchical_f_measure(hP, hR, beta=1.0): | |
""" | |
Calculate the hierarchical F-measure. | |
Parameters: | |
hP (float): The hierarchical precision. | |
hR (float): The hierarchical recall. | |
beta (float, optional): The beta value for F-measure calculation. Default is 1.0. | |
Returns: | |
float: The hierarchical F-measure. | |
""" | |
if hP + hR == 0: | |
return 0 | |
return (beta**2 + 1) * hP * hR / (beta**2 * hP + hR) | |
# Example list usage: | |
# reference_codes = ["1111", "1112", "1113", "1114"] | |
# predicted_codes = ["1111", "1113", "1120", "1211"] | |
# hierarchy_dict = {'1111': {'111', '1', '11'}, '1112': {'111', '1', '11'}, '1113': {'111', '1', '11'}, '1114': {'111', '1', '11'} ...} | |
# result = calculate_hierarchical_precision_recall(real_codes, predicted_codes, hierarchy_dict) | |
# print(result) | |