liang.zhao commited on
Commit
b274ce6
1 Parent(s): 96b2e3e

update model and config

Browse files
config.json CHANGED
@@ -33,7 +33,7 @@
33
  "rms_norm_eps": 1e-06,
34
  "tie_word_embeddings": false,
35
  "torch_dtype": "bfloat16",
36
- "transformers_version": "4.34.0",
37
  "use_cache": true,
38
  "vocab_size": 65519
39
  }
 
33
  "rms_norm_eps": 1e-06,
34
  "tie_word_embeddings": false,
35
  "torch_dtype": "bfloat16",
36
+ "transformers_version": "4.33.1",
37
  "use_cache": true,
38
  "vocab_size": 65519
39
  }
configuration_skywork.py CHANGED
@@ -1,13 +1,14 @@
1
  # Copyright (c) SkyworkAI and the HuggingFace Inc. team. All rights reserved.
2
  # This code is built upon Huggingface's transformers repository.
3
 
 
4
  from transformers.configuration_utils import PretrainedConfig
5
  from transformers.utils import logging
6
 
7
 
8
  logger = logging.get_logger(__name__)
9
 
10
- Skywork_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
11
 
12
 
13
  class SkyworkConfig(PretrainedConfig):
@@ -28,15 +29,13 @@ class SkyworkConfig(PretrainedConfig):
28
  initializer_range=0.02,
29
  rms_norm_eps=1e-6,
30
  use_cache=True,
31
- pad_token_id=0,
32
  bos_token_id=1,
33
  eos_token_id=2,
34
  pretraining_tp=1,
35
  tie_word_embeddings=False,
36
- rope_scaling=None,
37
  rope_theta=10000.0,
38
- attention_bias=False,
39
- use_flash_attention=False,
40
  **kwargs,
41
  ):
42
  self.vocab_size = vocab_size
@@ -56,16 +55,9 @@ class SkyworkConfig(PretrainedConfig):
56
  self.rms_norm_eps = rms_norm_eps
57
  self.pretraining_tp = pretraining_tp
58
  self.use_cache = use_cache
59
- self.rope_scaling = rope_scaling
60
  self.rope_theta = rope_theta
61
- self.attention_bias = attention_bias
62
- self.use_flash_attention = use_flash_attention
63
- if self.use_flash_attention:
64
- try:
65
- from flash_attn.flash_attn_interface import flash_attn_varlen_func
66
- from einops import rearrange
67
- except:
68
- raise ValueError("`use_flash_attention` requires Flash Attention 2+ and einops.\nTry `pip install einops` and installing Flash Attention from from https://github.com/Dao-AILab/flash-attention")
69
 
70
  super().__init__(
71
  pad_token_id=pad_token_id,
@@ -74,3 +66,24 @@ class SkyworkConfig(PretrainedConfig):
74
  tie_word_embeddings=tie_word_embeddings,
75
  **kwargs,
76
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Copyright (c) SkyworkAI and the HuggingFace Inc. team. All rights reserved.
2
  # This code is built upon Huggingface's transformers repository.
3
 
4
+
5
  from transformers.configuration_utils import PretrainedConfig
6
  from transformers.utils import logging
7
 
8
 
9
  logger = logging.get_logger(__name__)
10
 
11
+ LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
12
 
13
 
14
  class SkyworkConfig(PretrainedConfig):
 
29
  initializer_range=0.02,
30
  rms_norm_eps=1e-6,
31
  use_cache=True,
32
+ pad_token_id=None,
33
  bos_token_id=1,
34
  eos_token_id=2,
35
  pretraining_tp=1,
36
  tie_word_embeddings=False,
 
37
  rope_theta=10000.0,
38
+ rope_scaling=None,
 
39
  **kwargs,
40
  ):
41
  self.vocab_size = vocab_size
 
55
  self.rms_norm_eps = rms_norm_eps
56
  self.pretraining_tp = pretraining_tp
