Spaces:
Build error
Build error
from abc import ABC | |
from typing import * | |
import torch | |
from allennlp.common import Registrable | |
from allennlp.data.vocabulary import DEFAULT_OOV_TOKEN, Vocabulary | |
from allennlp.training.metrics import CategoricalAccuracy | |
class SpanTyping(Registrable, torch.nn.Module, ABC): | |
""" | |
Models the probability p(child_label | child_span, parent_span, parent_label). | |
""" | |
def __init__( | |
self, | |
n_label: int, | |
label_to_ignore: Optional[List[int]] = None, | |
): | |
""" | |
:param label_to_ignore: Label indexes in this list will be ignored. | |
Usually this should include NULL, PADDING and UNKNOWN. | |
""" | |
super().__init__() | |
self.label_to_ignore = label_to_ignore or list() | |
self.acc_metric = CategoricalAccuracy() | |
self.onto = torch.ones([n_label, n_label], dtype=torch.bool) | |
self.register_buffer('ontology', self.onto) | |
def load_ontology(self, path: str, vocab: Vocabulary): | |
unk_id = vocab.get_token_index(DEFAULT_OOV_TOKEN, 'span_label') | |
for line in open(path).readlines(): | |
entities = [vocab.get_token_index(ent, 'span_label') for ent in line.replace('\n', '').split('\t')] | |
parent, children = entities[0], entities[1:] | |
if parent == unk_id: | |
continue | |
self.onto[parent, :] = False | |
children = list(filter(lambda x: x != unk_id, children)) | |
self.onto[parent, children] = True | |
self.register_buffer('ontology', self.onto) | |
def forward( | |
self, | |
span_vec: torch.Tensor, | |
parent_at_span: torch.Tensor, | |
span_labels: Optional[torch.Tensor], | |
prediction_only: bool = False, | |
) -> Dict[str, torch.Tensor]: | |
""" | |
Inputs: All features for typing a child span. | |
Output: The loss of typing and predictions. | |
:param span_vec: Shape [batch, span, token_dim] | |
:param parent_at_span: Shape [batch, span] | |
:param span_labels: Shape [batch, span] | |
:param prediction_only: If True, no loss returned & metric will not be updated | |
:return: | |
loss: Loss for label prediction. (absent of pred_only = True) | |
prediction: Predicted labels. | |
""" | |
raise NotImplementedError | |
def get_metric(self, reset): | |
return{ | |
"typing_acc": self.acc_metric.get_metric(reset) * 100 | |
} | |