|
import abc |
|
import logging |
|
from typing import Dict |
|
|
|
import torch |
|
from datasets import Dataset |
|
from pie_modules.document.processing import tokenize_document |
|
from pie_modules.documents import TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions |
|
from pytorch_ie.annotations import Span |
|
from pytorch_ie.documents import TextBasedDocument |
|
from torch import FloatTensor, Tensor |
|
from torch.utils.data import DataLoader |
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class EmbeddingModel(abc.ABC): |
|
def __call__( |
|
self, document: TextBasedDocument, span_layer_name: str |
|
) -> Dict[Span, FloatTensor]: |
|
"""Embed text annotations from a document. |
|
|
|
Args: |
|
document: The document to embed. |
|
span_layer_name: The name of the annotation layer in the document that contains the |
|
text span annotations to embed. |
|
|
|
Returns: |
|
A dictionary mapping text annotations to their embeddings. |
|
""" |
|
pass |
|
|
|
|
|
class HuggingfaceEmbeddingModel(EmbeddingModel): |
|
def __init__( |
|
self, |
|
model_name_or_path: str, |
|
revision: str = None, |
|
device: str = "cpu", |
|
max_length: int = 512, |
|
batch_size: int = 16, |
|
): |
|
self.load(model_name_or_path, revision, device) |
|
self.max_length = max_length |
|
self.batch_size = batch_size |
|
|
|
def load(self, model_name_or_path: str, revision: str = None, device: str = "cpu") -> None: |
|
self._model = AutoModel.from_pretrained(model_name_or_path, revision=revision).to(device) |
|
self._tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, revision=revision) |
|
|
|
def __call__( |
|
self, document: TextBasedDocument, span_layer_name: str |
|
) -> Dict[Span, FloatTensor]: |
|
|
|
document = document.copy() |
|
|
|
document[span_layer_name].extend(document[span_layer_name].predictions.clear()) |
|
added_annotations = [] |
|
tokenizer_kwargs = { |
|
"max_length": self.max_length, |
|
"stride": self.max_length // 8, |
|
"truncation": True, |
|
"padding": True, |
|
"return_overflowing_tokens": True, |
|
} |
|
|
|
tokenized_documents = tokenize_document( |
|
document, |
|
tokenizer=self._tokenizer, |
|
result_document_type=TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, |
|
partition_layer="labeled_partitions", |
|
added_annotations=added_annotations, |
|
strict_span_conversion=False, |
|
**tokenizer_kwargs, |
|
) |
|
|
|
|
|
dataset = Dataset.from_dict({"text": [document.text]}) |
|
|
|
def tokenize_function(examples): |
|
return self._tokenizer(examples["text"], **tokenizer_kwargs) |
|
|
|
|
|
|
|
|
|
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"]) |
|
|
|
tokenized_dataset = tokenized_dataset.remove_columns(["overflow_to_sample_mapping"]) |
|
tokenized_dataset.set_format(type="torch") |
|
|
|
dataloader = DataLoader(tokenized_dataset, batch_size=self.batch_size) |
|
|
|
embeddings = {} |
|
example_idx = 0 |
|
for batch in dataloader: |
|
batch_at_device = { |
|
k: v.to(self._model.device) if isinstance(v, Tensor) else v |
|
for k, v in batch.items() |
|
} |
|
with torch.no_grad(): |
|
model_output = self._model(**batch_at_device) |
|
|
|
for last_hidden_state in model_output.last_hidden_state: |
|
text2tok_ann = added_annotations[example_idx][span_layer_name] |
|
tok2text_ann = {v: k for k, v in text2tok_ann.items()} |
|
for tok_ann in tokenized_documents[example_idx].labeled_spans: |
|
|
|
if tok_ann.start == tok_ann.end: |
|
continue |
|
|
|
embedding = ( |
|
last_hidden_state[tok_ann.start : tok_ann.end].max(dim=0)[0].detach().cpu() |
|
) |
|
text_ann = tok2text_ann[tok_ann] |
|
|
|
if text_ann in embeddings: |
|
logger.warning( |
|
f"Overwriting embedding for annotation '{text_ann}' (do you use striding?)" |
|
) |
|
embeddings[text_ann] = embedding |
|
example_idx += 1 |
|
|
|
return embeddings |
|
|