57
  self.use_cache = use_cache
 
58
  self.rope_theta = rope_theta
59
+ self.rope_scaling = rope_scaling
60
+ self._rope_scaling_validation()
 
 
 
 
 
 
61
 
62
  super().__init__(
63
  pad_token_id=pad_token_id,
 
66
  tie_word_embeddings=tie_word_embeddings,
67
  **kwargs,
68
  )
69
+
70
+ def _rope_scaling_validation(self):
71
+ """
72
+ Validate the `rope_scaling` configuration.
73
+ """
74
+ if self.rope_scaling is None:
75
+ return
76
+
77
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
78
+ raise ValueError(
79
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
80
+ f"got {self.rope_scaling}"
81
+ )
82
+ rope_scaling_type = self.rope_scaling.get("type", None)
83
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
84
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "ntk"]:
85
+ raise ValueError(
86
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
87
+ )
88
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
89
+ raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
generation_config.json CHANGED
@@ -6,5 +6,5 @@
6
  "pad_token_id": 0,
7
  "temperature": 0.6,
8
  "top_p": 0.9,
9
- "transformers_version": "4.34.0"
10
  }
 
6
  "pad_token_id": 0,
7
  "temperature": 0.6,
8
  "top_p": 0.9,
9
+ "transformers_version": "4.33.1"
10
  }
modeling_skywork.py CHANGED
@@ -1,5 +1,6 @@
1
  # Copyright (c) SkyworkAI and the HuggingFace Inc. team. All rights reserved.
2
  # This code is built upon Huggingface's transformers repository.
 
3
  import math
4
  from typing import List, Optional, Tuple, Union
5
 
@@ -12,39 +13,15 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
12
  from transformers.activations import ACT2FN
13
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
14
  from transformers.modeling_utils import PreTrainedModel
15
- from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
16
- from transformers.utils import (
17
- add_start_docstrings,
18
- add_start_docstrings_to_model_forward,
19
- is_flash_attn_available,
20
- logging,
21
- replace_return_docstrings,
22
- )
23
  from .configuration_skywork import SkyworkConfig
24
 
25
 
26
- if is_flash_attn_available():
27
- from flash_attn import flash_attn_func, flash_attn_varlen_func
28
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
29
-
30
-
31
  logger = logging.get_logger(__name__)
32
 
33
  _CONFIG_FOR_DOC = "SkyworkConfig"
34
 
35
 
36
- def _get_unpad_data(padding_mask):
37
- seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
38
- indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
39
- max_seqlen_in_batch = seqlens_in_batch.max().item()
40
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
41
- return (
42
- indices,
43
- cu_seqlens,
44
- max_seqlen_in_batch,
45
- )
46
-
47
-
48
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
49
  def _make_causal_mask(
50
  input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
@@ -95,10 +72,7 @@ class SkyworkRMSNorm(nn.Module):
95
  return self.weight * hidden_states.to(input_dtype)
96
 
97
 
98
- ALL_LAYERNORM_LAYERS.append(SkyworkRMSNorm)
99
-
100
-
101
- class SkyworkRotaryEmbedding(nn.Module):
102
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
103
  super().__init__()
104
 
@@ -120,8 +94,8 @@ class SkyworkRotaryEmbedding(nn.Module):
120
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
121
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
122
  emb = torch.cat((freqs, freqs), dim=-1)
123
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
124
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
125
 
126
  def forward(self, x, seq_len=None):
127
  # x: [bs, num_attention_heads, seq_len, head_size]
@@ -129,8 +103,8 @@ class SkyworkRotaryEmbedding(nn.Module):
129
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
130
 
131
  return (
132
- self.cos_cached[:seq_len].to(dtype=x.dtype),
133
- self.sin_cached[:seq_len].to(dtype=x.dtype),
134
  )
135
 
136
 
@@ -149,8 +123,8 @@ class SkyworkLinearScalingRotaryEmbedding(SkyworkRotaryEmbedding):
149
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
150
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
151
  emb = torch.cat((freqs, freqs), dim=-1)
152
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
153
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
154
 
155
 
156
  class SkyworkDynamicNTKScalingRotaryEmbedding(SkyworkRotaryEmbedding):
@@ -175,9 +149,42 @@ class SkyworkDynamicNTKScalingRotaryEmbedding(SkyworkRotaryEmbedding):
175
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
176
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
177
  emb = torch.cat((freqs, freqs), dim=-1)
178
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
179
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
 
 
 
 
181
 
182
  def rotate_half(x):
183
  """Rotates half the hidden dims of the input."""
@@ -186,10 +193,12 @@ def rotate_half(x):
186
  return torch.cat((-x2, x1), dim=-1)
187
 
188
 
189
- # Copied from transformers.models.gpt_neox.modeling_gpt_neox.apply_rotary_pos_emb
190
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
191
- cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
192
- sin = sin[position_ids].unsqueeze(1)
 
 
 
193
  q_embed = (q * cos) + (rotate_half(q) * sin)
194
  k_embed = (k * cos) + (rotate_half(k) * sin)
195
  return q_embed, k_embed
@@ -260,10 +269,10 @@ class SkyworkAttention(nn.Module):
260
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
261
  f" and `num_heads`: {self.num_heads})."
262
  )
