imvladikon commited on
Commit
f1983d5
·
1 Parent(s): 9ad9c83

Create configuration_enc_t5.py

Browse files
Files changed (1) hide show
  1. configuration_enc_t5.py +120 -0
configuration_enc_t5.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ from transformers import T5Config, T5TokenizerFast
6
+
7
+
8
+ class EncT5Config(T5Config):
9
+ model_type = "enc-t5"
10
+
11
+ def __init__(self, **kwargs: Any) -> None:
12
+ super().__init__(**kwargs)
13
+
14
+
15
+ class EncT5Tokenizer(T5TokenizerFast):
16
+
17
+ def __init__(
18
+ self,
19
+ vocab_file,
20
+ bos_token="<s>",
21
+ eos_token="</s>",
22
+ unk_token="<unk>",
23
+ pad_token="<pad>",
24
+ extra_ids=100,
25
+ additional_special_tokens=None,
26
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
27
+ **kwargs,
28
+ ) -> None:
29
+ sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
30
+
31
+ super().__init__(
32
+ vocab_file=vocab_file,
33
+ bos_token=bos_token,
34
+ eos_token=eos_token,
35
+ unk_token=unk_token,
36
+ pad_token=pad_token,
37
+ extra_ids=extra_ids,
38
+ additional_special_tokens=additional_special_tokens,
39
+ sp_model_kwargs=sp_model_kwargs,
40
+ **kwargs,
41
+ )
42
+
43
+ def get_special_tokens_mask(
44
+ self,
45
+ token_ids_0: List[int],
46
+ token_ids_1: Optional[List[int]] = None,
47
+ already_has_special_tokens: bool = False,
48
+ ) -> List[int]:
49
+ """
50
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
51
+ special tokens using the tokenizer `prepare_for_model` method.
52
+ Args:
53
+ token_ids_0 (`List[int]`):
54
+ List of IDs.
55
+ token_ids_1 (`List[int]`, *optional*):
56
+ Optional second list of IDs for sequence pairs.
57
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
58
+ Whether or not the token list is already formatted with special tokens for the model.
59
+ Returns:
60
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
61
+ """
62
+ if already_has_special_tokens:
63
+ return super().get_special_tokens_mask(
64
+ token_ids_0=token_ids_0,
65
+ token_ids_1=token_ids_1,
66
+ already_has_special_tokens=True,
67
+ )
68
+
69
+ # normal case: some special tokens
70
+ if token_ids_1 is None:
71
+ return [1] + ([0] * len(token_ids_0)) + [1]
72
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
73
+
74
+ def create_token_type_ids_from_sequences(
75
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
76
+ ) -> List[int]:
77
+ """
78
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
79
+ use of token type ids, therefore a list of zeros is returned.
80
+ Args:
81
+ token_ids_0 (`List[int]`):
82
+ List of IDs.
83
+ token_ids_1 (`List[int]`, *optional*):
84
+ Optional second list of IDs for sequence pairs.
85
+ Returns:
86
+ `List[int]`: List of zeros.
87
+ """
88
+ bos = [self.bos_token_id]
89
+ eos = [self.eos_token_id]
90
+
91
+ if token_ids_1 is None:
92
+ return len(bos + token_ids_0 + eos) * [0]
93
+ return len(bos + token_ids_0 + eos + token_ids_1 + eos) * [0]
94
+
95
+ def build_inputs_with_special_tokens(
96
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
97
+ ) -> List[int]:
98
+ """
99
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
100
+ adding special tokens. A sequence has the following format:
101
+ - single sequence: `<s> X </s>`
102
+ - pair of sequences: `<s> A </s> B </s>`
103
+ Args:
104
+ token_ids_0 (`List[int]`):
105
+ List of IDs to which the special tokens will be added.
106
+ token_ids_1 (`List[int]`, *optional*):
107
+ Optional second list of IDs for sequence pairs.
108
+ Returns:
109
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
110
+ """
111
+ if token_ids_1 is None:
112
+ return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
113
+ else:
114
+ return (
115
+ [self.bos_token_id]
116
+ + token_ids_0
117
+ + [self.eos_token_id]
118
+ + token_ids_1
119
+ + [self.eos_token_id]
120
+ )