import abc import logging from typing import Dict import torch from datasets import Dataset from pie_modules.document.processing import tokenize_document from pie_modules.documents import TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from pytorch_ie.annotations import Span from pytorch_ie.documents import TextBasedDocument from torch import FloatTensor, Tensor from torch.utils.data import DataLoader from transformers import AutoModel, AutoTokenizer logger = logging.getLogger(__name__) class EmbeddingModel(abc.ABC): def __call__( self, document: TextBasedDocument, span_layer_name: str ) -> Dict[Span, FloatTensor]: """Embed text annotations from a document. Args: document: The document to embed. span_layer_name: The name of the annotation layer in the document that contains the text span annotations to embed. Returns: A dictionary mapping text annotations to their embeddings. """ pass class HuggingfaceEmbeddingModel(EmbeddingModel): def __init__( self, model_name_or_path: str, revision: str = None, device: str = "cpu", max_length: int = 512, batch_size: int = 16, ): self.load(model_name_or_path, revision, device) self.max_length = max_length self.batch_size = batch_size def load(self, model_name_or_path: str, revision: str = None, device: str = "cpu") -> None: self._model = AutoModel.from_pretrained(model_name_or_path, revision=revision).to(device) self._tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, revision=revision) def __call__( self, document: TextBasedDocument, span_layer_name: str ) -> Dict[Span, FloatTensor]: # to not modify the original document document = document.copy() # tokenize_document does not yet consider predictions, so we need to add them manually document[span_layer_name].extend(document[span_layer_name].predictions.clear()) added_annotations = [] tokenizer_kwargs = { "max_length": self.max_length, "stride": self.max_length // 8, "truncation": True, "padding": True, "return_overflowing_tokens": True, } # tokenize once to get the tokenized documents with mapped annotations tokenized_documents = tokenize_document( document, tokenizer=self._tokenizer, result_document_type=TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, partition_layer="labeled_partitions", added_annotations=added_annotations, strict_span_conversion=False, **tokenizer_kwargs, ) # just tokenize again to get tensors in the correct format for the model dataset = Dataset.from_dict({"text": [document.text]}) def tokenize_function(examples): return self._tokenizer(examples["text"], **tokenizer_kwargs) # Tokenize the texts. Note that we remove the text column directly in the map call, # otherwise the map would fail because we produce we amy produce multipel new rows # (tokenization result) for each input row (text). tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"]) # remove the overflow_to_sample_mapping column tokenized_dataset = tokenized_dataset.remove_columns(["overflow_to_sample_mapping"]) tokenized_dataset.set_format(type="torch") dataloader = DataLoader(tokenized_dataset, batch_size=self.batch_size) embeddings = {} example_idx = 0 for batch in dataloader: batch_at_device = { k: v.to(self._model.device) if isinstance(v, Tensor) else v for k, v in batch.items() } with torch.no_grad(): model_output = self._model(**batch_at_device) for last_hidden_state in model_output.last_hidden_state: text2tok_ann = added_annotations[example_idx][span_layer_name] tok2text_ann = {v: k for k, v in text2tok_ann.items()} for tok_ann in tokenized_documents[example_idx].labeled_spans: # skip "empty" annotations if tok_ann.start == tok_ann.end: continue # use the max pooling strategy to get a single embedding for the annotation text embedding = ( last_hidden_state[tok_ann.start : tok_ann.end].max(dim=0)[0].detach().cpu() ) text_ann = tok2text_ann[tok_ann] if text_ann in embeddings: logger.warning( f"Overwriting embedding for annotation '{text_ann}' (do you use striding?)" ) embeddings[text_ann] = embedding example_idx += 1 return embeddings