shivanandmn commited on
Commit
d036c0e
·
verified ·
1 Parent(s): a213916

Model save

Browse files
Files changed (3) hide show
  1. README.md +79 -0
  2. generation_config.json +7 -0
  3. modeling_rotating_head_gpt2.py +1131 -0
README.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - generated_from_trainer
5
+ metrics:
6
+ - accuracy
7
+ - bleu
8
+ model-index:
9
+ - name: rotating-head-gp-gpt2-medium-wikitext
10
+ results: []
11
+ ---
12
+
13
+ <!-- This model card has been generated automatically according to the information the Trainer had access to. You
14
+ should probably proofread and complete it, then remove this comment. -->
15
+
16
+ # rotating-head-gp-gpt2-medium-wikitext
17
+
18
+ This model is a fine-tuned version of [](https://huggingface.co/) on an unknown dataset.
19
+ It achieves the following results on the evaluation set:
20
+ - Loss: 3.2064
21
+ - Accuracy: 0.4186
22
+ - Perplexity: 24.6904
23
+ - Bleu: 0.1335
24
+
25
+ ## Model description
26
+
27
+ More information needed
28
+
29
+ ## Intended uses & limitations
30
+
31
+ More information needed
32
+
33
+ ## Training and evaluation data
34
+
35
+ More information needed
36
+
37
+ ## Training procedure
38
+
39
+ ### Training hyperparameters
40
+
41
+ The following hyperparameters were used during training:
42
+ - learning_rate: 0.0001
43
+ - train_batch_size: 64
44
+ - eval_batch_size: 64
45
+ - seed: 42
46
+ - optimizer: Use OptimizerNames.ADAMW_TORCH with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
47
+ - lr_scheduler_type: linear
48
+ - lr_scheduler_warmup_ratio: 0.1
49
+ - num_epochs: 5
50
+
51
+ ### Training results
52
+
53
+ | Training Loss | Epoch | Step | Validation Loss | Accuracy | Perplexity | Bleu |
54
+ |:-------------:|:------:|:----:|:---------------:|:--------:|:----------:|:------:|
55
+ | 5.9051 | 0.2806 | 500 | 5.7418 | 0.2238 | 311.6215 | 0.0481 |
56
+ | 4.8617 | 0.5612 | 1000 | 4.7405 | 0.2812 | 114.4922 | 0.0721 |
57
+ | 4.2992 | 0.8418 | 1500 | 4.2277 | 0.3179 | 68.5611 | 0.0832 |
58
+ | 3.9585 | 1.1223 | 2000 | 3.9252 | 0.3466 | 50.6640 | 0.0905 |
59
+ | 3.7838 | 1.4029 | 2500 | 3.7564 | 0.3626 | 42.7959 | 0.0986 |
60
+ | 3.6863 | 1.6835 | 3000 | 3.6388 | 0.3739 | 38.0459 | 0.1069 |
61
+ | 3.5869 | 1.9641 | 3500 | 3.5518 | 0.3826 | 34.8757 | 0.1100 |
62
+ | 3.4733 | 2.2447 | 4000 | 3.4846 | 0.3886 | 32.6092 | 0.1159 |
63
+ | 3.4122 | 2.5253 | 4500 | 3.4307 | 0.3941 | 30.8979 | 0.1212 |
64
+ | 3.3791 | 2.8058 | 5000 | 3.3804 | 0.3991 | 29.3811 | 0.1223 |
65
+ | 3.2616 | 3.0864 | 5500 | 3.3447 | 0.4026 | 28.3518 | 0.1222 |
66
+ | 3.2499 | 3.3670 | 6000 | 3.3096 | 0.4067 | 27.3740 | 0.1261 |
67
+ | 3.2277 | 3.6476 | 6500 | 3.2812 | 0.4100 | 26.6073 | 0.1299 |
68
+ | 3.1992 | 3.9282 | 7000 | 3.2523 | 0.4128 | 25.8505 | 0.1305 |
69
+ | 3.13 | 4.2088 | 7500 | 3.2332 | 0.4154 | 25.3608 | 0.1326 |
70
+ | 3.0915 | 4.4893 | 8000 | 3.2200 | 0.4168 | 25.0291 | 0.1317 |
71
+ | 3.1011 | 4.7699 | 8500 | 3.2064 | 0.4186 | 24.6904 | 0.1335 |
72
+
73
+
74
+ ### Framework versions
75
+
76
+ - Transformers 4.49.0
77
+ - Pytorch 2.6.0+cu124
78
+ - Datasets 3.3.2
79
+ - Tokenizers 0.21.0
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 50256,
4
+ "eos_token_id": 50256,
5
+ "transformers_version": "4.49.0",
6
+ "use_cache": false
7
+ }
modeling_rotating_head_gpt2.py ADDED
@@ -0,0 +1,1131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """PyTorch OpenAI GPT-2 model, code copied from Huggingface"""
3
+
4
+ import math
5
+ import os
6
+ import warnings
7
+ from dataclasses import dataclass
8
+ from typing import Callable, Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.utils.checkpoint
12
+ from torch import nn
13
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
14
+
15
+ from transformers.activations import ACT2FN
16
+ from transformers.generation import GenerationMixin
17
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
18
+ from transformers.modeling_outputs import (
19
+ BaseModelOutputWithPastAndCrossAttentions,
20
+ CausalLMOutputWithCrossAttentions,
21
+ QuestionAnsweringModelOutput,
22
+ SequenceClassifierOutputWithPast,
23
+ TokenClassifierOutput,
24
+ )
25
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel, SequenceSummary
26
+ from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
27
+ from transformers.utils import (
28
+ ModelOutput,
29
+ add_code_sample_docstrings,
30
+ add_start_docstrings,
31
+ add_start_docstrings_to_model_forward,
32
+ logging,
33
+ replace_return_docstrings,
34
+ )
35
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
36
+ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
37
+ from src.models.modeling_gpt2 import GPT2PreTrainedModel, GPT2Block
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+ class RotatingHeadGPT2Config(GPT2Config):
42
+ model_type = "rotating-head-gpt2"
43
+ architectures = ["RotatingHeadGPT2LMHeadModel"]
44
+
45
+ class RotatingHeadGPT2PretrainedModel(GPT2PreTrainedModel):
46
+ config_class = RotatingHeadGPT2Config
47
+
48
+
49
+ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
50
+ """Load tf checkpoints in a pytorch model"""
51
+ try:
52
+ import re
53
+
54
+ import tensorflow as tf
55
+ except ImportError:
56
+ logger.error(
57
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
58
+ "https://www.tensorflow.org/install/ for installation instructions."
59
+ )
60
+ raise
61
+ tf_path = os.path.abspath(gpt2_checkpoint_path)
62
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
63
+ # Load weights from TF model
64
+ init_vars = tf.train.list_variables(tf_path)
65
+ names = []
66
+ arrays = []
67
+ for name, shape in init_vars:
68
+ logger.info(f"Loading TF weight {name} with shape {shape}")
69
+ array = tf.train.load_variable(tf_path, name)
70
+ names.append(name)
71
+ arrays.append(array.squeeze())
72
+
73
+ for name, array in zip(names, arrays):
74
+ name = name[6:] # skip "model/"
75
+ name = name.split("/")
76
+ pointer = model
77
+ for m_name in name:
78
+ if re.fullmatch(r"[A-Za-z]+\d+", m_name):
79
+ scope_names = re.split(r"(\d+)", m_name)
80
+ else:
81
+ scope_names = [m_name]
82
+ if scope_names[0] == "w" or scope_names[0] == "g":
83
+ pointer = getattr(pointer, "weight")
84
+ elif scope_names[0] == "b":
85
+ pointer = getattr(pointer, "bias")
86
+ elif scope_names[0] == "wpe" or scope_names[0] == "wte":
87
+ pointer = getattr(pointer, scope_names[0])
88
+ pointer = getattr(pointer, "weight")
89
+ else:
90
+ pointer = getattr(pointer, scope_names[0])
91
+ if len(scope_names) >= 2:
92
+ num = int(scope_names[1])
93
+ pointer = pointer[num]
94
+ try:
95
+ if pointer.shape != array.shape:
96
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
97
+ except ValueError as e:
98
+ e.args += (pointer.shape, array.shape)
99
+ raise
100
+ logger.info(f"Initialize PyTorch weight {name}")
101
+ pointer.data = torch.from_numpy(array)
102
+ return model
103
+
104
+
105
+ def eager_attention_forward(module, query, key, value, attention_mask, head_mask=None, **kwargs):
106
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
107
+
108
+ if module.scale_attn_weights:
109
+ attn_weights = attn_weights / torch.full(
110
+ [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
111
+ )
112
+
113
+ # Layer-wise attention scaling
114
+ if module.scale_attn_by_inverse_layer_idx:
115
+ attn_weights = attn_weights / float(module.layer_idx + 1)
116
+
117
+ if not module.is_cross_attention:
118
+ # if only "normal" attention layer implements causal mask
119
+ query_length, key_length = query.size(-2), key.size(-2)
120
+ causal_mask = module.bias[:, :, key_length - query_length : key_length, :key_length]
121
+ mask_value = torch.finfo(attn_weights.dtype).min
122
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
123
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
124
+ mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
125
+ attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
126
+
127
+ if attention_mask is not None:
128
+ # Apply the attention mask
129
+ attn_weights = attn_weights + attention_mask
130
+
131
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
132
+
133
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
134
+ attn_weights = attn_weights.type(value.dtype)
135
+ attn_weights = module.attn_dropout(attn_weights)
136
+
137
+ # Mask heads if we want to
138
+ if head_mask is not None:
139
+ attn_weights = attn_weights * head_mask
140
+
141
+ attn_output = torch.matmul(attn_weights, value)
142
+ attn_output = attn_output.transpose(1, 2)
143
+
144
+ return attn_output, attn_weights
145
+
146
+
147
+ class HeadSpecificLRRoPE(nn.Module):
148
+ def __init__(self, num_heads, head_dim):
149
+ super().__init__()
150
+ self.num_heads = num_heads
151
+ self.head_dim = head_dim
152
+
153
+ # Initialize head-specific frequencies (learnable)
154
+ self.frequencies = nn.Parameter(torch.randn(num_heads, head_dim // 2))
155
+
156
+ def forward(self, Q, K):
157
+ bs, heads, seq, embed = Q.size()
158
+ # Q = torch.einsum('bse,hed->bhse', X, W_Q) # [batch, heads, seq, embed]
159
+ # K = torch.einsum('bse,hed->bhse', X, W_K)
160
+
161
+ positions = torch.arange(seq, device=Q.device).unsqueeze(1) # [seq_length, 1]
162
+
163
+ cos_theta = torch.cos(positions * self.frequencies.unsqueeze(1))
164
+ sin_theta = torch.sin(positions * self.frequencies.unsqueeze(1))
165
+
166
+ Q_even, Q_odd = Q[..., ::2], Q[..., 1::2]
167
+ K_even, K_odd = K[..., ::2], K[..., 1::2]
168
+
169
+ Q_rotated = torch.stack([Q_even * cos_theta - Q_odd * sin_theta,
170
+ Q_even * sin_theta + Q_odd * cos_theta], dim=-1).reshape_as(Q)
171
+ K_rotated = torch.stack([K_even * cos_theta - K_odd * sin_theta,
172
+ K_even * sin_theta + K_odd * cos_theta], dim=-1).reshape_as(K)
173
+
174
+ return Q_rotated, K_rotated
175
+
176
+ class HeadSpecificGPRoPE(nn.Module):
177
+ def __init__(self, num_heads, head_dim, base_frequency=10000):
178
+ super().__init__()
179
+ self.num_heads = num_heads
180
+ self.head_dim = head_dim
181
+
182
+ # Geometric frequency progression (fixed)
183
+ frequency_base = base_frequency ** (-torch.arange(0, head_dim, 2).float() / head_dim)
184
+ scales = torch.logspace(0, -1, steps=num_heads, base=10.0).unsqueeze(1) # [num_heads, 1]
185
+ self.frequencies = (scales @ frequency_base.unsqueeze(0)) # [num_heads, dim//2]
186
+
187
+ def forward(self, Q, K):
188
+ bs, heads, seq, embed = Q.size()
189
+ # Q = torch.einsum('bse,hed->bhse', X, W_Q) # [batch, heads, seq, embed]
190
+ # K = torch.einsum('bse,hed->bhse', X, W_K)
191
+
192
+ positions = torch.arange(seq, device=Q.device).unsqueeze(1).unsqueeze(0) # [1, seq_length, 1]
193
+
194
+ cos_theta = torch.cos(positions * self.frequencies.unsqueeze(1))
195
+ sin_theta = torch.sin(positions * self.frequencies.unsqueeze(1))
196
+
197
+ Q_even, Q_odd = Q[..., ::2], Q[..., 1::2]
198
+ K_even, K_odd = K[..., ::2], K[..., 1::2]
199
+
200
+ Q_rotated = torch.stack([Q_even * cos_theta - Q_odd * sin_theta,
201
+ Q_even * sin_theta + Q_odd * cos_theta], dim=-1).reshape_as(Q)
202
+ K_rotated = torch.stack([K_even * cos_theta - K_odd * sin_theta,
203
+ K_even * sin_theta + K_odd * cos_theta], dim=-1).reshape_as(K)
204
+
205
+ return Q_rotated, K_rotated
206
+
207
+
208
+
209
+ class GPT2Attention(nn.Module):
210
+ def __init__(self, config, is_cross_attention=False, layer_idx=None):
211
+ super().__init__()
212
+ self.config = config
213
+ max_positions = config.max_position_embeddings
214
+ self.register_buffer(
215
+ "bias",
216
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
217
+ 1, 1, max_positions, max_positions
218
+ ),
219
+ persistent=False,
220
+ )
221
+ self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
222
+
223
+ self.embed_dim = config.hidden_size
224
+ self.num_heads = config.num_attention_heads
225
+ self.head_dim = self.embed_dim // self.num_heads
226
+ self.split_size = self.embed_dim
227
+ if self.head_dim * self.num_heads != self.embed_dim:
228
+ raise ValueError(
229
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
230
+ f" {self.num_heads})."
231
+ )
232
+
233
+ self.scale_attn_weights = config.scale_attn_weights
234
+ self.is_cross_attention = is_cross_attention
235
+
236
+ # Layer-wise attention scaling, reordering, and upcasting
237
+ self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
238
+ self.layer_idx = layer_idx
239
+ self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
240
+
241
+ if self.is_cross_attention:
242
+ self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
243
+ self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
244
+ else:
245
+ self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
246
+ self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
247
+
248
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
249
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
250
+ self.is_causal = True
251
+
252
+ self.pruned_heads = set()
253
+
254
+ def prune_heads(self, heads):
255
+ if len(heads) == 0:
256
+ return
257
+ heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
258
+ index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
259
+
260
+ # Prune conv1d layers
261
+ self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
262
+ self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
263
+
264
+ # Update hyper params
265
+ self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
266
+ self.num_heads = self.num_heads - len(heads)
267
+ self.pruned_heads = self.pruned_heads.union(heads)
268
+
269
+ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
270
+ # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
271
+ bsz, num_heads, q_seq_len, dk = query.size()
272
+ _, _, k_seq_len, _ = key.size()
273
+
274
+ # Preallocate attn_weights for `baddbmm`
275
+ attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
276
+
277
+ # Compute Scale Factor
278
+ scale_factor = 1.0
279
+ if self.scale_attn_weights:
280
+ scale_factor /= float(value.size(-1)) ** 0.5
281
+
282
+ if self.scale_attn_by_inverse_layer_idx:
283
+ scale_factor /= float(self.layer_idx + 1)
284
+
285
+ # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
286
+ with torch.amp.autocast(query.device.type, enabled=False):
287
+ q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
288
+ attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
289
+ attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
290
+
291
+ if not self.is_cross_attention:
292
+ # if only "normal" attention layer implements causal mask
293
+ query_length, key_length = query.size(-2), key.size(-2)
294
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
295
+ mask_value = torch.finfo(attn_weights.dtype).min
296
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
297
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
298
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
299
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
300
+
301
+ if attention_mask is not None:
302
+ # Apply the attention mask
303
+ attn_weights = attn_weights + attention_mask
304
+
305
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
306
+
307
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
308
+ if attn_weights.dtype != torch.float32:
309
+ raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
310
+ attn_weights = attn_weights.type(value.dtype)
311
+ attn_weights = self.attn_dropout(attn_weights)
312
+
313
+ # Mask heads if we want to
314
+ if head_mask is not None:
315
+ attn_weights = attn_weights * head_mask
316
+
317
+ attn_output = torch.matmul(attn_weights, value)
318
+ attn_output = attn_output.transpose(1, 2)
319
+
320
+ return attn_output, attn_weights
321
+
322
+ def forward(
323
+ self,
324
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
325
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
326
+ attention_mask: Optional[torch.FloatTensor] = None,
327
+ head_mask: Optional[torch.FloatTensor] = None,
328
+ encoder_hidden_states: Optional[torch.Tensor] = None,
329
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
330
+ use_cache: Optional[bool] = False,
331
+ output_attentions: Optional[bool] = False,
332
+ **kwargs,
333
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
334
+ if encoder_hidden_states is not None:
335
+ if not hasattr(self, "q_attn"):
336
+ raise ValueError(
337
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
338
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
339
+ )
340
+
341
+ query_states = self.q_attn(hidden_states)
342
+ key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
343
+ attention_mask = encoder_attention_mask
344
+ else:
345
+ query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
346
+
347
+ shape_q = (*query_states.shape[:-1], -1, self.head_dim)
348
+ shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
349
+
350
+ query_states = query_states.view(shape_q).transpose(1, 2)
351
+ key_states = key_states.view(shape_kv).transpose(1, 2)
352
+ value_states = value_states.view(shape_kv).transpose(1, 2)
353
+
354
+ if layer_past is not None:
355
+ past_key, past_value = layer_past
356
+ key_states = torch.cat((past_key, key_states), dim=-2)
357
+ value_states = torch.cat((past_value, value_states), dim=-2)
358
+
359
+ if use_cache is True:
360
+ present = (key_states, value_states)
361
+ else:
362
+ present = None
363
+
364
+ is_cross_attention = encoder_hidden_states is not None
365
+ is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
366
+
367
+ using_eager = self.config._attn_implementation == "eager"
368
+ attention_interface: Callable = eager_attention_forward
369
+ if self.config._attn_implementation != "eager":
370
+ if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None):
371
+ using_eager = True
372
+ logger.warning_once(
373
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
374
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
375
+ )
376
+ else:
377
+ # Attention functions are consistent with previous equivalent attention classes, however they do not support some options
378
+ # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but
379
+ # not necessarily to eager (if mentionned options are provided).
380
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
381
+
382
+ if using_eager and self.reorder_and_upcast_attn:
383
+ attn_output, attn_weights = self._upcast_and_reordered_attn(
384
+ query_states, key_states, value_states, attention_mask, head_mask
385
+ )
386
+ else:
387
+ attn_output, attn_weights = attention_interface(
388
+ self,
389
+ query_states,
390
+ key_states,
391
+ value_states,
392
+ attention_mask,
393
+ head_mask=head_mask,
394
+ dropout=self.attn_dropout.p if self.training else 0.0,
395
+ is_causal=is_causal,
396
+ **kwargs,
397
+ )
398
+
399
+ attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
400
+ attn_output = self.c_proj(attn_output)
401
+ attn_output = self.resid_dropout(attn_output)
402
+
403
+ outputs = (attn_output, present)
404
+ if output_attentions:
405
+ outputs += (attn_weights,)
406
+
407
+ return outputs # a, present, (attentions)
408
+
409
+ class RotatingheadGPT2Attention(GPT2Attention):
410
+ def __init__(self, config, is_cross_attention=False, layer_idx=None):
411
+ super().__init__(config, is_cross_attention, layer_idx)
412
+ if config.rotatinghead == 'lr':
413
+ self.rope = HeadSpecificLRRoPE(config.num_attention_heads, self.head_dim)
414
+ elif config.rotatinghead == 'gp':
415
+ self.rope = HeadSpecificGPRoPE(config.num_attention_heads, self.head_dim)
416
+
417
+ self.rotatinghead = config.rotatinghead
418
+
419
+ def forward(
420
+ self,
421
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
422
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
423
+ attention_mask: Optional[torch.FloatTensor] = None,
424
+ head_mask: Optional[torch.FloatTensor] = None,
425
+ encoder_hidden_states: Optional[torch.Tensor] = None,
426
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
427
+ use_cache: Optional[bool] = False,
428
+ output_attentions: Optional[bool] = False,
429
+ **kwargs,
430
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
431
+ if encoder_hidden_states is not None:
432
+ if not hasattr(self, "q_attn"):
433
+ raise ValueError(
434
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
435
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
436
+ )
437
+
438
+ query_states = self.q_attn(hidden_states)
439
+ key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
440
+ attention_mask = encoder_attention_mask
441
+ else:
442
+ query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
443
+
444
+ shape_q = (*query_states.shape[:-1], -1, self.head_dim)
445
+ shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
446
+
447
+ query_states = query_states.view(shape_q).transpose(1, 2)
448
+ key_states = key_states.view(shape_kv).transpose(1, 2)
449
+ value_states = value_states.view(shape_kv).transpose(1, 2)
450
+
451
+ if layer_past is not None:
452
+ past_key, past_value = layer_past
453
+ key_states = torch.cat((past_key, key_states), dim=-2)
454
+ value_states = torch.cat((past_value, value_states), dim=-2)
455
+
456
+ if use_cache is True:
457
+ present = (key_states, value_states)
458
+ else:
459
+ present = None
460
+
461
+ is_cross_attention = encoder_hidden_states is not None
462
+ is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
463
+
464
+ using_eager = self.config._attn_implementation == "eager"
465
+ attention_interface: Callable = eager_attention_forward
466
+ if self.config._attn_implementation != "eager":
467
+ if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None):
468
+ using_eager = True
469
+ logger.warning_once(
470
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
471
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
472
+ )
473
+ else:
474
+ # Attention functions are consistent with previous equivalent attention classes, however they do not support some options
475
+ # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but
476
+ # not necessarily to eager (if mentionned options are provided).
477
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
478
+
479
+ query_states, key_states = self.rope(query_states, key_states)
480
+
481
+ if using_eager and self.reorder_and_upcast_attn:
482
+ attn_output, attn_weights = self._upcast_and_reordered_attn(
483
+ query_states, key_states, value_states, attention_mask, head_mask
484
+ )
485
+ else:
486
+ attn_output, attn_weights = attention_interface(
487
+ self,
488
+ query_states,
489
+ key_states,
490
+ value_states,
491
+ attention_mask,
492
+ head_mask=head_mask,
493
+ dropout=self.attn_dropout.p if self.training else 0.0,
494
+ is_causal=is_causal,
495
+ **kwargs,
496
+ )
497
+
498
+ attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
499
+ attn_output = self.c_proj(attn_output)
500
+ attn_output = self.resid_dropout(attn_output)
501
+
502
+ outputs = (attn_output, present)
503
+ if output_attentions:
504
+ outputs += (attn_weights,)
505
+
506
+ return outputs # a, present, (attentions)
507
+
508
+
509
+ class GPT2MLP(nn.Module):
510
+ def __init__(self, intermediate_size, config):
511
+ super().__init__()
512
+ embed_dim = config.hidden_size
513
+ self.c_fc = Conv1D(intermediate_size, embed_dim)
514
+ self.c_proj = Conv1D(embed_dim, intermediate_size)
515
+ self.act = ACT2FN[config.activation_function]
516
+ self.dropout = nn.Dropout(config.resid_pdrop)
517
+
518
+ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
519
+ hidden_states = self.c_fc(hidden_states)
520
+ hidden_states = self.act(hidden_states)
521
+ hidden_states = self.c_proj(hidden_states)
522
+ hidden_states = self.dropout(hidden_states)
523
+ return hidden_states
524
+
525
+
526
+ class GPT2Block(nn.Module):
527
+ def __init__(self, config, layer_idx=None):
528
+ super().__init__()
529
+ hidden_size = config.hidden_size
530
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
531
+
532
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
533
+ if config.rotatinghead is not None:
534
+ self.attn = RotatingheadGPT2Attention(config, layer_idx=layer_idx)
535
+ else:
536
+ self.attn = GPT2Attention(config=config, layer_idx=layer_idx)
537
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
538
+
539
+ if config.add_cross_attention:
540
+ self.crossattention = GPT2Attention(config=config, is_cross_attention=True, layer_idx=layer_idx)
541
+ self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
542
+
543
+ self.mlp = GPT2MLP(inner_dim, config)
544
+
545
+ def forward(
546
+ self,
547
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
548
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
549
+ attention_mask: Optional[torch.FloatTensor] = None,
550
+ head_mask: Optional[torch.FloatTensor] = None,
551
+ encoder_hidden_states: Optional[torch.Tensor] = None,
552
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
553
+ use_cache: Optional[bool] = False,
554
+ output_attentions: Optional[bool] = False,
555
+ ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
556
+ residual = hidden_states
557
+ hidden_states = self.ln_1(hidden_states)
558
+ attn_outputs = self.attn(
559
+ hidden_states,
560
+ layer_past=layer_past,
561
+ attention_mask=attention_mask,
562
+ head_mask=head_mask,
563
+ use_cache=use_cache,
564
+ output_attentions=output_attentions,
565
+ )
566
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
567
+ outputs = attn_outputs[1:]
568
+ # residual connection
569
+ hidden_states = attn_output + residual
570
+
571
+ if encoder_hidden_states is not None:
572
+ # add one self-attention block for cross-attention
573
+ if not hasattr(self, "crossattention"):
574
+ raise ValueError(
575
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
576
+ "cross-attention layers by setting `config.add_cross_attention=True`"
577
+ )
578
+ residual = hidden_states
579
+ hidden_states = self.ln_cross_attn(hidden_states)
580
+ cross_attn_outputs = self.crossattention(
581
+ hidden_states,
582
+ attention_mask=attention_mask,
583
+ head_mask=head_mask,
584
+ encoder_hidden_states=encoder_hidden_states,
585
+ encoder_attention_mask=encoder_attention_mask,
586
+ output_attentions=output_attentions,
587
+ )
588
+ attn_output = cross_attn_outputs[0]
589
+ # residual connection
590
+ hidden_states = residual + attn_output
591
+ outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
592
+
593
+ residual = hidden_states
594
+ hidden_states = self.ln_2(hidden_states)
595
+ feed_forward_hidden_states = self.mlp(hidden_states)
596
+ # residual connection
597
+ hidden_states = residual + feed_forward_hidden_states
598
+
599
+ if use_cache:
600
+ outputs = (hidden_states,) + outputs
601
+ else:
602
+ outputs = (hidden_states,) + outputs[1:]
603
+
604
+ return outputs # hidden_states, present, (attentions, cross_attentions)
605
+
606
+
607
+ class RotatingHeadGPT2PretrainedModel(PreTrainedModel):
608
+ """
609
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
610
+ models.
611
+ """
612
+
613
+ config_class = RotatingHeadGPT2Config
614
+ load_tf_weights = load_tf_weights_in_gpt2
615
+ base_model_prefix = "transformer"
616
+ is_parallelizable = True
617
+ supports_gradient_checkpointing = True
618
+ _no_split_modules = ["GPT2Block"]
619
+ _skip_keys_device_placement = "past_key_values"
620
+ _supports_flash_attn_2 = True
621
+ _supports_sdpa = True
622
+
623
+ def __init__(self, *inputs, **kwargs):
624
+ super().__init__(*inputs, **kwargs)
625
+
626
+ def _init_weights(self, module):
627
+ """Initialize the weights."""
628
+ if isinstance(module, (nn.Linear, Conv1D)):
629
+ # Slightly different from the TF version which uses truncated_normal for initialization
630
+ # cf https://github.com/pytorch/pytorch/pull/5617
631
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
632
+ if module.bias is not None:
633
+ module.bias.data.zero_()
634
+ elif isinstance(module, nn.Embedding):
635
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
636
+ if module.padding_idx is not None:
637
+ module.weight.data[module.padding_idx].zero_()
638
+ elif isinstance(module, nn.LayerNorm):
639
+ module.bias.data.zero_()
640
+ module.weight.data.fill_(1.0)
641
+
642
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
643
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
644
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
645
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
646
+ #
647
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
648
+ for name, p in module.named_parameters():
649
+ if name == "c_proj.weight":
650
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
651
+ p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
652
+
653
+
654
+ @dataclass
655
+ class GPT2DoubleHeadsModelOutput(ModelOutput):
656
+ """
657
+ Base class for outputs of models predicting if two sentences are consecutive or not.
658
+
659
+ Args:
660
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
661
+ Language modeling loss.
662
+ mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided):
663
+ Multiple choice classification loss.
664
+ logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
665
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
666
+ mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
667
+ Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
668
+ past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
669
+ Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads,
670
+ sequence_length, embed_size_per_head)`).
671
+
672
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
673
+ `past_key_values` input) to speed up sequential decoding.
674
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
675
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
676
+ shape `(batch_size, sequence_length, hidden_size)`.
677
+
678
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
679
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
680
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
681
+ sequence_length)`.
682
+
683
+ GPT2Attentions weights after the attention softmax, used to compute the weighted average in the
684
+ self-attention heads.
685
+ """
686
+
687
+ loss: Optional[torch.FloatTensor] = None
688
+ mc_loss: Optional[torch.FloatTensor] = None
689
+ logits: torch.FloatTensor = None
690
+ mc_logits: torch.FloatTensor = None
691
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
692
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
693
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
694
+
695
+
696
+
697
+ class RotatingHeadGPT2Model(RotatingHeadGPT2PretrainedModel):
698
+ _supports_param_buffer_assignment = False
699
+
700
+ def __init__(self, config):
701
+ super().__init__(config)
702
+
703
+ self.embed_dim = config.hidden_size
704
+
705
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
706
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
707
+
708
+ self.drop = nn.Dropout(config.embd_pdrop)
709
+ self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
710
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
711
+
712
+ # Model parallel
713
+ self.model_parallel = False
714
+ self.device_map = None
715
+ self.gradient_checkpointing = False
716
+ self._attn_implementation = config._attn_implementation
717
+
718
+ # Initialize weights and apply final processing
719
+ self.post_init()
720
+
721
+ def parallelize(self, device_map=None):
722
+ # Check validity of device_map
723
+ warnings.warn(
724
+ "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your"
725
+ " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
726
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1,"
727
+ " ...}",
728
+ FutureWarning,
729
+ )
730
+ self.device_map = (
731
+ get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
732
+ )
733
+ assert_device_map(self.device_map, len(self.h))
734
+ self.model_parallel = True
735
+ self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
736
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
737
+ self.wte = self.wte.to(self.first_device)
738
+ self.wpe = self.wpe.to(self.first_device)
739
+ # Load onto devices
740
+ for k, v in self.device_map.items():
741
+ for block in v:
742
+ cuda_device = "cuda:" + str(k)
743
+ self.h[block] = self.h[block].to(cuda_device)
744
+ # ln_f to last
745
+ self.ln_f = self.ln_f.to(self.last_device)
746
+
747
+ def deparallelize(self):
748
+ warnings.warn(
749
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
750
+ FutureWarning,
751
+ )
752
+ self.model_parallel = False
753
+ self.device_map = None
754
+ self.first_device = "cpu"
755
+ self.last_device = "cpu"
756
+ self.wte = self.wte.to("cpu")
757
+ self.wpe = self.wpe.to("cpu")
758
+ for index in range(len(self.h)):
759
+ self.h[index] = self.h[index].to("cpu")
760
+ self.ln_f = self.ln_f.to("cpu")
761
+ torch.cuda.empty_cache()
762
+
763
+ def get_input_embeddings(self):
764
+ return self.wte
765
+
766
+ def set_input_embeddings(self, new_embeddings):
767
+ self.wte = new_embeddings
768
+
769
+ def _prune_heads(self, heads_to_prune):
770
+ """
771
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
772
+ """
773
+ for layer, heads in heads_to_prune.items():
774
+ self.h[layer].attn.prune_heads(heads)
775
+
776
+ def forward(
777
+ self,
778
+ input_ids: Optional[torch.LongTensor] = None,
779
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
780
+ attention_mask: Optional[torch.FloatTensor] = None,
781
+ token_type_ids: Optional[torch.LongTensor] = None,
782
+ position_ids: Optional[torch.LongTensor] = None,
783
+ head_mask: Optional[torch.FloatTensor] = None,
784
+ inputs_embeds: Optional[torch.FloatTensor] = None,
785
+ encoder_hidden_states: Optional[torch.Tensor] = None,
786
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
787
+ use_cache: Optional[bool] = None,
788
+ output_attentions: Optional[bool] = None,
789
+ output_hidden_states: Optional[bool] = None,
790
+ return_dict: Optional[bool] = None,
791
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
792
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
793
+ output_hidden_states = (
794
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
795
+ )
796
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
797
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
798
+
799
+ if input_ids is not None and inputs_embeds is not None:
800
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
801
+ elif input_ids is not None:
802
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
803
+ input_shape = input_ids.size()
804
+ input_ids = input_ids.view(-1, input_shape[-1])
805
+ batch_size = input_ids.shape[0]
806
+ elif inputs_embeds is not None:
807
+ input_shape = inputs_embeds.size()[:-1]
808
+ batch_size = inputs_embeds.shape[0]
809
+ else:
810
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
811
+
812
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
813
+
814
+ if token_type_ids is not None:
815
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
816
+
817
+ if past_key_values is None:
818
+ past_length = 0
819
+ past_key_values = tuple([None] * len(self.h))
820
+ else:
821
+ past_length = past_key_values[0][0].size(-2)
822
+ if position_ids is None:
823
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
824
+ position_ids = position_ids.unsqueeze(0)
825
+
826
+ if inputs_embeds is None:
827
+ inputs_embeds = self.wte(input_ids)
828
+ position_embeds = self.wpe(position_ids)
829
+ hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)
830
+
831
+ # Attention mask.
832
+ _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
833
+ attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None
834
+ if self._attn_implementation == "flash_attention_2":
835
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
836
+ elif _use_sdpa:
837
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
838
+ attention_mask=attention_mask,
839
+ input_shape=(batch_size, input_shape[-1]),
840
+ inputs_embeds=inputs_embeds,
841
+ past_key_values_length=past_length,
842
+ )
843
+ else:
844
+ if attention_mask is not None:
845
+ # We create a 3D attention mask from a 2D tensor mask.
846
+ # Sizes are [batch_size, 1, 1, to_seq_length]
847
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
848
+ # this attention mask is more simple than the triangular masking of causal attention
849
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
850
+ attention_mask = attention_mask[:, None, None, :]
851
+
852
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
853
+ # masked positions, this operation will create a tensor which is 0.0 for
854
+ # positions we want to attend and the dtype's smallest value for masked positions.
855
+ # Since we are adding it to the raw scores before the softmax, this is
856
+ # effectively the same as removing these entirely.
857
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
858
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
859
+
860
+ # If a 2D or 3D attention mask is provided for the cross-attention
861
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
862
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
863
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
864
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
865
+ if encoder_attention_mask is None:
866
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
867
+ if _use_sdpa:
868
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
869
+ mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
870
+ )
871
+ elif not self._attn_implementation == "flash_attention_2":
872
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
873
+ else:
874
+ encoder_attention_mask = None
875
+
876
+ # Prepare head mask if needed
877
+ # 1.0 in head_mask indicate we keep the head
878
+ # attention_probs has shape bsz x n_heads x N x N
879
+ # head_mask has shape n_layer x batch x n_heads x N x N
880
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
881
+
882
+ if token_type_ids is not None:
883
+ token_type_embeds = self.wte(token_type_ids)
884
+ hidden_states = hidden_states + token_type_embeds
885
+
886
+ hidden_states = self.drop(hidden_states)
887
+
888
+ output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
889
+
890
+ if self.gradient_checkpointing and self.training:
891
+ if use_cache:
892
+ logger.warning_once(
893
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
894
+ )
895
+ use_cache = False
896
+
897
+ presents = () if use_cache else None
898
+ all_self_attentions = () if output_attentions else None
899
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
900
+ all_hidden_states = () if output_hidden_states else None
901
+ for i in range(len(self.h)):
902
+ block, layer_past = self.h[i], past_key_values[i]
903
+ # Model parallel
904
+ if self.model_parallel:
905
+ torch.cuda.set_device(hidden_states.device)
906
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
907
+ if layer_past is not None:
908
+ layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
909
+ # Ensure that attention_mask is always on the same device as hidden_states
910
+ if attention_mask is not None:
911
+ attention_mask = attention_mask.to(hidden_states.device)
912
+ if isinstance(head_mask, torch.Tensor):
913
+ head_mask = head_mask.to(hidden_states.device)
914
+ if output_hidden_states:
915
+ all_hidden_states = all_hidden_states + (hidden_states,)
916
+
917
+ if self.gradient_checkpointing and self.training:
918
+ outputs = self._gradient_checkpointing_func(
919
+ block.__call__,
920
+ hidden_states,
921
+ None,
922
+ attention_mask,
923
+ head_mask[i],
924
+ encoder_hidden_states,
925
+ encoder_attention_mask,
926
+ use_cache,
927
+ output_attentions,
928
+ )
929
+ else:
930
+ outputs = block(
931
+ hidden_states,
932
+ layer_past=layer_past,
933
+ attention_mask=attention_mask,
934
+ head_mask=head_mask[i],
935
+ encoder_hidden_states=encoder_hidden_states,
936
+ encoder_attention_mask=encoder_attention_mask,
937
+ use_cache=use_cache,
938
+ output_attentions=output_attentions,
939
+ )
940
+
941
+ hidden_states = outputs[0]
942
+ if use_cache is True:
943
+ presents = presents + (outputs[1],)
944
+
945
+ if output_attentions:
946
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
947
+ if self.config.add_cross_attention:
948
+ all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
949
+
950
+ # Model Parallel: If it's the last layer for that device, put things on the next device
951
+ if self.model_parallel:
952
+ for k, v in self.device_map.items():
953
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
954
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
955
+
956
+ hidden_states = self.ln_f(hidden_states)
957
+
958
+ hidden_states = hidden_states.view(output_shape)
959
+ # Add last hidden state
960
+ if output_hidden_states:
961
+ all_hidden_states = all_hidden_states + (hidden_states,)
962
+
963
+ if not return_dict:
964
+ return tuple(
965
+ v
966
+ for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
967
+ if v is not None
968
+ )
969
+
970
+ return BaseModelOutputWithPastAndCrossAttentions(
971
+ last_hidden_state=hidden_states,
972
+ past_key_values=presents,
973
+ hidden_states=all_hidden_states,
974
+ attentions=all_self_attentions,
975
+ cross_attentions=all_cross_attentions,
976
+ )
977
+
978
+
979
+ class RotatingHeadGPT2LMHeadModel(RotatingHeadGPT2PretrainedModel, GenerationMixin):
980
+ _tied_weights_keys = ["lm_head.weight"]
981
+
982
+ def __init__(self, config):
983
+ super().__init__(config)
984
+ self.transformer = RotatingHeadGPT2Model(config)
985
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
986
+
987
+ # Model parallel
988
+ self.model_parallel = False
989
+ self.device_map = None
990
+
991
+ # Initialize weights and apply final processing
992
+ self.post_init()
993
+
994
+ def parallelize(self, device_map=None):
995
+ warnings.warn(
996
+ "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
997
+ " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
998
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':"
999
+ " 0, 'transformer.h.1': 1, ...}",
1000
+ FutureWarning,
1001
+ )
1002
+ self.device_map = (
1003
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1004
+ if device_map is None
1005
+ else device_map
1006
+ )
1007
+ assert_device_map(self.device_map, len(self.transformer.h))
1008
+ self.transformer.parallelize(self.device_map)
1009
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
1010
+ self.model_parallel = True
1011
+
1012
+ def deparallelize(self):
1013
+ warnings.warn(
1014
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1015
+ FutureWarning,
1016
+ )
1017
+ self.transformer.deparallelize()
1018
+ self.transformer = self.transformer.to("cpu")
1019
+ self.lm_head = self.lm_head.to("cpu")
1020
+ self.model_parallel = False
1021
+ torch.cuda.empty_cache()
1022
+
1023
+ def get_output_embeddings(self):
1024
+ return self.lm_head
1025
+
1026
+ def set_output_embeddings(self, new_embeddings):
1027
+ self.lm_head = new_embeddings
1028
+
1029
+ def forward(
1030
+ self,
1031
+ input_ids: Optional[torch.LongTensor] = None,
1032
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1033
+ attention_mask: Optional[torch.FloatTensor] = None,
1034
+ token_type_ids: Optional[torch.LongTensor] = None,
1035
+ position_ids: Optional[torch.LongTensor] = None,
1036
+ head_mask: Optional[torch.FloatTensor] = None,
1037
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1038
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1039
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1040
+ labels: Optional[torch.LongTensor] = None,
1041
+ use_cache: Optional[bool] = None,
1042
+ output_attentions: Optional[bool] = None,
1043
+ output_hidden_states: Optional[bool] = None,
1044
+ return_dict: Optional[bool] = None,
1045
+ **kwargs,
1046
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1047
+ r"""
1048
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1049
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1050
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1051
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1052
+ """
1053
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1054
+
1055
+ transformer_outputs = self.transformer(
1056
+ input_ids,
1057
+ past_key_values=past_key_values,
1058
+ attention_mask=attention_mask,
1059
+ token_type_ids=token_type_ids,
1060
+ position_ids=position_ids,
1061
+ head_mask=head_mask,
1062
+ inputs_embeds=inputs_embeds,
1063
+ encoder_hidden_states=encoder_hidden_states,
1064
+ encoder_attention_mask=encoder_attention_mask,
1065
+ use_cache=use_cache,
1066
+ output_attentions=output_attentions,
1067
+ output_hidden_states=output_hidden_states,
1068
+ return_dict=return_dict,
1069
+ )
1070
+ hidden_states = transformer_outputs[0]
1071
+
1072
+ # Set device for model parallelism
1073
+ if self.model_parallel:
1074
+ torch.cuda.set_device(self.transformer.first_device)
1075
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
1076
+
1077
+ lm_logits = self.lm_head(hidden_states)
1078
+
1079
+ loss = None
1080
+ if labels is not None:
1081
+ # Flatten the tokens
1082
+ loss = self.loss_function(
1083
+ lm_logits,
1084
+ labels,
1085
+ vocab_size=self.config.vocab_size,
1086
+ **kwargs,
1087
+ )
1088
+
1089
+ if not return_dict:
1090
+ output = (lm_logits,) + transformer_outputs[1:]
1091
+ return ((loss,) + output) if loss is not None else output
1092
+
1093
+ return CausalLMOutputWithCrossAttentions(
1094
+ loss=loss,
1095
+ logits=lm_logits,
1096
+ past_key_values=transformer_outputs.past_key_values,
1097
+ hidden_states=transformer_outputs.hidden_states,
1098
+ attentions=transformer_outputs.attentions,
1099
+ cross_attentions=transformer_outputs.cross_attentions,
1100
+ )
1101
+
1102
+ @staticmethod
1103
+ def _reorder_cache(
1104
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1105
+ ) -> Tuple[Tuple[torch.Tensor]]:
1106
+ """
1107
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1108
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1109
+ beam_idx at every generation step.
1110
+ """
1111
+ return tuple(
1112
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
1113
+ for layer_past in past_key_values
1114
+ )
1115
+
1116
+ __all__ = [
1117
+ "RotatingHeadGPT2LMHeadModel",
1118
+ "RotatingHeadGPT2Model",
1119
+ "RotatingHeadGPT2PretrainedModel",
1120
+ "load_tf_weights_in_gpt2",
1121
+ ]
1122
+
1123
+
1124
+ if __name__ == "__main__":
1125
+ cg = GPT2Config.from_pretrained("gpt2-medium")
1126
+ cg.rotatinghead = 'gp'
1127
+ model = RotatingHeadGPT2LMHeadModel(cg)
1128
+ from src.utils.model_utlis import print_trainable_parameters
1129
+ print_trainable_parameters(model)
1130
+ model(torch.randint(0, 10000, (1, 100)))
1131
+ print()