Spaces:
Build error
Build error
File size: 9,156 Bytes
6680682 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
from typing import *
import torch
from allennlp.data.dataset_readers.dataset_utils.span_utils import bio_tags_to_spans
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder
from allennlp.modules.span_extractors import SpanExtractor
from allennlp.training.metrics import FBetaMeasure
from ..smooth_crf import SmoothCRF
from .span_finder import SpanFinder
from ...utils import num2mask, mask2idx, BIO
@SpanFinder.register("bio")
class BIOSpanFinder(SpanFinder):
"""
Train BIO representations for span finding.
"""
def __init__(
self,
bio_encoder: Seq2SeqEncoder,
label_emb: torch.nn.Embedding,
no_label: bool = True,
):
super().__init__(no_label)
self.bio_encoder = bio_encoder
self.label_emb = label_emb
self.classifier = torch.nn.Linear(bio_encoder.get_output_dim(), 3)
self.crf = SmoothCRF(3)
self.fb_measure = FBetaMeasure(1., 'micro', [BIO.index('B'), BIO.index('I')])
def forward(
self,
token_vec: torch.Tensor,
token_mask: torch.Tensor,
span_vec: torch.Tensor,
span_mask: Optional[torch.Tensor] = None, # Do not need to provide
span_labels: Optional[torch.Tensor] = None, # Do not need to provide
parent_indices: Optional[torch.Tensor] = None, # Do not need to provide
parent_mask: Optional[torch.Tensor] = None,
bio_seqs: Optional[torch.Tensor] = None,
prediction: bool = False,
**extra
) -> Dict[str, torch.Tensor]:
"""
See doc of SpanFinder.
Possible extra variables:
smoothing_factor
:return:
- loss
- prediction
"""
ret = dict()
is_soft = span_labels.dtype != torch.int64
distinct_parent_indices, num_parents = mask2idx(parent_mask)
n_batch, n_parent = distinct_parent_indices.shape
n_token = token_vec.shape[1]
# Shape [batch, parent, token_dim]
parent_span_features = span_vec.gather(
1, distinct_parent_indices.unsqueeze(2).expand(-1, -1, span_vec.shape[2])
)
label_features = span_labels @ self.label_emb.weight if is_soft else self.label_emb(span_labels)
if self._no_label:
label_features = label_features.zero_()
# Shape [batch, span, label_dim]
parent_label_features = label_features.gather(
1, distinct_parent_indices.unsqueeze(2).expand(-1, -1, label_features.shape[2])
)
# Shape [batch, parent, token, token_dim*2]
encoder_inputs = torch.cat([
parent_span_features.unsqueeze(2).expand(-1, -1, n_token, -1),
token_vec.unsqueeze(1).expand(-1, n_parent, -1, -1),
parent_label_features.unsqueeze(2).expand(-1, -1, n_token, -1),
], dim=3)
encoder_inputs = encoder_inputs.reshape(n_batch * n_parent, n_token, -1)
# Shape [batch, parent]. Considers batches may have fewer seqs.
seq_mask = num2mask(num_parents)
# Shape [batch, parent, token]. Also considers batches may have fewer tokens.
token_mask = seq_mask.unsqueeze(2).expand(-1, -1, n_token) & token_mask.unsqueeze(1).expand(-1, n_parent, -1)
class_in = self.bio_encoder(encoder_inputs, token_mask.flatten(0, 1))
class_out = self.classifier(class_in).reshape(n_batch, n_parent, n_token, 3)
if not prediction:
# For training
# We use `seq_mask` here because seq with length 0 is not acceptable.
ret['loss'] = -self.crf(class_out[seq_mask], bio_seqs[seq_mask], token_mask[seq_mask])
self.fb_measure(class_out[seq_mask], bio_seqs[seq_mask].max(2).indices, token_mask[seq_mask])
else:
# For prediction
features_for_decode = class_out.clone().detach()
decoded = self.crf.viterbi_tags(features_for_decode.flatten(0, 1), token_mask.flatten(0, 1))
pred_tag = torch.tensor(
[path + [BIO.index('O')] * (n_token - len(path)) for path, _ in decoded]
)
pred_tag = pred_tag.reshape(n_batch, n_parent, n_token)
ret['prediction'] = pred_tag
return ret
@staticmethod
def bio2boundary(seqs) -> Tuple[torch.Tensor, torch.Tensor]:
def recursive_construct_spans(seqs_):
"""
Helper function for bio2boundary
Recursively convert seqs of integers to boundary indices.
Return boundary indices and corresponding lens
"""
if isinstance(seqs_, torch.Tensor):
if seqs_.device.type == 'cuda':
seqs_ = seqs_.to(device='cpu')
seqs_ = seqs_.tolist()
if isinstance(seqs_[0], int):
seqs_ = [BIO[i] for i in seqs_]
span_boundary_list = bio_tags_to_spans(seqs_)
return torch.tensor([item[1] for item in span_boundary_list]), len(span_boundary_list)
span_boundary = list()
lens_ = list()
for seq in seqs_:
one_bou, one_len = recursive_construct_spans(seq)
span_boundary.append(one_bou)
lens_.append(one_len)
if isinstance(lens_[0], int):
lens_ = torch.tensor(lens_)
else:
lens_ = torch.stack(lens_)
return span_boundary, lens_
boundary_list, lens = recursive_construct_spans(seqs)
max_span = int(lens.max())
boundary = torch.zeros((*lens.shape, max_span, 2), dtype=torch.long)
def recursive_copy(list_var, tensor_var):
if len(list_var) == 0:
return
if isinstance(list_var, torch.Tensor):
tensor_var[:len(list_var)] = list_var
return
assert len(list_var) == len(tensor_var)
for list_var_, tensor_var_ in zip(list_var, tensor_var):
recursive_copy(list_var_, tensor_var_)
recursive_copy(boundary_list, boundary)
return boundary, lens
def inference_forward_handler(
self,
token_vec: torch.Tensor,
token_mask: torch.Tensor,
span_extractor: SpanExtractor,
**auxiliaries,
) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], None]:
"""
Refer to the doc of the SpanFinder for definition of this function.
"""
def handler(
span_boundary: torch.Tensor,
span_labels: torch.Tensor,
parent_mask: torch.Tensor,
parent_indices: torch.Tensor,
cursor: torch.tensor,
):
"""
Refer to the doc of the SpanFinder for definition of this function.
"""
max_decoding_span = span_boundary.shape[1]
# Shape [batch, span, token_dim]
span_vec = span_extractor(token_vec, span_boundary)
# Shape [batch, parent]
parent_indices_at_span, _ = mask2idx(parent_mask)
pred_bio = self(
token_vec, token_mask, span_vec, None, span_labels, None, parent_mask, prediction=True
)['prediction']
# Shape [batch, parent, span, 2]; Shape [batch, parent]
pred_boundary, pred_num = self.bio2boundary(pred_bio)
if pred_boundary.device != span_boundary.device:
pred_boundary = pred_boundary.to(device=span_boundary.device)
pred_num = pred_num.to(device=span_boundary.device)
# Shape [batch, parent, span]
pred_mask = num2mask(pred_num)
# Parent Loop
for pred_boundary_parent, pred_mask_parent, parent_indices_parent \
in zip(pred_boundary.unbind(1), pred_mask.unbind(1), parent_indices_at_span.unbind(1)):
for pred_boundary_step, step_mask in zip(pred_boundary_parent.unbind(1), pred_mask_parent.unbind(1)):
step_mask &= cursor < max_decoding_span
parent_indices[step_mask] = parent_indices[step_mask].scatter(
1,
cursor[step_mask].unsqueeze(1),
parent_indices_parent[step_mask].unsqueeze(1)
)
span_boundary[step_mask] = span_boundary[step_mask].scatter(
1,
cursor[step_mask].reshape(-1, 1, 1).expand(-1, -1, 2),
pred_boundary_step[step_mask].unsqueeze(1)
)
cursor[step_mask] += 1
return handler
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
score = self.fb_measure.get_metric(reset)
if reset:
return {
'finder_p': score['precision'] * 100,
'finder_r': score['recall'] * 100,
'finder_f': score['fscore'] * 100,
}
else:
return {'finder_f': score['fscore'] * 100}
|