stefan-it commited on
Commit
65b7504
·
verified ·
1 Parent(s): 5ece352

model: add initial version of NeoBERTForTokenClassification

Browse files
Files changed (1) hide show
  1. model.py +495 -0
model.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://github.com/facebookresearch/llama/blob/main/llama/model.py
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
7
+ from torch.nn.functional import scaled_dot_product_attention
8
+
9
+ from typing import Optional
10
+ import numpy as np
11
+
12
+ from xformers.ops import SwiGLU
13
+
14
+ try:
15
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
16
+
17
+ FLASH_ATTN_AVAILABLE = True
18
+ except ImportError:
19
+ FLASH_ATTN_AVAILABLE = False
20
+
21
+ from transformers import (
22
+ PreTrainedModel,
23
+ PretrainedConfig,
24
+ DataCollatorForLanguageModeling,
25
+ )
26
+ from transformers.modeling_outputs import (
27
+ BaseModelOutput,
28
+ MaskedLMOutput,
29
+ SequenceClassifierOutput,
30
+ TokenClassifierOutput,
31
+ )
32
+
33
+ from .rotary import precompute_freqs_cis, apply_rotary_emb
34
+
35
+
36
+ class DataCollatorWithPacking(DataCollatorForLanguageModeling):
37
+ def __init__(self, pack_sequences=False, **kwargs):
38
+ super().__init__(**kwargs)
39
+ self.pack_sequences = pack_sequences
40
+
41
+ def __call__(self, batch):
42
+ if self.pack_sequences:
43
+ # Add position_ids if not present
44
+ if "position_ids" not in batch[0]:
45
+ for item in batch:
46
+ item["position_ids"] = list(range(len(item["input_ids"])))
47
+
48
+ # Pack the sequences into a single list
49
+ input_ids_list = [item["input_ids"] for item in batch]
50
+ position_ids_list = [item["position_ids"] for item in batch]
51
+ seqlens = np.array([0] + [len(ids) for ids in input_ids_list])
52
+
53
+ packed_batch = {
54
+ "position_ids": np.concatenate(position_ids_list, axis=0),
55
+ "input_ids": np.concatenate(input_ids_list, axis=0),
56
+ "cu_seqlens": np.cumsum(seqlens),
57
+ "max_seqlen": max(seqlens),
58
+ }
59
+
60
+ batch = super().__call__([packed_batch])
61
+ batch["cu_seqlens"] = batch["cu_seqlens"].to(torch.int32).squeeze()
62
+ else:
63
+ batch = super().__call__(batch)
64
+ batch["attention_mask"] = batch["attention_mask"].to(torch.bool)
65
+
66
+ return batch
67
+
68
+
69
+ class NeoBERTConfig(PretrainedConfig):
70
+ model_type = "neobert"
71
+
72
+ # All config parameters must have a default value.
73
+ def __init__(
74
+ self,
75
+ hidden_size: int = 768,
76
+ num_hidden_layers: int = 28,
77
+ num_attention_heads: int = 12,
78
+ intermediate_size: int = 3072,
79
+ embedding_init_range: float = 0.02,
80
+ decoder_init_range: float = 0.02,
81
+ norm_eps: float = 1e-06,
82
+ vocab_size: int = 30522,
83
+ pad_token_id: int = 0,
84
+ max_length: int = 1024,
85
+ **kwargs,
86
+ ):
87
+ super().__init__(**kwargs)
88
+
89
+ self.hidden_size = hidden_size
90
+ self.num_hidden_layers = num_hidden_layers
91
+ self.num_attention_heads = num_attention_heads
92
+ if hidden_size % num_attention_heads != 0:
93
+ raise ValueError("Hidden size must be divisible by the number of heads.")
94
+ self.dim_head = hidden_size // num_attention_heads
95
+ self.intermediate_size = intermediate_size
96
+ self.embedding_init_range = embedding_init_range
97
+ self.decoder_init_range = decoder_init_range
98
+ self.norm_eps = norm_eps
99
+ self.vocab_size = vocab_size
100
+ self.pad_token_id = pad_token_id
101
+ self.max_length = max_length
102
+ self.kwargs = kwargs
103
+
104
+
105
+ class EncoderBlock(nn.Module):
106
+ """Transformer encoder block."""
107
+
108
+ def __init__(self, config: NeoBERTConfig):
109
+ super().__init__()
110
+
111
+ self.config = config
112
+
113
+ # Attention
114
+ self.qkv = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size * 3, bias=False)
115
+ self.wo = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size, bias=False)
116
+
117
+ # Feedforward network
118
+ multiple_of = 8
119
+ intermediate_size = int(2 * config.intermediate_size / 3)
120
+ intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)
121
+ self.ffn = SwiGLU(config.hidden_size, intermediate_size, config.hidden_size, bias=False)
122
+
123
+ # Layer norms
124
+ self.attention_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
125
+ self.ffn_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
126
+
127
+ def forward(
128
+ self,
129
+ x: torch.Tensor,
130
+ attention_mask: torch.Tensor,
131
+ freqs_cis: torch.Tensor,
132
+ output_attentions: bool,
133
+ max_seqlen: int = None,
134
+ cu_seqlens: torch.Tensor = None,
135
+ ):
136
+ # Attention
137
+ attn_output, attn_weights = self._att_block(
138
+ self.attention_norm(x), attention_mask, freqs_cis, output_attentions, max_seqlen, cu_seqlens
139
+ )
140
+
141
+ # Residual
142
+ x = x + attn_output
143
+
144
+ # Feed-forward
145
+ x = x + self.ffn(self.ffn_norm(x))
146
+
147
+ return x, attn_weights
148
+
149
+ def _att_block(
150
+ self,
151
+ x: torch.Tensor,
152
+ attention_mask: torch.Tensor,
153
+ freqs_cis: torch.Tensor,
154
+ output_attentions: bool,
155
+ max_seqlen: int = None,
156
+ cu_seqlens: torch.Tensor = None,
157
+ ):
158
+ batch_size, seq_len, _ = x.shape
159
+
160
+ xq, xk, xv = self.qkv(x).view(batch_size, seq_len, self.config.num_attention_heads, self.config.dim_head * 3).chunk(3, axis=-1)
161
+
162
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
163
+
164
+ # Attn block
165
+ attn_weights = None
166
+
167
+ # Flash attention if the tensors are packed
168
+ if cu_seqlens is not None:
169
+ attn = flash_attn_varlen_func(
170
+ q=xq.squeeze(0),
171
+ k=xk.squeeze(0),
172
+ v=xv.squeeze(0),
173
+ cu_seqlens_q=cu_seqlens,
174
+ cu_seqlens_k=cu_seqlens,
175
+ max_seqlen_q=max_seqlen,
176
+ max_seqlen_k=max_seqlen,
177
+ dropout_p=0.0,
178
+ causal=False,
179
+ )
180
+ # Eager attention if attention weights are needed in the output
181
+ elif output_attentions:
182
+ attn_weights = xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / (xq.size(-1) ** 0.5)
183
+ if attention_mask is not None:
184
+ attn_weights = attn_weights * attention_mask
185
+ attn_weights = attn_weights.softmax(-1)
186
+ attn = attn_weights @ xv.permute(0, 2, 1, 3)
187
+ attn = attn.transpose(1, 2)
188
+ # Fall back to SDPA otherwise
189
+ else:
190
+ attn = scaled_dot_product_attention(
191
+ query=xq.transpose(1, 2),
192
+ key=xk.transpose(1, 2),
193
+ value=xv.transpose(1, 2),
194
+ attn_mask=attention_mask.bool(),
195
+ dropout_p=0,
196
+ ).transpose(1, 2)
197
+
198
+ return self.wo(attn.reshape(batch_size, seq_len, self.config.num_attention_heads * self.config.dim_head)), attn_weights
199
+
200
+
201
+ class NeoBERTPreTrainedModel(PreTrainedModel):
202
+ config_class = NeoBERTConfig
203
+ base_model_prefix = "model"
204
+ _supports_cache_class = True
205
+
206
+ def _init_weights(self, module):
207
+ if isinstance(module, nn.Linear):
208
+ module.weight.data.uniform_(-self.config.decoder_init_range, self.config.decoder_init_range)
209
+ elif isinstance(module, nn.Embedding):
210
+ module.weight.data.uniform_(-self.config.embedding_init_range, self.config.embedding_init_range)
211
+
212
+
213
+ class NeoBERT(NeoBERTPreTrainedModel):
214
+ config_class = NeoBERTConfig
215
+
216
+ def __init__(self, config: NeoBERTConfig):
217
+ super().__init__(config)
218
+
219
+ self.config = config
220
+
221
+ self.encoder = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
222
+
223
+ # Ensures freqs_cis is moved to the same devices as the model. Non-persistent buffers are not saved in the state_dict.
224
+ freqs_cis = precompute_freqs_cis(config.hidden_size // config.num_attention_heads, config.max_length)
225
+ self.register_buffer("freqs_cis", freqs_cis, persistent=False)
226
+
227
+ self.transformer_encoder = nn.ModuleList()
228
+ for _ in range(config.num_hidden_layers):
229
+ self.transformer_encoder.append(EncoderBlock(config))
230
+
231
+ self.layer_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
232
+
233
+ # Initialize weights and apply final processing
234
+ self.post_init()
235
+
236
+ def forward(
237
+ self,
238
+ input_ids: torch.Tensor,
239
+ position_ids: torch.Tensor = None,
240
+ max_seqlen: int = None,
241
+ cu_seqlens: torch.Tensor = None,
242
+ attention_mask: torch.Tensor = None,
243
+ output_hidden_states: bool = False,
244
+ output_attentions: bool = False,
245
+ **kwargs,
246
+ ):
247
+ # Initialize
248
+ hidden_states, attentions = [], []
249
+
250
+ # Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
251
+ if attention_mask is not None:
252
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, attention_mask.size(-1), 1)
253
+
254
+ # Checks to be done if inputs are packed sequences
255
+ if cu_seqlens is not None:
256
+ assert (
257
+ FLASH_ATTN_AVAILABLE
258
+ ), "Flash-attention is not available. Please ''pip install flash_attn'', or provide un-packed sequences."
259
+ assert not output_attentions, "Output attentions is not supported when sequences are packed."
260
+ assert max_seqlen is not None, "Missing max_seqlen. It must be provided when cu_seqlens are not None."
261
+ assert input_ids.shape[0] == 1, "Cumulative sequence lengths are provided but input_ids are not packed."
262
+ assert input_ids.is_cuda, "Packing uses an implementation of flash-attention and is only supported on GPU."
263
+
264
+ # RoPE
265
+ freqs_cis = self.freqs_cis[position_ids] if position_ids is not None else self.freqs_cis[: input_ids.shape[1]].unsqueeze(0)
266
+
267
+ # Embedding
268
+ x = self.encoder(input_ids)
269
+
270
+ # Transformer encoder
271
+ for layer in self.transformer_encoder:
272
+ x, attn = layer(x, attention_mask, freqs_cis, output_attentions, max_seqlen, cu_seqlens)
273
+ if output_hidden_states:
274
+ hidden_states.append(x)
275
+ if output_attentions:
276
+ attentions.append(attn)
277
+
278
+ # Final normalization layer
279
+ x = self.layer_norm(x)
280
+
281
+ # Return the output of the last hidden layer
282
+ return BaseModelOutput(
283
+ last_hidden_state=x,
284
+ hidden_states=hidden_states if output_hidden_states else None,
285
+ attentions=attentions if output_attentions else None,
286
+ )
287
+
288
+
289
+ class NeoBERTLMHead(NeoBERTPreTrainedModel):
290
+ config_class = NeoBERTConfig
291
+
292
+ def __init__(self, config: NeoBERTConfig):
293
+ super().__init__(config)
294
+
295
+ self.config = config
296
+
297
+ self.model = NeoBERT(config)
298
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
299
+
300
+ self.post_init()
301
+
302
+ def forward(
303
+ self,
304
+ input_ids: torch.Tensor,
305
+ position_ids: torch.Tensor = None,
306
+ max_seqlen: int = None,
307
+ cu_seqlens: torch.Tensor = None,
308
+ attention_mask: torch.Tensor = None,
309
+ output_hidden_states: bool = False,
310
+ output_attentions: bool = False,
311
+ **kwargs,
312
+ ):
313
+
314
+ output = self.model.forward(
315
+ input_ids,
316
+ position_ids,
317
+ max_seqlen,
318
+ cu_seqlens,
319
+ attention_mask,
320
+ output_hidden_states,
321
+ output_attentions,
322
+ )
323
+ logits = self.decoder(output.last_hidden_state)
324
+
325
+ return MaskedLMOutput(
326
+ hidden_states=output.hidden_states if output_hidden_states else None,
327
+ attentions=output.attentions if output_attentions else None,
328
+ logits=logits,
329
+ )
330
+
331
+
332
+ class NeoBERTForSequenceClassification(NeoBERTPreTrainedModel):
333
+ config_class = NeoBERTConfig
334
+
335
+ def __init__(self, config: NeoBERTConfig):
336
+ super().__init__(config)
337
+
338
+ self.config = config
339
+
340
+ self.num_labels = getattr(config, "num_labels", 2)
341
+ self.classifier_dropout = getattr(config, "classifier_dropout", 0.1)
342
+ self.classifier_init_range = getattr(config, "classifier_init_range", 0.02)
343
+
344
+ self.model = NeoBERT(config)
345
+
346
+ self.dense = nn.Linear(self.config.hidden_size, self.config.hidden_size)
347
+ self.dropout = nn.Dropout(self.classifier_dropout)
348
+ self.classifier = nn.Linear(self.config.hidden_size, self.num_labels)
349
+
350
+ self.post_init()
351
+
352
+ def _init_weights(self, module):
353
+ if isinstance(module, nn.Linear):
354
+ module.weight.data.normal_(mean=0.0, std=self.classifier_init_range)
355
+ if module.bias is not None:
356
+ module.bias.data.zero_()
357
+
358
+ def forward(
359
+ self,
360
+ input_ids: torch.Tensor,
361
+ position_ids: torch.Tensor = None,
362
+ max_seqlen: int = None,
363
+ cu_seqlens: torch.Tensor = None,
364
+ attention_mask: torch.Tensor = None,
365
+ output_hidden_states: bool = False,
366
+ output_attentions: bool = False,
367
+ labels: Optional[torch.Tensor] = None,
368
+ return_dict: Optional[bool] = None,
369
+ ):
370
+
371
+ output = self.model.forward(
372
+ input_ids,
373
+ position_ids,
374
+ max_seqlen,
375
+ cu_seqlens,
376
+ attention_mask,
377
+ output_hidden_states,
378
+ output_attentions,
379
+ )
380
+ hidden_states = output.last_hidden_state
381
+
382
+ x = hidden_states[:, 0, :]
383
+ x = self.dropout(x)
384
+ x = self.dense(x)
385
+ x = torch.tanh(x)
386
+ x = self.dropout(x)
387
+
388
+ logits = self.classifier(x)
389
+
390
+ loss = None
391
+ if labels is not None:
392
+ if self.config.problem_type is None:
393
+ if self.num_labels == 1:
394
+ self.config.problem_type = "regression"
395
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
396
+ self.config.problem_type = "single_label_classification"
397
+ else:
398
+ self.config.problem_type = "multi_label_classification"
399
+
400
+ if self.config.problem_type == "regression":
401
+ loss_fct = MSELoss()
402
+ if self.num_labels == 1:
403
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
404
+ else:
405
+ loss = loss_fct(logits, labels)
406
+ elif self.config.problem_type == "single_label_classification":
407
+ loss_fct = CrossEntropyLoss()
408
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
409
+ elif self.config.problem_type == "multi_label_classification":
410
+ loss_fct = BCEWithLogitsLoss()
411
+ loss = loss_fct(logits, labels)
412
+
413
+ if not return_dict:
414
+ result = (logits,)
415
+ return ((loss,) + result) if loss is not None else result
416
+
417
+ return SequenceClassifierOutput(
418
+ loss=loss,
419
+ logits=logits,
420
+ hidden_states=output.hidden_states if output_hidden_states else None,
421
+ attentions=output.attentions if output_attentions else None,
422
+ )
423
+
424
+
425
+ class NeoBERTForTokenClassification(NeoBERTPreTrainedModel):
426
+ config_class = NeoBERTConfig
427
+
428
+ def __init__(self, config: NeoBERTConfig):
429
+ super().__init__(config)
430
+
431
+ self.config = config
432
+
433
+ self.num_labels = getattr(config, "num_labels", 2)
434
+ self.classifier_dropout = getattr(config, "classifier_dropout", 0.1)
435
+ self.classifier_init_range = getattr(config, "classifier_init_range", 0.02)
436
+
437
+ self.model = NeoBERT(config)
438
+
439
+ self.dense = nn.Linear(self.config.hidden_size, self.config.hidden_size)
440
+ self.dropout = nn.Dropout(self.classifier_dropout)
441
+ self.classifier = nn.Linear(self.config.hidden_size, self.num_labels)
442
+
443
+ self.post_init()
444
+
445
+ def _init_weights(self, module):
446
+ if isinstance(module, nn.Linear):
447
+ module.weight.data.normal_(mean=0.0, std=self.classifier_init_range)
448
+ if module.bias is not None:
449
+ module.bias.data.zero_()
450
+
451
+ def forward(
452
+ self,
453
+ input_ids: torch.Tensor,
454
+ position_ids: torch.Tensor = None,
455
+ max_seqlen: int = None,
456
+ cu_seqlens: torch.Tensor = None,
457
+ attention_mask: torch.Tensor = None,
458
+ output_hidden_states: bool = False,
459
+ output_attentions: bool = False,
460
+ labels: Optional[torch.Tensor] = None,
461
+ return_dict: Optional[bool] = None,
462
+ ):
463
+ output = self.model.forward(
464
+ input_ids,
465
+ position_ids,
466
+ max_seqlen,
467
+ cu_seqlens,
468
+ attention_mask,
469
+ output_hidden_states,
470
+ output_attentions,
471
+ )
472
+ x = output.last_hidden_state
473
+
474
+ x = self.dropout(x)
475
+ x = self.dense(x)
476
+ x = torch.tanh(x)
477
+ x = self.dropout(x)
478
+
479
+ logits = self.classifier(x)
480
+
481
+ loss = None
482
+ if labels is not None:
483
+ loss_fct = CrossEntropyLoss()
484
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
485
+
486
+ if not return_dict:
487
+ result = (logits,) + output[1:]
488
+ return ((loss,) + result) if loss is not None else result
489
+
490
+ return TokenClassifierOutput(
491
+ loss=loss,
492
+ logits=logits,
493
+ hidden_states=output.hidden_states if output_hidden_states else None,
494
+ attentions=output.attentions if output_attentions else None,
495
+ )