from abc import ABC, abstractmethod from typing import * import torch from allennlp.common import Registrable from allennlp.modules.span_extractors import SpanExtractor class SpanFinder(Registrable, ABC, torch.nn.Module): """ Model the probability p(child_span | parent_span [, parent_label]) It's optional to model parent_label, since in some cases we may want the parameters to be shared across different tasks, where we may have similar span semantics but different label space. """ def __init__( self, no_label: bool = True, ): """ :param no_label: If True, will not use input labels as features and use all 0 vector instead. """ super().__init__() self._no_label = no_label @abstractmethod def forward( self, token_vec: torch.Tensor, token_mask: torch.Tensor, span_vec: torch.Tensor, span_mask: Optional[torch.Tensor] = None, # Do not need to provide span_labels: Optional[torch.Tensor] = None, # Do not need to provide parent_indices: Optional[torch.Tensor] = None, # Do not need to provide parent_mask: Optional[torch.Tensor] = None, bio_seqs: Optional[torch.Tensor] = None, prediction: bool = False, **extra ) -> Dict[str, torch.Tensor]: """ Return training loss and predictions. :param token_vec: Vector representation of tokens. Shape [batch, token ,token_dim] :param token_mask: True for non-padding tokens. :param span_vec: Vector representation of spans. Shape [batch, span, token_dim] :param span_mask: True for non-padding spans. Shape [batch, span] :param span_labels: The labels of spans. Shape [batch, span] :param parent_indices: Parent indices of spans. Shape [batch, span] :param parent_mask: True for parent spans. Shape [batch, span] :param prediction: If True, no loss will be return & no metrics will be updated. :param bio_seqs: BIO sequences. Shape [batch, parent, token, 3] :return: loss: Training loss prediction: Shape [batch, span]. True for positive predictions. """ raise NotImplementedError @abstractmethod def inference_forward_handler( self, token_vec: torch.Tensor, token_mask: torch.Tensor, span_extractor: SpanExtractor, **auxiliaries, ) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], None]: """ Pre-process some information and return a callable module for p(child_span | parent_span [,parent_label]) :param token_vec: Vector representation of tokens. Shape [batch, token ,token_dim] :param token_mask: True for non-padding tokens. :param span_extractor: The same module in model. :param auxiliaries: Environment variables. You can pass extra environment variables since the extras will be ignored. :return: A callable function in a closure. The arguments for the callable object are: - span_boundary: Shape [batch, span, 2] - span_labels: Shape [batch, span] - parent_mask: Shape [batch, span] - parent_indices: Shape [batch, span] - cursor: Shape [batch] No return values. Everything should be done inplace. Note the span indexing space has different meaning from training process. We don't have gold span list, so span here refers to the predicted spans. """ raise NotImplementedError @abstractmethod def get_metrics(self, reset: bool = False) -> Dict[str, float]: raise NotImplementedError