263
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
264
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
265
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
266
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
267
  self._init_rope()
268
 
269
  def _init_rope(self):
@@ -290,9 +299,18 @@ class SkyworkAttention(nn.Module):
290
  scaling_factor=scaling_factor,
291
  base=self.rope_theta,
292
  )
 
 
 
 
 
 
 
293
  else:
294
  raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
295
-
 
 
296
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
297
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
298
 
@@ -304,7 +322,6 @@ class SkyworkAttention(nn.Module):
304
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
305
  output_attentions: bool = False,
306
  use_cache: bool = False,
307
- padding_mask: Optional[torch.LongTensor] = None,
308
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
309
  bsz, q_len, _ = hidden_states.size()
310
 
@@ -347,6 +364,7 @@ class SkyworkAttention(nn.Module):
347
 
348
  past_key_value = (key_states, value_states) if use_cache else None
349
 
 
350
  key_states = repeat_kv(key_states, self.num_key_value_groups)
351
  value_states = repeat_kv(value_states, self.num_key_value_groups)
352
 
@@ -376,7 +394,6 @@ class SkyworkAttention(nn.Module):
376
  )
377
 
378
  attn_output = attn_output.transpose(1, 2).contiguous()
379
-
380
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
381
 
382
  if self.config.pretraining_tp > 1:
@@ -392,193 +409,11 @@ class SkyworkAttention(nn.Module):
392
  return attn_output, attn_weights, past_key_value
393
 
394
 
