das
#9
by
jupyterjazz
- opened
- README.md +0 -33
- config.json +0 -40
- configuration_bert.py +6 -6
- convert_v2_weights.py +0 -151
- mha.py +0 -4
- mlp.py +0 -47
- modeling_bert.py +27 -145
- modeling_lora.py +15 -75
- tokenizer.py +88 -0
README.md
DELETED
@@ -1,33 +0,0 @@
|
|
1 |
-
# BERT with Flash-Attention
|
2 |
-
### Installing dependencies
|
3 |
-
To run the model on GPU, you need to install Flash Attention.
|
4 |
-
You may either install from pypi (which may not work with fused-dense), or from source.
|
5 |
-
To install from source, clone the GitHub repository:
|
6 |
-
```console
|
7 |
-
git clone git@github.com:Dao-AILab/flash-attention.git
|
8 |
-
```
|
9 |
-
The code provided here should work with commit `43950dd`.
|
10 |
-
Change to the cloned repo and install:
|
11 |
-
```console
|
12 |
-
cd flash-attention && python setup.py install
|
13 |
-
```
|
14 |
-
This will compile the flash-attention kernel, which will take some time.
|
15 |
-
|
16 |
-
If you would like to use fused MLPs (e.g. to use activation checkpointing),
|
17 |
-
you may install fused-dense also from source:
|
18 |
-
```console
|
19 |
-
cd csrc/fused_dense_lib && python setup.py install
|
20 |
-
```
|
21 |
-
|
22 |
-
|
23 |
-
### Configuration
|
24 |
-
The config adds some new parameters:
|
25 |
-
- `use_flash_attn`: If `True`, always use flash attention. If `None`, use flash attention when GPU is available. If `False`, never use flash attention (works on CPU).
|
26 |
-
- `window_size`: Size (left and right) of the local attention window. If `(-1, -1)`, use global attention
|
27 |
-
- `dense_seq_output`: If true, we only need to pass the hidden states for the masked out token (around 15%) to the classifier heads. I set this to true for pretraining.
|
28 |
-
- `fused_mlp`: Whether to use fused-dense. Useful to reduce VRAM in combination with activation checkpointing
|
29 |
-
- `mlp_checkpoint_lvl`: One of `{0, 1, 2}`. Increasing this increases the amount of activation checkpointing within the MLP. Keep this at 0 for pretraining and use gradient accumulation instead. For embedding training, increase this as much as needed.
|
30 |
-
- `last_layer_subset`: If true, we only need the compute the last layer for a subset of tokens. I left this to false.
|
31 |
-
- `use_qk_norm`: Whether or not to use QK-normalization
|
32 |
-
- `num_loras`: Number of LoRAs to use when initializing a `BertLoRA` model. Has no effect on other models.
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config.json
DELETED
@@ -1,40 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"_name_or_path": "jinaai/jina-bert-flash-implementation",
|
3 |
-
"auto_map": {
|
4 |
-
"AutoConfig": "jinaai/jina-bert-flash-implementation--configuration_bert.JinaBertConfig",
|
5 |
-
"AutoModel": "jinaai/jina-bert-flash-implementation--modeling_bert.BertModel",
|
6 |
-
"AutoModelForPreTraining": "jinaai/jina-bert-flash-implementation--modeling_bert.BertForPreTraining",
|
7 |
-
"AutoModelForMaskedLM": "jinaai/jina-bert-flash-implementation--modeling_bert.BertForPreTraining"
|
8 |
-
},
|
9 |
-
"attention_probs_dropout_prob": 0.1,
|
10 |
-
"classifier_dropout": null,
|
11 |
-
"dense_seq_output": false,
|
12 |
-
"emb_pooler": null,
|
13 |
-
"fused_bias_fc": false,
|
14 |
-
"fused_dropout_add_ln": false,
|
15 |
-
"hidden_act": "gelu",
|
16 |
-
"hidden_dropout_prob": 0.1,
|
17 |
-
"hidden_size": 768,
|
18 |
-
"initializer_range": 0.02,
|
19 |
-
"intermediate_size": 3072,
|
20 |
-
"last_layer_subset": false,
|
21 |
-
"layer_norm_eps": 1e-12,
|
22 |
-
"mlp_checkpoint_lvl": 0,
|
23 |
-
"mlp_type": "glu",
|
24 |
-
"model_type": "bert",
|
25 |
-
"num_attention_heads": 12,
|
26 |
-
"num_hidden_layers": 12,
|
27 |
-
"num_loras": 5,
|
28 |
-
"pad_token_id": 0,
|
29 |
-
"pad_vocab_size_multiple": 1,
|
30 |
-
"torch_dtype": "float16",
|
31 |
-
"transformers_version": "4.39.3",
|
32 |
-
"type_vocab_size": 2,
|
33 |
-
"use_flash_attn": true,
|
34 |
-
"use_qk_norm": false,
|
35 |
-
"vocab_size": 30528,
|
36 |
-
"window_size": [
|
37 |
-
-1,
|
38 |
-
-1
|
39 |
-
]
|
40 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configuration_bert.py
CHANGED
@@ -75,24 +75,24 @@ class JinaBertConfig(PretrainedConfig):
|
|
75 |
pad_token_id=0,
|
76 |
window_size=(-1, -1),
|
77 |
dense_seq_output=False,
|
78 |
-
|
79 |
mlp_checkpoint_lvl=0,
|
80 |
last_layer_subset=False,
|
81 |
fused_dropout_add_ln=False,
|
82 |
fused_bias_fc=False,
|
83 |
pad_vocab_size_multiple=1,
|
|
|
84 |
use_flash_attn=True,
|
85 |
use_qk_norm=True,
|
86 |
emb_pooler=None,
|
87 |
classifier_dropout=None,
|
88 |
-
num_loras=5,
|
89 |
**kwargs,
|
90 |
):
|
91 |
assert 'position_embedding_type' not in kwargs
|
92 |
assert 'max_position_embeddings' not in kwargs
|
93 |
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
94 |
|
95 |
-
if
|
96 |
raise ValueError('Fused MLP only supports approximate gelu')
|
97 |
|
98 |
self.vocab_size = vocab_size
|
@@ -108,14 +108,14 @@ class JinaBertConfig(PretrainedConfig):
|
|
108 |
self.layer_norm_eps = layer_norm_eps
|
109 |
self.window_size = window_size
|
110 |
self.dense_seq_output = dense_seq_output
|
111 |
-
self.
|
112 |
self.mlp_checkpoint_lvl = mlp_checkpoint_lvl
|
113 |
self.last_layer_subset = last_layer_subset
|
114 |
self.fused_dropout_add_ln = fused_dropout_add_ln
|
115 |
self.fused_bias_fc = fused_bias_fc
|
116 |
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
|
|
117 |
self.use_flash_attn = use_flash_attn
|
118 |
self.use_qk_norm = use_qk_norm
|
119 |
self.emb_pooler = emb_pooler
|
120 |
-
self.classifier_dropout = classifier_dropout
|
121 |
-
self.num_loras = num_loras
|
|
|
75 |
pad_token_id=0,
|
76 |
window_size=(-1, -1),
|
77 |
dense_seq_output=False,
|
78 |
+
fused_mlp=False,
|
79 |
mlp_checkpoint_lvl=0,
|
80 |
last_layer_subset=False,
|
81 |
fused_dropout_add_ln=False,
|
82 |
fused_bias_fc=False,
|
83 |
pad_vocab_size_multiple=1,
|
84 |
+
num_tasks=0,
|
85 |
use_flash_attn=True,
|
86 |
use_qk_norm=True,
|
87 |
emb_pooler=None,
|
88 |
classifier_dropout=None,
|
|
|
89 |
**kwargs,
|
90 |
):
|
91 |
assert 'position_embedding_type' not in kwargs
|
92 |
assert 'max_position_embeddings' not in kwargs
|
93 |
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
94 |
|
95 |
+
if fused_mlp and hidden_act not in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]:
|
96 |
raise ValueError('Fused MLP only supports approximate gelu')
|
97 |
|
98 |
self.vocab_size = vocab_size
|
|
|
108 |
self.layer_norm_eps = layer_norm_eps
|
109 |
self.window_size = window_size
|
110 |
self.dense_seq_output = dense_seq_output
|
111 |
+
self.fused_mlp = fused_mlp
|
112 |
self.mlp_checkpoint_lvl = mlp_checkpoint_lvl
|
113 |
self.last_layer_subset = last_layer_subset
|
114 |
self.fused_dropout_add_ln = fused_dropout_add_ln
|
115 |
self.fused_bias_fc = fused_bias_fc
|
116 |
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
117 |
+
self.num_tasks = num_tasks
|
118 |
self.use_flash_attn = use_flash_attn
|
119 |
self.use_qk_norm = use_qk_norm
|
120 |
self.emb_pooler = emb_pooler
|
121 |
+
self.classifier_dropout = classifier_dropout
|
|
convert_v2_weights.py
DELETED
@@ -1,151 +0,0 @@
|
|
1 |
-
import re
|
2 |
-
from collections import OrderedDict
|
3 |
-
from transformers import AutoModel, AutoTokenizer
|
4 |
-
from .configuration_bert import JinaBertConfig
|
5 |
-
import torch
|
6 |
-
from .modeling_bert import BertModel
|
7 |
-
|
8 |
-
def remap_state_dict(state_dict, config: JinaBertConfig):
|
9 |
-
"""
|
10 |
-
Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
|
11 |
-
"""
|
12 |
-
|
13 |
-
# LayerNorm
|
14 |
-
def key_mapping_ln_gamma_beta(key):
|
15 |
-
key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
|
16 |
-
key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
|
17 |
-
return key
|
18 |
-
|
19 |
-
state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
|
20 |
-
|
21 |
-
# Layers
|
22 |
-
def key_mapping_layers(key):
|
23 |
-
return re.sub(r"^encoder.layer.", "encoder.layers.", key)
|
24 |
-
|
25 |
-
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
26 |
-
|
27 |
-
# LayerNorm
|
28 |
-
def key_mapping_ln(key):
|
29 |
-
key = re.sub(r"^embeddings.LayerNorm.", "emb_ln.", key)
|
30 |
-
key = re.sub(
|
31 |
-
r"^encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
|
32 |
-
r"encoder.layers.\1.norm1.\2",
|
33 |
-
key,
|
34 |
-
)
|
35 |
-
key = re.sub(
|
36 |
-
r"^encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
|
37 |
-
r"encoder.layers.\1.norm2.\2",
|
38 |
-
key,
|
39 |
-
)
|
40 |
-
key = re.sub(
|
41 |
-
r"^cls.predictions.transform.LayerNorm.(weight|bias)",
|
42 |
-
r"cls.predictions.transform.layer_norm.\1",
|
43 |
-
key,
|
44 |
-
)
|
45 |
-
return key
|
46 |
-
|
47 |
-
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
48 |
-
|
49 |
-
# MLP
|
50 |
-
def key_mapping_mlp(key):
|
51 |
-
key = re.sub(
|
52 |
-
r"^encoder.layers.(\d+).intermediate.dense.(weight|bias)",
|
53 |
-
r"encoder.layers.\1.mlp.fc1.\2",
|
54 |
-
key,
|
55 |
-
)
|
56 |
-
key = re.sub(
|
57 |
-
r"^encoder.layers.(\d+).output.dense.(weight|bias)",
|
58 |
-
r"encoder.layers.\1.mlp.fc2.\2",
|
59 |
-
key,
|
60 |
-
)
|
61 |
-
return key
|
62 |
-
|
63 |
-
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
64 |
-
|
65 |
-
# Attention
|
66 |
-
last_layer_subset = getattr(config, "last_layer_subset", False)
|
67 |
-
for d in range(config.num_hidden_layers):
|
68 |
-
Wq = state_dict.pop(f"encoder.layers.{d}.attention.self.query.weight")
|
69 |
-
Wk = state_dict.pop(f"encoder.layers.{d}.attention.self.key.weight")
|
70 |
-
Wv = state_dict.pop(f"encoder.layers.{d}.attention.self.value.weight")
|
71 |
-
bq = state_dict.pop(f"encoder.layers.{d}.attention.self.query.bias")
|
72 |
-
bk = state_dict.pop(f"encoder.layers.{d}.attention.self.key.bias")
|
73 |
-
bv = state_dict.pop(f"encoder.layers.{d}.attention.self.value.bias")
|
74 |
-
if not (last_layer_subset and d == config.num_hidden_layers - 1):
|
75 |
-
state_dict[f"encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat(
|
76 |
-
[Wq, Wk, Wv], dim=0
|
77 |
-
)
|
78 |
-
state_dict[f"encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
|
79 |
-
else:
|
80 |
-
state_dict[f"encoder.layers.{d}.mixer.Wq.weight"] = Wq
|
81 |
-
state_dict[f"encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat([Wk, Wv], dim=0)
|
82 |
-
state_dict[f"encoder.layers.{d}.mixer.Wq.bias"] = bq
|
83 |
-
state_dict[f"encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat([bk, bv], dim=0)
|
84 |
-
|
85 |
-
def key_mapping_attn(key):
|
86 |
-
return re.sub(
|
87 |
-
r"^encoder.layers.(\d+).attention.output.dense.(weight|bias)",
|
88 |
-
r"encoder.layers.\1.mixer.out_proj.\2",
|
89 |
-
key,
|
90 |
-
)
|
91 |
-
|
92 |
-
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
93 |
-
|
94 |
-
def key_mapping_decoder_bias(key):
|
95 |
-
return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
|
96 |
-
|
97 |
-
state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
|
98 |
-
|
99 |
-
# Word embedding
|
100 |
-
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
101 |
-
if pad_vocab_size_multiple > 1:
|
102 |
-
word_embeddings = state_dict["embeddings.word_embeddings.weight"]
|
103 |
-
state_dict["embeddings.word_embeddings.weight"] = F.pad(
|
104 |
-
word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
|
105 |
-
)
|
106 |
-
decoder_weight = state_dict["cls.predictions.decoder.weight"]
|
107 |
-
state_dict["cls.predictions.decoder.weight"] = F.pad(
|
108 |
-
decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
|
109 |
-
)
|
110 |
-
# If the vocab was padded, we want to set the decoder bias for those padded indices to be
|
111 |
-
# strongly negative (i.e. the decoder shouldn't predict those indices).
|
112 |
-
# TD [2022-05-09]: I don't think it affects the MLPerf training.
|
113 |
-
decoder_bias = state_dict["cls.predictions.decoder.bias"]
|
114 |
-
state_dict["cls.predictions.decoder.bias"] = F.pad(
|
115 |
-
decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
|
116 |
-
)
|
117 |
-
|
118 |
-
# LayerNorm
|
119 |
-
def key_mapping_layernorm(key):
|
120 |
-
return re.sub(r'^encoder.layers.(\d+).mlp.layernorm.(weight|bias)', r"encoder.layers.\1.norm2.\2", key)
|
121 |
-
|
122 |
-
state_dict = OrderedDict((key_mapping_layernorm(k), v) for k, v in state_dict.items())
|
123 |
-
|
124 |
-
return state_dict
|
125 |
-
|
126 |
-
|
127 |
-
v2_model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True)
|
128 |
-
config = JinaBertConfig(vocab_size=30528, use_qk_norm=False, mlp_type='glu', hidden_act='gelu')
|
129 |
-
state_dict = v2_model.state_dict()
|
130 |
-
new_state_dict = remap_state_dict(state_dict, config)
|
131 |
-
flash_model = BertModel(config)
|
132 |
-
flash_model.load_state_dict(new_state_dict)
|
133 |
-
|
134 |
-
|
135 |
-
torch.save(new_state_dict, 'converted_weights.bin')
|
136 |
-
print(config.to_json_string())
|
137 |
-
|
138 |
-
|
139 |
-
"""
|
140 |
-
tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v2-base-en')
|
141 |
-
inp = tokenizer.batch_encode_plus(['Hello world', 'How is the weather today?', 'It is raining a lot in Berlin'], return_tensors='pt', padding=True).to('cuda')
|
142 |
-
v2_model.eval()
|
143 |
-
flash_model.eval()
|
144 |
-
v2_model = v2_model.to('cuda', torch.float16)
|
145 |
-
flash_model = flash_model.to('cuda', torch.float16)
|
146 |
-
output_v2 = v2_model(**inp)
|
147 |
-
output_flash = flash_model(**inp)
|
148 |
-
x = output_v2.last_hidden_state
|
149 |
-
y = output_flash.last_hidden_state
|
150 |
-
print(torch.abs(x - y))
|
151 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mha.py
CHANGED
@@ -514,10 +514,6 @@ class MHA(nn.Module):
|
|
514 |
alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
|
515 |
else:
|
516 |
alibi_slopes = None
|
517 |
-
|
518 |
-
if isinstance(window_size, list):
|
519 |
-
window_size = tuple(window_size)
|
520 |
-
|
521 |
if window_size != (-1, -1):
|
522 |
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
|
523 |
|
|
|
514 |
alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
|
515 |
else:
|
516 |
alibi_slopes = None
|
|
|
|
|
|
|
|
|
517 |
if window_size != (-1, -1):
|
518 |
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
|
519 |
|
mlp.py
CHANGED
@@ -27,53 +27,6 @@ except ImportError:
|
|
27 |
FusedMLP, ParallelFusedMLP = None, None
|
28 |
|
29 |
|
30 |
-
class GLUMLP(nn.Module):
|
31 |
-
def __init__(
|
32 |
-
self,
|
33 |
-
in_features,
|
34 |
-
hidden_features,
|
35 |
-
activation,
|
36 |
-
use_flash_attn,
|
37 |
-
return_residual=False,
|
38 |
-
hidden_dropout_prob=0.1
|
39 |
-
):
|
40 |
-
super().__init__()
|
41 |
-
self.hidden_features = hidden_features
|
42 |
-
self.gated_layers = nn.Linear(
|
43 |
-
in_features, hidden_features * 2, bias=False
|
44 |
-
)
|
45 |
-
if activation == 'relu':
|
46 |
-
self.act = nn.ReLU()
|
47 |
-
elif activation == 'gelu':
|
48 |
-
self.act = nn.GELU()
|
49 |
-
else:
|
50 |
-
raise ValueError(
|
51 |
-
f"activation {activation} not supported"
|
52 |
-
)
|
53 |
-
self.wo = nn.Linear(hidden_features, in_features)
|
54 |
-
self.dropout = nn.Dropout(hidden_dropout_prob)
|
55 |
-
self.return_residual = return_residual
|
56 |
-
self.use_flash_attn = use_flash_attn
|
57 |
-
#self.layernorm = nn.LayerNorm(in_features, eps=layer_norm_eps)
|
58 |
-
|
59 |
-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
60 |
-
residual_connection = hidden_states
|
61 |
-
# compute the activation
|
62 |
-
hidden_states = self.gated_layers(hidden_states)
|
63 |
-
if self.use_flash_attn:
|
64 |
-
gated = hidden_states[:, : self.hidden_features]
|
65 |
-
non_gated = hidden_states[:, self.hidden_features :]
|
66 |
-
else:
|
67 |
-
gated = hidden_states[:, :, : self.hidden_features]
|
68 |
-
non_gated = hidden_states[:, :, self.hidden_features :]
|
69 |
-
hidden_states = self.act(gated) * non_gated
|
70 |
-
hidden_states = self.dropout(hidden_states)
|
71 |
-
# multiply by the second matrix
|
72 |
-
hidden_states = self.wo(hidden_states)
|
73 |
-
# add the residual connection and post-LN
|
74 |
-
# hidden_states = self.layernorm(hidden_states + residual_connection)
|
75 |
-
return hidden_states if not self.return_residual else (hidden_states, residual_connection)
|
76 |
-
|
77 |
class Mlp(nn.Module):
|
78 |
def __init__(
|
79 |
self,
|
|
|
27 |
FusedMLP, ParallelFusedMLP = None, None
|
28 |
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
class Mlp(nn.Module):
|
31 |
def __init__(
|
32 |
self,
|
modeling_bert.py
CHANGED
@@ -39,7 +39,7 @@ from .bert_padding import (
|
|
39 |
from .block import Block
|
40 |
from .embedding import BertEmbeddings
|
41 |
from .mha import MHA
|
42 |
-
from .mlp import FusedMLP, Mlp
|
43 |
|
44 |
try:
|
45 |
from flash_attn.ops.fused_dense import FusedDense
|
@@ -81,23 +81,19 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
|
81 |
return_residual=return_residual,
|
82 |
use_alibi=True,
|
83 |
window_size=window_size,
|
84 |
-
qk_norm=use_qk_norm
|
85 |
-
checkpointing=False,
|
86 |
)
|
87 |
return mixer_cls
|
88 |
|
89 |
|
90 |
def create_mlp_cls(config, layer_idx=None, return_residual=False):
|
91 |
inner_dim = config.intermediate_size
|
92 |
-
|
93 |
-
|
94 |
-
if mlp_type == 'fused_mlp':
|
95 |
assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
|
96 |
"fused_mlp only " "supports approximate gelu"
|
97 |
)
|
98 |
-
if
|
99 |
-
assert config.hidden_act in ('relu', 'gelu')
|
100 |
-
if mlp_type == 'mlp':
|
101 |
approximate = (
|
102 |
"tanh"
|
103 |
if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
|
@@ -109,16 +105,7 @@ def create_mlp_cls(config, layer_idx=None, return_residual=False):
|
|
109 |
activation=partial(F.gelu, approximate=approximate),
|
110 |
return_residual=return_residual,
|
111 |
)
|
112 |
-
|
113 |
-
mlp_cls = partial(
|
114 |
-
GLUMLP,
|
115 |
-
hidden_features=inner_dim,
|
116 |
-
activation=config.hidden_act,
|
117 |
-
use_flash_attn=config.use_flash_attn,
|
118 |
-
hidden_dropout_prob=config.hidden_dropout_prob,
|
119 |
-
return_residual=return_residual,
|
120 |
-
)
|
121 |
-
elif mlp_type == 'fused_mlp':
|
122 |
if FusedMLP is None:
|
123 |
raise ImportError("fused_dense is not installed")
|
124 |
mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
|
@@ -132,8 +119,6 @@ def create_mlp_cls(config, layer_idx=None, return_residual=False):
|
|
132 |
checkpoint_lvl=mlp_checkpoint_lvl,
|
133 |
return_residual=return_residual,
|
134 |
)
|
135 |
-
else:
|
136 |
-
raise NotImplementedError
|
137 |
return mlp_cls
|
138 |
|
139 |
|
@@ -167,7 +152,7 @@ def _init_weights(module, initializer_range=0.02):
|
|
167 |
nn.init.normal_(module.weight, std=initializer_range)
|
168 |
if module.bias is not None:
|
169 |
nn.init.zeros_(module.bias)
|
170 |
-
elif isinstance(module, nn.Embedding):
|
171 |
nn.init.normal_(module.weight, std=initializer_range)
|
172 |
if module.padding_idx is not None:
|
173 |
nn.init.zeros_(module.weight[module.padding_idx])
|
@@ -189,6 +174,8 @@ class BertEncoder(nn.Module):
|
|
189 |
@gradient_checkpointing.setter
|
190 |
def gradient_checkpointing(self, value):
|
191 |
self._grad_checkpointing = value
|
|
|
|
|
192 |
|
193 |
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
|
194 |
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
@@ -200,15 +187,7 @@ class BertEncoder(nn.Module):
|
|
200 |
{"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
|
201 |
)
|
202 |
for layer in self.layers:
|
203 |
-
|
204 |
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
205 |
-
layer,
|
206 |
-
hidden_states,
|
207 |
-
use_reentrant=False,
|
208 |
-
mixer_kwargs=mixer_kwargs
|
209 |
-
)
|
210 |
-
else:
|
211 |
-
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
212 |
if subset_mask is not None:
|
213 |
hidden_states = hidden_states[subset_mask]
|
214 |
else:
|
@@ -219,27 +198,11 @@ class BertEncoder(nn.Module):
|
|
219 |
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
220 |
if subset_mask is None:
|
221 |
for layer in self.layers:
|
222 |
-
|
223 |
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
224 |
-
layer,
|
225 |
-
hidden_states,
|
226 |
-
use_reentrant=False,
|
227 |
-
mixer_kwargs=mixer_kwargs
|
228 |
-
)
|
229 |
-
else:
|
230 |
-
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
231 |
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
232 |
else:
|
233 |
for layer in self.layers[:-1]:
|
234 |
-
|
235 |
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
236 |
-
layer,
|
237 |
-
hidden_states,
|
238 |
-
use_reentrant=False,
|
239 |
-
mixer_kwargs=mixer_kwargs
|
240 |
-
)
|
241 |
-
else:
|
242 |
-
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
243 |
if key_padding_mask is not None:
|
244 |
subset_idx = torch.nonzero(
|
245 |
subset_mask[key_padding_mask], as_tuple=False
|
@@ -265,15 +228,7 @@ class BertEncoder(nn.Module):
|
|
265 |
"cu_seqlens_k": cu_seqlens,
|
266 |
"max_seqlen_k": max_seqlen_in_batch,
|
267 |
}
|
268 |
-
|
269 |
-
torch.utils.checkpoint.checkpoint(
|
270 |
-
self.layers[-1],
|
271 |
-
hidden_states_subset,
|
272 |
-
use_reentrant=False,
|
273 |
-
mixer_kwargs=mixer_kwargs
|
274 |
-
)
|
275 |
-
else:
|
276 |
-
hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
|
277 |
return hidden_states
|
278 |
|
279 |
|
@@ -396,16 +351,24 @@ class BertModel(BertPreTrainedModel):
|
|
396 |
self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
397 |
self.encoder = BertEncoder(config)
|
398 |
self.pooler = BertPooler(config) if add_pooling_layer else None
|
|
|
399 |
|
400 |
self.emb_pooler = config.emb_pooler
|
401 |
self._name_or_path = config._name_or_path
|
402 |
if self.emb_pooler is not None:
|
403 |
from transformers import AutoTokenizer
|
404 |
|
405 |
-
self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path
|
406 |
else:
|
407 |
self.tokenizer = None
|
408 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
410 |
|
411 |
def forward(
|
@@ -413,9 +376,9 @@ class BertModel(BertPreTrainedModel):
|
|
413 |
input_ids,
|
414 |
position_ids=None,
|
415 |
token_type_ids=None,
|
|
|
416 |
attention_mask=None,
|
417 |
masked_tokens_mask=None,
|
418 |
-
return_dict=True,
|
419 |
):
|
420 |
"""If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
|
421 |
we only want the output for the masked tokens. This means that we only compute the last
|
@@ -425,6 +388,8 @@ class BertModel(BertPreTrainedModel):
|
|
425 |
hidden_states = self.embeddings(
|
426 |
input_ids, position_ids=position_ids, token_type_ids=token_type_ids
|
427 |
)
|
|
|
|
|
428 |
|
429 |
# TD [2022-12:18]: Don't need to force residual in fp32
|
430 |
# BERT puts embedding LayerNorm before embedding dropout.
|
@@ -464,9 +429,6 @@ class BertModel(BertPreTrainedModel):
|
|
464 |
sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
|
465 |
pooled_output = self.pooler(pool_input, pool=False) if self.pooler is not None else None
|
466 |
|
467 |
-
if not return_dict:
|
468 |
-
return (sequence_output, pooled_output)
|
469 |
-
|
470 |
return BaseModelOutputWithPoolingAndCrossAttentions(
|
471 |
last_hidden_state=sequence_output,
|
472 |
pooler_output=pooled_output,
|
@@ -522,7 +484,7 @@ class BertModel(BertPreTrainedModel):
|
|
522 |
self.emb_pooler = 'mean'
|
523 |
from transformers import AutoTokenizer
|
524 |
|
525 |
-
self.tokenizer = AutoTokenizer.from_pretrained(self._name_or_path
|
526 |
if self.emb_pooler != 'mean':
|
527 |
raise NotImplementedError
|
528 |
|
@@ -723,84 +685,4 @@ class BertForPreTraining(BertPreTrainedModel):
|
|
723 |
loss=total_loss,
|
724 |
prediction_logits=prediction_scores,
|
725 |
seq_relationship_logits=seq_relationship_score,
|
726 |
-
)
|
727 |
-
|
728 |
-
|
729 |
-
class BertForMaskedLM(BertPreTrainedModel):
|
730 |
-
def __init__(self, config: JinaBertConfig):
|
731 |
-
super().__init__(config)
|
732 |
-
# If dense_seq_output, we only need to pass the hidden states for the masked out tokens
|
733 |
-
# (around 15%) to the classifier heads.
|
734 |
-
self.dense_seq_output = getattr(config, "dense_seq_output", False)
|
735 |
-
# If last_layer_subset, we only need the compute the last layer for a subset of tokens
|
736 |
-
# (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
|
737 |
-
self.last_layer_subset = getattr(config, "last_layer_subset", False)
|
738 |
-
if self.last_layer_subset:
|
739 |
-
assert self.dense_seq_output, "last_layer_subset requires dense_seq_output"
|
740 |
-
use_xentropy = getattr(config, "use_xentropy", False)
|
741 |
-
if use_xentropy and CrossEntropyLoss is None:
|
742 |
-
raise ImportError("xentropy_cuda is not installed")
|
743 |
-
loss_cls = (
|
744 |
-
nn.CrossEntropyLoss
|
745 |
-
if not use_xentropy
|
746 |
-
else partial(CrossEntropyLoss, inplace_backward=True)
|
747 |
-
)
|
748 |
-
|
749 |
-
self.bert = BertModel(config)
|
750 |
-
self.cls = BertPreTrainingHeads(config)
|
751 |
-
self.mlm_loss = loss_cls(ignore_index=0)
|
752 |
-
|
753 |
-
# Initialize weights and apply final processing
|
754 |
-
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
755 |
-
self.tie_weights()
|
756 |
-
|
757 |
-
def tie_weights(self):
|
758 |
-
self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
|
759 |
-
|
760 |
-
def get_input_embeddings(self):
|
761 |
-
return self.bert.embeddings.word_embeddings
|
762 |
-
|
763 |
-
def forward(
|
764 |
-
self,
|
765 |
-
input_ids,
|
766 |
-
position_ids=None,
|
767 |
-
token_type_ids=None,
|
768 |
-
attention_mask=None,
|
769 |
-
labels=None
|
770 |
-
):
|
771 |
-
masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
|
772 |
-
outputs = self.bert(
|
773 |
-
input_ids,
|
774 |
-
position_ids=position_ids,
|
775 |
-
token_type_ids=token_type_ids,
|
776 |
-
attention_mask=attention_mask.bool() if attention_mask is not None else None,
|
777 |
-
masked_tokens_mask=masked_tokens_mask,
|
778 |
-
)
|
779 |
-
sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
|
780 |
-
if self.dense_seq_output and labels is not None:
|
781 |
-
masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
|
782 |
-
if not self.last_layer_subset:
|
783 |
-
sequence_output = index_first_axis(
|
784 |
-
rearrange(sequence_output, "b s d -> (b s) d"), masked_token_idx
|
785 |
-
)
|
786 |
-
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
787 |
-
|
788 |
-
if (
|
789 |
-
self.dense_seq_output and labels is not None
|
790 |
-
): # prediction_scores are already flattened
|
791 |
-
masked_lm_loss = self.mlm_loss(
|
792 |
-
prediction_scores, labels.flatten()[masked_token_idx]
|
793 |
-
).float()
|
794 |
-
elif labels is not None:
|
795 |
-
masked_lm_loss = self.mlm_loss(
|
796 |
-
rearrange(prediction_scores, "... v -> (...) v"),
|
797 |
-
rearrange(labels, "... -> (...)"),
|
798 |
-
).float()
|
799 |
-
else:
|
800 |
-
raise ValueError('MLM labels must not be None')
|
801 |
-
|
802 |
-
return BertForPreTrainingOutput(
|
803 |
-
loss=masked_lm_loss,
|
804 |
-
prediction_logits=prediction_scores,
|
805 |
-
seq_relationship_logits=seq_relationship_score,
|
806 |
-
)
|
|
|
39 |
from .block import Block
|
40 |
from .embedding import BertEmbeddings
|
41 |
from .mha import MHA
|
42 |
+
from .mlp import FusedMLP, Mlp
|
43 |
|
44 |
try:
|
45 |
from flash_attn.ops.fused_dense import FusedDense
|
|
|
81 |
return_residual=return_residual,
|
82 |
use_alibi=True,
|
83 |
window_size=window_size,
|
84 |
+
qk_norm=use_qk_norm
|
|
|
85 |
)
|
86 |
return mixer_cls
|
87 |
|
88 |
|
89 |
def create_mlp_cls(config, layer_idx=None, return_residual=False):
|
90 |
inner_dim = config.intermediate_size
|
91 |
+
fused_mlp = getattr(config, "fused_mlp", False)
|
92 |
+
if fused_mlp:
|
|
|
93 |
assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
|
94 |
"fused_mlp only " "supports approximate gelu"
|
95 |
)
|
96 |
+
if not fused_mlp:
|
|
|
|
|
97 |
approximate = (
|
98 |
"tanh"
|
99 |
if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
|
|
|
105 |
activation=partial(F.gelu, approximate=approximate),
|
106 |
return_residual=return_residual,
|
107 |
)
|
108 |
+
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
if FusedMLP is None:
|
110 |
raise ImportError("fused_dense is not installed")
|
111 |
mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
|
|
|
119 |
checkpoint_lvl=mlp_checkpoint_lvl,
|
120 |
return_residual=return_residual,
|
121 |
)
|
|
|
|
|
122 |
return mlp_cls
|
123 |
|
124 |
|
|
|
152 |
nn.init.normal_(module.weight, std=initializer_range)
|
153 |
if module.bias is not None:
|
154 |
nn.init.zeros_(module.bias)
|
155 |
+
elif isinstance(module, nn.Embedding) and not getattr(module, "skip_init", False):
|
156 |
nn.init.normal_(module.weight, std=initializer_range)
|
157 |
if module.padding_idx is not None:
|
158 |
nn.init.zeros_(module.weight[module.padding_idx])
|
|
|
174 |
@gradient_checkpointing.setter
|
175 |
def gradient_checkpointing(self, value):
|
176 |
self._grad_checkpointing = value
|
177 |
+
for block in self.layers:
|
178 |
+
block.mixer.checkpointing = value
|
179 |
|
180 |
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
|
181 |
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
|
|
187 |
{"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
|
188 |
)
|
189 |
for layer in self.layers:
|
190 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
if subset_mask is not None:
|
192 |
hidden_states = hidden_states[subset_mask]
|
193 |
else:
|
|
|
198 |
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
199 |
if subset_mask is None:
|
200 |
for layer in self.layers:
|
201 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
203 |
else:
|
204 |
for layer in self.layers[:-1]:
|
205 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
if key_padding_mask is not None:
|
207 |
subset_idx = torch.nonzero(
|
208 |
subset_mask[key_padding_mask], as_tuple=False
|
|
|
228 |
"cu_seqlens_k": cu_seqlens,
|
229 |
"max_seqlen_k": max_seqlen_in_batch,
|
230 |
}
|
231 |
+
hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
return hidden_states
|
233 |
|
234 |
|
|
|
351 |
self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
352 |
self.encoder = BertEncoder(config)
|
353 |
self.pooler = BertPooler(config) if add_pooling_layer else None
|
354 |
+
self.task_type_embeddings = nn.Embedding(config.num_tasks, config.hidden_size)
|
355 |
|
356 |
self.emb_pooler = config.emb_pooler
|
357 |
self._name_or_path = config._name_or_path
|
358 |
if self.emb_pooler is not None:
|
359 |
from transformers import AutoTokenizer
|
360 |
|
361 |
+
self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
|
362 |
else:
|
363 |
self.tokenizer = None
|
364 |
|
365 |
+
# We now initialize the task embeddings to 0; We do not use task types during
|
366 |
+
# pretraining. When we start using task types during embedding training,
|
367 |
+
# we want the model to behave exactly as in pretraining (i.e. task types
|
368 |
+
# have no effect).
|
369 |
+
nn.init.zeros_(self.task_type_embeddings.weight)
|
370 |
+
self.task_type_embeddings.skip_init = True
|
371 |
+
# The following code should skip the embeddings layer
|
372 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
373 |
|
374 |
def forward(
|
|
|
376 |
input_ids,
|
377 |
position_ids=None,
|
378 |
token_type_ids=None,
|
379 |
+
task_type_ids=None,
|
380 |
attention_mask=None,
|
381 |
masked_tokens_mask=None,
|
|
|
382 |
):
|
383 |
"""If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
|
384 |
we only want the output for the masked tokens. This means that we only compute the last
|
|
|
388 |
hidden_states = self.embeddings(
|
389 |
input_ids, position_ids=position_ids, token_type_ids=token_type_ids
|
390 |
)
|
391 |
+
if task_type_ids is not None:
|
392 |
+
hidden_states = hidden_states + self.task_type_embeddings(task_type_ids)
|
393 |
|
394 |
# TD [2022-12:18]: Don't need to force residual in fp32
|
395 |
# BERT puts embedding LayerNorm before embedding dropout.
|
|
|
429 |
sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
|
430 |
pooled_output = self.pooler(pool_input, pool=False) if self.pooler is not None else None
|
431 |
|
|
|
|
|
|
|
432 |
return BaseModelOutputWithPoolingAndCrossAttentions(
|
433 |
last_hidden_state=sequence_output,
|
434 |
pooler_output=pooled_output,
|
|
|
484 |
self.emb_pooler = 'mean'
|
485 |
from transformers import AutoTokenizer
|
486 |
|
487 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self._name_or_path)
|
488 |
if self.emb_pooler != 'mean':
|
489 |
raise NotImplementedError
|
490 |
|
|
|
685 |
loss=total_loss,
|
686 |
prediction_logits=prediction_scores,
|
687 |
seq_relationship_logits=seq_relationship_score,
|
688 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modeling_lora.py
CHANGED
@@ -65,8 +65,6 @@ class LoRAParametrization(nn.Module):
|
|
65 |
fan_in_fan_out = layer_type == "embedding"
|
66 |
self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x)
|
67 |
|
68 |
-
# For the officially "correct" LoRA initialization, check here: https://github.com/microsoft/LoRA
|
69 |
-
# TODO: Ensure that the initialization here is correct
|
70 |
if layer_type == "linear":
|
71 |
self.lora_A = nn.Parameter(
|
72 |
initialized_weights((rank, fan_in), num_adaptions, init="kaiming")
|
@@ -196,64 +194,30 @@ class LoRAParametrization(nn.Module):
|
|
196 |
),
|
197 |
)
|
198 |
|
199 |
-
@
|
200 |
-
def select_task_for_layer(layer: nn.Module, task_idx: Optional[int] = None):
|
201 |
if isinstance(layer, LoRAParametrization):
|
202 |
layer.current_task = task_idx
|
203 |
|
204 |
-
@staticmethod
|
205 |
-
def merge_lora_into_layer(layer: nn.Module):
|
206 |
-
if hasattr(layer, "parametrizations"):
|
207 |
-
for attr_name in layer.parametrizations.keys():
|
208 |
-
parametrize.remove_parametrizations(layer, attr_name, leave_parametrized=True)
|
209 |
-
|
210 |
|
211 |
class BertLoRA(BertPreTrainedModel):
|
212 |
-
def __init__(self, config: JinaBertConfig, bert: Optional[BertModel] = None, add_pooling_layer=True):
|
213 |
super().__init__(config)
|
214 |
if bert is None:
|
215 |
self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
|
216 |
else:
|
217 |
self.bert = bert
|
218 |
-
self.
|
219 |
-
self._num_adaptions = config.num_loras
|
220 |
-
self._register_lora(self._num_adaptions)
|
221 |
-
self.main_params_trainable = False
|
222 |
-
self._task_idx = None
|
223 |
-
# By default, we select the first LoRA
|
224 |
-
self.current_task = 0
|
225 |
-
|
226 |
-
@property
|
227 |
-
def main_params_trainable(self):
|
228 |
-
return self._main_params_trainable
|
229 |
-
|
230 |
-
@main_params_trainable.setter
|
231 |
-
def main_params_trainable(self, val: bool):
|
232 |
-
"""Whether the main parameters (i.e. those that are not LoRA) should be trainable.
|
233 |
-
|
234 |
-
This method sets the `requires_grad_` attribute of the main weights
|
235 |
-
and controls which parameters are returned in `self.parameters()`.
|
236 |
-
|
237 |
-
:param val: Whether or not to make the parameters trainable.
|
238 |
-
:return: None
|
239 |
-
"""
|
240 |
-
self._main_params_trainable = val
|
241 |
for name, param in super().named_parameters():
|
242 |
if "lora" not in name:
|
243 |
-
param.requires_grad_(
|
|
|
244 |
|
245 |
@classmethod
|
246 |
-
def from_bert(cls, *args, **kwargs):
|
247 |
bert = BertModel.from_pretrained(*args, **kwargs)
|
248 |
config = JinaBertConfig.from_pretrained(*args, **kwargs)
|
249 |
-
return cls(config, bert=bert)
|
250 |
-
|
251 |
-
def merge_lora(self):
|
252 |
-
"""Merges currently selected LoRA into main weights."""
|
253 |
-
if self._is_merged:
|
254 |
-
raise Exception('LoRA has already been merged, cannot merge again')
|
255 |
-
self._is_merged = True
|
256 |
-
self.apply(LoRAParametrization.merge_lora_into_layer)
|
257 |
|
258 |
@classmethod
|
259 |
def from_pretrained(
|
@@ -270,13 +234,7 @@ class BertLoRA(BertPreTrainedModel):
|
|
270 |
use_safetensors: bool = None,
|
271 |
**kwargs,
|
272 |
):
|
273 |
-
|
274 |
-
TODO: choose between from_bert and super().from_pretrained
|
275 |
-
|
276 |
-
We want to be able to load both a pretrained BertModel, and a trained
|
277 |
-
BertLoRA via this method. To this end, we need to check which of these
|
278 |
-
models we are expected to load.
|
279 |
-
"""
|
280 |
return cls.from_bert(pretrained_model_name_or_path)
|
281 |
|
282 |
def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
|
@@ -292,34 +250,16 @@ class BertLoRA(BertPreTrainedModel):
|
|
292 |
|
293 |
@property
|
294 |
def current_task(self):
|
295 |
-
""" Which LoRA is currently selected
|
296 |
-
:return: Integer or None (when LoRA is disabled)
|
297 |
-
"""
|
298 |
return self._task_idx
|
299 |
|
300 |
@current_task.setter
|
301 |
def current_task(self, task_idx: Union[None, int]):
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
:param task_idx: Which LoRA to use
|
308 |
-
:return:
|
309 |
-
"""
|
310 |
-
if self._is_merged:
|
311 |
-
raise Exception('LoRA has been merged, cannot select new task')
|
312 |
-
assert task_idx is None or 0 <= task_idx < self._num_adaptions
|
313 |
-
if self._task_idx != task_idx:
|
314 |
-
# In this case, we need to update the LoRAs everywhere
|
315 |
-
self._task_idx = task_idx
|
316 |
-
self.apply(
|
317 |
-
partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
|
318 |
-
)
|
319 |
|
320 |
-
def forward(self, *args,
|
321 |
-
if current_task is None or current_task >= 0:
|
322 |
-
self.current_task = current_task
|
323 |
return self.bert(*args, **kwargs)
|
324 |
|
325 |
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
@@ -332,5 +272,5 @@ class BertLoRA(BertPreTrainedModel):
|
|
332 |
for name, param in super().named_parameters(
|
333 |
prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate
|
334 |
):
|
335 |
-
if "lora" in name
|
336 |
yield name, param
|
|
|
65 |
fan_in_fan_out = layer_type == "embedding"
|
66 |
self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x)
|
67 |
|
|
|
|
|
68 |
if layer_type == "linear":
|
69 |
self.lora_A = nn.Parameter(
|
70 |
initialized_weights((rank, fan_in), num_adaptions, init="kaiming")
|
|
|
194 |
),
|
195 |
)
|
196 |
|
197 |
+
@classmethod
|
198 |
+
def select_task_for_layer(cls, layer: nn.Module, task_idx: Optional[int] = None):
|
199 |
if isinstance(layer, LoRAParametrization):
|
200 |
layer.current_task = task_idx
|
201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
|
203 |
class BertLoRA(BertPreTrainedModel):
|
204 |
+
def __init__(self, config: JinaBertConfig, bert: Optional[BertModel] = None, add_pooling_layer=True, num_adaptions=1):
|
205 |
super().__init__(config)
|
206 |
if bert is None:
|
207 |
self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
|
208 |
else:
|
209 |
self.bert = bert
|
210 |
+
self._register_lora(num_adaptions)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
for name, param in super().named_parameters():
|
212 |
if "lora" not in name:
|
213 |
+
param.requires_grad_(False)
|
214 |
+
self.current_task = 0
|
215 |
|
216 |
@classmethod
|
217 |
+
def from_bert(cls, *args, num_adaptions=1, **kwargs):
|
218 |
bert = BertModel.from_pretrained(*args, **kwargs)
|
219 |
config = JinaBertConfig.from_pretrained(*args, **kwargs)
|
220 |
+
return cls(config, bert=bert, num_adaptions=num_adaptions)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
|
222 |
@classmethod
|
223 |
def from_pretrained(
|
|
|
234 |
use_safetensors: bool = None,
|
235 |
**kwargs,
|
236 |
):
|
237 |
+
# TODO: choose between from_bert and super().from_pretrained
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
return cls.from_bert(pretrained_model_name_or_path)
|
239 |
|
240 |
def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
|
|
|
250 |
|
251 |
@property
|
252 |
def current_task(self):
|
|
|
|
|
|
|
253 |
return self._task_idx
|
254 |
|
255 |
@current_task.setter
|
256 |
def current_task(self, task_idx: Union[None, int]):
|
257 |
+
self._task_idx = task_idx
|
258 |
+
self.apply(
|
259 |
+
partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
|
260 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
|
262 |
+
def forward(self, *args, **kwargs):
|
|
|
|
|
263 |
return self.bert(*args, **kwargs)
|
264 |
|
265 |
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
|
|
272 |
for name, param in super().named_parameters(
|
273 |
prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate
|
274 |
):
|
275 |
+
if "lora" in name:
|
276 |
yield name, param
|
tokenizer.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from transformers import RobertaTokenizer, BatchEncoding, RobertaTokenizerFast
|
4 |
+
import warnings
|
5 |
+
|
6 |
+
|
7 |
+
def get_tokenizer(parent_class):
|
8 |
+
class TokenizerClass(parent_class):
|
9 |
+
def __init__(self, *args, **kwargs):
|
10 |
+
"""
|
11 |
+
This class dynamically extends a given tokenizer class from the HF
|
12 |
+
Transformers library (RobertaTokenizer or RobertaTokenizerFast).
|
13 |
+
The task_type_ids are used to pass instruction information to the model.
|
14 |
+
A task_type should either be an integer or a sequence of integers with the same
|
15 |
+
length as the batch size.
|
16 |
+
"""
|
17 |
+
super().__init__(*args, **kwargs)
|
18 |
+
|
19 |
+
def __call__(self, *args, task_type=None, **kwargs):
|
20 |
+
batch_encoding = super().__call__(*args, **kwargs)
|
21 |
+
if task_type is not None:
|
22 |
+
batch_encoding = self._add_task_type_ids(batch_encoding, task_type, kwargs.get('return_tensors'))
|
23 |
+
return batch_encoding
|
24 |
+
|
25 |
+
def _batch_encode_plus(self, *args, task_type=None, **kwargs):
|
26 |
+
batch_encoding = super()._batch_encode_plus(*args, **kwargs)
|
27 |
+
if task_type is not None:
|
28 |
+
batch_encoding = self._add_task_type_ids(batch_encoding, task_type, kwargs.get('return_tensors'))
|
29 |
+
return batch_encoding
|
30 |
+
|
31 |
+
def _encode_plus(self, *args, task_type=None, **kwargs):
|
32 |
+
batch_encoding = super()._encode_plus(*args, **kwargs)
|
33 |
+
if task_type is not None:
|
34 |
+
batch_encoding = self._add_task_type_ids(batch_encoding, task_type, kwargs.get('return_tensors'))
|
35 |
+
return batch_encoding
|
36 |
+
|
37 |
+
@classmethod
|
38 |
+
def _add_task_type_ids(cls, batch_encoding, task_type, tensor_type):
|
39 |
+
return BatchEncoding(
|
40 |
+
{
|
41 |
+
'task_type_ids': cls._get_task_type_ids(batch_encoding, task_type),
|
42 |
+
**batch_encoding,
|
43 |
+
},
|
44 |
+
tensor_type=tensor_type,
|
45 |
+
)
|
46 |
+
|
47 |
+
@staticmethod
|
48 |
+
def _get_task_type_ids(batch_encoding: BatchEncoding, task_type):
|
49 |
+
|
50 |
+
def apply_task_type(m, x):
|
51 |
+
x = torch.tensor(x)
|
52 |
+
assert (
|
53 |
+
len(x.shape) == 0 or x.shape[0] == m.shape[0]
|
54 |
+
), 'The shape of task_type does not match the size of the batch.'
|
55 |
+
return m * x if len(x.shape) == 0 else m * x[:, None]
|
56 |
+
|
57 |
+
if isinstance(batch_encoding['input_ids'], torch.Tensor):
|
58 |
+
shape = batch_encoding['input_ids'].shape
|
59 |
+
return apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
|
60 |
+
else:
|
61 |
+
try:
|
62 |
+
shape = torch.tensor(batch_encoding['input_ids']).shape
|
63 |
+
except:
|
64 |
+
raise ValueError(
|
65 |
+
"Unable to create tensor, you should probably "
|
66 |
+
"activate truncation and/or padding with "
|
67 |
+
"'padding=True' 'truncation=True' to have batched "
|
68 |
+
"tensors with the same length."
|
69 |
+
)
|
70 |
+
if isinstance(batch_encoding['input_ids'], list):
|
71 |
+
return (
|
72 |
+
apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
|
73 |
+
).tolist()
|
74 |
+
elif isinstance(batch_encoding['input_ids'], np.array):
|
75 |
+
return (
|
76 |
+
apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
|
77 |
+
).numpy()
|
78 |
+
else:
|
79 |
+
warnings.warn(
|
80 |
+
'input_ids is not a torch tensor, numpy array, or list. Returning torch tensor'
|
81 |
+
)
|
82 |
+
return apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
|
83 |
+
|
84 |
+
return TokenizerClass
|
85 |
+
|
86 |
+
|
87 |
+
JinaTokenizer = get_tokenizer(RobertaTokenizer)
|
88 |
+
JinaTokenizerFast = get_tokenizer(RobertaTokenizerFast)
|