import logging from pathlib import Path from typing import Any, Callable, Dict, Iterator, List, Optional, Union import numpy as np import torch import transformers as tr from reader.data.relik_reader_data_utils import batchify, flatten from reader.data.relik_reader_sample import RelikReaderSample from reader.pytorch_modules.hf.modeling_relik import ( RelikReaderConfig, RelikReaderREModel, ) from tqdm import tqdm from transformers import AutoConfig from relik.common.log import get_console_logger, get_logger from relik.reader.utils.special_symbols import NME_SYMBOL, get_special_symbols_re console_logger = get_console_logger() logger = get_logger(__name__, level=logging.INFO) class RelikReaderForTripletExtraction(torch.nn.Module): def __init__( self, transformer_model: Optional[Union[str, tr.PreTrainedModel]] = None, additional_special_symbols: Optional[int] = 0, num_layers: Optional[int] = None, activation: str = "gelu", linears_hidden_size: Optional[int] = 512, use_last_k_layers: int = 1, training: bool = False, device: Optional[Union[str, torch.device]] = None, tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None, **kwargs, ) -> None: super().__init__() if isinstance(transformer_model, str): config = AutoConfig.from_pretrained( transformer_model, trust_remote_code=True ) if "relik_reader" in config.model_type: transformer_model = RelikReaderREModel.from_pretrained( transformer_model, **kwargs ) else: reader_config = RelikReaderConfig( transformer_model=transformer_model, additional_special_symbols=additional_special_symbols, num_layers=num_layers, activation=activation, linears_hidden_size=linears_hidden_size, use_last_k_layers=use_last_k_layers, training=training, ) transformer_model = RelikReaderREModel(reader_config) self.relik_reader_re_model = transformer_model self._tokenizer = tokenizer # move the model to the device self.to(device or torch.device("cpu")) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor, prediction_mask: Optional[torch.Tensor] = None, special_symbols_mask: Optional[torch.Tensor] = None, special_symbols_mask_entities: Optional[torch.Tensor] = None, start_labels: Optional[torch.Tensor] = None, end_labels: Optional[torch.Tensor] = None, disambiguation_labels: Optional[torch.Tensor] = None, relation_labels: Optional[torch.Tensor] = None, is_validation: bool = False, is_prediction: bool = False, *args, **kwargs, ) -> Dict[str, Any]: return self.relik_reader_re_model( input_ids, attention_mask, token_type_ids, prediction_mask, special_symbols_mask, special_symbols_mask_entities, start_labels, end_labels, disambiguation_labels, relation_labels, is_validation, is_prediction, *args, **kwargs, ) def batch_predict( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: Optional[torch.Tensor] = None, prediction_mask: Optional[torch.Tensor] = None, special_symbols_mask: Optional[torch.Tensor] = None, special_symbols_mask_entities: Optional[torch.Tensor] = None, sample: Optional[List[RelikReaderSample]] = None, *args, **kwargs, ) -> Iterator[RelikReaderSample]: """ Predicts the labels for a batch of samples. Args: input_ids: The input ids of the batch. attention_mask: The attention mask of the batch. token_type_ids: The token type ids of the batch. prediction_mask: The prediction mask of the batch. special_symbols_mask: The special symbols mask of the batch. special_symbols_mask_entities: The special symbols mask entities of the batch. sample: The samples of the batch. Returns: The predicted labels for each sample. """ forward_output = self.forward( input_ids, attention_mask, token_type_ids, prediction_mask, special_symbols_mask, special_symbols_mask_entities, is_prediction=True, ) ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy() ned_end_predictions = forward_output["ned_end_predictions"] # .cpu().numpy() ed_predictions = forward_output["re_entities_predictions"].cpu().numpy() ned_type_predictions = forward_output["ned_type_predictions"].cpu().numpy() re_predictions = forward_output["re_predictions"].cpu().numpy() re_probabilities = forward_output["re_probabilities"].detach().cpu().numpy() if sample is None: sample = [RelikReaderSample() for _ in range(len(input_ids))] for ts, ne_st, ne_end, re_pred, re_prob, edp, ne_et in zip( sample, ned_start_predictions, ned_end_predictions, re_predictions, re_probabilities, ed_predictions, ned_type_predictions, ): ne_end = ne_end.cpu().numpy() entities = [] if self.relik_reader_re_model.entity_type_loss: starts = np.argwhere(ne_st) i = 0 for start, end in zip(starts, ne_end): ends = np.argwhere(end) for e in ends: entities.append([start[0], e[0], ne_et[i]]) i += 1 else: starts = np.argwhere(ne_st) for start, end in zip(starts, ne_end): ends = np.argwhere(end) for e in ends: entities.append([start[0], e[0]]) edp = edp[: len(entities)] re_pred = re_pred[: len(entities), : len(entities)] re_prob = re_prob[: len(entities), : len(entities)] possible_re = np.argwhere(re_pred) predicted_triplets = [] predicted_triplets_prob = [] for i, j, r in possible_re: if self.relik_reader_re_model.relation_disambiguation_loss: if not ( i != j and edp[i, r] == 1 and edp[j, r] == 1 and edp[i, 0] == 0 and edp[j, 0] == 0 ): continue predicted_triplets.append([i, j, r]) predicted_triplets_prob.append(re_prob[i, j, r]) ts._d["predicted_relations"] = predicted_triplets ts._d["predicted_entities"] = entities ts._d["predicted_relations_probabilities"] = predicted_triplets_prob if ts.token2word: self._convert_tokens_to_word_annotations(ts) yield ts def _build_input(self, text: List[str], candidates: List[List[str]]) -> List[int]: candidates_symbols = get_special_symbols_re(len(candidates)) candidates = [ [cs, ct] if ct != NME_SYMBOL else [NME_SYMBOL] for cs, ct in zip(candidates_symbols, candidates) ] return ( [self.tokenizer.cls_token] + text + [self.tokenizer.sep_token] + flatten(candidates) + [self.tokenizer.sep_token] ) @staticmethod def _compute_offsets(offsets_mapping): offsets_mapping = offsets_mapping.numpy() token2word = [] word2token = {} count = 0 for i, offset in enumerate(offsets_mapping): if offset[0] == 0: token2word.append(i - count) word2token[i - count] = [i] else: token2word.append(token2word[-1]) word2token[token2word[-1]].append(i) count += 1 return token2word, word2token @staticmethod def _convert_tokens_to_word_annotations(sample: RelikReaderSample): triplets = [] entities = [] for entity in sample.predicted_entities: if sample.entity_candidates: entities.append( ( sample.token2word[entity[0] - 1], sample.token2word[entity[1] - 1] + 1, sample.entity_candidates[entity[2]], ) ) else: entities.append( ( sample.token2word[entity[0] - 1], sample.token2word[entity[1] - 1] + 1, -1, ) ) for predicted_triplet, predicted_triplet_probabilities in zip( sample.predicted_relations, sample.predicted_relations_probabilities ): subject, object_, relation = predicted_triplet subject = entities[subject] object_ = entities[object_] relation = sample.candidates[relation] triplets.append( { "subject": { "start": subject[0], "end": subject[1], "type": subject[2], "name": " ".join(sample.tokens[subject[0] : subject[1]]), }, "relation": { "name": relation, "probability": float(predicted_triplet_probabilities.round(2)), }, "object": { "start": object_[0], "end": object_[1], "type": object_[2], "name": " ".join(sample.tokens[object_[0] : object_[1]]), }, } ) sample.predicted_entities = entities sample.predicted_relations = triplets sample.predicted_relations_probabilities = None @torch.no_grad() @torch.inference_mode() def read( self, text: Optional[Union[List[str], List[List[str]]]] = None, samples: Optional[List[RelikReaderSample]] = None, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, prediction_mask: Optional[torch.Tensor] = None, special_symbols_mask: Optional[torch.Tensor] = None, special_symbols_mask_entities: Optional[torch.Tensor] = None, candidates: Optional[List[List[str]]] = None, max_length: Optional[int] = 1024, max_batch_size: Optional[int] = 64, token_batch_size: Optional[int] = None, progress_bar: bool = False, *args, **kwargs, ) -> List[List[RelikReaderSample]]: """ Reads the given text. Args: text: The text to read in tokens. input_ids: The input ids of the text. attention_mask: The attention mask of the text. token_type_ids: The token type ids of the text. prediction_mask: The prediction mask of the text. special_symbols_mask: The special symbols mask of the text. special_symbols_mask_entities: The special symbols mask entities of the text. candidates: The candidates of the text. max_length: The maximum length of the text. max_batch_size: The maximum batch size. token_batch_size: The maximum number of tokens per batch. Returns: The predicted labels for each sample. """ if text is None and input_ids is None and samples is None: raise ValueError( "Either `text` or `input_ids` or `samples` must be provided." ) if (input_ids is None and samples is None) and ( text is None or candidates is None ): raise ValueError( "`text` and `candidates` must be provided to return the predictions when `input_ids` and `samples` is not provided." ) if text is not None and samples is None: if len(text) != len(candidates): raise ValueError("`text` and `candidates` must have the same length.") if isinstance(text[0], str): # change to list of text text = [text] candidates = [candidates] samples = [ RelikReaderSample(tokens=t, candidates=c) for t, c in zip(text, candidates) ] if samples is not None: # function that creates a batch from the 'current_batch' list def output_batch() -> Dict[str, Any]: assert ( len( set( [ len(elem["predictable_candidates"]) for elem in current_batch ] ) ) == 1 ), " ".join( map( str, [len(elem["predictable_candidates"]) for elem in current_batch], ) ) batch_dict = dict() de_values_by_field = { fn: [de[fn] for de in current_batch if fn in de] for fn in self.fields_batcher } # in case you provide fields batchers but in the batch # there are no elements for that field de_values_by_field = { fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0 } assert len(set([len(v) for v in de_values_by_field.values()])) # todo: maybe we should report the user about possible # fields filtering due to "None" instances de_values_by_field = { fn: fvs for fn, fvs in de_values_by_field.items() if all([fv is not None for fv in fvs]) } for field_name, field_values in de_values_by_field.items(): field_batch = ( self.fields_batcher[field_name]([fv[0] for fv in field_values]) if self.fields_batcher[field_name] is not None else field_values ) batch_dict[field_name] = field_batch batch_dict = { k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch_dict.items() } return batch_dict current_batch = [] predictions = [] current_cand_len = -1 for sample in tqdm(samples, disable=not progress_bar): sample.candidates = [NME_SYMBOL] + sample.candidates inputs_text = self._build_input(sample.tokens, sample.candidates) model_inputs = self.tokenizer( inputs_text, is_split_into_words=True, add_special_tokens=False, padding=False, truncation=True, max_length=max_length or self.tokenizer.model_max_length, return_offsets_mapping=True, return_tensors="pt", ) model_inputs["special_symbols_mask"] = ( model_inputs["input_ids"] > self.tokenizer.vocab_size ) # prediction mask is 0 until the first special symbol model_inputs["token_type_ids"] = ( torch.cumsum(model_inputs["special_symbols_mask"], dim=1) > 0 ).long() # shift prediction_mask to the left model_inputs["prediction_mask"] = model_inputs["token_type_ids"].roll( shifts=-1, dims=1 ) model_inputs["prediction_mask"][:, -1] = 1 model_inputs["prediction_mask"][:, 0] = 1 assert ( len(model_inputs["special_symbols_mask"]) == len(model_inputs["prediction_mask"]) == len(model_inputs["input_ids"]) ) model_inputs["sample"] = sample # compute cand_len using special_symbols_mask model_inputs["predictable_candidates"] = sample.candidates[ : model_inputs["special_symbols_mask"].sum().item() ] # cand_len = sum([id_ > self.tokenizer.vocab_size for id_ in model_inputs["input_ids"]]) offsets = model_inputs.pop("offset_mapping") offsets = offsets[model_inputs["prediction_mask"] == 0] sample.token2word, sample.word2token = self._compute_offsets(offsets) future_max_len = max( len(model_inputs["input_ids"]), max([len(b["input_ids"]) for b in current_batch], default=0), ) future_tokens_per_batch = future_max_len * (len(current_batch) + 1) if len(current_batch) > 0 and ( ( len(model_inputs["predictable_candidates"]) != current_cand_len and current_cand_len != -1 ) or ( isinstance(token_batch_size, int) and future_tokens_per_batch >= token_batch_size ) or len(current_batch) == max_batch_size ): batch_inputs = output_batch() current_batch = [] predictions.extend(list(self.batch_predict(**batch_inputs))) current_cand_len = len(model_inputs["predictable_candidates"]) current_batch.append(model_inputs) if current_batch: batch_inputs = output_batch() predictions.extend(list(self.batch_predict(**batch_inputs))) else: predictions = list( self.batch_predict( input_ids, attention_mask, token_type_ids, prediction_mask, special_symbols_mask, special_symbols_mask_entities, *args, **kwargs, ) ) return predictions @property def device(self) -> torch.device: """ The device of the model. """ return next(self.parameters()).device @property def tokenizer(self) -> tr.PreTrainedTokenizer: """ The tokenizer. """ if self._tokenizer: return self._tokenizer self._tokenizer = tr.AutoTokenizer.from_pretrained( self.relik_reader_re_model.config.name_or_path ) return self._tokenizer @property def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]: fields_batchers = { "input_ids": lambda x: batchify( x, padding_value=self.tokenizer.pad_token_id ), "attention_mask": lambda x: batchify(x, padding_value=0), "token_type_ids": lambda x: batchify(x, padding_value=0), "prediction_mask": lambda x: batchify(x, padding_value=1), "global_attention": lambda x: batchify(x, padding_value=0), "token2word": None, "sample": None, "special_symbols_mask": lambda x: batchify(x, padding_value=False), "special_symbols_mask_entities": lambda x: batchify(x, padding_value=False), } if "roberta" in self.relik_reader_re_model.config.model_type: del fields_batchers["token_type_ids"] return fields_batchers def save_pretrained( self, output_dir: str, model_name: Optional[str] = None, push_to_hub: bool = False, **kwargs, ) -> None: """ Saves the model to the given path. Args: output_dir: The path to save the model to. model_name: The name of the model. push_to_hub: Whether to push the model to the hub. """ # create the output directory output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) model_name = model_name or "relik_reader_for_triplet_extraction" logger.info(f"Saving reader to {output_dir / model_name}") # save the model self.relik_reader_re_model.register_for_auto_class() self.relik_reader_re_model.save_pretrained( output_dir / model_name, push_to_hub=push_to_hub, **kwargs ) logger.info("Saving reader to disk done.") if self.tokenizer: self.tokenizer.save_pretrained( output_dir / model_name, push_to_hub=push_to_hub, **kwargs ) logger.info("Saving tokenizer to disk done.")