395
- class SkyworkFlashAttention2(SkyworkAttention):
396
- """
397
- Skywork flash attention module. This module inherits from `SkyworkAttention` as the weights of the module stays
398
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
399
- flash attention and deal with padding tokens in case the input contains any of them.
400
- """
401
-
402
- def forward(
403
- self,
404
- hidden_states: torch.Tensor,
405
- attention_mask: Optional[torch.Tensor] = None,
406
- position_ids: Optional[torch.LongTensor] = None,
407
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
408
- output_attentions: bool = False,
409
- use_cache: bool = False,
410
- padding_mask: Optional[torch.LongTensor] = None,
411
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
412
- # SkyworkFlashAttention2 attention does not support output_attentions
413
- output_attentions = False
414
-
415
- bsz, q_len, _ = hidden_states.size()
416
-
417
- query_states = self.q_proj(hidden_states)
418
- key_states = self.k_proj(hidden_states)
419
- value_states = self.v_proj(hidden_states)
420
-
421
- # Flash attention requires the input to have the shape
422
- # batch_size x seq_length x head_dime x hidden_dim
423
- # therefore we just need to keep the original shape
424
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
425
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
426
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
427
-
428
- kv_seq_len = key_states.shape[-2]
429
- if past_key_value is not None:
430
- kv_seq_len += past_key_value[0].shape[-2]
431
-
432
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
433
-
434
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
435
-
436
- if past_key_value is not None:
437
- # reuse k, v, self_attention
438
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
439
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
440
-
441
- past_key_value = (key_states, value_states) if use_cache else None
442
-
443
- query_states = query_states.transpose(1, 2)
444
- key_states = key_states.transpose(1, 2)
445
- value_states = value_states.transpose(1, 2)
446
-
447
- # TODO: skywork does not have dropout in the config??
448
- # It is recommended to use dropout with FA according to the docs
449
- # when training.
450
- dropout_rate = 0.0 # if not self.training else self.attn_dropout
451
-
452
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
453
- # therefore the input hidden states gets silently casted in float32. Hence, we need
454
- # cast them back in float16 just to be sure everything works as expected.
455
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
456
- # in fp32. (SkyworkRMSNorm handles it correctly)
457
- input_dtype = query_states.dtype
458
- if input_dtype == torch.float32:
459
- logger.warning_once(
460
- "The input hidden states seems to be silently casted in float32, this might be related to"
461
- " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
462
- " float16."
463
- )
464
-
465
- query_states = query_states.to(torch.float16)
466
- key_states = key_states.to(torch.float16)
467
- value_states = value_states.to(torch.float16)
468
-
469
- attn_output = self._flash_attention_forward(
470
- query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate
471
- )
472
-
473
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
474
- attn_output = self.o_proj(attn_output)
475
-
476
- if not output_attentions:
477
- attn_weights = None
478
-
479
- return attn_output, attn_weights, past_key_value
480
-
481
- def _flash_attention_forward(
482
- self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None
483
- ):
484
- """
485
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
486
- first unpad the input, then computes the attention scores and pad the final attention scores.
487
-
488
- Args:
489
- query_states (`torch.Tensor`):
490
- Input query states to be passed to Flash Attention API
491
- key_states (`torch.Tensor`):
492
- Input key states to be passed to Flash Attention API
493
- value_states (`torch.Tensor`):
494
- Input value states to be passed to Flash Attention API
495
- padding_mask (`torch.Tensor`):
496
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
497
- position of padding tokens and 1 for the position of non-padding tokens.
498
- dropout (`int`, *optional*):
499
- Attention dropout
500
- softmax_scale (`float`, *optional*):
501
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
502
- """
503
- # Contains at least one padding token in the sequence
504
- if padding_mask is not None:
505
- batch_size = query_states.shape[0]
506
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
507
- query_states, key_states, value_states, padding_mask, query_length
508
- )
509
-
510
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
511
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
512
-
513
- attn_output_unpad = flash_attn_varlen_func(
514
- query_states,
515
- key_states,
516
- value_states,
517
- cu_seqlens_q=cu_seqlens_q,
518
- cu_seqlens_k=cu_seqlens_k,
519
- max_seqlen_q=max_seqlen_in_batch_q,
520
- max_seqlen_k=max_seqlen_in_batch_k,
521
- dropout_p=dropout,
522
- softmax_scale=softmax_scale,
523
- causal=True,
524
- )
525
-
526
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
527
- else:
528
- attn_output = flash_attn_func(
529
- query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True
530
- )
531
-
532
- return attn_output
533
-
534
- def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length):
535
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
536
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
537
-
538
- key_layer = index_first_axis(
539
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
540
- )
541
- value_layer = index_first_axis(
542
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
543
- )
544
- if query_length == kv_seq_len:
545
- query_layer = index_first_axis(
546
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
547
- )
548
- cu_seqlens_q = cu_seqlens_k
549
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
550
- indices_q = indices_k
551
- elif query_length == 1:
552
- max_seqlen_in_batch_q = 1
553
- cu_seqlens_q = torch.arange(
554
- batch_size + 1, dtype=torch.int32, device=query_layer.device
555
- ) # There is a memcpy here, that is very bad.
556
- indices_q = cu_seqlens_q[:-1]
557
- query_layer = query_layer.squeeze(1)
558
- else:
559
- # The -q_len: slice assumes left padding.
560
- padding_mask = padding_mask[:, -query_length:]
561
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask)
562
-
563
- return (
564
- query_layer,
565
- key_layer,
566
- value_layer,
567
- indices_q,
568
- (cu_seqlens_q, cu_seqlens_k),
569
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
570
- )
571
-
572
-
573
  class SkyworkDecoderLayer(nn.Module):
