singletongue
commited on
Commit
•
2601e53
1
Parent(s):
380f2db
Update tokenizer
Browse files- README.md +22 -1
- entity_vocab.json +2 -2
- tokenization_luke_bert_japanese.py +412 -44
README.md
CHANGED
@@ -13,4 +13,25 @@ tags:
|
|
13 |
- 2023年7月1日時点の日本語Wikipediaのデータで事前学習をおこないました
|
14 |
- `[UNK]` (unknown) エンティティを扱えるようにしました
|
15 |
|
16 |
-
詳細は[ブログ記事](https://tech.uzabase.com/entry/2023/09/07/172958)をご参照ください。
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
- 2023年7月1日時点の日本語Wikipediaのデータで事前学習をおこないました
|
14 |
- `[UNK]` (unknown) エンティティを扱えるようにしました
|
15 |
|
16 |
+
詳細は[ブログ記事](https://tech.uzabase.com/entry/2023/09/07/172958)をご参照ください。
|
17 |
+
|
18 |
+
## 使用方法
|
19 |
+
|
20 |
+
```python
|
21 |
+
from transformers import AutoTokenizer, AutoModel
|
22 |
+
|
23 |
+
# 本モデル用のトークナイザのコードを使用するため、trust_remote_code=True の指定が必要です
|
24 |
+
tokenizer = AutoTokenizer.from_pretrained("uzabase/luke-japanese-wordpiece-base", trust_remote_code=True)
|
25 |
+
|
26 |
+
model = AutoModel.from_pretrained("uzabase/luke-japanese-wordpiece-base")
|
27 |
+
```
|
28 |
+
|
29 |
+
## 更新情報
|
30 |
+
|
31 |
+
- **2023/11/28:** 以下の更新を行いました。
|
32 |
+
- トークナイザが transformers v4.34.0 以降で読み込み不可となっていた問題を修正しました。
|
33 |
+
- トークナイザの出力に `position_ids` を含めるように変更しました。
|
34 |
+
- 以前は LUKE のモデルが [自動的に付与](https://github.com/huggingface/transformers/blob/v4.35.2/src/transformers/models/luke/modeling_luke.py#L424) する `position_ids` が使われていましたが、これは RoBERTa 仕様のものであり、BERT を使った本モデルでは正しい値となっていませんでした。そこで、 BERT 向けの正しい `position_ids` の値がモデルに入力されるように、`position_ids` を明示的にトークナイザの出力に含めるようにしました。
|
35 |
+
- トークナイザの `entity_vocab` の各トークン(`"[PAD]"` 等の特殊トークンを除く)の先頭に付いていた `"None:"` の文字列を除去しました。
|
36 |
+
- 例えば、 `"None:聖徳太子"` となっていたトークンは `"聖徳太子"` に修正されています。
|
37 |
+
- **2023/09/07:** モデルを公開しました。
|
entity_vocab.json
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:44b62a4236024bcfbc396e434fb137edecbb106e7f6bc36bc2465016d99d84dd
|
3 |
+
size 20763373
|
tokenization_luke_bert_japanese.py
CHANGED
@@ -18,7 +18,7 @@ import collections
|
|
18 |
import copy
|
19 |
import json
|
20 |
import os
|
21 |
-
from typing import List, Optional, Tuple
|
22 |
|
23 |
from transformers.models.bert_japanese.tokenization_bert_japanese import (
|
24 |
BasicTokenizer,
|
@@ -31,7 +31,9 @@ from transformers.models.bert_japanese.tokenization_bert_japanese import (
|
|
31 |
load_vocab,
|
32 |
)
|
33 |
from transformers.models.luke import LukeTokenizer
|
34 |
-
from transformers.tokenization_utils_base import
|
|
|
|
|
35 |
from transformers.utils import logging
|
36 |
|
37 |
|
@@ -53,7 +55,7 @@ class LukeBertJapaneseTokenizer(LukeTokenizer):
|
|
53 |
vocab_files_names = VOCAB_FILES_NAMES
|
54 |
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
55 |
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
56 |
-
model_input_names = ["input_ids", "attention_mask"]
|
57 |
|
58 |
def __init__(
|
59 |
self,
|
@@ -85,35 +87,6 @@ class LukeBertJapaneseTokenizer(LukeTokenizer):
|
|
85 |
jumanpp_kwargs=None,
|
86 |
**kwargs,
|
87 |
):
|
88 |
-
# We call the grandparent's init, not the parent's.
|
89 |
-
super(LukeTokenizer, self).__init__(
|
90 |
-
spm_file=spm_file,
|
91 |
-
unk_token=unk_token,
|
92 |
-
sep_token=sep_token,
|
93 |
-
pad_token=pad_token,
|
94 |
-
cls_token=cls_token,
|
95 |
-
mask_token=mask_token,
|
96 |
-
do_lower_case=do_lower_case,
|
97 |
-
do_word_tokenize=do_word_tokenize,
|
98 |
-
do_subword_tokenize=do_subword_tokenize,
|
99 |
-
word_tokenizer_type=word_tokenizer_type,
|
100 |
-
subword_tokenizer_type=subword_tokenizer_type,
|
101 |
-
never_split=never_split,
|
102 |
-
mecab_kwargs=mecab_kwargs,
|
103 |
-
sudachi_kwargs=sudachi_kwargs,
|
104 |
-
jumanpp_kwargs=jumanpp_kwargs,
|
105 |
-
task=task,
|
106 |
-
max_entity_length=32,
|
107 |
-
max_mention_length=30,
|
108 |
-
entity_token_1="<ent>",
|
109 |
-
entity_token_2="<ent2>",
|
110 |
-
entity_unk_token=entity_unk_token,
|
111 |
-
entity_pad_token=entity_pad_token,
|
112 |
-
entity_mask_token=entity_mask_token,
|
113 |
-
entity_mask2_token=entity_mask2_token,
|
114 |
-
**kwargs,
|
115 |
-
)
|
116 |
-
|
117 |
if subword_tokenizer_type == "sentencepiece":
|
118 |
if not os.path.isfile(spm_file):
|
119 |
raise ValueError(
|
@@ -161,11 +134,11 @@ class LukeBertJapaneseTokenizer(LukeTokenizer):
|
|
161 |
self.subword_tokenizer_type = subword_tokenizer_type
|
162 |
if do_subword_tokenize:
|
163 |
if subword_tokenizer_type == "wordpiece":
|
164 |
-
self.subword_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=
|
165 |
elif subword_tokenizer_type == "character":
|
166 |
-
self.subword_tokenizer = CharacterTokenizer(vocab=self.vocab, unk_token=
|
167 |
elif subword_tokenizer_type == "sentencepiece":
|
168 |
-
self.subword_tokenizer = SentencepieceTokenizer(vocab=self.spm_file, unk_token=
|
169 |
else:
|
170 |
raise ValueError(f"Invalid subword_tokenizer_type '{subword_tokenizer_type}' is specified.")
|
171 |
|
@@ -212,6 +185,35 @@ class LukeBertJapaneseTokenizer(LukeTokenizer):
|
|
212 |
|
213 |
self.max_mention_length = max_mention_length
|
214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
@property
|
216 |
# Copied from BertJapaneseTokenizer
|
217 |
def do_lower_case(self):
|
@@ -298,16 +300,13 @@ class LukeBertJapaneseTokenizer(LukeTokenizer):
|
|
298 |
"""
|
299 |
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
300 |
adding special tokens. A BERT sequence has the following format:
|
301 |
-
|
302 |
- single sequence: `[CLS] X [SEP]`
|
303 |
- pair of sequences: `[CLS] A [SEP] B [SEP]`
|
304 |
-
|
305 |
Args:
|
306 |
token_ids_0 (`List[int]`):
|
307 |
List of IDs to which the special tokens will be added.
|
308 |
token_ids_1 (`List[int]`, *optional*):
|
309 |
Optional second list of IDs for sequence pairs.
|
310 |
-
|
311 |
Returns:
|
312 |
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
313 |
"""
|
@@ -324,7 +323,6 @@ class LukeBertJapaneseTokenizer(LukeTokenizer):
|
|
324 |
"""
|
325 |
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
326 |
special tokens using the tokenizer `prepare_for_model` method.
|
327 |
-
|
328 |
Args:
|
329 |
token_ids_0 (`List[int]`):
|
330 |
List of IDs.
|
@@ -332,7 +330,6 @@ class LukeBertJapaneseTokenizer(LukeTokenizer):
|
|
332 |
Optional second list of IDs for sequence pairs.
|
333 |
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
334 |
Whether or not the token list is already formatted with special tokens for the model.
|
335 |
-
|
336 |
Returns:
|
337 |
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
338 |
"""
|
@@ -353,20 +350,16 @@ class LukeBertJapaneseTokenizer(LukeTokenizer):
|
|
353 |
"""
|
354 |
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
|
355 |
pair mask has the following format:
|
356 |
-
|
357 |
```
|
358 |
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
359 |
| first sequence | second sequence |
|
360 |
```
|
361 |
-
|
362 |
If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
|
363 |
-
|
364 |
Args:
|
365 |
token_ids_0 (`List[int]`):
|
366 |
List of IDs.
|
367 |
token_ids_1 (`List[int]`, *optional*):
|
368 |
Optional second list of IDs for sequence pairs.
|
369 |
-
|
370 |
Returns:
|
371 |
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
|
372 |
"""
|
@@ -376,9 +369,384 @@ class LukeBertJapaneseTokenizer(LukeTokenizer):
|
|
376 |
return len(cls + token_ids_0 + sep) * [0]
|
377 |
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
378 |
|
|
|
379 |
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
|
380 |
return (text, kwargs)
|
381 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
383 |
if os.path.isdir(save_directory):
|
384 |
if self.subword_tokenizer_type == "sentencepiece":
|
|
|
18 |
import copy
|
19 |
import json
|
20 |
import os
|
21 |
+
from typing import Dict, List, Optional, Tuple, Union
|
22 |
|
23 |
from transformers.models.bert_japanese.tokenization_bert_japanese import (
|
24 |
BasicTokenizer,
|
|
|
31 |
load_vocab,
|
32 |
)
|
33 |
from transformers.models.luke import LukeTokenizer
|
34 |
+
from transformers.tokenization_utils_base import (
|
35 |
+
AddedToken, BatchEncoding, EncodedInput, PaddingStrategy, TensorType, TruncationStrategy
|
36 |
+
)
|
37 |
from transformers.utils import logging
|
38 |
|
39 |
|
|
|
55 |
vocab_files_names = VOCAB_FILES_NAMES
|
56 |
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
57 |
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
58 |
+
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
59 |
|
60 |
def __init__(
|
61 |
self,
|
|
|
87 |
jumanpp_kwargs=None,
|
88 |
**kwargs,
|
89 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
if subword_tokenizer_type == "sentencepiece":
|
91 |
if not os.path.isfile(spm_file):
|
92 |
raise ValueError(
|
|
|
134 |
self.subword_tokenizer_type = subword_tokenizer_type
|
135 |
if do_subword_tokenize:
|
136 |
if subword_tokenizer_type == "wordpiece":
|
137 |
+
self.subword_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
|
138 |
elif subword_tokenizer_type == "character":
|
139 |
+
self.subword_tokenizer = CharacterTokenizer(vocab=self.vocab, unk_token=str(unk_token))
|
140 |
elif subword_tokenizer_type == "sentencepiece":
|
141 |
+
self.subword_tokenizer = SentencepieceTokenizer(vocab=self.spm_file, unk_token=str(unk_token))
|
142 |
else:
|
143 |
raise ValueError(f"Invalid subword_tokenizer_type '{subword_tokenizer_type}' is specified.")
|
144 |
|
|
|
185 |
|
186 |
self.max_mention_length = max_mention_length
|
187 |
|
188 |
+
# We call the grandparent's init, not the parent's.
|
189 |
+
super(LukeTokenizer, self).__init__(
|
190 |
+
spm_file=spm_file,
|
191 |
+
unk_token=unk_token,
|
192 |
+
sep_token=sep_token,
|
193 |
+
pad_token=pad_token,
|
194 |
+
cls_token=cls_token,
|
195 |
+
mask_token=mask_token,
|
196 |
+
do_lower_case=do_lower_case,
|
197 |
+
do_word_tokenize=do_word_tokenize,
|
198 |
+
do_subword_tokenize=do_subword_tokenize,
|
199 |
+
word_tokenizer_type=word_tokenizer_type,
|
200 |
+
subword_tokenizer_type=subword_tokenizer_type,
|
201 |
+
never_split=never_split,
|
202 |
+
mecab_kwargs=mecab_kwargs,
|
203 |
+
sudachi_kwargs=sudachi_kwargs,
|
204 |
+
jumanpp_kwargs=jumanpp_kwargs,
|
205 |
+
task=task,
|
206 |
+
max_entity_length=32,
|
207 |
+
max_mention_length=30,
|
208 |
+
entity_token_1="<ent>",
|
209 |
+
entity_token_2="<ent2>",
|
210 |
+
entity_unk_token=entity_unk_token,
|
211 |
+
entity_pad_token=entity_pad_token,
|
212 |
+
entity_mask_token=entity_mask_token,
|
213 |
+
entity_mask2_token=entity_mask2_token,
|
214 |
+
**kwargs,
|
215 |
+
)
|
216 |
+
|
217 |
@property
|
218 |
# Copied from BertJapaneseTokenizer
|
219 |
def do_lower_case(self):
|
|
|
300 |
"""
|
301 |
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
302 |
adding special tokens. A BERT sequence has the following format:
|
|
|
303 |
- single sequence: `[CLS] X [SEP]`
|
304 |
- pair of sequences: `[CLS] A [SEP] B [SEP]`
|
|
|
305 |
Args:
|
306 |
token_ids_0 (`List[int]`):
|
307 |
List of IDs to which the special tokens will be added.
|
308 |
token_ids_1 (`List[int]`, *optional*):
|
309 |
Optional second list of IDs for sequence pairs.
|
|
|
310 |
Returns:
|
311 |
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
312 |
"""
|
|
|
323 |
"""
|
324 |
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
325 |
special tokens using the tokenizer `prepare_for_model` method.
|
|
|
326 |
Args:
|
327 |
token_ids_0 (`List[int]`):
|
328 |
List of IDs.
|
|
|
330 |
Optional second list of IDs for sequence pairs.
|
331 |
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
332 |
Whether or not the token list is already formatted with special tokens for the model.
|
|
|
333 |
Returns:
|
334 |
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
335 |
"""
|
|
|
350 |
"""
|
351 |
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
|
352 |
pair mask has the following format:
|
|
|
353 |
```
|
354 |
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
355 |
| first sequence | second sequence |
|
356 |
```
|
|
|
357 |
If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
|
|
|
358 |
Args:
|
359 |
token_ids_0 (`List[int]`):
|
360 |
List of IDs.
|
361 |
token_ids_1 (`List[int]`, *optional*):
|
362 |
Optional second list of IDs for sequence pairs.
|
|
|
363 |
Returns:
|
364 |
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
|
365 |
"""
|
|
|
369 |
return len(cls + token_ids_0 + sep) * [0]
|
370 |
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
371 |
|
372 |
+
# Copied and modified from LukeTokenizer, removing the `add_prefix_space` process
|
373 |
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
|
374 |
return (text, kwargs)
|
375 |
|
376 |
+
# Copied and modified from LukeTokenizer, adding `position_ids` to the output
|
377 |
+
def prepare_for_model(
|
378 |
+
self,
|
379 |
+
ids: List[int],
|
380 |
+
pair_ids: Optional[List[int]] = None,
|
381 |
+
entity_ids: Optional[List[int]] = None,
|
382 |
+
pair_entity_ids: Optional[List[int]] = None,
|
383 |
+
entity_token_spans: Optional[List[Tuple[int, int]]] = None,
|
384 |
+
pair_entity_token_spans: Optional[List[Tuple[int, int]]] = None,
|
385 |
+
add_special_tokens: bool = True,
|
386 |
+
padding: Union[bool, str, PaddingStrategy] = False,
|
387 |
+
truncation: Union[bool, str, TruncationStrategy] = None,
|
388 |
+
max_length: Optional[int] = None,
|
389 |
+
max_entity_length: Optional[int] = None,
|
390 |
+
stride: int = 0,
|
391 |
+
pad_to_multiple_of: Optional[int] = None,
|
392 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
393 |
+
return_token_type_ids: Optional[bool] = None,
|
394 |
+
return_attention_mask: Optional[bool] = None,
|
395 |
+
return_overflowing_tokens: bool = False,
|
396 |
+
return_special_tokens_mask: bool = False,
|
397 |
+
return_offsets_mapping: bool = False,
|
398 |
+
return_length: bool = False,
|
399 |
+
verbose: bool = True,
|
400 |
+
prepend_batch_axis: bool = False,
|
401 |
+
**kwargs,
|
402 |
+
) -> BatchEncoding:
|
403 |
+
"""
|
404 |
+
Prepares a sequence of input id, entity id and entity span, or a pair of sequences of inputs ids, entity ids,
|
405 |
+
entity spans so that it can be used by the model. It adds special tokens, truncates sequences if overflowing
|
406 |
+
while taking into account the special tokens and manages a moving window (with user defined stride) for
|
407 |
+
overflowing tokens. Please Note, for *pair_ids* different than `None` and *truncation_strategy = longest_first*
|
408 |
+
or `True`, it is not possible to return overflowing tokens. Such a combination of arguments will raise an
|
409 |
+
error.
|
410 |
+
|
411 |
+
Args:
|
412 |
+
ids (`List[int]`):
|
413 |
+
Tokenized input ids of the first sequence.
|
414 |
+
pair_ids (`List[int]`, *optional*):
|
415 |
+
Tokenized input ids of the second sequence.
|
416 |
+
entity_ids (`List[int]`, *optional*):
|
417 |
+
Entity ids of the first sequence.
|
418 |
+
pair_entity_ids (`List[int]`, *optional*):
|
419 |
+
Entity ids of the second sequence.
|
420 |
+
entity_token_spans (`List[Tuple[int, int]]`, *optional*):
|
421 |
+
Entity spans of the first sequence.
|
422 |
+
pair_entity_token_spans (`List[Tuple[int, int]]`, *optional*):
|
423 |
+
Entity spans of the second sequence.
|
424 |
+
max_entity_length (`int`, *optional*):
|
425 |
+
The maximum length of the entity sequence.
|
426 |
+
"""
|
427 |
+
|
428 |
+
# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
|
429 |
+
padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
|
430 |
+
padding=padding,
|
431 |
+
truncation=truncation,
|
432 |
+
max_length=max_length,
|
433 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
434 |
+
verbose=verbose,
|
435 |
+
**kwargs,
|
436 |
+
)
|
437 |
+
|
438 |
+
# Compute lengths
|
439 |
+
pair = bool(pair_ids is not None)
|
440 |
+
len_ids = len(ids)
|
441 |
+
len_pair_ids = len(pair_ids) if pair else 0
|
442 |
+
|
443 |
+
if return_token_type_ids and not add_special_tokens:
|
444 |
+
raise ValueError(
|
445 |
+
"Asking to return token_type_ids while setting add_special_tokens to False "
|
446 |
+
"results in an undefined behavior. Please set add_special_tokens to True or "
|
447 |
+
"set return_token_type_ids to None."
|
448 |
+
)
|
449 |
+
if (
|
450 |
+
return_overflowing_tokens
|
451 |
+
and truncation_strategy == TruncationStrategy.LONGEST_FIRST
|
452 |
+
and pair_ids is not None
|
453 |
+
):
|
454 |
+
raise ValueError(
|
455 |
+
"Not possible to return overflowing tokens for pair of sequences with the "
|
456 |
+
"`longest_first`. Please select another truncation strategy than `longest_first`, "
|
457 |
+
"for instance `only_second` or `only_first`."
|
458 |
+
)
|
459 |
+
|
460 |
+
# Load from model defaults
|
461 |
+
if return_token_type_ids is None:
|
462 |
+
return_token_type_ids = "token_type_ids" in self.model_input_names
|
463 |
+
if return_attention_mask is None:
|
464 |
+
return_attention_mask = "attention_mask" in self.model_input_names
|
465 |
+
|
466 |
+
encoded_inputs = {}
|
467 |
+
|
468 |
+
# Compute the total size of the returned word encodings
|
469 |
+
total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
|
470 |
+
|
471 |
+
# Truncation: Handle max sequence length and max_entity_length
|
472 |
+
overflowing_tokens = []
|
473 |
+
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
|
474 |
+
# truncate words up to max_length
|
475 |
+
ids, pair_ids, overflowing_tokens = self.truncate_sequences(
|
476 |
+
ids,
|
477 |
+
pair_ids=pair_ids,
|
478 |
+
num_tokens_to_remove=total_len - max_length,
|
479 |
+
truncation_strategy=truncation_strategy,
|
480 |
+
stride=stride,
|
481 |
+
)
|
482 |
+
|
483 |
+
if return_overflowing_tokens:
|
484 |
+
encoded_inputs["overflowing_tokens"] = overflowing_tokens
|
485 |
+
encoded_inputs["num_truncated_tokens"] = total_len - max_length
|
486 |
+
|
487 |
+
# Add special tokens
|
488 |
+
if add_special_tokens:
|
489 |
+
sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
|
490 |
+
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
|
491 |
+
entity_token_offset = 1 # 1 * <s> token
|
492 |
+
pair_entity_token_offset = len(ids) + 3 # 1 * <s> token & 2 * <sep> tokens
|
493 |
+
else:
|
494 |
+
sequence = ids + pair_ids if pair else ids
|
495 |
+
token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else [])
|
496 |
+
entity_token_offset = 0
|
497 |
+
pair_entity_token_offset = len(ids)
|
498 |
+
|
499 |
+
# Build output dictionary
|
500 |
+
encoded_inputs["input_ids"] = sequence
|
501 |
+
encoded_inputs["position_ids"] = list(range(len(sequence)))
|
502 |
+
if return_token_type_ids:
|
503 |
+
encoded_inputs["token_type_ids"] = token_type_ids
|
504 |
+
if return_special_tokens_mask:
|
505 |
+
if add_special_tokens:
|
506 |
+
encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
|
507 |
+
else:
|
508 |
+
encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
|
509 |
+
|
510 |
+
# Set max entity length
|
511 |
+
if not max_entity_length:
|
512 |
+
max_entity_length = self.max_entity_length
|
513 |
+
|
514 |
+
if entity_ids is not None:
|
515 |
+
total_entity_len = 0
|
516 |
+
num_invalid_entities = 0
|
517 |
+
valid_entity_ids = [ent_id for ent_id, span in zip(entity_ids, entity_token_spans) if span[1] <= len(ids)]
|
518 |
+
valid_entity_token_spans = [span for span in entity_token_spans if span[1] <= len(ids)]
|
519 |
+
|
520 |
+
total_entity_len += len(valid_entity_ids)
|
521 |
+
num_invalid_entities += len(entity_ids) - len(valid_entity_ids)
|
522 |
+
|
523 |
+
valid_pair_entity_ids, valid_pair_entity_token_spans = None, None
|
524 |
+
if pair_entity_ids is not None:
|
525 |
+
valid_pair_entity_ids = [
|
526 |
+
ent_id
|
527 |
+
for ent_id, span in zip(pair_entity_ids, pair_entity_token_spans)
|
528 |
+
if span[1] <= len(pair_ids)
|
529 |
+
]
|
530 |
+
valid_pair_entity_token_spans = [span for span in pair_entity_token_spans if span[1] <= len(pair_ids)]
|
531 |
+
total_entity_len += len(valid_pair_entity_ids)
|
532 |
+
num_invalid_entities += len(pair_entity_ids) - len(valid_pair_entity_ids)
|
533 |
+
|
534 |
+
if num_invalid_entities != 0:
|
535 |
+
logger.warning(
|
536 |
+
f"{num_invalid_entities} entities are ignored because their entity spans are invalid due to the"
|
537 |
+
" truncation of input tokens"
|
538 |
+
)
|
539 |
+
|
540 |
+
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and total_entity_len > max_entity_length:
|
541 |
+
# truncate entities up to max_entity_length
|
542 |
+
valid_entity_ids, valid_pair_entity_ids, overflowing_entities = self.truncate_sequences(
|
543 |
+
valid_entity_ids,
|
544 |
+
pair_ids=valid_pair_entity_ids,
|
545 |
+
num_tokens_to_remove=total_entity_len - max_entity_length,
|
546 |
+
truncation_strategy=truncation_strategy,
|
547 |
+
stride=stride,
|
548 |
+
)
|
549 |
+
valid_entity_token_spans = valid_entity_token_spans[: len(valid_entity_ids)]
|
550 |
+
if valid_pair_entity_token_spans is not None:
|
551 |
+
valid_pair_entity_token_spans = valid_pair_entity_token_spans[: len(valid_pair_entity_ids)]
|
552 |
+
|
553 |
+
if return_overflowing_tokens:
|
554 |
+
encoded_inputs["overflowing_entities"] = overflowing_entities
|
555 |
+
encoded_inputs["num_truncated_entities"] = total_entity_len - max_entity_length
|
556 |
+
|
557 |
+
final_entity_ids = valid_entity_ids + valid_pair_entity_ids if valid_pair_entity_ids else valid_entity_ids
|
558 |
+
encoded_inputs["entity_ids"] = list(final_entity_ids)
|
559 |
+
entity_position_ids = []
|
560 |
+
entity_start_positions = []
|
561 |
+
entity_end_positions = []
|
562 |
+
for token_spans, offset in (
|
563 |
+
(valid_entity_token_spans, entity_token_offset),
|
564 |
+
(valid_pair_entity_token_spans, pair_entity_token_offset),
|
565 |
+
):
|
566 |
+
if token_spans is not None:
|
567 |
+
for start, end in token_spans:
|
568 |
+
start += offset
|
569 |
+
end += offset
|
570 |
+
position_ids = list(range(start, end))[: self.max_mention_length]
|
571 |
+
position_ids += [-1] * (self.max_mention_length - end + start)
|
572 |
+
entity_position_ids.append(position_ids)
|
573 |
+
entity_start_positions.append(start)
|
574 |
+
entity_end_positions.append(end - 1)
|
575 |
+
|
576 |
+
encoded_inputs["entity_position_ids"] = entity_position_ids
|
577 |
+
if self.task == "entity_span_classification":
|
578 |
+
encoded_inputs["entity_start_positions"] = entity_start_positions
|
579 |
+
encoded_inputs["entity_end_positions"] = entity_end_positions
|
580 |
+
|
581 |
+
if return_token_type_ids:
|
582 |
+
encoded_inputs["entity_token_type_ids"] = [0] * len(encoded_inputs["entity_ids"])
|
583 |
+
|
584 |
+
# Check lengths
|
585 |
+
self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose)
|
586 |
+
|
587 |
+
# Padding
|
588 |
+
if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
|
589 |
+
encoded_inputs = self.pad(
|
590 |
+
encoded_inputs,
|
591 |
+
max_length=max_length,
|
592 |
+
max_entity_length=max_entity_length,
|
593 |
+
padding=padding_strategy.value,
|
594 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
595 |
+
return_attention_mask=return_attention_mask,
|
596 |
+
)
|
597 |
+
|
598 |
+
if return_length:
|
599 |
+
encoded_inputs["length"] = len(encoded_inputs["input_ids"])
|
600 |
+
|
601 |
+
batch_outputs = BatchEncoding(
|
602 |
+
encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis
|
603 |
+
)
|
604 |
+
|
605 |
+
return batch_outputs
|
606 |
+
|
607 |
+
# Copied and modified from LukeTokenizer, adding the padding of `position_ids`
|
608 |
+
def _pad(
|
609 |
+
self,
|
610 |
+
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
611 |
+
max_length: Optional[int] = None,
|
612 |
+
max_entity_length: Optional[int] = None,
|
613 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
614 |
+
pad_to_multiple_of: Optional[int] = None,
|
615 |
+
return_attention_mask: Optional[bool] = None,
|
616 |
+
) -> dict:
|
617 |
+
"""
|
618 |
+
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
619 |
+
|
620 |
+
|
621 |
+
Args:
|
622 |
+
encoded_inputs:
|
623 |
+
Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
|
624 |
+
max_length: maximum length of the returned list and optionally padding length (see below).
|
625 |
+
Will truncate by taking into account the special tokens.
|
626 |
+
max_entity_length: The maximum length of the entity sequence.
|
627 |
+
padding_strategy: PaddingStrategy to use for padding.
|
628 |
+
|
629 |
+
|
630 |
+
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
631 |
+
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
632 |
+
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
633 |
+
The tokenizer padding sides are defined in self.padding_side:
|
634 |
+
|
635 |
+
|
636 |
+
- 'left': pads on the left of the sequences
|
637 |
+
- 'right': pads on the right of the sequences
|
638 |
+
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
639 |
+
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
|
640 |
+
`>= 7.5` (Volta).
|
641 |
+
return_attention_mask:
|
642 |
+
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
643 |
+
"""
|
644 |
+
entities_provided = bool("entity_ids" in encoded_inputs)
|
645 |
+
|
646 |
+
# Load from model defaults
|
647 |
+
if return_attention_mask is None:
|
648 |
+
return_attention_mask = "attention_mask" in self.model_input_names
|
649 |
+
|
650 |
+
if padding_strategy == PaddingStrategy.LONGEST:
|
651 |
+
max_length = len(encoded_inputs["input_ids"])
|
652 |
+
if entities_provided:
|
653 |
+
max_entity_length = len(encoded_inputs["entity_ids"])
|
654 |
+
|
655 |
+
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
656 |
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
657 |
+
|
658 |
+
if (
|
659 |
+
entities_provided
|
660 |
+
and max_entity_length is not None
|
661 |
+
and pad_to_multiple_of is not None
|
662 |
+
and (max_entity_length % pad_to_multiple_of != 0)
|
663 |
+
):
|
664 |
+
max_entity_length = ((max_entity_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
665 |
+
|
666 |
+
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and (
|
667 |
+
len(encoded_inputs["input_ids"]) != max_length
|
668 |
+
or (entities_provided and len(encoded_inputs["entity_ids"]) != max_entity_length)
|
669 |
+
)
|
670 |
+
|
671 |
+
# Initialize attention mask if not present.
|
672 |
+
if return_attention_mask and "attention_mask" not in encoded_inputs:
|
673 |
+
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"])
|
674 |
+
if entities_provided and return_attention_mask and "entity_attention_mask" not in encoded_inputs:
|
675 |
+
encoded_inputs["entity_attention_mask"] = [1] * len(encoded_inputs["entity_ids"])
|
676 |
+
|
677 |
+
if needs_to_be_padded:
|
678 |
+
difference = max_length - len(encoded_inputs["input_ids"])
|
679 |
+
if entities_provided:
|
680 |
+
entity_difference = max_entity_length - len(encoded_inputs["entity_ids"])
|
681 |
+
if self.padding_side == "right":
|
682 |
+
if return_attention_mask:
|
683 |
+
encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
|
684 |
+
if entities_provided:
|
685 |
+
encoded_inputs["entity_attention_mask"] = (
|
686 |
+
encoded_inputs["entity_attention_mask"] + [0] * entity_difference
|
687 |
+
)
|
688 |
+
if "token_type_ids" in encoded_inputs:
|
689 |
+
encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"] + [0] * difference
|
690 |
+
if entities_provided:
|
691 |
+
encoded_inputs["entity_token_type_ids"] = (
|
692 |
+
encoded_inputs["entity_token_type_ids"] + [0] * entity_difference
|
693 |
+
)
|
694 |
+
if "special_tokens_mask" in encoded_inputs:
|
695 |
+
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
|
696 |
+
encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference
|
697 |
+
encoded_inputs["position_ids"] = encoded_inputs["position_ids"] + [0] * difference
|
698 |
+
if entities_provided:
|
699 |
+
encoded_inputs["entity_ids"] = (
|
700 |
+
encoded_inputs["entity_ids"] + [self.entity_pad_token_id] * entity_difference
|
701 |
+
)
|
702 |
+
encoded_inputs["entity_position_ids"] = (
|
703 |
+
encoded_inputs["entity_position_ids"] + [[-1] * self.max_mention_length] * entity_difference
|
704 |
+
)
|
705 |
+
if self.task == "entity_span_classification":
|
706 |
+
encoded_inputs["entity_start_positions"] = (
|
707 |
+
encoded_inputs["entity_start_positions"] + [0] * entity_difference
|
708 |
+
)
|
709 |
+
encoded_inputs["entity_end_positions"] = (
|
710 |
+
encoded_inputs["entity_end_positions"] + [0] * entity_difference
|
711 |
+
)
|
712 |
+
|
713 |
+
elif self.padding_side == "left":
|
714 |
+
if return_attention_mask:
|
715 |
+
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
|
716 |
+
if entities_provided:
|
717 |
+
encoded_inputs["entity_attention_mask"] = [0] * entity_difference + encoded_inputs[
|
718 |
+
"entity_attention_mask"
|
719 |
+
]
|
720 |
+
if "token_type_ids" in encoded_inputs:
|
721 |
+
encoded_inputs["token_type_ids"] = [0] * difference + encoded_inputs["token_type_ids"]
|
722 |
+
if entities_provided:
|
723 |
+
encoded_inputs["entity_token_type_ids"] = [0] * entity_difference + encoded_inputs[
|
724 |
+
"entity_token_type_ids"
|
725 |
+
]
|
726 |
+
if "special_tokens_mask" in encoded_inputs:
|
727 |
+
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
|
728 |
+
encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"]
|
729 |
+
encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
|
730 |
+
if entities_provided:
|
731 |
+
encoded_inputs["entity_ids"] = [self.entity_pad_token_id] * entity_difference + encoded_inputs[
|
732 |
+
"entity_ids"
|
733 |
+
]
|
734 |
+
encoded_inputs["entity_position_ids"] = [
|
735 |
+
[-1] * self.max_mention_length
|
736 |
+
] * entity_difference + encoded_inputs["entity_position_ids"]
|
737 |
+
if self.task == "entity_span_classification":
|
738 |
+
encoded_inputs["entity_start_positions"] = [0] * entity_difference + encoded_inputs[
|
739 |
+
"entity_start_positions"
|
740 |
+
]
|
741 |
+
encoded_inputs["entity_end_positions"] = [0] * entity_difference + encoded_inputs[
|
742 |
+
"entity_end_positions"
|
743 |
+
]
|
744 |
+
else:
|
745 |
+
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
|
746 |
+
|
747 |
+
return encoded_inputs
|
748 |
+
|
749 |
+
# Copied and modified from BertJapaneseTokenizer and LukeTokenizer
|
750 |
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
751 |
if os.path.isdir(save_directory):
|
752 |
if self.subword_tokenizer_type == "sentencepiece":
|