File size: 5,109 Bytes
1681237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8df5fb
 
 
 
1681237
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import abc
import logging
from typing import Dict

import torch
from datasets import Dataset
from pie_modules.document.processing import tokenize_document
from pie_modules.documents import TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
from pytorch_ie.annotations import Span
from pytorch_ie.documents import TextBasedDocument
from torch import FloatTensor, Tensor
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoTokenizer

logger = logging.getLogger(__name__)


class EmbeddingModel(abc.ABC):
    def __call__(
        self, document: TextBasedDocument, span_layer_name: str
    ) -> Dict[Span, FloatTensor]:
        """Embed text annotations from a document.

        Args:
            document: The document to embed.
            span_layer_name: The name of the annotation layer in the document that contains the
                text span annotations to embed.

        Returns:
            A dictionary mapping text annotations to their embeddings.
        """
        pass


class HuggingfaceEmbeddingModel(EmbeddingModel):
    def __init__(
        self,
        model_name_or_path: str,
        revision: str = None,
        device: str = "cpu",
        max_length: int = 512,
        batch_size: int = 16,
    ):
        self.load(model_name_or_path, revision, device)
        self.max_length = max_length
        self.batch_size = batch_size

    def load(self, model_name_or_path: str, revision: str = None, device: str = "cpu") -> None:
        self._model = AutoModel.from_pretrained(model_name_or_path, revision=revision).to(device)
        self._tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, revision=revision)

    def __call__(
        self, document: TextBasedDocument, span_layer_name: str
    ) -> Dict[Span, FloatTensor]:
        # to not modify the original document
        document = document.copy()
        # tokenize_document does not yet consider predictions, so we need to add them manually
        document[span_layer_name].extend(document[span_layer_name].predictions.clear())
        added_annotations = []
        tokenizer_kwargs = {
            "max_length": self.max_length,
            "stride": self.max_length // 8,
            "truncation": True,
            "padding": True,
            "return_overflowing_tokens": True,
        }
        # tokenize once to get the tokenized documents with mapped annotations
        tokenized_documents = tokenize_document(
            document,
            tokenizer=self._tokenizer,
            result_document_type=TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
            partition_layer="labeled_partitions",
            added_annotations=added_annotations,
            strict_span_conversion=False,
            **tokenizer_kwargs,
        )

        # just tokenize again to get tensors in the correct format for the model
        dataset = Dataset.from_dict({"text": [document.text]})

        def tokenize_function(examples):
            return self._tokenizer(examples["text"], **tokenizer_kwargs)

        # Tokenize the texts. Note that we remove the text column directly in the map call,
        # otherwise the map would fail because we produce we amy produce multipel new rows
        # (tokenization result) for each input row (text).
        tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
        # remove the overflow_to_sample_mapping column
        tokenized_dataset = tokenized_dataset.remove_columns(["overflow_to_sample_mapping"])
        tokenized_dataset.set_format(type="torch")

        dataloader = DataLoader(tokenized_dataset, batch_size=self.batch_size)

        embeddings = {}
        example_idx = 0
        for batch in dataloader:
            batch_at_device = {
                k: v.to(self._model.device) if isinstance(v, Tensor) else v
                for k, v in batch.items()
            }
            with torch.no_grad():
                model_output = self._model(**batch_at_device)

            for last_hidden_state in model_output.last_hidden_state:
                text2tok_ann = added_annotations[example_idx][span_layer_name]
                tok2text_ann = {v: k for k, v in text2tok_ann.items()}
                for tok_ann in tokenized_documents[example_idx].labeled_spans:
                    # skip "empty" annotations
                    if tok_ann.start == tok_ann.end:
                        continue
                    # use the max pooling strategy to get a single embedding for the annotation text
                    embedding = (
                        last_hidden_state[tok_ann.start : tok_ann.end].max(dim=0)[0].detach().cpu()
                    )
                    text_ann = tok2text_ann[tok_ann]

                    # if text_ann in embeddings:
                    #    logger.warning(
                    #        f"Overwriting embedding for annotation '{text_ann}' (do you use striding?)"
                    #    )
                    embeddings[text_ann] = embedding
                example_idx += 1

        return embeddings