|
import contextlib |
|
import logging |
|
from typing import Any, Dict, Iterator, List |
|
|
|
import numpy as np |
|
import torch |
|
import transformers as tr |
|
from lightning_fabric.utilities import move_data_to_device |
|
from torch.utils.data import DataLoader, IterableDataset |
|
from tqdm import tqdm |
|
|
|
from relik.common.log import get_logger |
|
from relik.common.torch_utils import get_autocast_context |
|
from relik.common.utils import get_callable_from_string |
|
from relik.inference.data.objects import AnnotationType |
|
from relik.reader.data.relik_reader_sample import RelikReaderSample |
|
from relik.reader.pytorch_modules.base import RelikReaderBase |
|
from relik.retriever.pytorch_modules import PRECISION_MAP |
|
|
|
logger = get_logger(__name__, level=logging.INFO) |
|
|
|
|
|
class RelikReaderForTripletExtraction(RelikReaderBase): |
|
""" |
|
A class for the RelikReader model for triplet extraction. |
|
|
|
Args: |
|
transformer_model (:obj:`str` or :obj:`transformers.PreTrainedModel` or :obj:`None`, `optional`): |
|
The transformer model to use. If `None`, the default model is used. |
|
additional_special_symbols (:obj:`int`, `optional`, defaults to 0): |
|
The number of additional special symbols to add to the tokenizer. |
|
num_layers (:obj:`int`, `optional`): |
|
The number of layers to use. If `None`, all layers are used. |
|
activation (:obj:`str`, `optional`, defaults to "gelu"): |
|
The activation function to use. |
|
linears_hidden_size (:obj:`int`, `optional`, defaults to 512): |
|
The hidden size of the linears. |
|
use_last_k_layers (:obj:`int`, `optional`, defaults to 1): |
|
The number of last layers to use. |
|
training (:obj:`bool`, `optional`, defaults to False): |
|
Whether the model is in training mode. |
|
device (:obj:`str` or :obj:`torch.device` or :obj:`None`, `optional`): |
|
The device to use. If `None`, the default device is used. |
|
tokenizer (:obj:`str` or :obj:`transformers.PreTrainedTokenizer` or :obj:`None`, `optional`): |
|
The tokenizer to use. If `None`, the default tokenizer is used. |
|
dataset (:obj:`IterableDataset` or :obj:`str` or :obj:`None`, `optional`): |
|
The dataset to use. If `None`, the default dataset is used. |
|
dataset_kwargs (:obj:`Dict[str, Any]` or :obj:`None`, `optional`): |
|
The keyword arguments to pass to the dataset class. |
|
default_reader_class (:obj:`str` or :obj:`transformers.PreTrainedModel` or :obj:`None`, `optional`): |
|
The default reader class to use. If `None`, the default reader class is used. |
|
**kwargs: |
|
Keyword arguments. |
|
""" |
|
|
|
default_reader_class: str = ( |
|
"relik.reader.pytorch_modules.hf.modeling_relik.RelikReaderREModel" |
|
) |
|
default_data_class: str = "relik.reader.data.relik_reader_re_data.RelikREDataset" |
|
|
|
def __init__( |
|
self, |
|
transformer_model: str | tr.PreTrainedModel | None = None, |
|
additional_special_symbols: int = 0, |
|
additional_special_symbols_types: int = 0, |
|
entity_type_loss: bool | None = None, |
|
add_entity_embedding: bool | None = None, |
|
num_layers: int | None = None, |
|
activation: str = "gelu", |
|
linears_hidden_size: int | None = 512, |
|
use_last_k_layers: int = 1, |
|
training: bool = False, |
|
device: str | torch.device | None = None, |
|
tokenizer: str | tr.PreTrainedTokenizer | None = None, |
|
dataset: IterableDataset | str | None = None, |
|
dataset_kwargs: Dict[str, Any] | None = None, |
|
default_reader_class: tr.PreTrainedModel | str | None = None, |
|
**kwargs, |
|
): |
|
super().__init__( |
|
transformer_model=transformer_model, |
|
additional_special_symbols=additional_special_symbols, |
|
additional_special_symbols_types=additional_special_symbols_types, |
|
entity_type_loss=entity_type_loss, |
|
add_entity_embedding=add_entity_embedding, |
|
num_layers=num_layers, |
|
activation=activation, |
|
linears_hidden_size=linears_hidden_size, |
|
use_last_k_layers=use_last_k_layers, |
|
training=training, |
|
device=device, |
|
tokenizer=tokenizer, |
|
dataset=dataset, |
|
default_reader_class=default_reader_class, |
|
**kwargs, |
|
) |
|
|
|
self.dataset = dataset |
|
if self.dataset is None and training is False: |
|
self.default_data_class = get_callable_from_string(self.default_data_class) |
|
default_data_kwargs = dict( |
|
dataset_path=None, |
|
materialize_samples=False, |
|
transformer_model=self.tokenizer, |
|
special_symbols=self.default_data_class.get_special_symbols_re( |
|
self.relik_reader_model.config.additional_special_symbols, |
|
use_nme=kwargs.get("use_nme_re", False), |
|
), |
|
special_symbols_types=self.default_data_class.get_special_symbols( |
|
self.relik_reader_model.config.additional_special_symbols_types - 1 |
|
) |
|
if self.relik_reader_model.config.additional_special_symbols_types > 0 |
|
else [], |
|
for_inference=True, |
|
use_nme=kwargs.get("use_nme", False), |
|
) |
|
|
|
default_data_kwargs.update(dataset_kwargs or {}) |
|
self.dataset = self.default_data_class(**default_data_kwargs) |
|
|
|
@torch.no_grad() |
|
@torch.inference_mode() |
|
def _read( |
|
self, |
|
samples: List[RelikReaderSample] | None = None, |
|
input_ids: torch.Tensor | None = None, |
|
attention_mask: torch.Tensor | None = None, |
|
token_type_ids: torch.Tensor | None = None, |
|
prediction_mask: torch.Tensor | None = None, |
|
special_symbols_mask: torch.Tensor | None = None, |
|
max_length: int = 2048, |
|
max_batch_size: int = 128, |
|
token_batch_size: int = 2048, |
|
precision: str = 32, |
|
annotation_type: AnnotationType = AnnotationType.CHAR, |
|
progress_bar: bool = False, |
|
*args: object, |
|
**kwargs: object, |
|
) -> List[RelikReaderSample] | List[List[RelikReaderSample]]: |
|
""" |
|
A wrapper around the forward method that returns the predicted labels for each sample. |
|
|
|
Args: |
|
samples (:obj:`List[RelikReaderSample]`, `optional`): |
|
The samples to read. If provided, `text` and `candidates` are ignored. |
|
input_ids (:obj:`torch.Tensor`, `optional`): |
|
The input ids of the text. If `samples` is provided, this is ignored. |
|
attention_mask (:obj:`torch.Tensor`, `optional`): |
|
The attention mask of the text. If `samples` is provided, this is ignored. |
|
token_type_ids (:obj:`torch.Tensor`, `optional`): |
|
The token type ids of the text. If `samples` is provided, this is ignored. |
|
prediction_mask (:obj:`torch.Tensor`, `optional`): |
|
The prediction mask of the text. If `samples` is provided, this is ignored. |
|
special_symbols_mask (:obj:`torch.Tensor`, `optional`): |
|
The special symbols mask of the text. If `samples` is provided, this is ignored. |
|
max_length (:obj:`int`, `optional`, defaults to 1000): |
|
The maximum length of the text. |
|
max_batch_size (:obj:`int`, `optional`, defaults to 128): |
|
The maximum batch size. |
|
token_batch_size (:obj:`int`, `optional`): |
|
The token batch size. |
|
progress_bar (:obj:`bool`, `optional`, defaults to False): |
|
Whether to show a progress bar. |
|
precision (:obj:`str`, `optional`, defaults to 32): |
|
The precision to use for the model. |
|
annotation_type (`AnnotationType`, `optional`, defaults to `AnnotationType.CHAR`): |
|
The type of annotation to return. If `char`, the spans will be in terms of |
|
character offsets. If `word`, the spans will be in terms of word offsets. |
|
*args: |
|
Positional arguments. |
|
**kwargs: |
|
Keyword arguments. |
|
|
|
Returns: |
|
:obj:`List[RelikReaderSample]` or :obj:`List[List[RelikReaderSample]]`: |
|
The predicted labels for each sample. |
|
""" |
|
|
|
precision = precision or self.precision |
|
if samples is not None: |
|
|
|
def _read_iterator(): |
|
def samples_it(): |
|
for i, sample in enumerate(samples): |
|
assert sample._mixin_prediction_position is None |
|
sample._mixin_prediction_position = i |
|
if sample.spans is not None and len(sample.spans) > 0: |
|
entities = [] |
|
offset_span = sample.char2token_start[str(sample.offset)] |
|
for span_start, span_end in sample.spans: |
|
if str(span_start) not in sample.char2token_start: |
|
|
|
|
|
while str(span_start) not in sample.char2token_start: |
|
span_start -= 1 |
|
|
|
if span_start < 0: |
|
break |
|
if str(span_end) not in sample.char2token_end: |
|
|
|
|
|
while str(span_end) not in sample.char2token_end: |
|
span_end += 1 |
|
|
|
if span_end >= int(list(sample.char2token_end.keys())[-1]): |
|
break |
|
|
|
if span_start < 0 or span_end > int(list(sample.char2token_end.keys())[-1]): |
|
continue |
|
entities.append([sample.char2token_start[str(span_start)]-offset_span, sample.char2token_end[str(span_end)]+1-offset_span, ""]) |
|
sample.entities = entities |
|
yield sample |
|
|
|
next_prediction_position = 0 |
|
position2predicted_sample = {} |
|
|
|
|
|
if self.dataset is None: |
|
raise ValueError( |
|
"You need to pass a dataset to the model in order to predict" |
|
) |
|
self.dataset.samples = samples_it() |
|
self.dataset.model_max_length = max_length |
|
self.dataset.tokens_per_batch = token_batch_size |
|
self.dataset.max_batch_size = max_batch_size |
|
|
|
|
|
iterator = DataLoader( |
|
self.dataset, batch_size=None, num_workers=0, shuffle=False |
|
) |
|
if progress_bar: |
|
iterator = tqdm(iterator, desc="Predicting with RelikReader") |
|
|
|
with get_autocast_context(self.device, precision): |
|
for batch in iterator: |
|
batch = move_data_to_device(batch, self.device) |
|
batch.update(kwargs) |
|
batch_out = self._batch_predict(**batch) |
|
|
|
for sample in batch_out: |
|
if ( |
|
sample._mixin_prediction_position |
|
>= next_prediction_position |
|
): |
|
position2predicted_sample[ |
|
sample._mixin_prediction_position |
|
] = sample |
|
|
|
|
|
while next_prediction_position in position2predicted_sample: |
|
yield position2predicted_sample[next_prediction_position] |
|
del position2predicted_sample[next_prediction_position] |
|
next_prediction_position += 1 |
|
|
|
outputs = list(_read_iterator()) |
|
for sample in outputs: |
|
self.dataset.merge_patches_predictions(sample) |
|
if annotation_type == AnnotationType.CHAR: |
|
self.dataset.convert_to_char_annotations(sample) |
|
elif annotation_type == AnnotationType.WORD: |
|
self.dataset.convert_to_word_annotations(sample) |
|
else: |
|
raise ValueError( |
|
f"Annotation type {annotation_type} not recognized. " |
|
f"Please choose one of {list(AnnotationType)}." |
|
) |
|
|
|
else: |
|
outputs = list( |
|
self._batch_predict( |
|
input_ids, |
|
attention_mask, |
|
token_type_ids, |
|
prediction_mask, |
|
special_symbols_mask, |
|
*args, |
|
**kwargs, |
|
) |
|
) |
|
return outputs |
|
|
|
def _batch_predict( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
token_type_ids: torch.Tensor | None = None, |
|
prediction_mask: torch.Tensor | None = None, |
|
special_symbols_mask: torch.Tensor | None = None, |
|
special_symbols_mask_entities: torch.Tensor | None = None, |
|
sample: List[RelikReaderSample] | None = None, |
|
*args, |
|
**kwargs, |
|
) -> Iterator[RelikReaderSample]: |
|
""" |
|
A wrapper around the forward method that returns the predicted labels for each sample. |
|
It also adds the predicted labels to the samples. |
|
|
|
Args: |
|
input_ids (:obj:`torch.Tensor`): |
|
The input ids of the text. |
|
attention_mask (:obj:`torch.Tensor`): |
|
The attention mask of the text. |
|
token_type_ids (:obj:`torch.Tensor`, `optional`): |
|
The token type ids of the text. |
|
prediction_mask (:obj:`torch.Tensor`, `optional`): |
|
The prediction mask of the text. |
|
special_symbols_mask (:obj:`torch.Tensor`, `optional`): |
|
The special symbols mask of the text. |
|
sample (:obj:`List[RelikReaderSample]`, `optional`): |
|
The samples to read. If provided, `text` and `candidates` are ignored. |
|
top_k (:obj:`int`, `optional`, defaults to 5): |
|
The amount of top-k most probable entities to predict. |
|
*args: |
|
Positional arguments. |
|
**kwargs: |
|
Keyword arguments. |
|
|
|
Returns: |
|
The predicted labels for each sample. |
|
""" |
|
forward_output = self.forward( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
prediction_mask=prediction_mask, |
|
special_symbols_mask=special_symbols_mask, |
|
special_symbols_mask_entities=special_symbols_mask_entities, |
|
is_prediction=True, |
|
*args, |
|
**kwargs, |
|
) |
|
|
|
ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy() |
|
ned_end_predictions = forward_output["ned_end_predictions"] |
|
ed_predictions = forward_output["re_entities_predictions"].cpu().numpy() |
|
ned_type_predictions = forward_output["ned_type_predictions"].cpu().numpy() |
|
re_predictions = forward_output["re_predictions"].cpu().numpy() |
|
re_probabilities = forward_output["re_probabilities"].detach().cpu().numpy() |
|
|
|
for ts, ne_st, ne_end, re_pred, re_prob, edp, ne_et in zip( |
|
sample, |
|
ned_start_predictions, |
|
ned_end_predictions, |
|
re_predictions, |
|
re_probabilities, |
|
ed_predictions, |
|
ned_type_predictions, |
|
): |
|
ne_end = ne_end.cpu().numpy() |
|
entities = [] |
|
if self.relik_reader_model.config.entity_type_loss: |
|
starts = np.argwhere(ne_st) |
|
i = 0 |
|
for start, end in zip(starts, ne_end): |
|
ends = np.argwhere(end) |
|
for e in ends: |
|
entities.append([start[0], e[0], ne_et[i]]) |
|
i += 1 |
|
|
|
|
|
|
|
|
|
else: |
|
starts = np.argwhere(ne_st) |
|
for start, end in zip(starts, ne_end): |
|
ends = np.argwhere(end) |
|
for e in ends: |
|
entities.append([start[0], e[0]]) |
|
|
|
edp = edp[: len(entities)] |
|
re_pred = re_pred[: len(entities), : len(entities)] |
|
re_prob = re_prob[: len(entities), : len(entities)] |
|
possible_re = np.argwhere(re_pred) |
|
predicted_triplets = [] |
|
predicted_triplets_prob = [] |
|
for i, j, r in possible_re: |
|
if self.relik_reader_model.relation_disambiguation_loss: |
|
if not ( |
|
i != j |
|
and edp[i, r] == 1 |
|
and edp[j, r] == 1 |
|
and edp[i, 0] == 0 |
|
and edp[j, 0] == 0 |
|
): |
|
continue |
|
predicted_triplets.append([i, j, r]) |
|
predicted_triplets_prob.append(re_prob[i, j, r]) |
|
|
|
ts._d["predicted_relations"] = predicted_triplets |
|
ts._d["predicted_entities"] = entities |
|
ts._d["predicted_relations_probabilities"] = predicted_triplets_prob |
|
|
|
|
|
ts._d["predicted_triples"] = predicted_triplets |
|
|
|
yield ts |
|
|