|
from abc import ABC |
|
from typing import * |
|
|
|
from allennlp.training.metrics import Metric |
|
|
|
|
|
class BaseF(Metric, ABC): |
|
def __init__(self, prefix: str): |
|
self.tp = self.fp = self.fn = 0 |
|
self.prefix = prefix |
|
|
|
def reset(self) -> None: |
|
self.tp = self.fp = self.fn = 0 |
|
|
|
def get_metric( |
|
self, reset: bool |
|
) -> Union[float, Tuple[float, ...], Dict[str, float], Dict[str, List[float]]]: |
|
precision = self.tp * 100 / (self.tp + self.fp) if self.tp > 0 else 0. |
|
recall = self.tp * 100 / (self.tp + self.fn) if self.tp > 0 else 0. |
|
rst = { |
|
f'{self.prefix}_p': precision, |
|
f'{self.prefix}_r': recall, |
|
f'{self.prefix}_f': 2 / (1 / precision + 1 / recall) if self.tp > 0 else 0. |
|
} |
|
if reset: |
|
self.reset() |
|
return rst |
|
|