|
import collections |
|
from typing import Any, Dict, Iterator, List, Optional |
|
|
|
import torch |
|
from transformers import AutoModel |
|
from transformers.activations import ClippedGELUActivation, GELUActivation |
|
from transformers.modeling_utils import PoolerEndLogits |
|
|
|
from relik.reader.data.relik_reader_sample import RelikReaderSample |
|
|
|
activation2functions = { |
|
"relu": torch.nn.ReLU(), |
|
"gelu": GELUActivation(), |
|
"gelu_10": ClippedGELUActivation(-10, 10), |
|
} |
|
|
|
|
|
class RelikReaderCoreModel(torch.nn.Module): |
|
def __init__( |
|
self, |
|
transformer_model: str, |
|
additional_special_symbols: int, |
|
num_layers: Optional[int] = None, |
|
activation: str = "gelu", |
|
linears_hidden_size: Optional[int] = 512, |
|
use_last_k_layers: int = 1, |
|
training: bool = False, |
|
) -> None: |
|
super().__init__() |
|
|
|
|
|
self.transformer_model_name = transformer_model |
|
self.transformer_model = ( |
|
AutoModel.from_pretrained(transformer_model) |
|
if num_layers is None |
|
else AutoModel.from_pretrained( |
|
transformer_model, num_hidden_layers=num_layers |
|
) |
|
) |
|
self.transformer_model.resize_token_embeddings( |
|
self.transformer_model.config.vocab_size + additional_special_symbols |
|
) |
|
|
|
self.activation = activation |
|
self.linears_hidden_size = linears_hidden_size |
|
self.use_last_k_layers = use_last_k_layers |
|
|
|
|
|
self.ned_start_classifier = self._get_projection_layer( |
|
self.activation, last_hidden=2, layer_norm=False |
|
) |
|
self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config) |
|
|
|
|
|
self.ed_start_projector = self._get_projection_layer(self.activation) |
|
self.ed_end_projector = self._get_projection_layer(self.activation) |
|
|
|
self.training = training |
|
|
|
|
|
self.criterion = torch.nn.CrossEntropyLoss() |
|
|
|
def _get_projection_layer( |
|
self, |
|
activation: str, |
|
last_hidden: Optional[int] = None, |
|
input_hidden=None, |
|
layer_norm: bool = True, |
|
) -> torch.nn.Sequential: |
|
head_components = [ |
|
torch.nn.Dropout(0.1), |
|
torch.nn.Linear( |
|
self.transformer_model.config.hidden_size * self.use_last_k_layers |
|
if input_hidden is None |
|
else input_hidden, |
|
self.linears_hidden_size, |
|
), |
|
activation2functions[activation], |
|
torch.nn.Dropout(0.1), |
|
torch.nn.Linear( |
|
self.linears_hidden_size, |
|
self.linears_hidden_size if last_hidden is None else last_hidden, |
|
), |
|
] |
|
|
|
if layer_norm: |
|
head_components.append( |
|
torch.nn.LayerNorm( |
|
self.linears_hidden_size if last_hidden is None else last_hidden, |
|
self.transformer_model.config.layer_norm_eps, |
|
) |
|
) |
|
|
|
return torch.nn.Sequential(*head_components) |
|
|
|
def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
|
mask = mask.unsqueeze(-1) |
|
if next(self.parameters()).dtype == torch.float16: |
|
logits = logits * (1 - mask) - 65500 * mask |
|
else: |
|
logits = logits * (1 - mask) - 1e30 * mask |
|
return logits |
|
|
|
def _get_model_features( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
token_type_ids: Optional[torch.Tensor], |
|
): |
|
model_input = { |
|
"input_ids": input_ids, |
|
"attention_mask": attention_mask, |
|
"output_hidden_states": self.use_last_k_layers > 1, |
|
} |
|
|
|
if token_type_ids is not None: |
|
model_input["token_type_ids"] = token_type_ids |
|
|
|
model_output = self.transformer_model(**model_input) |
|
|
|
if self.use_last_k_layers > 1: |
|
model_features = torch.cat( |
|
model_output[1][-self.use_last_k_layers :], dim=-1 |
|
) |
|
else: |
|
model_features = model_output[0] |
|
|
|
return model_features |
|
|
|
def compute_ned_end_logits( |
|
self, |
|
start_predictions, |
|
start_labels, |
|
model_features, |
|
prediction_mask, |
|
batch_size, |
|
) -> Optional[torch.Tensor]: |
|
|
|
|
|
|
|
start_positions = start_labels if self.training else start_predictions |
|
start_positions_indices = ( |
|
torch.arange(start_positions.size(1), device=start_positions.device) |
|
.unsqueeze(0) |
|
.expand(batch_size, -1)[start_positions > 0] |
|
).to(start_positions.device) |
|
|
|
if len(start_positions_indices) > 0: |
|
expanded_features = torch.cat( |
|
[ |
|
model_features[i].unsqueeze(0).expand(x, -1, -1) |
|
for i, x in enumerate(torch.sum(start_positions > 0, dim=-1)) |
|
if x > 0 |
|
], |
|
dim=0, |
|
).to(start_positions_indices.device) |
|
|
|
expanded_prediction_mask = torch.cat( |
|
[ |
|
prediction_mask[i].unsqueeze(0).expand(x, -1) |
|
for i, x in enumerate(torch.sum(start_positions > 0, dim=-1)) |
|
if x > 0 |
|
], |
|
dim=0, |
|
).to(expanded_features.device) |
|
|
|
end_logits = self.ned_end_classifier( |
|
hidden_states=expanded_features, |
|
start_positions=start_positions_indices, |
|
p_mask=expanded_prediction_mask, |
|
) |
|
|
|
return end_logits |
|
|
|
return None |
|
|
|
def compute_classification_logits( |
|
self, |
|
model_features, |
|
special_symbols_mask, |
|
prediction_mask, |
|
batch_size, |
|
start_positions=None, |
|
end_positions=None, |
|
) -> torch.Tensor: |
|
if start_positions is None or end_positions is None: |
|
start_positions = torch.zeros_like(prediction_mask) |
|
end_positions = torch.zeros_like(prediction_mask) |
|
|
|
model_start_features = self.ed_start_projector(model_features) |
|
model_end_features = self.ed_end_projector(model_features) |
|
model_end_features[start_positions > 0] = model_end_features[end_positions > 0] |
|
|
|
model_ed_features = torch.cat( |
|
[model_start_features, model_end_features], dim=-1 |
|
) |
|
|
|
|
|
classes_representations = torch.sum(special_symbols_mask, dim=1)[0].item() |
|
special_symbols_representation = model_ed_features[special_symbols_mask].view( |
|
batch_size, classes_representations, -1 |
|
) |
|
|
|
logits = torch.bmm( |
|
model_ed_features, |
|
torch.permute(special_symbols_representation, (0, 2, 1)), |
|
) |
|
|
|
logits = self._mask_logits(logits, prediction_mask) |
|
|
|
return logits |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
prediction_mask: Optional[torch.Tensor] = None, |
|
special_symbols_mask: Optional[torch.Tensor] = None, |
|
start_labels: Optional[torch.Tensor] = None, |
|
end_labels: Optional[torch.Tensor] = None, |
|
use_predefined_spans: bool = False, |
|
*args, |
|
**kwargs, |
|
) -> Dict[str, Any]: |
|
batch_size, seq_len = input_ids.shape |
|
|
|
model_features = self._get_model_features( |
|
input_ids, attention_mask, token_type_ids |
|
) |
|
|
|
|
|
if use_predefined_spans: |
|
ned_start_logits, ned_start_probabilities, ned_start_predictions = ( |
|
None, |
|
None, |
|
torch.clone(start_labels) |
|
if start_labels is not None |
|
else torch.zeros_like(input_ids), |
|
) |
|
ned_end_logits, ned_end_probabilities, ned_end_predictions = ( |
|
None, |
|
None, |
|
torch.clone(end_labels) |
|
if end_labels is not None |
|
else torch.zeros_like(input_ids), |
|
) |
|
|
|
ned_start_predictions[ned_start_predictions > 0] = 1 |
|
ned_end_predictions[ned_end_predictions > 0] = 1 |
|
|
|
else: |
|
|
|
ned_start_logits = self.ned_start_classifier(model_features) |
|
ned_start_logits = self._mask_logits(ned_start_logits, prediction_mask) |
|
ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1) |
|
ned_start_predictions = ned_start_probabilities.argmax(dim=-1) |
|
|
|
|
|
ned_start_labels = ( |
|
torch.zeros_like(start_labels) if start_labels is not None else None |
|
) |
|
|
|
if ned_start_labels is not None: |
|
ned_start_labels[start_labels == -100] = -100 |
|
ned_start_labels[start_labels > 0] = 1 |
|
|
|
ned_end_logits = self.compute_ned_end_logits( |
|
ned_start_predictions, |
|
ned_start_labels, |
|
model_features, |
|
prediction_mask, |
|
batch_size, |
|
) |
|
|
|
if ned_end_logits is not None: |
|
ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1) |
|
ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1) |
|
else: |
|
ned_end_logits, ned_end_probabilities = None, None |
|
ned_end_predictions = ned_start_predictions.new_zeros(batch_size) |
|
|
|
|
|
|
|
|
|
if not self.training: |
|
flattened_end_predictions = torch.clone(ned_start_predictions) |
|
flattened_end_predictions[flattened_end_predictions > 0] = 0 |
|
|
|
batch_start_predictions = list() |
|
for elem_idx in range(batch_size): |
|
batch_start_predictions.append( |
|
torch.where(ned_start_predictions[elem_idx] > 0)[0].tolist() |
|
) |
|
|
|
|
|
|
|
total_start_predictions = sum(map(len, batch_start_predictions)) |
|
total_end_predictions = len(ned_end_predictions) |
|
assert ( |
|
total_start_predictions == 0 |
|
or total_start_predictions == total_end_predictions |
|
), ( |
|
f"Total number of start predictions = {total_start_predictions}. " |
|
f"Total number of end predictions = {total_end_predictions}" |
|
) |
|
|
|
curr_end_pred_num = 0 |
|
for elem_idx, bsp in enumerate(batch_start_predictions): |
|
for sp in bsp: |
|
ep = ned_end_predictions[curr_end_pred_num].item() |
|
if ep < sp: |
|
ep = sp |
|
|
|
|
|
if flattened_end_predictions[elem_idx, ep] == 1: |
|
ned_start_predictions[elem_idx, sp] = 0 |
|
else: |
|
flattened_end_predictions[elem_idx, ep] = 1 |
|
|
|
curr_end_pred_num += 1 |
|
|
|
ned_end_predictions = flattened_end_predictions |
|
|
|
start_position, end_position = ( |
|
(start_labels, end_labels) |
|
if self.training |
|
else (ned_start_predictions, ned_end_predictions) |
|
) |
|
|
|
|
|
ed_logits = self.compute_classification_logits( |
|
model_features, |
|
special_symbols_mask, |
|
prediction_mask, |
|
batch_size, |
|
start_position, |
|
end_position, |
|
) |
|
ed_probabilities = torch.softmax(ed_logits, dim=-1) |
|
ed_predictions = torch.argmax(ed_probabilities, dim=-1) |
|
|
|
|
|
output_dict = dict( |
|
batch_size=batch_size, |
|
ned_start_logits=ned_start_logits, |
|
ned_start_probabilities=ned_start_probabilities, |
|
ned_start_predictions=ned_start_predictions, |
|
ned_end_logits=ned_end_logits, |
|
ned_end_probabilities=ned_end_probabilities, |
|
ned_end_predictions=ned_end_predictions, |
|
ed_logits=ed_logits, |
|
ed_probabilities=ed_probabilities, |
|
ed_predictions=ed_predictions, |
|
) |
|
|
|
|
|
if start_labels is not None and end_labels is not None and self.training: |
|
|
|
|
|
|
|
if ned_start_logits is not None: |
|
ned_start_loss = self.criterion( |
|
ned_start_logits.view(-1, ned_start_logits.shape[-1]), |
|
ned_start_labels.view(-1), |
|
) |
|
else: |
|
ned_start_loss = 0 |
|
|
|
|
|
if ned_end_logits is not None: |
|
ned_end_labels = torch.zeros_like(end_labels) |
|
ned_end_labels[end_labels == -100] = -100 |
|
ned_end_labels[end_labels > 0] = 1 |
|
|
|
ned_end_loss = self.criterion( |
|
ned_end_logits, |
|
( |
|
torch.arange( |
|
ned_end_labels.size(1), device=ned_end_labels.device |
|
) |
|
.unsqueeze(0) |
|
.expand(batch_size, -1)[ned_end_labels > 0] |
|
).to(ned_end_labels.device), |
|
) |
|
|
|
else: |
|
ned_end_loss = 0 |
|
|
|
|
|
start_labels[ned_start_labels != 1] = -100 |
|
ed_labels = torch.clone(start_labels) |
|
ed_labels[end_labels > 0] = end_labels[end_labels > 0] |
|
ed_loss = self.criterion( |
|
ed_logits.view(-1, ed_logits.shape[-1]), |
|
ed_labels.view(-1), |
|
) |
|
|
|
output_dict["ned_start_loss"] = ned_start_loss |
|
output_dict["ned_end_loss"] = ned_end_loss |
|
output_dict["ed_loss"] = ed_loss |
|
|
|
output_dict["loss"] = ned_start_loss + ned_end_loss + ed_loss |
|
|
|
return output_dict |
|
|
|
def batch_predict( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
prediction_mask: Optional[torch.Tensor] = None, |
|
special_symbols_mask: Optional[torch.Tensor] = None, |
|
sample: Optional[List[RelikReaderSample]] = None, |
|
top_k: int = 5, |
|
*args, |
|
**kwargs, |
|
) -> Iterator[RelikReaderSample]: |
|
forward_output = self.forward( |
|
input_ids, |
|
attention_mask, |
|
token_type_ids, |
|
prediction_mask, |
|
special_symbols_mask, |
|
) |
|
|
|
ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy() |
|
ned_end_predictions = forward_output["ned_end_predictions"].cpu().numpy() |
|
ed_predictions = forward_output["ed_predictions"].cpu().numpy() |
|
ed_probabilities = forward_output["ed_probabilities"].cpu().numpy() |
|
|
|
batch_predictable_candidates = kwargs["predictable_candidates"] |
|
patch_offset = kwargs["patch_offset"] |
|
for ts, ne_sp, ne_ep, edp, edpr, pred_cands, po in zip( |
|
sample, |
|
ned_start_predictions, |
|
ned_end_predictions, |
|
ed_predictions, |
|
ed_probabilities, |
|
batch_predictable_candidates, |
|
patch_offset, |
|
): |
|
ne_start_indices = [ti for ti, c in enumerate(ne_sp[1:]) if c > 0] |
|
ne_end_indices = [ti for ti, c in enumerate(ne_ep[1:]) if c > 0] |
|
|
|
final_class2predicted_spans = collections.defaultdict(list) |
|
spans2predicted_probabilities = dict() |
|
for start_token_index, end_token_index in zip( |
|
ne_start_indices, ne_end_indices |
|
): |
|
|
|
token_class = edp[start_token_index + 1] - 1 |
|
predicted_candidate_title = pred_cands[token_class] |
|
final_class2predicted_spans[predicted_candidate_title].append( |
|
[start_token_index, end_token_index] |
|
) |
|
|
|
|
|
classes_probabilities = edpr[start_token_index + 1] |
|
classes_probabilities_best_indices = classes_probabilities.argsort()[ |
|
::-1 |
|
] |
|
titles_2_probs = [] |
|
top_k = ( |
|
min( |
|
top_k, |
|
len(classes_probabilities_best_indices), |
|
) |
|
if top_k != -1 |
|
else len(classes_probabilities_best_indices) |
|
) |
|
for i in range(top_k): |
|
titles_2_probs.append( |
|
( |
|
pred_cands[classes_probabilities_best_indices[i] - 1], |
|
classes_probabilities[ |
|
classes_probabilities_best_indices[i] |
|
].item(), |
|
) |
|
) |
|
spans2predicted_probabilities[ |
|
(start_token_index, end_token_index) |
|
] = titles_2_probs |
|
|
|
if "patches" not in ts._d: |
|
ts._d["patches"] = dict() |
|
|
|
ts._d["patches"][po] = dict() |
|
sample_patch = ts._d["patches"][po] |
|
|
|
sample_patch["predicted_window_labels"] = final_class2predicted_spans |
|
sample_patch["span_title_probabilities"] = spans2predicted_probabilities |
|
|
|
|
|
sample_patch["predictable_candidates"] = pred_cands |
|
|
|
yield ts |
|
|