File size: 10,850 Bytes
8778cfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
"""
Converts from linguistically motivated word-based tokenization to subword
tokenization used by pre-trained models.
"""

import numpy as np
import torch
import transformers


def retokenize(
    tokenizer,
    words,
    space_after,
    return_attention_mask=True,
    return_offsets_mapping=False,
    return_tensors=None,
    **kwargs
):
    """Re-tokenize into subwords.

    Args:
        tokenizer: An instance of transformers.PreTrainedTokenizerFast
        words: List of words
        space_after: A list of the same length as `words`, indicating whether
            whitespace follows each word.
        **kwargs: all remaining arguments are passed on to tokenizer.__call__

    Returns:
        The output of tokenizer.__call__, with one additional dictionary field:
        - **words_from_tokens** -- List of the same length as `words`, where
          each entry is the index of the *last* subword that overlaps the
          corresponding word.
    """
    s = "".join([w + (" " if sp else "") for w, sp in zip(words, space_after)])
    word_offset_starts = np.cumsum(
        [0] + [len(w) + (1 if sp else 0) for w, sp in zip(words, space_after)]
    )[:-1]
    word_offset_ends = word_offset_starts + np.asarray([len(w) for w in words])

    tokenized = tokenizer(
        s,
        return_attention_mask=return_attention_mask,
        return_offsets_mapping=True,
        return_tensors=return_tensors,
        **kwargs
    )
    if return_offsets_mapping:
        token_offset_mapping = tokenized["offset_mapping"]
    else:
        token_offset_mapping = tokenized.pop("offset_mapping")
    if return_tensors is not None:
        token_offset_mapping = np.asarray(token_offset_mapping)[0].tolist()

    offset_mapping_iter = iter(
        [
            (i, (start, end))
            for (i, (start, end)) in enumerate(token_offset_mapping)
            if start != end
        ]
    )
    token_idx, (token_start, token_end) = next(offset_mapping_iter)
    words_from_tokens = [-100] * len(words)
    for word_idx, (word_start, word_end) in enumerate(
        zip(word_offset_starts, word_offset_ends)
    ):
        while token_end <= word_start:
            token_idx, (token_start, token_end) = next(offset_mapping_iter)
        if token_end > word_end:
            words_from_tokens[word_idx] = token_idx
        while token_end <= word_end:
            words_from_tokens[word_idx] = token_idx
            try:
                token_idx, (token_start, token_end) = next(offset_mapping_iter)
            except StopIteration:
                assert word_idx == len(words) - 1
                break
    if return_tensors == "np":
        words_from_tokens = np.asarray(words_from_tokens, dtype=int)
    elif return_tensors == "pt":
        words_from_tokens = torch.tensor(words_from_tokens, dtype=torch.long)
    elif return_tensors == "tf":
        raise NotImplementedError("Returning tf tensors is not implemented")
    tokenized["words_from_tokens"] = words_from_tokens
    return tokenized