574
  def __init__(self, config: SkyworkConfig):
575
  super().__init__()
576
  self.hidden_size = config.hidden_size
577
- self.self_attn = (
578
- SkyworkAttention(config=config)
579
- if not getattr(config, "_flash_attn_2_enabled", False)
580
- else SkyworkFlashAttention2(config=config)
581
- )
582
  self.mlp = SkyworkMLP(config)
583
  self.input_layernorm = SkyworkRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
584
  self.post_attention_layernorm = SkyworkRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -591,7 +426,6 @@ class SkyworkDecoderLayer(nn.Module):
591
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
592
  output_attentions: Optional[bool] = False,
593
  use_cache: Optional[bool] = False,
594
- padding_mask: Optional[torch.LongTensor] = None,
595
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
596
  """
597
  Args:
@@ -619,7 +453,6 @@ class SkyworkDecoderLayer(nn.Module):
619
  past_key_value=past_key_value,
620
  output_attentions=output_attentions,
621
  use_cache=use_cache,
622
- padding_mask=padding_mask,
623
  )
624
  hidden_states = residual + hidden_states
625
 
@@ -645,7 +478,6 @@ class SkyworkPreTrainedModel(PreTrainedModel):
645
  supports_gradient_checkpointing = True
646
  _no_split_modules = ["SkyworkDecoderLayer"]
647
  _skip_keys_device_placement = "past_key_values"
648
- _supports_flash_attn_2 = True
649
 
650
  def _init_weights(self, module):
651
  std = self.config.initializer_range
@@ -735,13 +567,13 @@ class SkyworkModel(SkyworkPreTrainedModel):
735
 
736
  # retrieve input_ids and inputs_embeds
737
  if input_ids is not None and inputs_embeds is not None:
738
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
739
  elif input_ids is not None:
740
  batch_size, seq_length = input_ids.shape
741
  elif inputs_embeds is not None:
742
  batch_size, seq_length, _ = inputs_embeds.shape
743
  else:
744
- raise ValueError("You have to specify either input_ids or inputs_embeds")
745
 
746
  seq_length_with_past = seq_length
747
  past_key_values_length = 0
@@ -755,7 +587,9 @@ class SkyworkModel(SkyworkPreTrainedModel):
755
  position_ids = torch.arange(
756
  past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
757
  )
758
- position_ids = position_ids.unsqueeze(0)
 
 
759
 
760
  if inputs_embeds is None:
761
  inputs_embeds = self.embed_tokens(input_ids)
@@ -764,13 +598,6 @@ class SkyworkModel(SkyworkPreTrainedModel):
764
  attention_mask = torch.ones(
765
  (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
766
  )
767
- padding_mask = None
768
- else:
769
- if 0 in attention_mask:
770
- padding_mask = attention_mask
771
- else:
772
- padding_mask = None
773
-
774
  attention_mask = self._prepare_decoder_attention_mask(
775
  attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
776
  )
@@ -800,12 +627,15 @@ class SkyworkModel(SkyworkPreTrainedModel):
800
  def create_custom_forward(module):
801
  def custom_forward(*inputs):
802
  # None for past_key_value
803
- return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask)
804
 
805
  return custom_forward
806
 
807
  layer_outputs = torch.utils.checkpoint.checkpoint(
808
- create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids
 
 
 
809
  )
810
  else:
811
  layer_outputs = decoder_layer(
@@ -815,7 +645,6 @@ class SkyworkModel(SkyworkPreTrainedModel):
815
  past_key_value=past_key_value,
816
  output_attentions=output_attentions,
817
  use_cache=use_cache,
818
- padding_mask=padding_mask,
819
  )
820
 
821
  hidden_states = layer_outputs[0]
@@ -873,7 +702,6 @@ class SkyworkForCausalLM(SkyworkPreTrainedModel):
873
  def get_decoder(self):
874
  return self.model
875
 
876
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
877
  def forward(
878
  self,
879
  input_ids: torch.LongTensor = None,
@@ -887,31 +715,6 @@ class SkyworkForCausalLM(SkyworkPreTrainedModel):
887
  output_hidden_states: Optional[bool] = None,
888
  return_dict: Optional[bool] = None,
889
  ) -> Union[Tuple, CausalLMOutputWithPast]:
890
- r"""
891
- Args:
892
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
893
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
894
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
895
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
896
-
897
- Returns:
898
-
899
- Example:
900
-
901
- ```python
902
- >>> from transformers import AutoTokenizer, SkyworkForCausalLM
903
-
904
- >>> model = SkyworkForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
905
- >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
906
-
907
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
908
- >>> inputs = tokenizer(prompt, return_tensors="pt")
909
-
910
- >>> # Generate
911
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
912
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
913
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
914
- ```"""
915
 
916
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
917
  output_hidden_states = (
@@ -1005,6 +808,7 @@ class SkyworkForCausalLM(SkyworkPreTrainedModel):
1005
  )
1006
  return reordered_past
1007
 
 
1008
  class SkyworkForSequenceClassification(SkyworkPreTrainedModel):
1009
  def __init__(self, config):
1010
  super().__init__(config)
@@ -1034,12 +838,8 @@ class SkyworkForSequenceClassification(SkyworkPreTrainedModel):
1034
  output_hidden_states: Optional[bool] = None,
1035
  return_dict: Optional[bool] = None,
1036
  ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1037
- r"""
1038
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1039
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1040
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1041
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1042
- """
1043
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1044
 
1045
  transformer_outputs = self.model(
@@ -1108,4 +908,4 @@ class SkyworkForSequenceClassification(SkyworkPreTrainedModel):
1108
  past_key_values=transformer_outputs.past_key_values,
1109
  hidden_states=transformer_outputs.hidden_states,
1110
  attentions=transformer_outputs.attentions,
1111
- )
 
1
  # Copyright (c) SkyworkAI and the HuggingFace Inc. team. All rights reserved.
2
  # This code is built upon Huggingface's transformers repository.
3
+
4
  import math
5
  from typing import List, Optional, Tuple, Union
6
 
 
13
  from transformers.activations import ACT2FN
14
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
15
  from transformers.modeling_utils import PreTrainedModel
16
+ from transformers.utils import logging
 
 
 
 
 
 
 
17
  from .configuration_skywork import SkyworkConfig
18
 
19
 
 
 
 
 
 
20
  logger = logging.get_logger(__name__)
21
 
22
  _CONFIG_FOR_DOC = "SkyworkConfig"
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
26
  def _make_causal_mask(
27
  input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
 
72
  return self.weight * hidden_states.to(input_dtype)
73
 
74
 
75
+ class SkyworkRotaryEmbedding(torch.nn.Module):
 
 
 
76
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
77
  super().__init__()
78
 
 
94
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
95
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
96
  emb = torch.cat((freqs, freqs), dim=-1)
97
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
98
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
99
 
100
  def forward(self, x, seq_len=None):
101
  # x: [bs, num_attention_heads, seq_len, head_size]
 
103
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
104
 
105
  return (
106
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
107
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
108
  )
109
 
110
 
 
123
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
124
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
125
  emb = torch.cat((freqs, freqs), dim=-1)
126
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
127
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
128
 
129
 
130
  class SkyworkDynamicNTKScalingRotaryEmbedding(SkyworkRotaryEmbedding):
 
149
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
150
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
151
  emb = torch.cat((freqs, freqs), dim=-1)
152
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
153
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
154
+
155
+
156
+
157
+ class SkyworkNTKScalingRotaryEmbedding(torch.nn.Module):
158
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=100, device=None):
159
+ super().__init__()
160
+
161
+ self.dim = dim
162
+ self.max_position_embeddings = max_position_embeddings
163
+ self.base = base * scaling_factor
164
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
165
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
166
+
167
+ # Build here to make `torch.jit.trace` work.
168
+ self._set_cos_sin_cache(
169
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
170
+ )
171
+
172
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
173
+ self.max_seq_len_cached = seq_len
174
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
175
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
176
+ emb = torch.cat((freqs, freqs), dim=-1)
177
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
178
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
179
+
180
+ def forward(self, x, seq_len=None):
181
+ if seq_len > self.max_seq_len_cached:
182
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
183
 
184
+ return (
185
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
186
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
187
+ )
188
 
189
  def rotate_half(x):
190
  """Rotates half the hidden dims of the input."""
 
193
  return torch.cat((-x2, x1), dim=-1)
194
 
195
 
 
196
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
197
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
198
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
199
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
200
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
201
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
202
  q_embed = (q * cos) + (rotate_half(q) * sin)
203
  k_embed = (k * cos) + (rotate_half(k) * sin)
204
  return q_embed, k_embed
 
269
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
270
  f" and `num_heads`: {self.num_heads})."
271
  )
272
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
273
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
274
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
275
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
276
  self._init_rope()
277
 
278
  def _init_rope(self):
 
299
  scaling_factor=scaling_factor,
300
  base=self.rope_theta,
301
  )
302
+ elif scaling_type == "ntk":
303
+ self.rotary_emb = SkyworkNTKScalingRotaryEmbedding(
304
+ self.head_dim,
305
+ max_position_embeddings=self.max_position_embeddings,
306
+ scaling_factor=scaling_factor,
307
+ base=self.rope_theta,
308
+ )
309
  else:
310
  raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
311
+ print('-'*80)
312
+ print(f"USING COSTOM MODELING, scaling_type is {scaling_type}, scaling_factor is {scaling_factor}")
313
+
314
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
315
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
316
 
 
322
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
323
  output_attentions: bool = False,
324
  use_cache: bool = False,
 
325
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
326
  bsz, q_len, _ = hidden_states.size()
327
 
 
364
 
365
  past_key_value = (key_states, value_states) if use_cache else None
366
 
367
+ # repeat k/v heads if n_kv_heads < n_heads
368
  key_states = repeat_kv(key_states, self.num_key_value_groups)
369
  value_states = repeat_kv(value_states, self.num_key_value_groups)
370
 
 
394
  )
395
 
396
  attn_output = attn_output.transpose(1, 2).contiguous()
 
397
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
398
 
399
  if self.config.pretraining_tp > 1:
 
409
  return attn_output, attn_weights, past_key_value
410
 
411
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  class SkyworkDecoderLayer(nn.Module):
413
  def __init__(self, config: SkyworkConfig):
414
  super().__init__()
415
  self.hidden_size = config.hidden_size
416
+ self.self_attn = SkyworkAttention(config=config)
 
 
 
 
417
  self.mlp = SkyworkMLP(config)
418
  self.input_layernorm = SkyworkRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
419
  self.post_attention_layernorm = SkyworkRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
426
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
427
  output_attentions: Optional[bool] = False,
428
  use_cache: Optional[bool] = False,
 
429
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
430
  """
