import collections from typing import Any, Dict, Iterator, List, Optional import torch from transformers import AutoModel from transformers.activations import ClippedGELUActivation, GELUActivation from transformers.modeling_utils import PoolerEndLogits from relik.reader.data.relik_reader_sample import RelikReaderSample activation2functions = { "relu": torch.nn.ReLU(), "gelu": GELUActivation(), "gelu_10": ClippedGELUActivation(-10, 10), } class RelikReaderCoreModel(torch.nn.Module): def __init__( self, transformer_model: str, additional_special_symbols: int, num_layers: Optional[int] = None, activation: str = "gelu", linears_hidden_size: Optional[int] = 512, use_last_k_layers: int = 1, training: bool = False, ) -> None: super().__init__() # Transformer model declaration self.transformer_model_name = transformer_model self.transformer_model = ( AutoModel.from_pretrained(transformer_model) if num_layers is None else AutoModel.from_pretrained( transformer_model, num_hidden_layers=num_layers ) ) # self.transformer_model.resize_token_embeddings( # self.transformer_model.config.vocab_size + additional_special_symbols # ) self.activation = activation self.linears_hidden_size = linears_hidden_size self.use_last_k_layers = use_last_k_layers # named entity detection layers self.ned_start_classifier = self._get_projection_layer( self.activation, last_hidden=2, layer_norm=False ) self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config) # END entity disambiguation layer self.ed_start_projector = self._get_projection_layer(self.activation) self.ed_end_projector = self._get_projection_layer(self.activation) self.training = training # criterion self.criterion = torch.nn.CrossEntropyLoss() def _get_projection_layer( self, activation: str, last_hidden: Optional[int] = None, input_hidden=None, layer_norm: bool = True, ) -> torch.nn.Sequential: head_components = [ torch.nn.Dropout(0.1), torch.nn.Linear( self.transformer_model.config.hidden_size * self.use_last_k_layers if input_hidden is None else input_hidden, self.linears_hidden_size, ), activation2functions[activation], torch.nn.Dropout(0.1), torch.nn.Linear( self.linears_hidden_size, self.linears_hidden_size if last_hidden is None else last_hidden, ), ] if layer_norm: head_components.append( torch.nn.LayerNorm( self.linears_hidden_size if last_hidden is None else last_hidden, self.transformer_model.config.layer_norm_eps, ) ) return torch.nn.Sequential(*head_components) def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: mask = mask.unsqueeze(-1) if next(self.parameters()).dtype == torch.float16: logits = logits * (1 - mask) - 65500 * mask else: logits = logits * (1 - mask) - 1e30 * mask return logits def _get_model_features( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: Optional[torch.Tensor], ): model_input = { "input_ids": input_ids, "attention_mask": attention_mask, "output_hidden_states": self.use_last_k_layers > 1, } if token_type_ids is not None: model_input["token_type_ids"] = token_type_ids model_output = self.transformer_model(**model_input) if self.use_last_k_layers > 1: model_features = torch.cat( model_output[1][-self.use_last_k_layers :], dim=-1 ) else: model_features = model_output[0] return model_features def compute_ned_end_logits( self, start_predictions, start_labels, model_features, prediction_mask, batch_size, ) -> Optional[torch.Tensor]: # todo: maybe when constraining on the spans, # we should not use a prediction_mask for the end tokens. # at least we should not during training imo start_positions = start_labels if self.training else start_predictions start_positions_indices = ( torch.arange(start_positions.size(1), device=start_positions.device) .unsqueeze(0) .expand(batch_size, -1)[start_positions > 0] ).to(start_positions.device) if len(start_positions_indices) > 0: expanded_features = torch.cat( [ model_features[i].unsqueeze(0).expand(x, -1, -1) for i, x in enumerate(torch.sum(start_positions > 0, dim=-1)) if x > 0 ], dim=0, ).to(start_positions_indices.device) expanded_prediction_mask = torch.cat( [ prediction_mask[i].unsqueeze(0).expand(x, -1) for i, x in enumerate(torch.sum(start_positions > 0, dim=-1)) if x > 0 ], dim=0, ).to(expanded_features.device) end_logits = self.ned_end_classifier( hidden_states=expanded_features, start_positions=start_positions_indices, p_mask=expanded_prediction_mask, ) return end_logits return None def compute_classification_logits( self, model_features, special_symbols_mask, prediction_mask, batch_size, start_positions=None, end_positions=None, ) -> torch.Tensor: if start_positions is None or end_positions is None: start_positions = torch.zeros_like(prediction_mask) end_positions = torch.zeros_like(prediction_mask) model_start_features = self.ed_start_projector(model_features) model_end_features = self.ed_end_projector(model_features) model_end_features[start_positions > 0] = model_end_features[end_positions > 0] model_ed_features = torch.cat( [model_start_features, model_end_features], dim=-1 ) # computing ed features classes_representations = torch.sum(special_symbols_mask, dim=1)[0].item() special_symbols_representation = model_ed_features[special_symbols_mask].view( batch_size, classes_representations, -1 ) logits = torch.bmm( model_ed_features, torch.permute(special_symbols_representation, (0, 2, 1)), ) logits = self._mask_logits(logits, prediction_mask) return logits def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: Optional[torch.Tensor] = None, prediction_mask: Optional[torch.Tensor] = None, special_symbols_mask: Optional[torch.Tensor] = None, start_labels: Optional[torch.Tensor] = None, end_labels: Optional[torch.Tensor] = None, use_predefined_spans: bool = False, *args, **kwargs, ) -> Dict[str, Any]: batch_size, seq_len = input_ids.shape model_features = self._get_model_features( input_ids, attention_mask, token_type_ids ) # named entity detection if required if use_predefined_spans: # no need to compute spans ned_start_logits, ned_start_probabilities, ned_start_predictions = ( None, None, torch.clone(start_labels) if start_labels is not None else torch.zeros_like(input_ids), ) ned_end_logits, ned_end_probabilities, ned_end_predictions = ( None, None, torch.clone(end_labels) if end_labels is not None else torch.zeros_like(input_ids), ) ned_start_predictions[ned_start_predictions > 0] = 1 ned_end_predictions[ned_end_predictions > 0] = 1 else: # compute spans # start boundary prediction ned_start_logits = self.ned_start_classifier(model_features) ned_start_logits = self._mask_logits(ned_start_logits, prediction_mask) ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1) ned_start_predictions = ned_start_probabilities.argmax(dim=-1) # end boundary prediction ned_start_labels = ( torch.zeros_like(start_labels) if start_labels is not None else None ) if ned_start_labels is not None: ned_start_labels[start_labels == -100] = -100 ned_start_labels[start_labels > 0] = 1 ned_end_logits = self.compute_ned_end_logits( ned_start_predictions, ned_start_labels, model_features, prediction_mask, batch_size, ) if ned_end_logits is not None: ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1) ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1) else: ned_end_logits, ned_end_probabilities = None, None ned_end_predictions = ned_start_predictions.new_zeros(batch_size) # flattening end predictions # (flattening can happen only if the # end boundaries were not predicted using the gold labels) if not self.training: flattened_end_predictions = torch.clone(ned_start_predictions) flattened_end_predictions[flattened_end_predictions > 0] = 0 batch_start_predictions = list() for elem_idx in range(batch_size): batch_start_predictions.append( torch.where(ned_start_predictions[elem_idx] > 0)[0].tolist() ) # check that the total number of start predictions # is equal to the end predictions total_start_predictions = sum(map(len, batch_start_predictions)) total_end_predictions = len(ned_end_predictions) assert ( total_start_predictions == 0 or total_start_predictions == total_end_predictions ), ( f"Total number of start predictions = {total_start_predictions}. " f"Total number of end predictions = {total_end_predictions}" ) curr_end_pred_num = 0 for elem_idx, bsp in enumerate(batch_start_predictions): for sp in bsp: ep = ned_end_predictions[curr_end_pred_num].item() if ep < sp: ep = sp # if we already set this span throw it (no overlap) if flattened_end_predictions[elem_idx, ep] == 1: ned_start_predictions[elem_idx, sp] = 0 else: flattened_end_predictions[elem_idx, ep] = 1 curr_end_pred_num += 1 ned_end_predictions = flattened_end_predictions start_position, end_position = ( (start_labels, end_labels) if self.training else (ned_start_predictions, ned_end_predictions) ) # Entity disambiguation ed_logits = self.compute_classification_logits( model_features, special_symbols_mask, prediction_mask, batch_size, start_position, end_position, ) ed_probabilities = torch.softmax(ed_logits, dim=-1) ed_predictions = torch.argmax(ed_probabilities, dim=-1) # output build output_dict = dict( batch_size=batch_size, ned_start_logits=ned_start_logits, ned_start_probabilities=ned_start_probabilities, ned_start_predictions=ned_start_predictions, ned_end_logits=ned_end_logits, ned_end_probabilities=ned_end_probabilities, ned_end_predictions=ned_end_predictions, ed_logits=ed_logits, ed_probabilities=ed_probabilities, ed_predictions=ed_predictions, ) # compute loss if labels if start_labels is not None and end_labels is not None and self.training: # named entity detection loss # start if ned_start_logits is not None: ned_start_loss = self.criterion( ned_start_logits.view(-1, ned_start_logits.shape[-1]), ned_start_labels.view(-1), ) else: ned_start_loss = 0 # end if ned_end_logits is not None: ned_end_labels = torch.zeros_like(end_labels) ned_end_labels[end_labels == -100] = -100 ned_end_labels[end_labels > 0] = 1 ned_end_loss = self.criterion( ned_end_logits, ( torch.arange( ned_end_labels.size(1), device=ned_end_labels.device ) .unsqueeze(0) .expand(batch_size, -1)[ned_end_labels > 0] ).to(ned_end_labels.device), ) else: ned_end_loss = 0 # entity disambiguation loss start_labels[ned_start_labels != 1] = -100 ed_labels = torch.clone(start_labels) ed_labels[end_labels > 0] = end_labels[end_labels > 0] ed_loss = self.criterion( ed_logits.view(-1, ed_logits.shape[-1]), ed_labels.view(-1), ) output_dict["ned_start_loss"] = ned_start_loss output_dict["ned_end_loss"] = ned_end_loss output_dict["ed_loss"] = ed_loss output_dict["loss"] = ned_start_loss + ned_end_loss + ed_loss return output_dict def batch_predict( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: Optional[torch.Tensor] = None, prediction_mask: Optional[torch.Tensor] = None, special_symbols_mask: Optional[torch.Tensor] = None, sample: Optional[List[RelikReaderSample]] = None, top_k: int = 5, # the amount of top-k most probable entities to predict *args, **kwargs, ) -> Iterator[RelikReaderSample]: forward_output = self.forward( input_ids, attention_mask, token_type_ids, prediction_mask, special_symbols_mask, ) ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy() ned_end_predictions = forward_output["ned_end_predictions"].cpu().numpy() ed_predictions = forward_output["ed_predictions"].cpu().numpy() ed_probabilities = forward_output["ed_probabilities"].cpu().numpy() batch_predictable_candidates = kwargs["predictable_candidates"] patch_offset = kwargs["patch_offset"] for ts, ne_sp, ne_ep, edp, edpr, pred_cands, po in zip( sample, ned_start_predictions, ned_end_predictions, ed_predictions, ed_probabilities, batch_predictable_candidates, patch_offset, ): ne_start_indices = [ti for ti, c in enumerate(ne_sp[1:]) if c > 0] ne_end_indices = [ti for ti, c in enumerate(ne_ep[1:]) if c > 0] final_class2predicted_spans = collections.defaultdict(list) spans2predicted_probabilities = dict() for start_token_index, end_token_index in zip( ne_start_indices, ne_end_indices ): # predicted candidate token_class = edp[start_token_index + 1] - 1 predicted_candidate_title = pred_cands[token_class] final_class2predicted_spans[predicted_candidate_title].append( [start_token_index, end_token_index] ) # candidates probabilities classes_probabilities = edpr[start_token_index + 1] classes_probabilities_best_indices = classes_probabilities.argsort()[ ::-1 ] titles_2_probs = [] top_k = ( min( top_k, len(classes_probabilities_best_indices), ) if top_k != -1 else len(classes_probabilities_best_indices) ) for i in range(top_k): titles_2_probs.append( ( pred_cands[classes_probabilities_best_indices[i] - 1], classes_probabilities[ classes_probabilities_best_indices[i] ].item(), ) ) spans2predicted_probabilities[ (start_token_index, end_token_index) ] = titles_2_probs if "patches" not in ts._d: ts._d["patches"] = dict() ts._d["patches"][po] = dict() sample_patch = ts._d["patches"][po] sample_patch["predicted_window_labels"] = final_class2predicted_spans sample_patch["span_title_probabilities"] = spans2predicted_probabilities # additional info sample_patch["predictable_candidates"] = pred_cands yield ts