Spaces:
Build error
Build error
from typing import * | |
import torch | |
from allennlp.data.dataset_readers.dataset_utils.span_utils import bio_tags_to_spans | |
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder | |
from allennlp.modules.span_extractors import SpanExtractor | |
from allennlp.training.metrics import FBetaMeasure | |
from ..smooth_crf import SmoothCRF | |
from .span_finder import SpanFinder | |
from ...utils import num2mask, mask2idx, BIO | |
class BIOSpanFinder(SpanFinder): | |
""" | |
Train BIO representations for span finding. | |
""" | |
def __init__( | |
self, | |
bio_encoder: Seq2SeqEncoder, | |
label_emb: torch.nn.Embedding, | |
no_label: bool = True, | |
): | |
super().__init__(no_label) | |
self.bio_encoder = bio_encoder | |
self.label_emb = label_emb | |
self.classifier = torch.nn.Linear(bio_encoder.get_output_dim(), 3) | |
self.crf = SmoothCRF(3) | |
self.fb_measure = FBetaMeasure(1., 'micro', [BIO.index('B'), BIO.index('I')]) | |
def forward( | |
self, | |
token_vec: torch.Tensor, | |
token_mask: torch.Tensor, | |
span_vec: torch.Tensor, | |
span_mask: Optional[torch.Tensor] = None, # Do not need to provide | |
span_labels: Optional[torch.Tensor] = None, # Do not need to provide | |
parent_indices: Optional[torch.Tensor] = None, # Do not need to provide | |
parent_mask: Optional[torch.Tensor] = None, | |
bio_seqs: Optional[torch.Tensor] = None, | |
prediction: bool = False, | |
**extra | |
) -> Dict[str, torch.Tensor]: | |
""" | |
See doc of SpanFinder. | |
Possible extra variables: | |
smoothing_factor | |
:return: | |
- loss | |
- prediction | |
""" | |
ret = dict() | |
is_soft = span_labels.dtype != torch.int64 | |
distinct_parent_indices, num_parents = mask2idx(parent_mask) | |
n_batch, n_parent = distinct_parent_indices.shape | |
n_token = token_vec.shape[1] | |
# Shape [batch, parent, token_dim] | |
parent_span_features = span_vec.gather( | |
1, distinct_parent_indices.unsqueeze(2).expand(-1, -1, span_vec.shape[2]) | |
) | |
label_features = span_labels @ self.label_emb.weight if is_soft else self.label_emb(span_labels) | |
if self._no_label: | |
label_features = label_features.zero_() | |
# Shape [batch, span, label_dim] | |
parent_label_features = label_features.gather( | |
1, distinct_parent_indices.unsqueeze(2).expand(-1, -1, label_features.shape[2]) | |
) | |
# Shape [batch, parent, token, token_dim*2] | |
encoder_inputs = torch.cat([ | |
parent_span_features.unsqueeze(2).expand(-1, -1, n_token, -1), | |
token_vec.unsqueeze(1).expand(-1, n_parent, -1, -1), | |
parent_label_features.unsqueeze(2).expand(-1, -1, n_token, -1), | |
], dim=3) | |
encoder_inputs = encoder_inputs.reshape(n_batch * n_parent, n_token, -1) | |
# Shape [batch, parent]. Considers batches may have fewer seqs. | |
seq_mask = num2mask(num_parents) | |
# Shape [batch, parent, token]. Also considers batches may have fewer tokens. | |
token_mask = seq_mask.unsqueeze(2).expand(-1, -1, n_token) & token_mask.unsqueeze(1).expand(-1, n_parent, -1) | |
class_in = self.bio_encoder(encoder_inputs, token_mask.flatten(0, 1)) | |
class_out = self.classifier(class_in).reshape(n_batch, n_parent, n_token, 3) | |
if not prediction: | |
# For training | |
# We use `seq_mask` here because seq with length 0 is not acceptable. | |
ret['loss'] = -self.crf(class_out[seq_mask], bio_seqs[seq_mask], token_mask[seq_mask]) | |
self.fb_measure(class_out[seq_mask], bio_seqs[seq_mask].max(2).indices, token_mask[seq_mask]) | |
else: | |
# For prediction | |
features_for_decode = class_out.clone().detach() | |
decoded = self.crf.viterbi_tags(features_for_decode.flatten(0, 1), token_mask.flatten(0, 1)) | |
pred_tag = torch.tensor( | |
[path + [BIO.index('O')] * (n_token - len(path)) for path, _ in decoded] | |
) | |
pred_tag = pred_tag.reshape(n_batch, n_parent, n_token) | |
ret['prediction'] = pred_tag | |
return ret | |
def bio2boundary(seqs) -> Tuple[torch.Tensor, torch.Tensor]: | |
def recursive_construct_spans(seqs_): | |
""" | |
Helper function for bio2boundary | |
Recursively convert seqs of integers to boundary indices. | |
Return boundary indices and corresponding lens | |
""" | |
if isinstance(seqs_, torch.Tensor): | |
if seqs_.device.type == 'cuda': | |
seqs_ = seqs_.to(device='cpu') | |
seqs_ = seqs_.tolist() | |
if isinstance(seqs_[0], int): | |
seqs_ = [BIO[i] for i in seqs_] | |
span_boundary_list = bio_tags_to_spans(seqs_) | |
return torch.tensor([item[1] for item in span_boundary_list]), len(span_boundary_list) | |
span_boundary = list() | |
lens_ = list() | |
for seq in seqs_: | |
one_bou, one_len = recursive_construct_spans(seq) | |
span_boundary.append(one_bou) | |
lens_.append(one_len) | |
if isinstance(lens_[0], int): | |
lens_ = torch.tensor(lens_) | |
else: | |
lens_ = torch.stack(lens_) | |
return span_boundary, lens_ | |
boundary_list, lens = recursive_construct_spans(seqs) | |
max_span = int(lens.max()) | |
boundary = torch.zeros((*lens.shape, max_span, 2), dtype=torch.long) | |
def recursive_copy(list_var, tensor_var): | |
if len(list_var) == 0: | |
return | |
if isinstance(list_var, torch.Tensor): | |
tensor_var[:len(list_var)] = list_var | |
return | |
assert len(list_var) == len(tensor_var) | |
for list_var_, tensor_var_ in zip(list_var, tensor_var): | |
recursive_copy(list_var_, tensor_var_) | |
recursive_copy(boundary_list, boundary) | |
return boundary, lens | |
def inference_forward_handler( | |
self, | |
token_vec: torch.Tensor, | |
token_mask: torch.Tensor, | |
span_extractor: SpanExtractor, | |
**auxiliaries, | |
) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], None]: | |
""" | |
Refer to the doc of the SpanFinder for definition of this function. | |
""" | |
def handler( | |
span_boundary: torch.Tensor, | |
span_labels: torch.Tensor, | |
parent_mask: torch.Tensor, | |
parent_indices: torch.Tensor, | |
cursor: torch.tensor, | |
): | |
""" | |
Refer to the doc of the SpanFinder for definition of this function. | |
""" | |
max_decoding_span = span_boundary.shape[1] | |
# Shape [batch, span, token_dim] | |
span_vec = span_extractor(token_vec, span_boundary) | |
# Shape [batch, parent] | |
parent_indices_at_span, _ = mask2idx(parent_mask) | |
pred_bio = self( | |
token_vec, token_mask, span_vec, None, span_labels, None, parent_mask, prediction=True | |
)['prediction'] | |
# Shape [batch, parent, span, 2]; Shape [batch, parent] | |
pred_boundary, pred_num = self.bio2boundary(pred_bio) | |
if pred_boundary.device != span_boundary.device: | |
pred_boundary = pred_boundary.to(device=span_boundary.device) | |
pred_num = pred_num.to(device=span_boundary.device) | |
# Shape [batch, parent, span] | |
pred_mask = num2mask(pred_num) | |
# Parent Loop | |
for pred_boundary_parent, pred_mask_parent, parent_indices_parent \ | |
in zip(pred_boundary.unbind(1), pred_mask.unbind(1), parent_indices_at_span.unbind(1)): | |
for pred_boundary_step, step_mask in zip(pred_boundary_parent.unbind(1), pred_mask_parent.unbind(1)): | |
step_mask &= cursor < max_decoding_span | |
parent_indices[step_mask] = parent_indices[step_mask].scatter( | |
1, | |
cursor[step_mask].unsqueeze(1), | |
parent_indices_parent[step_mask].unsqueeze(1) | |
) | |
span_boundary[step_mask] = span_boundary[step_mask].scatter( | |
1, | |
cursor[step_mask].reshape(-1, 1, 1).expand(-1, -1, 2), | |
pred_boundary_step[step_mask].unsqueeze(1) | |
) | |
cursor[step_mask] += 1 | |
return handler | |
def get_metrics(self, reset: bool = False) -> Dict[str, float]: | |
score = self.fb_measure.get_metric(reset) | |
if reset: | |
return { | |
'finder_p': score['precision'] * 100, | |
'finder_r': score['recall'] * 100, | |
'finder_f': score['fscore'] * 100, | |
} | |
else: | |
return {'finder_f': score['fscore'] * 100} | |