gossminn's picture
First version
6680682
raw
history blame
840 Bytes
from abc import ABC
from typing import *
from allennlp.training.metrics import Metric
class BaseF(Metric, ABC):
def __init__(self, prefix: str):
self.tp = self.fp = self.fn = 0
self.prefix = prefix
def reset(self) -> None:
self.tp = self.fp = self.fn = 0
def get_metric(
self, reset: bool
) -> Union[float, Tuple[float, ...], Dict[str, float], Dict[str, List[float]]]:
precision = self.tp * 100 / (self.tp + self.fp) if self.tp > 0 else 0.
recall = self.tp * 100 / (self.tp + self.fn) if self.tp > 0 else 0.
rst = {
f'{self.prefix}_p': precision,
f'{self.prefix}_r': recall,
f'{self.prefix}_f': 2 / (1 / precision + 1 / recall) if self.tp > 0 else 0.
}
if reset:
self.reset()
return rst