431
  Args:
 
453
  past_key_value=past_key_value,
454
  output_attentions=output_attentions,
455
  use_cache=use_cache,
 
456
  )
457
  hidden_states = residual + hidden_states
458
 
 
478
  supports_gradient_checkpointing = True
479
  _no_split_modules = ["SkyworkDecoderLayer"]
480
  _skip_keys_device_placement = "past_key_values"
 
481
 
482
  def _init_weights(self, module):
483
  std = self.config.initializer_range
 
567
 
568
  # retrieve input_ids and inputs_embeds
569
  if input_ids is not None and inputs_embeds is not None:
570
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
571
  elif input_ids is not None:
572
  batch_size, seq_length = input_ids.shape
573
  elif inputs_embeds is not None:
574
  batch_size, seq_length, _ = inputs_embeds.shape
575
  else:
576
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
577
 
578
  seq_length_with_past = seq_length
579
  past_key_values_length = 0
 
587
  position_ids = torch.arange(
588
  past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
589
  )
590
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
591
+ else:
592
+ position_ids = position_ids.view(-1, seq_length).long()
593
 
594
  if inputs_embeds is None:
595
  inputs_embeds = self.embed_tokens(input_ids)
 
598
  attention_mask = torch.ones(
599
  (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
600
  )
 
 
 
 
 
 
 
601
  attention_mask = self._prepare_decoder_attention_mask(
602
  attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
603
  )
 
627
  def create_custom_forward(module):
628
  def custom_forward(*inputs):
629
  # None for past_key_value
630
+ return module(*inputs, past_key_value, output_attentions)
631
 
632
  return custom_forward
633
 
634
  layer_outputs = torch.utils.checkpoint.checkpoint(
635
+ create_custom_forward(decoder_layer),
636
+ hidden_states,
637
+ attention_mask,
638
+ position_ids,
639
  )
640
  else:
641
  layer_outputs = decoder_layer(
 
645
  past_key_value=past_key_value,
646
  output_attentions=output_attentions,
647
  use_cache=use_cache,
 
648
  )
649
 
650
  hidden_states = layer_outputs[0]
 
702
  def get_decoder(self):
703
  return self.model
704
 
 
705
  def forward(
706
  self,
707
  input_ids: torch.LongTensor = None,
 
715
  output_hidden_states: Optional[bool] = None,
716
  return_dict: Optional[bool] = None,
717
  ) -> Union[Tuple, CausalLMOutputWithPast]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
718
 
719
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
720
  output_hidden_states = (
 
808
  )
809
  return reordered_past
810
 
811
+
812
  class SkyworkForSequenceClassification(SkyworkPreTrainedModel):
813
  def __init__(self, config):
814
  super().__init__(config)
 
838
  output_hidden_states: Optional[bool] = None,
839
  return_dict: Optional[bool] = None,
840
  ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
841
+
842
+
 
 
 
 
843
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
844
 
845
  transformer_outputs = self.model(
 
908
  past_key_values=transformer_outputs.past_key_values,
909
  hidden_states=transformer_outputs.hidden_states,
910
  attentions=transformer_outputs.attentions,
911
+ )
tokenization_skywork.py CHANGED
@@ -1,22 +1,5 @@
1
- # coding=utf-8
2
- # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
- # and OPT implementations in this library. It has been modified from its
6
- # original forms to accommodate minor architectural differences compared
7
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
- #
9
- # Licensed under the Apache License, Version 2.0 (the "License");
10
- # you may not use this file except in compliance with the License.
11
- # You may obtain a copy of the License at
12
- #
13
- # http://www.apache.org/licenses/LICENSE-2.0
14
- #
15
- # Unless required by applicable law or agreed to in writing, software
16
- # distributed under the License is distributed on an "AS IS" BASIS,
17
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
- # See the License for the specific language governing permissions and
19
- # limitations under the License.
20
 
21
  """Tokenization classes for Skywork."""
22
  import os
 
1
+ # Copyright (c) SkyworkAI and the HuggingFace Inc. team. All rights reserved.
2
+ # This code is built upon Huggingface's transformers repository.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  """Tokenization classes for Skywork."""
5
  import os