BorisAlbar
commited on
Commit
•
684f79f
1
Parent(s):
e56551d
Upload configuration_flash_t5.py with huggingface_hub
Browse files- configuration_flash_t5.py +84 -0
configuration_flash_t5.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from collections import OrderedDict
|
3 |
+
from typing import Mapping
|
4 |
+
import logging
|
5 |
+
|
6 |
+
from transformers import T5Config
|
7 |
+
|
8 |
+
AUTO_MAP = {
|
9 |
+
"AutoModel": "modeling_flash_t5.FlashT5ForConditionalGeneration",
|
10 |
+
"AutoModelForSeq2SeqLM": "modeling_flash_t5.FlashT5ForConditionalGeneration",
|
11 |
+
"AutoModelForTokenClassification": "custom_heads_flash_t5.FlashT5ForTokenClassification",
|
12 |
+
"AutoModelForQuestionAnswering": "custom_heads_flash_t5.FlashT5ForQuestionAnswering",
|
13 |
+
"AutoModelForSequenceClassification": "custom_heads_flash_t5.FlashT5ForSequenceClassification",
|
14 |
+
}
|
15 |
+
|
16 |
+
class FlashT5Config(T5Config):
|
17 |
+
|
18 |
+
model_type = "flash_t5"
|
19 |
+
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
decoder_start_token_id=0,
|
23 |
+
pad_token_id=-100,
|
24 |
+
use_glu_mlp=False,
|
25 |
+
position_encoding_type="t5",
|
26 |
+
use_randomized_position_encoding=False,
|
27 |
+
label_smoothing=0.0,
|
28 |
+
z_loss=None,
|
29 |
+
attention_type="ref",
|
30 |
+
max_sequence_length=1024,
|
31 |
+
attention_dropout_rate=0.0,
|
32 |
+
alibi_mode="symetric",
|
33 |
+
use_triton_layernorm=False,
|
34 |
+
use_triton_crossentropy=False,
|
35 |
+
use_triton_gated_mlp=False,
|
36 |
+
use_gelu_act=True,
|
37 |
+
use_full_bias_size=False,
|
38 |
+
rotary_emb_fraction=1.0,
|
39 |
+
rotary_base=10000,
|
40 |
+
rotary_interleaved=False,
|
41 |
+
rotary_scale_base=None,
|
42 |
+
fire_mlp_width=32,
|
43 |
+
use_masking=False,
|
44 |
+
attention_scale=None,
|
45 |
+
**kwargs,
|
46 |
+
):
|
47 |
+
super().__init__(**kwargs)
|
48 |
+
|
49 |
+
self.decoder_start_token_id = decoder_start_token_id
|
50 |
+
self.pad_token_id = pad_token_id
|
51 |
+
self.use_glu_mlp = use_glu_mlp
|
52 |
+
self.position_encoding_type = position_encoding_type
|
53 |
+
self.use_randomized_position_encoding = use_randomized_position_encoding
|
54 |
+
self.label_smoothing = label_smoothing
|
55 |
+
self.z_loss = z_loss
|
56 |
+
self.attention_type = attention_type
|
57 |
+
self.max_sequence_length = max_sequence_length
|
58 |
+
self.alibi_mode = alibi_mode
|
59 |
+
self.attention_dropout_rate = attention_dropout_rate
|
60 |
+
self.use_triton_layernorm = use_triton_layernorm
|
61 |
+
self.use_triton_crossentropy = use_triton_crossentropy
|
62 |
+
self.use_triton_gated_mlp = use_triton_gated_mlp
|
63 |
+
self.use_gelu_act = use_gelu_act
|
64 |
+
self.use_full_bias_size = use_full_bias_size
|
65 |
+
self.rotary_base = rotary_base
|
66 |
+
self.rotary_interleaved = rotary_interleaved
|
67 |
+
self.rotary_scale_base = rotary_scale_base
|
68 |
+
self.rotary_emb_fraction = rotary_emb_fraction
|
69 |
+
self.fire_mlp_width = fire_mlp_width
|
70 |
+
self.use_masking = use_masking
|
71 |
+
self.attention_scale = attention_scale
|
72 |
+
|
73 |
+
self.auto_map = AUTO_MAP
|
74 |
+
|
75 |
+
def str_to_class(classname):
|
76 |
+
return getattr(sys.modules[__name__], classname)
|
77 |
+
|
78 |
+
# Register model in Auto API
|
79 |
+
try:
|
80 |
+
FlashT5Config.register_for_auto_class()
|
81 |
+
for key, value in AUTO_MAP.items():
|
82 |
+
str_to_class(value.split(".")[-1]).register_for_auto_class(key)
|
83 |
+
except:
|
84 |
+
logging.warn("AutoRegister isn't available.")
|