class Retokenizer:
    def __init__(self, pretrained_model_name_or_path, retain_start_stop=False):
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path, fast=True
        )
        if not self.tokenizer.is_fast:
            raise NotImplementedError(
                "Converting from treebank tokenization to tokenization used by a "
                "pre-trained model requires a 'fast' tokenizer, which appears to not "
                "be available for this pre-trained model type."
            )
        self.retain_start_stop = retain_start_stop
        self.is_t5 = "T5Tokenizer" in str(type(self.tokenizer))
        self.is_gpt2 = "GPT2Tokenizer" in str(type(self.tokenizer))

        if self.is_gpt2:
            # The provided GPT-2 tokenizer does not specify a padding token by default
            self.tokenizer.pad_token = self.tokenizer.eos_token

        if self.retain_start_stop:
            # When retain_start_stop is set, the next layer after the pre-trained model
            # expects start and stop token embeddings. For BERT these can naturally be
            # the feature vectors for CLS and SEP, but pre-trained models differ in the
            # special tokens that they use. This code attempts to find special token
            # positions for each pre-trained model.
            dummy_ids = self.tokenizer.build_inputs_with_special_tokens([-100])
            if self.is_t5:
                # For T5 we use the output from the decoder, which accepts inputs that
                # are shifted relative to the encoder.
                dummy_ids = [self.tokenizer.pad_token_id] + dummy_ids
            if self.is_gpt2:
                # For GPT-2, we append an eos token if special tokens are needed
                dummy_ids = dummy_ids + [self.tokenizer.eos_token_id]
            try:
                input_idx = dummy_ids.index(-100)
            except ValueError:
                raise NotImplementedError(
                    "Could not automatically infer how to extract start/stop tokens "
                    "from this pre-trained model"
                )
            num_prefix_tokens = input_idx
            num_suffix_tokens = len(dummy_ids) - input_idx - 1
            self.start_token_idx = None
            self.stop_token_idx = None
            if num_prefix_tokens > 0:
                self.start_token_idx = num_prefix_tokens - 1
            if num_suffix_tokens > 0:
                self.stop_token_idx = -num_suffix_tokens
            if self.start_token_idx is None and num_suffix_tokens > 0:
                self.start_token_idx = -1
            if self.stop_token_idx is None and num_prefix_tokens > 0:
                self.stop_token_idx = 0
            if self.start_token_idx is None or self.stop_token_idx is None:
                assert num_prefix_tokens == 0 and num_suffix_tokens == 0
                raise NotImplementedError(
                    "Could not automatically infer how to extract start/stop tokens "
                    "from this pre-trained model because the associated tokenizer "
                    "appears not to add any special start/stop/cls/sep/etc. tokens "
                    "to the sequence."
                )

    def __call__(self, words, space_after, **kwargs):
        example = retokenize(self.tokenizer, words, space_after, **kwargs)
        if self.is_t5:
            # decoder_input_ids (which are shifted wrt input_ids) will be created after
            # padding, but we adjust words_from_tokens now, in anticipation.
            if isinstance(example["words_from_tokens"], list):
                example["words_from_tokens"] = [
                    x + 1 for x in example["words_from_tokens"]
                ]
            else:
                example["words_from_tokens"] += 1
        if self.retain_start_stop:
            num_tokens = len(example["input_ids"])
            if self.is_t5:
                num_tokens += 1
            if self.is_gpt2:
                num_tokens += 1
                if kwargs.get("return_tensors") == "pt":
                    example["input_ids"] = torch.cat(
                        example["input_ids"],
                        torch.tensor([self.tokenizer.eos_token_id]),
                    )
                    example["attention_mask"] = torch.cat(
                        example["attention_mask"], torch.tensor([1])
                    )
                else:
                    example["input_ids"].append(self.tokenizer.eos_token_id)
                    example["attention_mask"].append(1)
            if num_tokens > self.tokenizer.model_max_length:
                raise ValueError(
                    f"Sentence of length {num_tokens} (in sub-word tokens) exceeds the "
                    f"maximum supported length of {self.tokenizer.model_max_length}"
                )
            start_token_idx = (
                self.start_token_idx
                if self.start_token_idx >= 0
                else num_tokens + self.start_token_idx
            )
            stop_token_idx = (
                self.stop_token_idx
                if self.stop_token_idx >= 0
                else num_tokens + self.stop_token_idx
            )
            if kwargs.get("return_tensors") == "pt":
                example["words_from_tokens"] = torch.cat(
                    [
                        torch.tensor([start_token_idx]),
                        example["words_from_tokens"],
                        torch.tensor([stop_token_idx]),
                    ]
                )
            else:
                example["words_from_tokens"] = (
                    [start_token_idx] + example["words_from_tokens"] + [stop_token_idx]
                )
        return example

    def pad(self, encoded_inputs, return_tensors=None, **kwargs):
        if return_tensors != "pt":
            raise NotImplementedError("Only return_tensors='pt' is supported.")
        res = self.tokenizer.pad(
            [
                {k: v for k, v in example.items() if k != "words_from_tokens"}
                for example in encoded_inputs
            ],
            return_tensors=return_tensors,
            **kwargs
        )
        if self.tokenizer.padding_side == "right":
            res["words_from_tokens"] = torch.nn.utils.rnn.pad_sequence(
                [
                    torch.tensor(example["words_from_tokens"])
                    for example in encoded_inputs
                ],
                batch_first=True,
                padding_value=-100,
            )
        else:
            # XLNet adds padding tokens on the left of the sequence, so
            # words_from_tokens must be adjusted to skip the added padding tokens.
            assert self.tokenizer.padding_side == "left"
            res["words_from_tokens"] = torch.nn.utils.rnn.pad_sequence(
                [
                    torch.tensor(example["words_from_tokens"])
                    + (res["input_ids"].shape[-1] - len(example["input_ids"]))
                    for example in encoded_inputs
                ],
                batch_first=True,
                padding_value=-100,
            )

        if self.is_t5:
            res["decoder_input_ids"] = torch.cat(
                [
                    torch.full_like(
                        res["input_ids"][:, :1], self.tokenizer.pad_token_id
                    ),
                    res["input_ids"],
                ],
                1,
            )
            res["decoder_attention_mask"] = torch.cat(
                [
                    torch.ones_like(res["attention_mask"][:, :1]),
                    res["attention_mask"],
                ],
                1,
            )
        res["valid_token_mask"] = res["words_from_tokens"] != -100
        return res