ArneBinder's picture
Upload 10 files
1681237 verified
raw
history blame
5.1 kB
import abc
import logging
from typing import Dict
import torch
from datasets import Dataset
from pie_modules.document.processing import tokenize_document
from pie_modules.documents import TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
from pytorch_ie.annotations import Span
from pytorch_ie.documents import TextBasedDocument
from torch import FloatTensor, Tensor
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoTokenizer
logger = logging.getLogger(__name__)
class EmbeddingModel(abc.ABC):
def __call__(
self, document: TextBasedDocument, span_layer_name: str
) -> Dict[Span, FloatTensor]:
"""Embed text annotations from a document.
Args:
document: The document to embed.
span_layer_name: The name of the annotation layer in the document that contains the
text span annotations to embed.
Returns:
A dictionary mapping text annotations to their embeddings.
"""
pass
class HuggingfaceEmbeddingModel(EmbeddingModel):
def __init__(
self,
model_name_or_path: str,
revision: str = None,
device: str = "cpu",
max_length: int = 512,
batch_size: int = 16,
):
self.load(model_name_or_path, revision, device)
self.max_length = max_length
self.batch_size = batch_size
def load(self, model_name_or_path: str, revision: str = None, device: str = "cpu") -> None:
self._model = AutoModel.from_pretrained(model_name_or_path, revision=revision).to(device)
self._tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, revision=revision)
def __call__(
self, document: TextBasedDocument, span_layer_name: str
) -> Dict[Span, FloatTensor]:
# to not modify the original document
document = document.copy()
# tokenize_document does not yet consider predictions, so we need to add them manually
document[span_layer_name].extend(document[span_layer_name].predictions.clear())
added_annotations = []
tokenizer_kwargs = {
"max_length": self.max_length,
"stride": self.max_length // 8,
"truncation": True,
"padding": True,
"return_overflowing_tokens": True,
}
# tokenize once to get the tokenized documents with mapped annotations
tokenized_documents = tokenize_document(
document,
tokenizer=self._tokenizer,
result_document_type=TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
partition_layer="labeled_partitions",
added_annotations=added_annotations,
strict_span_conversion=False,
**tokenizer_kwargs,
)
# just tokenize again to get tensors in the correct format for the model
dataset = Dataset.from_dict({"text": [document.text]})
def tokenize_function(examples):
return self._tokenizer(examples["text"], **tokenizer_kwargs)
# Tokenize the texts. Note that we remove the text column directly in the map call,
# otherwise the map would fail because we produce we amy produce multipel new rows
# (tokenization result) for each input row (text).
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
# remove the overflow_to_sample_mapping column
tokenized_dataset = tokenized_dataset.remove_columns(["overflow_to_sample_mapping"])
tokenized_dataset.set_format(type="torch")
dataloader = DataLoader(tokenized_dataset, batch_size=self.batch_size)
embeddings = {}
example_idx = 0
for batch in dataloader:
batch_at_device = {
k: v.to(self._model.device) if isinstance(v, Tensor) else v
for k, v in batch.items()
}
with torch.no_grad():
model_output = self._model(**batch_at_device)
for last_hidden_state in model_output.last_hidden_state:
text2tok_ann = added_annotations[example_idx][span_layer_name]
tok2text_ann = {v: k for k, v in text2tok_ann.items()}
for tok_ann in tokenized_documents[example_idx].labeled_spans:
# skip "empty" annotations
if tok_ann.start == tok_ann.end:
continue
# use the max pooling strategy to get a single embedding for the annotation text
embedding = (
last_hidden_state[tok_ann.start : tok_ann.end].max(dim=0)[0].detach().cpu()
)
text_ann = tok2text_ann[tok_ann]
if text_ann in embeddings:
logger.warning(
f"Overwriting embedding for annotation '{text_ann}' (do you use striding?)"
)
embeddings[text_ann] = embedding
example_idx += 1
return embeddings