liang.zhao commited on
Commit
5019740
1 Parent(s): 92427c8

update model and config

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +53 -0
  2. config.json +1 -0
  3. configuration_skywork_moe.py +106 -0
  4. generation_config.json +1 -0
  5. modeling_skywork_moe.py +1604 -0
  6. pytorch_model-00001-of-00053.bin +3 -0
  7. pytorch_model-00002-of-00053.bin +3 -0
  8. pytorch_model-00003-of-00053.bin +3 -0
  9. pytorch_model-00004-of-00053.bin +3 -0
  10. pytorch_model-00005-of-00053.bin +3 -0
  11. pytorch_model-00006-of-00053.bin +3 -0
  12. pytorch_model-00007-of-00053.bin +3 -0
  13. pytorch_model-00008-of-00053.bin +3 -0
  14. pytorch_model-00009-of-00053.bin +3 -0
  15. pytorch_model-00010-of-00053.bin +3 -0
  16. pytorch_model-00011-of-00053.bin +3 -0
  17. pytorch_model-00012-of-00053.bin +3 -0
  18. pytorch_model-00013-of-00053.bin +3 -0
  19. pytorch_model-00014-of-00053.bin +3 -0
  20. pytorch_model-00015-of-00053.bin +3 -0
  21. pytorch_model-00016-of-00053.bin +3 -0
  22. pytorch_model-00017-of-00053.bin +3 -0
  23. pytorch_model-00018-of-00053.bin +3 -0
  24. pytorch_model-00019-of-00053.bin +3 -0
  25. pytorch_model-00020-of-00053.bin +3 -0
  26. pytorch_model-00021-of-00053.bin +3 -0
  27. pytorch_model-00022-of-00053.bin +3 -0
  28. pytorch_model-00023-of-00053.bin +3 -0
  29. pytorch_model-00024-of-00053.bin +3 -0
  30. pytorch_model-00025-of-00053.bin +3 -0
  31. pytorch_model-00026-of-00053.bin +3 -0
  32. pytorch_model-00027-of-00053.bin +3 -0
  33. pytorch_model-00028-of-00053.bin +3 -0
  34. pytorch_model-00029-of-00053.bin +3 -0
  35. pytorch_model-00030-of-00053.bin +3 -0
  36. pytorch_model-00031-of-00053.bin +3 -0
  37. pytorch_model-00032-of-00053.bin +3 -0
  38. pytorch_model-00033-of-00053.bin +3 -0
  39. pytorch_model-00034-of-00053.bin +3 -0
  40. pytorch_model-00035-of-00053.bin +3 -0
  41. pytorch_model-00036-of-00053.bin +3 -0
  42. pytorch_model-00037-of-00053.bin +3 -0
  43. pytorch_model-00038-of-00053.bin +3 -0
  44. pytorch_model-00039-of-00053.bin +3 -0
  45. pytorch_model-00040-of-00053.bin +3 -0
  46. pytorch_model-00041-of-00053.bin +3 -0
  47. pytorch_model-00042-of-00053.bin +3 -0
  48. pytorch_model-00043-of-00053.bin +3 -0
  49. pytorch_model-00044-of-00053.bin +3 -0
  50. pytorch_model-00045-of-00053.bin +3 -0
.gitattributes CHANGED
@@ -33,3 +33,56 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ pytorch_model-00011-of-00053.bin filter=lfs diff=lfs merge=lfs -text
37
+ pytorch_model-00035-of-00053.bin filter=lfs diff=lfs merge=lfs -text
38
+ pytorch_model-00047-of-00053.bin filter=lfs diff=lfs merge=lfs -text
39
+ pytorch_model-00008-of-00053.bin filter=lfs diff=lfs merge=lfs -text
40
+ pytorch_model-00009-of-00053.bin filter=lfs diff=lfs merge=lfs -text
41
+ pytorch_model-00021-of-00053.bin filter=lfs diff=lfs merge=lfs -text
42
+ pytorch_model-00034-of-00053.bin filter=lfs diff=lfs merge=lfs -text
43
+ pytorch_model-00052-of-00053.bin filter=lfs diff=lfs merge=lfs -text
44
+ pytorch_model-00010-of-00053.bin filter=lfs diff=lfs merge=lfs -text
45
+ pytorch_model-00016-of-00053.bin filter=lfs diff=lfs merge=lfs -text
46
+ pytorch_model-00024-of-00053.bin filter=lfs diff=lfs merge=lfs -text
47
+ pytorch_model-00040-of-00053.bin filter=lfs diff=lfs merge=lfs -text
48
+ pytorch_model-00033-of-00053.bin filter=lfs diff=lfs merge=lfs -text
49
+ pytorch_model-00036-of-00053.bin filter=lfs diff=lfs merge=lfs -text
50
+ pytorch_model-00038-of-00053.bin filter=lfs diff=lfs merge=lfs -text
51
+ pytorch_model-00046-of-00053.bin filter=lfs diff=lfs merge=lfs -text
52
+ pytorch_model-00026-of-00053.bin filter=lfs diff=lfs merge=lfs -text
53
+ pytorch_model-00048-of-00053.bin filter=lfs diff=lfs merge=lfs -text
54
+ pytorch_model-00015-of-00053.bin filter=lfs diff=lfs merge=lfs -text
55
+ pytorch_model-00014-of-00053.bin filter=lfs diff=lfs merge=lfs -text
56
+ pytorch_model-00004-of-00053.bin filter=lfs diff=lfs merge=lfs -text
57
+ pytorch_model-00025-of-00053.bin filter=lfs diff=lfs merge=lfs -text
58
+ pytorch_model-00050-of-00053.bin filter=lfs diff=lfs merge=lfs -text
59
+ pytorch_model-00019-of-00053.bin filter=lfs diff=lfs merge=lfs -text
60
+ pytorch_model-00027-of-00053.bin filter=lfs diff=lfs merge=lfs -text
61
+ pytorch_model-00028-of-00053.bin filter=lfs diff=lfs merge=lfs -text
62
+ pytorch_model-00045-of-00053.bin filter=lfs diff=lfs merge=lfs -text
63
+ pytorch_model-00003-of-00053.bin filter=lfs diff=lfs merge=lfs -text
64
+ pytorch_model-00006-of-00053.bin filter=lfs diff=lfs merge=lfs -text
65
+ pytorch_model-00017-of-00053.bin filter=lfs diff=lfs merge=lfs -text
66
+ pytorch_model-00039-of-00053.bin filter=lfs diff=lfs merge=lfs -text
67
+ pytorch_model-00041-of-00053.bin filter=lfs diff=lfs merge=lfs -text
68
+ pytorch_model-00002-of-00053.bin filter=lfs diff=lfs merge=lfs -text
69
+ pytorch_model-00005-of-00053.bin filter=lfs diff=lfs merge=lfs -text
70
+ pytorch_model-00037-of-00053.bin filter=lfs diff=lfs merge=lfs -text
71
+ pytorch_model-00042-of-00053.bin filter=lfs diff=lfs merge=lfs -text
72
+ pytorch_model-00007-of-00053.bin filter=lfs diff=lfs merge=lfs -text
73
+ pytorch_model-00023-of-00053.bin filter=lfs diff=lfs merge=lfs -text
74
+ pytorch_model-00051-of-00053.bin filter=lfs diff=lfs merge=lfs -text
75
+ pytorch_model-00053-of-00053.bin filter=lfs diff=lfs merge=lfs -text
76
+ pytorch_model-00043-of-00053.bin filter=lfs diff=lfs merge=lfs -text
77
+ pytorch_model-00044-of-00053.bin filter=lfs diff=lfs merge=lfs -text
78
+ pytorch_model-00001-of-00053.bin filter=lfs diff=lfs merge=lfs -text
79
+ pytorch_model-00013-of-00053.bin filter=lfs diff=lfs merge=lfs -text
80
+ pytorch_model-00022-of-00053.bin filter=lfs diff=lfs merge=lfs -text
81
+ pytorch_model-00020-of-00053.bin filter=lfs diff=lfs merge=lfs -text
82
+ pytorch_model-00031-of-00053.bin filter=lfs diff=lfs merge=lfs -text
83
+ pytorch_model-00032-of-00053.bin filter=lfs diff=lfs merge=lfs -text
84
+ pytorch_model-00049-of-00053.bin filter=lfs diff=lfs merge=lfs -text
85
+ pytorch_model-00012-of-00053.bin filter=lfs diff=lfs merge=lfs -text
86
+ pytorch_model-00030-of-00053.bin filter=lfs diff=lfs merge=lfs -text
87
+ pytorch_model-00018-of-00053.bin filter=lfs diff=lfs merge=lfs -text
88
+ pytorch_model-00029-of-00053.bin filter=lfs diff=lfs merge=lfs -text
config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"architectures": ["SkyworkForCausalLM"], "auto_map": {"AutoConfig": "configuration_skywork_moe.SkyworkMoeConfig", "AutoModelForCausalLM": "modeling_skywork_moe.SkyworkForCausalLM"}, "model_type": "skywork", "vocab_size": 65532, "bos_token_id": 1, "eos_token_id": 2, "pad_token_id": 0, "hidden_act": "silu", "hidden_size": 4608, "initializer_range": 0.01, "intermediate_size": 12288, "max_position_embeddings": 8192, "num_attention_heads": 36, "num_key_value_heads": 36, "num_hidden_layers": 52, "num_experts": [16], "moe_use_skywork_gating": false, "moe_2layer_gate": false, "moe_use_logits_norm": true, "moe_gate_norm_std": 1.0, "moe_feature_no_mul_topk": true, "sliding_window": null, "moe_expert_interval": 1, "rms_norm_eps": 1e-06, "rotary_percent": 1.0, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "use_cache": true, "transformers_version": "4.40.1", "rope_theta": 10000}
configuration_skywork_moe.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SkyworkAI and the HuggingFace Inc. team. All rights reserved.
2
+ # This code is built upon Huggingface's transformers repository.
3
+
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+ from transformers.utils import logging
7
+
8
+
9
+ logger = logging.get_logger(__name__)
10
+
11
+ LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
12
+
13
+
14
+ class SkyworkMoeConfig(PretrainedConfig):
15
+
16
+ model_type = "skywork"
17
+ keys_to_ignore_at_inference = ["past_key_values"]
18
+
19
+ def __init__(
20
+ self,
21
+ vocab_size=32000,
22
+ hidden_size=4096,
23
+ intermediate_size=11008,
24
+ num_hidden_layers=32,
25
+ num_attention_heads=32,
26
+ num_key_value_heads=None,
27
+ hidden_act="silu",
28
+ max_position_embeddings=2048,
29
+ initializer_range=0.02,
30
+ rms_norm_eps=1e-6,
31
+ use_cache=True,
32
+ pad_token_id=None,
33
+ bos_token_id=1,
34
+ eos_token_id=2,
35
+ pretraining_tp=1,
36
+ tie_word_embeddings=False,
37
+ rope_theta=10000.0,
38
+ rope_scaling=None,
39
+ num_experts=[32],
40
+ moe_expert_interval=1,
41
+ moe_use_skywork_gating=False,
42
+ moe_2layer_gate=True,
43
+ moe_use_logits_norm=False,
44
+ moe_gate_norm_std=1.0,
45
+ moe_feature_no_mul_topk=False,
46
+ sliding_window=None,
47
+
48
+ **kwargs,
49
+ ):
50
+ self.vocab_size = vocab_size
51
+ self.max_position_embeddings = max_position_embeddings
52
+ self.hidden_size = hidden_size
53
+ self.intermediate_size = intermediate_size
54
+ self.num_hidden_layers = num_hidden_layers
55
+ self.num_attention_heads = num_attention_heads
56
+
57
+ # for backward compatibility
58
+ if num_key_value_heads is None:
59
+ num_key_value_heads = num_attention_heads
60
+
61
+ self.num_key_value_heads = num_key_value_heads
62
+ self.hidden_act = hidden_act
63
+ self.initializer_range = initializer_range
64
+ self.rms_norm_eps = rms_norm_eps
65
+ self.pretraining_tp = pretraining_tp
66
+ self.use_cache = use_cache
67
+ self.rope_theta = rope_theta
68
+ self.rope_scaling = rope_scaling
69
+ self._rope_scaling_validation()
70
+ self.num_experts = num_experts
71
+ self.moe_expert_interval = moe_expert_interval
72
+ self.moe_use_skywork_gating = moe_use_skywork_gating
73
+ self.moe_2layer_gate = moe_2layer_gate
74
+ self.moe_use_logits_norm = moe_use_logits_norm
75
+ self.moe_gate_norm_std = moe_gate_norm_std
76
+ self.moe_feature_no_mul_topk = moe_feature_no_mul_topk
77
+ self.sliding_window = sliding_window
78
+
79
+ super().__init__(
80
+ pad_token_id=pad_token_id,
81
+ bos_token_id=bos_token_id,
82
+ eos_token_id=eos_token_id,
83
+ tie_word_embeddings=tie_word_embeddings,
84
+ **kwargs,
85
+ )
86
+
87
+ def _rope_scaling_validation(self):
88
+ """
89
+ Validate the `rope_scaling` configuration.
90
+ """
91
+ if self.rope_scaling is None:
92
+ return
93
+
94
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
95
+ raise ValueError(
96
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
97
+ f"got {self.rope_scaling}"
98
+ )
99
+ rope_scaling_type = self.rope_scaling.get("type", None)
100
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
101
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "ntk"]:
102
+ raise ValueError(
103
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
104
+ )
105
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
106
+ raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
generation_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"_from_model_config": true, "bos_token_id": 1, "eos_token_id": 2, "pad_token_id": 0, "transformers_version": "4.40.1"}
modeling_skywork_moe.py ADDED
@@ -0,0 +1,1604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch Skywork model."""
21
+ import inspect
22
+ import math
23
+ import warnings
24
+ from typing import List, Optional, Tuple, Union
25
+
26
+ import torch
27
+ import torch.nn.functional as F
28
+ import torch.utils.checkpoint
29
+ from torch import nn
30
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
+
32
+ from transformers.activations import ACT2FN
33
+ from transformers.cache_utils import Cache, DynamicCache
34
+ from transformers.modeling_attn_mask_utils import (
35
+ _prepare_4d_causal_attention_mask,
36
+ _prepare_4d_causal_attention_mask_for_sdpa,
37
+ )
38
+ from transformers.modeling_outputs import (
39
+ MoeCausalLMOutputWithPast,
40
+ MoeModelOutputWithPast,
41
+ SequenceClassifierOutputWithPast,
42
+ )
43
+ from transformers.modeling_utils import PreTrainedModel
44
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13
45
+ from transformers.utils import (
46
+ add_start_docstrings,
47
+ add_start_docstrings_to_model_forward,
48
+ is_flash_attn_2_available,
49
+ is_flash_attn_greater_or_equal_2_10,
50
+ logging,
51
+ replace_return_docstrings,
52
+ )
53
+ from transformers.utils.import_utils import is_torch_fx_available
54
+ from .configuration_skywork_moe import SkyworkMoeConfig
55
+
56
+
57
+ if is_flash_attn_2_available():
58
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
59
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
60
+
61
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
62
+
63
+ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
64
+ # It means that the function will not be traced through and simply appear as a node in the graph.
65
+ if is_torch_fx_available():
66
+ if not is_torch_greater_or_equal_than_1_13:
67
+ import torch.fx
68
+
69
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
70
+
71
+
72
+ logger = logging.get_logger(__name__)
73
+
74
+ _CONFIG_FOR_DOC = "SkyworkMoeConfig"
75
+
76
+
77
+ def load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2) -> float:
78
+ r"""
79
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
80
+
81
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
82
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
83
+ experts is too unbalanced.
84
+
85
+ Args:
86
+ gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
87
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
88
+ shape [batch_size X sequence_length, num_experts].
89
+ num_experts (`int`, *optional*):
90
+ Number of experts
91
+
92
+ Returns:
93
+ The auxiliary loss.
94
+ """
95
+ if gate_logits is None or not isinstance(gate_logits, tuple):
96
+ return 0
97
+
98
+ if isinstance(gate_logits, tuple):
99
+ compute_device = gate_logits[0].device
100
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
101
+
102
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
103
+
104
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
105
+
106
+ # treat `top_k` as tokens (shape is `top_k X [batch_size X sequence_length]`)
107
+ selected_experts = selected_experts.reshape(-1)
108
+
109
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
110
+ expert_mask = torch.max(expert_mask, dim=-2).values
111
+
112
+ # Compute the percentage of tokens routed to each experts
113
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
114
+
115
+ # Compute the average probability of routing to these experts
116
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
117
+
118
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(-1))
119
+ return overall_loss * num_experts
120
+
121
+
122
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
123
+ def _get_unpad_data(attention_mask):
124
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
125
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
126
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
127
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
128
+ return (
129
+ indices,
130
+ cu_seqlens,
131
+ max_seqlen_in_batch,
132
+ )
133
+
134
+
135
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Skywork
136
+ class SkyworkRMSNorm(nn.Module):
137
+ def __init__(self, hidden_size, eps=1e-6):
138
+ """
139
+ SkyworkRMSNorm is equivalent to T5LayerNorm
140
+ """
141
+ super().__init__()
142
+ self.weight = nn.Parameter(torch.ones(hidden_size))
143
+ self.variance_epsilon = eps
144
+
145
+ def forward(self, hidden_states):
146
+ input_dtype = hidden_states.dtype
147
+ hidden_states = hidden_states.to(torch.float32)
148
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
149
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
150
+ return self.weight * hidden_states.to(input_dtype)
151
+
152
+
153
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Skywork
154
+ class SkyworkRotaryEmbedding(nn.Module):
155
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
156
+ super().__init__()
157
+
158
+ self.dim = dim
159
+ self.max_position_embeddings = max_position_embeddings
160
+ self.base = base
161
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
162
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
163
+
164
+ # Build here to make `torch.jit.trace` work.
165
+ self._set_cos_sin_cache(
166
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
167
+ )
168
+
169
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
170
+ self.max_seq_len_cached = seq_len
171
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
172
+
173
+ freqs = torch.outer(t, self.inv_freq)
174
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
175
+ emb = torch.cat((freqs, freqs), dim=-1)
176
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
177
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
178
+
179
+ def forward(self, x, seq_len=None):
180
+ # x: [bs, num_attention_heads, seq_len, head_size]
181
+ if seq_len > self.max_seq_len_cached:
182
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
183
+
184
+ return (
185
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
186
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
187
+ )
188
+
189
+
190
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
191
+ def rotate_half(x):
192
+ """Rotates half the hidden dims of the input."""
193
+ x1 = x[..., : x.shape[-1] // 2]
194
+ x2 = x[..., x.shape[-1] // 2 :]
195
+ return torch.cat((-x2, x1), dim=-1)
196
+
197
+
198
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
199
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
200
+ """Applies Rotary Position Embedding to the query and key tensors.
201
+
202
+ Args:
203
+ q (`torch.Tensor`): The query tensor.
204
+ k (`torch.Tensor`): The key tensor.
205
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
206
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
207
+ position_ids (`torch.Tensor`):
208
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
209
+ used to pass offsetted position ids when working with a KV-cache.
210
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
211
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
212
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
213
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
214
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
215
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
216
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
217
+ Returns:
218
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
219
+ """
220
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
221
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
222
+ q_embed = (q * cos) + (rotate_half(q) * sin)
223
+ k_embed = (k * cos) + (rotate_half(k) * sin)
224
+ return q_embed, k_embed
225
+
226
+
227
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
228
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
229
+ """
230
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
231
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
232
+ """
233
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
234
+ if n_rep == 1:
235
+ return hidden_states
236
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
237
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
238
+
239
+
240
+ # Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Skywork
241
+ class SkyworkAttention(nn.Module):
242
+ """
243
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
244
+ and "Generating Long Sequences with Sparse Transformers".
245
+ """
246
+
247
+ def __init__(self, config: SkyworkMoeConfig, layer_idx: Optional[int] = None):
248
+ super().__init__()
249
+ self.config = config
250
+ self.layer_idx = layer_idx
251
+ if layer_idx is None:
252
+ logger.warning_once(
253
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
254
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
255
+ "when creating this class."
256
+ )
257
+
258
+ self.hidden_size = config.hidden_size
259
+ self.num_heads = config.num_attention_heads
260
+ self.head_dim = self.hidden_size // self.num_heads
261
+ self.num_key_value_heads = config.num_key_value_heads
262
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
263
+ self.max_position_embeddings = config.max_position_embeddings
264
+ self.rope_theta = config.rope_theta
265
+ self.is_causal = True
266
+ self.attention_dropout = 0.0 # notice: support inference only.
267
+
268
+ if (self.head_dim * self.num_heads) != self.hidden_size:
269
+ raise ValueError(
270
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
271
+ f" and `num_heads`: {self.num_heads})."
272
+ )
273
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
274
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
275
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
276
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
277
+
278
+ self.rotary_emb = SkyworkRotaryEmbedding(
279
+ self.head_dim,
280
+ max_position_embeddings=self.max_position_embeddings,
281
+ base=self.rope_theta,
282
+ )
283
+
284
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
285
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
286
+
287
+ def forward(
288
+ self,
289
+ hidden_states: torch.Tensor,
290
+ attention_mask: Optional[torch.Tensor] = None,
291
+ position_ids: Optional[torch.LongTensor] = None,
292
+ past_key_value: Optional[Cache] = None,
293
+ output_attentions: bool = False,
294
+ use_cache: bool = False,
295
+ **kwargs,
296
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
297
+ if "padding_mask" in kwargs:
298
+ warnings.warn(
299
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
300
+ )
301
+ bsz, q_len, _ = hidden_states.size()
302
+
303
+ query_states = self.q_proj(hidden_states)
304
+ key_states = self.k_proj(hidden_states)
305
+ value_states = self.v_proj(hidden_states)
306
+
307
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
308
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
309
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
310
+
311
+ kv_seq_len = key_states.shape[-2]
312
+ if past_key_value is not None:
313
+ if self.layer_idx is None:
314
+ raise ValueError(
315
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
316
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
317
+ "with a layer index."
318
+ )
319
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
320
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
321
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
322
+
323
+ if past_key_value is not None:
324
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
325
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
326
+
327
+ # repeat k/v heads if n_kv_heads < n_heads
328
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
329
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
330
+
331
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
332
+
333
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
334
+ raise ValueError(
335
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
336
+ f" {attn_weights.size()}"
337
+ )
338
+
339
+ if attention_mask is not None:
340
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
341
+ raise ValueError(
342
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
343
+ )
344
+
345
+ attn_weights = attn_weights + attention_mask
346
+
347
+ # upcast attention to fp32
348
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
349
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
350
+ attn_output = torch.matmul(attn_weights, value_states)
351
+
352
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
353
+ raise ValueError(
354
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
355
+ f" {attn_output.size()}"
356
+ )
357
+
358
+ attn_output = attn_output.transpose(1, 2).contiguous()
359
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
360
+
361
+ attn_output = self.o_proj(attn_output)
362
+
363
+ if not output_attentions:
364
+ attn_weights = None
365
+
366
+ return attn_output, attn_weights, past_key_value
367
+
368
+
369
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Skywork
370
+ class SkyworkFlashAttention2(SkyworkAttention):
371
+ """
372
+ Skywork flash attention module. This module inherits from `SkyworkAttention` as the weights of the module stays
373
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
374
+ flash attention and deal with padding tokens in case the input contains any of them.
375
+ """
376
+
377
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
378
+ def __init__(self, *args, **kwargs):
379
+ super().__init__(*args, **kwargs)
380
+
381
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
382
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
383
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
384
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
385
+
386
+ def forward(
387
+ self,
388
+ hidden_states: torch.Tensor,
389
+ attention_mask: Optional[torch.Tensor] = None,
390
+ position_ids: Optional[torch.LongTensor] = None,
391
+ past_key_value: Optional[Cache] = None,
392
+ output_attentions: bool = False,
393
+ use_cache: bool = False,
394
+ **kwargs,
395
+ ):
396
+ if "padding_mask" in kwargs:
397
+ warnings.warn(
398
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
399
+ )
400
+
401
+ # overwrite attention_mask with padding_mask
402
+ attention_mask = kwargs.pop("padding_mask")
403
+ bsz, q_len, _ = hidden_states.size()
404
+
405
+ query_states = self.q_proj(hidden_states)
406
+ key_states = self.k_proj(hidden_states)
407
+ value_states = self.v_proj(hidden_states)
408
+
409
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
410
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
411
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
412
+
413
+ kv_seq_len = key_states.shape[-2]
414
+ if past_key_value is not None:
415
+ if self.layer_idx is None:
416
+ raise ValueError(
417
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
418
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
419
+ "with a layer index."
420
+ )
421
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
422
+
423
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
424
+ rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
425
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
426
+
427
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
428
+
429
+ use_sliding_windows = (
430
+ _flash_supports_window_size
431
+ and getattr(self.config, "sliding_window", None) is not None
432
+ and kv_seq_len > self.config.sliding_window
433
+ )
434
+
435
+ if not _flash_supports_window_size:
436
+ logger.warning_once(
437
+ "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
438
+ " make sure to upgrade flash-attn library."
439
+ )
440
+
441
+ if past_key_value is not None:
442
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
443
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
444
+ if (
445
+ getattr(self.config, "sliding_window", None) is not None
446
+ and kv_seq_len > self.config.sliding_window
447
+ and cache_has_contents
448
+ ):
449
+ slicing_tokens = 1 - self.config.sliding_window
450
+
451
+ past_key = past_key_value[self.layer_idx][0]
452
+ past_value = past_key_value[self.layer_idx][1]
453
+
454
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
455
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
456
+
457
+ if past_key.shape[-2] != self.config.sliding_window - 1:
458
+ raise ValueError(
459
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
460
+ f" {past_key.shape}"
461
+ )
462
+
463
+ if attention_mask is not None:
464
+ attention_mask = attention_mask[:, slicing_tokens:]
465
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
466
+
467
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
468
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
469
+
470
+ # repeat k/v heads if n_kv_heads < n_heads
471
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
472
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
473
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
474
+
475
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
476
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
477
+ # cast them back in float16 just to be sure everything works as expected.
478
+ input_dtype = query_states.dtype
479
+ if input_dtype == torch.float32:
480
+ if torch.is_autocast_enabled():
481
+ target_dtype = torch.get_autocast_gpu_dtype()
482
+ # Handle the case where the model is quantized
483
+ elif hasattr(self.config, "_pre_quantization_dtype"):
484
+ target_dtype = self.config._pre_quantization_dtype
485
+ else:
486
+ target_dtype = self.q_proj.weight.dtype
487
+
488
+ logger.warning_once(
489
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
490
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
491
+ f" {target_dtype}."
492
+ )
493
+
494
+ query_states = query_states.to(target_dtype)
495
+ key_states = key_states.to(target_dtype)
496
+ value_states = value_states.to(target_dtype)
497
+
498
+ # Reashape to the expected shape for Flash Attention
499
+ query_states = query_states.transpose(1, 2)
500
+ key_states = key_states.transpose(1, 2)
501
+ value_states = value_states.transpose(1, 2)
502
+
503
+ attn_output = self._flash_attention_forward(
504
+ query_states,
505
+ key_states,
506
+ value_states,
507
+ attention_mask,
508
+ q_len,
509
+ dropout=dropout_rate,
510
+ use_sliding_windows=use_sliding_windows,
511
+ )
512
+
513
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
514
+ attn_output = self.o_proj(attn_output)
515
+
516
+ if not output_attentions:
517
+ attn_weights = None
518
+
519
+ return attn_output, attn_weights, past_key_value
520
+
521
+ def _flash_attention_forward(
522
+ self,
523
+ query_states,
524
+ key_states,
525
+ value_states,
526
+ attention_mask,
527
+ query_length,
528
+ dropout=0.0,
529
+ softmax_scale=None,
530
+ use_sliding_windows=False,
531
+ ):
532
+ """
533
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
534
+ first unpad the input, then computes the attention scores and pad the final attention scores.
535
+
536
+ Args:
537
+ query_states (`torch.Tensor`):
538
+ Input query states to be passed to Flash Attention API
539
+ key_states (`torch.Tensor`):
540
+ Input key states to be passed to Flash Attention API
541
+ value_states (`torch.Tensor`):
542
+ Input value states to be passed to Flash Attention API
543
+ attention_mask (`torch.Tensor`):
544
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
545
+ position of padding tokens and 1 for the position of non-padding tokens.
546
+ dropout (`int`, *optional*):
547
+ Attention dropout
548
+ softmax_scale (`float`, *optional*):
549
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
550
+ use_sliding_windows (`bool`, *optional*):
551
+ Whether to activate sliding window attention.
552
+ """
553
+ if not self._flash_attn_uses_top_left_mask:
554
+ causal = self.is_causal
555
+ else:
556
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
557
+ causal = self.is_causal and query_length != 1
558
+
559
+ # Contains at least one padding token in the sequence
560
+ if attention_mask is not None:
561
+ batch_size = query_states.shape[0]
562
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
563
+ query_states, key_states, value_states, attention_mask, query_length
564
+ )
565
+
566
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
567
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
568
+
569
+ if not use_sliding_windows:
570
+ attn_output_unpad = flash_attn_varlen_func(
571
+ query_states,
572
+ key_states,
573
+ value_states,
574
+ cu_seqlens_q=cu_seqlens_q,
575
+ cu_seqlens_k=cu_seqlens_k,
576
+ max_seqlen_q=max_seqlen_in_batch_q,
577
+ max_seqlen_k=max_seqlen_in_batch_k,
578
+ dropout_p=dropout,
579
+ softmax_scale=softmax_scale,
580
+ causal=causal,
581
+ )
582
+ else:
583
+ attn_output_unpad = flash_attn_varlen_func(
584
+ query_states,
585
+ key_states,
586
+ value_states,
587
+ cu_seqlens_q=cu_seqlens_q,
588
+ cu_seqlens_k=cu_seqlens_k,
589
+ max_seqlen_q=max_seqlen_in_batch_q,
590
+ max_seqlen_k=max_seqlen_in_batch_k,
591
+ dropout_p=dropout,
592
+ softmax_scale=softmax_scale,
593
+ causal=causal,
594
+ window_size=(self.config.sliding_window, self.config.sliding_window),
595
+ )
596
+
597
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
598
+ else:
599
+ if not use_sliding_windows:
600
+ attn_output = flash_attn_func(
601
+ query_states,
602
+ key_states,
603
+ value_states,
604
+ dropout,
605
+ softmax_scale=softmax_scale,
606
+ causal=causal,
607
+ )
608
+ else:
609
+ attn_output = flash_attn_func(
610
+ query_states,
611
+ key_states,
612
+ value_states,
613
+ dropout,
614
+ softmax_scale=softmax_scale,
615
+ causal=causal,
616
+ window_size=(self.config.sliding_window, self.config.sliding_window),
617
+ )
618
+
619
+ return attn_output
620
+
621
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
622
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
623
+
624
+ # On the first iteration we need to properly re-create the padding mask
625
+ # by slicing it on the proper place
626
+ if kv_seq_len != attention_mask.shape[-1]:
627
+ attention_mask_num_tokens = attention_mask.shape[-1]
628
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
629
+
630
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
631
+
632
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
633
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
634
+
635
+ if query_length == kv_seq_len:
636
+ query_layer = index_first_axis(
637
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
638
+ )
639
+ cu_seqlens_q = cu_seqlens_k
640
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
641
+ indices_q = indices_k
642
+ elif query_length == 1:
643
+ max_seqlen_in_batch_q = 1
644
+ cu_seqlens_q = torch.arange(
645
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
646
+ ) # There is a memcpy here, that is very bad.
647
+ indices_q = cu_seqlens_q[:-1]
648
+ query_layer = query_layer.squeeze(1)
649
+ else:
650
+ # The -q_len: slice assumes left padding.
651
+ attention_mask = attention_mask[:, -query_length:]
652
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
653
+
654
+ return (
655
+ query_layer,
656
+ key_layer,
657
+ value_layer,
658
+ indices_q,
659
+ (cu_seqlens_q, cu_seqlens_k),
660
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
661
+ )
662
+
663
+
664
+ # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Skywork
665
+ class SkyworkSdpaAttention(SkyworkAttention):
666
+ """
667
+ Skywork attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
668
+ `SkyworkAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
669
+ SDPA API.
670
+ """
671
+
672
+ # Adapted from SkyworkAttention.forward
673
+ def forward(
674
+ self,
675
+ hidden_states: torch.Tensor,
676
+ attention_mask: Optional[torch.Tensor] = None,
677
+ position_ids: Optional[torch.LongTensor] = None,
678
+ past_key_value: Optional[Cache] = None,
679
+ output_attentions: bool = False,
680
+ use_cache: bool = False,
681
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
682
+ if output_attentions:
683
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
684
+ logger.warning_once(
685
+ "SkyworkModel is using SkyworkSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
686
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
687
+ )
688
+ return super().forward(
689
+ hidden_states=hidden_states,
690
+ attention_mask=attention_mask,
691
+ position_ids=position_ids,
692
+ past_key_value=past_key_value,
693
+ output_attentions=output_attentions,
694
+ use_cache=use_cache,
695
+ )
696
+
697
+ bsz, q_len, _ = hidden_states.size()
698
+
699
+ query_states = self.q_proj(hidden_states)
700
+ key_states = self.k_proj(hidden_states)
701
+ value_states = self.v_proj(hidden_states)
702
+
703
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
704
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
705
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
706
+
707
+ kv_seq_len = key_states.shape[-2]
708
+ if past_key_value is not None:
709
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
710
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
711
+
712
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
713
+
714
+ if past_key_value is not None:
715
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
716
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
717
+
718
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
719
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
720
+
721
+ if attention_mask is not None:
722
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
723
+ raise ValueError(
724
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
725
+ )
726
+
727
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
728
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
729
+ if query_states.device.type == "cuda" and attention_mask is not None:
730
+ query_states = query_states.contiguous()
731
+ key_states = key_states.contiguous()
732
+ value_states = value_states.contiguous()
733
+
734
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
735
+ query_states,
736
+ key_states,
737
+ value_states,
738
+ attn_mask=attention_mask,
739
+ dropout_p=self.attention_dropout if self.training else 0.0,
740
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
741
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
742
+ )
743
+
744
+ attn_output = attn_output.transpose(1, 2).contiguous()
745
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
746
+
747
+ attn_output = self.o_proj(attn_output)
748
+
749
+ return attn_output, None, past_key_value
750
+
751
+
752
+ SKYWORK_ATTENTION_CLASSES = {
753
+ "eager": SkyworkAttention,
754
+ "flash_attention_2": SkyworkFlashAttention2,
755
+ "sdpa": SkyworkSdpaAttention,
756
+ }
757
+
758
+
759
+ class SkyworkBLockSparseTop2MLP(nn.Module):
760
+ def __init__(self, config: SkyworkMoeConfig):
761
+ super().__init__()
762
+ self.ffn_dim = config.intermediate_size
763
+ self.hidden_dim = config.hidden_size
764
+
765
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
766
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
767
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
768
+
769
+ self.act_fn = ACT2FN[config.hidden_act]
770
+
771
+ def forward(self, hidden_states):
772
+ current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
773
+ current_hidden_states = self.w2(current_hidden_states)
774
+ return current_hidden_states
775
+
776
+ MOE_TOP_K = 2
777
+
778
+ class SkyworkSparseMoeBlock(nn.Module):
779
+ """
780
+ This implementation is
781
+ strictly equivalent to standard MoE with full capacity (no
782
+ dropped tokens). It's faster since it formulates MoE operations
783
+ in terms of block-sparse operations to accomodate imbalanced
784
+ assignments of tokens to experts, whereas standard MoE either
785
+ (1) drop tokens at the cost of reduced performance or (2) set
786
+ capacity factor to number of experts and thus waste computation
787
+ and memory on padding.
788
+ """
789
+
790
+ def __init__(self, config):
791
+ super().__init__()
792
+ self.hidden_dim = config.hidden_size
793
+ self.ffn_dim = config.intermediate_size
794
+ self.num_experts = config.num_experts[0]
795
+ self.top_k = MOE_TOP_K
796
+ self.moe_use_skywork_gating = config.moe_use_skywork_gating
797
+ self.moe_use_logits_norm = config.moe_use_logits_norm
798
+ self.moe_gate_norm_std = config.moe_gate_norm_std
799
+ self.moe_feature_no_mul_topk = config.moe_feature_no_mul_topk
800
+
801
+ # gating
802
+ if config.moe_2layer_gate:
803
+ self.gate = torch.nn.Sequential(
804
+ nn.Linear(self.hidden_dim, self.num_experts * 8, bias=False).float(),
805
+ torch.nn.Tanh(),
806
+ nn.Linear(self.num_experts * 8, self.num_experts, bias=False).float()).float()
807
+ else:
808
+ self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
809
+
810
+ self.experts = nn.ModuleList([SkyworkBLockSparseTop2MLP(config) for _ in range(self.num_experts)])
811
+
812
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
813
+ """ """
814
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
815
+ hidden_states = hidden_states.view(-1, hidden_dim)
816
+
817
+ if isinstance(self.gate, torch.nn.Linear):
818
+ if self.gate.weight.dtype != torch.float32:
819
+ self.gate = self.gate.float()
820
+ setattr(self.gate.weight, 'router', True)
821
+ else:
822
+ if self.gate[0].weight.dtype != torch.float32:
823
+ self.gate = self.gate.float()
824
+ setattr(self.gate[0].weight, "router", True)
825
+ setattr(self.gate[2].weight, "router", True)
826
+ hidden_states_fp32 = hidden_states.float()
827
+ # router_logits: (batch * sequence_length, n_experts)
828
+ router_logits = self.gate(hidden_states_fp32)
829
+ if not (self.moe_use_skywork_gating or self.moe_feature_no_mul_topk):
830
+ router_logits *= self.top_k
831
+
832
+ if self.moe_use_skywork_gating:
833
+ if self.moe_use_logits_norm:
834
+ target_std = self.moe_gate_norm_std
835
+ logits_std = router_logits.std(dim=1, keepdim=True)
836
+ router_logits = router_logits / (logits_std / target_std)
837
+ routing_weights, selected_experts = torch.topk(router_logits, k=self.top_k, dim=1)
838
+ routing_weights = F.softmax(routing_weights, dim=1)
839
+ else:
840
+ target_std = self.moe_gate_norm_std
841
+ if self.moe_use_logits_norm:
842
+ logits_std = router_logits.std(dim=1, keepdim=True)
843
+ routing_weights = F.softmax(router_logits / (logits_std / target_std), dim=1)
844
+ else:
845
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
846
+
847
+ routing_weights, selected_experts = torch.topk(routing_weights,
848
+ self.top_k,
849
+ dim=-1)
850
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
851
+
852
+ # we cast back to the input dtype
853
+ routing_weights = routing_weights.to(hidden_states.dtype)
854
+
855
+ final_hidden_states = torch.zeros(
856
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
857
+ )
858
+
859
+ # One hot encode the selected experts to create an expert mask
860
+ # this will be used to easily index which expert is going to be sollicitated
861
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
862
+
863
+ # Loop over all available experts in the model and perform the computation on each expert
864
+ for expert_idx in range(self.num_experts):
865
+ expert_layer = self.experts[expert_idx]
866
+ idx, top_x = torch.where(expert_mask[expert_idx])
867
+
868
+ if top_x.shape[0] == 0:
869
+ continue
870
+
871
+ # in torch it is faster to index using lists than torch tensors
872
+ top_x_list = top_x.tolist()
873
+ idx_list = idx.tolist()
874
+
875
+ # Index the correct hidden states and compute the expert hidden state for
876
+ # the current expert. We need to make sure to multiply the output hidden
877
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
878
+ current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
879
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
880
+
881
+ # However `index_add_` only support torch tensors for indexing so we'll use
882
+ # the `top_x` tensor here.
883
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
884
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
885
+ return final_hidden_states, router_logits
886
+
887
+
888
+ class SkyworkDecoderLayer(nn.Module):
889
+ def __init__(self, config: SkyworkMoeConfig, layer_idx: int):
890
+ super().__init__()
891
+ self.hidden_size = config.hidden_size
892
+
893
+ self.self_attn = SKYWORK_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
894
+
895
+ self.block_sparse_moe = SkyworkSparseMoeBlock(config)
896
+ self.input_layernorm = SkyworkRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
897
+ self.post_attention_layernorm = SkyworkRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
898
+
899
+ def forward(
900
+ self,
901
+ hidden_states: torch.Tensor,
902
+ attention_mask: Optional[torch.Tensor] = None,
903
+ position_ids: Optional[torch.LongTensor] = None,
904
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
905
+ output_attentions: Optional[bool] = False,
906
+ output_router_logits: Optional[bool] = False,
907
+ use_cache: Optional[bool] = False,
908
+ **kwargs,
909
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
910
+ if "padding_mask" in kwargs:
911
+ warnings.warn(
912
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
913
+ )
914
+ """
915
+ Args:
916
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
917
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
918
+ `(batch, sequence_length)` where padding elements are indicated by 0.
919
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
920
+ output_attentions (`bool`, *optional*):
921
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
922
+ returned tensors for more detail.
923
+ output_router_logits (`bool`, *optional*):
924
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
925
+ should not be returned during inference.
926
+ use_cache (`bool`, *optional*):
927
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
928
+ (see `past_key_values`).
929
+ """
930
+
931
+ residual = hidden_states
932
+
933
+ hidden_states = self.input_layernorm(hidden_states)
934
+
935
+ # Self Attention
936
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
937
+ hidden_states=hidden_states,
938
+ attention_mask=attention_mask,
939
+ position_ids=position_ids,
940
+ past_key_value=past_key_value,
941
+ output_attentions=output_attentions,
942
+ use_cache=use_cache,
943
+ )
944
+ hidden_states = residual + hidden_states
945
+
946
+ # Fully Connected
947
+ residual = hidden_states
948
+ hidden_states = self.post_attention_layernorm(hidden_states)
949
+ hidden_states, router_logits = self.block_sparse_moe(hidden_states)
950
+ hidden_states = residual + hidden_states
951
+
952
+ outputs = (hidden_states,)
953
+
954
+ if output_attentions:
955
+ outputs += (self_attn_weights,)
956
+
957
+ if use_cache:
958
+ outputs += (present_key_value,)
959
+
960
+ if output_router_logits:
961
+ outputs += (router_logits,)
962
+
963
+ return outputs
964
+
965
+
966
+ SKYWORK_START_DOCSTRING = r"""
967
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
968
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
969
+ etc.)
970
+
971
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
972
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
973
+ and behavior.
974
+
975
+ Parameters:
976
+ config ([`SkyworkMoeConfig`]):
977
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
978
+ load the weights associated with the model, only the configuration. Check out the
979
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
980
+ """
981
+
982
+
983
+ @add_start_docstrings(
984
+ "The bare Skywork Model outputting raw hidden-states without any specific head on top.",
985
+ SKYWORK_START_DOCSTRING,
986
+ )
987
+ # Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Skywork
988
+ class SkyworkPreTrainedModel(PreTrainedModel):
989
+ config_class = SkyworkMoeConfig
990
+ base_model_prefix = "model"
991
+ supports_gradient_checkpointing = True
992
+ _no_split_modules = ["SkyworkDecoderLayer"]
993
+ _skip_keys_device_placement = "past_key_values"
994
+ _supports_flash_attn_2 = True
995
+ _supports_sdpa = True
996
+ _supports_cache_class = True
997
+
998
+ # def _init_weights(self, module):
999
+ # std = self.config.initializer_range
1000
+ # if isinstance(module, nn.Linear):
1001
+ # module.weight.data.normal_(mean=0.0, std=std)
1002
+ # if module.bias is not None:
1003
+ # module.bias.data.zero_()
1004
+ # elif isinstance(module, nn.Embedding):
1005
+ # module.weight.data.normal_(mean=0.0, std=std)
1006
+ # if module.padding_idx is not None:
1007
+ # module.weight.data[module.padding_idx].zero_()
1008
+
1009
+
1010
+ SKYWORK_INPUTS_DOCSTRING = r"""
1011
+ Args:
1012
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1013
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1014
+ it.
1015
+
1016
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1017
+ [`PreTrainedTokenizer.__call__`] for details.
1018
+
1019
+ [What are input IDs?](../glossary#input-ids)
1020
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1021
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1022
+
1023
+ - 1 for tokens that are **not masked**,
1024
+ - 0 for tokens that are **masked**.
1025
+
1026
+ [What are attention masks?](../glossary#attention-mask)
1027
+
1028
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1029
+ [`PreTrainedTokenizer.__call__`] for details.
1030
+
1031
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
1032
+ `past_key_values`).
1033
+
1034
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1035
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1036
+ information on the default strategy.
1037
+
1038
+ - 1 indicates the head is **not masked**,
1039
+ - 0 indicates the head is **masked**.
1040
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1041
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1042
+ config.n_positions - 1]`.
1043
+
1044
+ [What are position IDs?](../glossary#position-ids)
1045
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1046
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
1047
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
1048
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
1049
+
1050
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1051
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1052
+
1053
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1054
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1055
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1056
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1057
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1058
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1059
+ model's internal embedding lookup matrix.
1060
+ use_cache (`bool`, *optional*):
1061
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1062
+ `past_key_values`).
1063
+ output_attentions (`bool`, *optional*):
1064
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1065
+ tensors for more detail.
1066
+ output_hidden_states (`bool`, *optional*):
1067
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1068
+ more detail.
1069
+ output_router_logits (`bool`, *optional*):
1070
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
1071
+ should not be returned during inference.
1072
+ return_dict (`bool`, *optional*):
1073
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1074
+ """
1075
+
1076
+
1077
+ @add_start_docstrings(
1078
+ "The bare Skywork Model outputting raw hidden-states without any specific head on top.",
1079
+ SKYWORK_START_DOCSTRING,
1080
+ )
1081
+ # Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->SKYWORK,Mistral->Skywork
1082
+ class SkyworkModel(SkyworkPreTrainedModel):
1083
+ """
1084
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SkyworkDecoderLayer`]
1085
+
1086
+ Args:
1087
+ config: SkyworkMoeConfig
1088
+ """
1089
+
1090
+ def __init__(self, config: SkyworkMoeConfig):
1091
+ super().__init__(config)
1092
+ self.padding_idx = config.pad_token_id
1093
+ self.vocab_size = config.vocab_size
1094
+
1095
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1096
+ self.layers = nn.ModuleList(
1097
+ [SkyworkDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1098
+ )
1099
+ self._attn_implementation = config._attn_implementation
1100
+ self.norm = SkyworkRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1101
+
1102
+ self.gradient_checkpointing = False
1103
+ # Initialize weights and apply final processing
1104
+ self.post_init()
1105
+
1106
+ def get_input_embeddings(self):
1107
+ return self.embed_tokens
1108
+
1109
+ def set_input_embeddings(self, value):
1110
+ self.embed_tokens = value
1111
+
1112
+ # Ignore copy
1113
+ @add_start_docstrings_to_model_forward(SKYWORK_INPUTS_DOCSTRING)
1114
+ def forward(
1115
+ self,
1116
+ input_ids: torch.LongTensor = None,
1117
+ attention_mask: Optional[torch.Tensor] = None,
1118
+ position_ids: Optional[torch.LongTensor] = None,
1119
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1120
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1121
+ use_cache: Optional[bool] = None,
1122
+ output_attentions: Optional[bool] = None,
1123
+ output_hidden_states: Optional[bool] = None,
1124
+ output_router_logits: Optional[bool] = None,
1125
+ return_dict: Optional[bool] = None,
1126
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
1127
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1128
+ output_router_logits = (
1129
+ output_router_logits if output_router_logits is not None else False
1130
+ )
1131
+ output_hidden_states = (
1132
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1133
+ )
1134
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1135
+
1136
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1137
+
1138
+ # retrieve input_ids and inputs_embeds
1139
+ if input_ids is not None and inputs_embeds is not None:
1140
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1141
+ elif input_ids is not None:
1142
+ batch_size, seq_length = input_ids.shape
1143
+ elif inputs_embeds is not None:
1144
+ batch_size, seq_length, _ = inputs_embeds.shape
1145
+ else:
1146
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1147
+
1148
+ past_key_values_length = 0
1149
+
1150
+ if self.gradient_checkpointing and self.training:
1151
+ if use_cache:
1152
+ logger.warning_once(
1153
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1154
+ )
1155
+ use_cache = False
1156
+
1157
+ if use_cache:
1158
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1159
+ if use_legacy_cache:
1160
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1161
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
1162
+
1163
+ if position_ids is None:
1164
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1165
+ position_ids = torch.arange(
1166
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1167
+ )
1168
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1169
+ else:
1170
+ position_ids = position_ids.view(-1, seq_length).long()
1171
+
1172
+ if inputs_embeds is None:
1173
+ inputs_embeds = self.embed_tokens(input_ids)
1174
+
1175
+ if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
1176
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1177
+ if is_padding_right:
1178
+ raise ValueError(
1179
+ "You are attempting to perform batched generation with padding_side='right'"
1180
+ " this may lead to unexpected behaviour for Flash Attention version of Skywork. Make sure to "
1181
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1182
+ )
1183
+
1184
+ if self._attn_implementation == "flash_attention_2":
1185
+ # 2d mask is passed through the layers
1186
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1187
+ elif self._attn_implementation == "sdpa" and not output_attentions:
1188
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
1189
+ # the manual implementation that requires a 4D causal mask in all cases.
1190
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1191
+ attention_mask,
1192
+ (batch_size, seq_length),
1193
+ inputs_embeds,
1194
+ past_key_values_length,
1195
+ )
1196
+ else:
1197
+ # 4d mask is passed through the layers
1198
+ attention_mask = _prepare_4d_causal_attention_mask(
1199
+ attention_mask,
1200
+ (batch_size, seq_length),
1201
+ inputs_embeds,
1202
+ past_key_values_length,
1203
+ sliding_window=self.config.sliding_window,
1204
+ )
1205
+
1206
+ hidden_states = inputs_embeds
1207
+
1208
+ # decoder layers
1209
+ all_hidden_states = () if output_hidden_states else None
1210
+ all_self_attns = () if output_attentions else None
1211
+ all_router_logits = () if output_router_logits else None
1212
+ next_decoder_cache = None
1213
+
1214
+ for decoder_layer in self.layers:
1215
+ if output_hidden_states:
1216
+ all_hidden_states += (hidden_states,)
1217
+
1218
+ if self.gradient_checkpointing and self.training:
1219
+ layer_outputs = self._gradient_checkpointing_func(
1220
+ decoder_layer.__call__,
1221
+ hidden_states,
1222
+ attention_mask,
1223
+ position_ids,
1224
+ past_key_values,
1225
+ output_attentions,
1226
+ output_router_logits,
1227
+ use_cache,
1228
+ )
1229
+ else:
1230
+ layer_outputs = decoder_layer(
1231
+ hidden_states,
1232
+ attention_mask=attention_mask,
1233
+ position_ids=position_ids,
1234
+ past_key_value=past_key_values,
1235
+ output_attentions=output_attentions,
1236
+ output_router_logits=output_router_logits,
1237
+ use_cache=use_cache,
1238
+ )
1239
+
1240
+ hidden_states = layer_outputs[0]
1241
+
1242
+ if use_cache:
1243
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1244
+
1245
+ if output_attentions:
1246
+ all_self_attns += (layer_outputs[1],)
1247
+
1248
+ if output_router_logits:
1249
+ all_router_logits += (layer_outputs[-1],)
1250
+
1251
+ hidden_states = self.norm(hidden_states)
1252
+
1253
+ # add hidden states from the last decoder layer
1254
+ if output_hidden_states:
1255
+ all_hidden_states += (hidden_states,)
1256
+
1257
+ next_cache = None
1258
+ if use_cache:
1259
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1260
+
1261
+ if not return_dict:
1262
+ return tuple(
1263
+ v
1264
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
1265
+ if v is not None
1266
+ )
1267
+ return MoeModelOutputWithPast(
1268
+ last_hidden_state=hidden_states,
1269
+ past_key_values=next_cache,
1270
+ hidden_states=all_hidden_states,
1271
+ attentions=all_self_attns,
1272
+ router_logits=all_router_logits,
1273
+ )
1274
+
1275
+
1276
+ class SkyworkForCausalLM(SkyworkPreTrainedModel):
1277
+ _tied_weights_keys = ["lm_head.weight"]
1278
+
1279
+ def __init__(self, config):
1280
+ super().__init__(config)
1281
+ self.model = SkyworkModel(config)
1282
+ self.vocab_size = config.vocab_size
1283
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1284
+ self.router_aux_loss_coef = 0.001
1285
+ self.num_experts = config.num_experts[0]
1286
+ self.num_experts_per_tok = MOE_TOP_K
1287
+ # Initialize weights and apply final processing
1288
+ self.post_init()
1289
+
1290
+ def get_input_embeddings(self):
1291
+ return self.model.embed_tokens
1292
+
1293
+ def set_input_embeddings(self, value):
1294
+ self.model.embed_tokens = value
1295
+
1296
+ def get_output_embeddings(self):
1297
+ return self.lm_head
1298
+
1299
+ def set_output_embeddings(self, new_embeddings):
1300
+ self.lm_head = new_embeddings
1301
+
1302
+ def set_decoder(self, decoder):
1303
+ self.model = decoder
1304
+
1305
+ def get_decoder(self):
1306
+ return self.model
1307
+
1308
+ @add_start_docstrings_to_model_forward(SKYWORK_INPUTS_DOCSTRING)
1309
+ @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1310
+ # Ignore copy
1311
+ def forward(
1312
+ self,
1313
+ input_ids: torch.LongTensor = None,
1314
+ attention_mask: Optional[torch.Tensor] = None,
1315
+ position_ids: Optional[torch.LongTensor] = None,
1316
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1317
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1318
+ labels: Optional[torch.LongTensor] = None,
1319
+ use_cache: Optional[bool] = None,
1320
+ output_attentions: Optional[bool] = None,
1321
+ output_hidden_states: Optional[bool] = None,
1322
+ output_router_logits: Optional[bool] = None,
1323
+ return_dict: Optional[bool] = None,
1324
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1325
+ r"""
1326
+ Args:
1327
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1328
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1329
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1330
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1331
+
1332
+ Returns:
1333
+
1334
+ Example:
1335
+
1336
+ ```python
1337
+ >>> from transformers import AutoTokenizer, SkyworkForCausalLM
1338
+
1339
+ >>> model = SkyworkForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1340
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1341
+
1342
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1343
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1344
+
1345
+ >>> # Generate
1346
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1347
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1348
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1349
+ ```"""
1350
+
1351
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1352
+ output_router_logits = (
1353
+ output_router_logits if output_router_logits is not None else False
1354
+ )
1355
+
1356
+ output_hidden_states = (
1357
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1358
+ )
1359
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1360
+
1361
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1362
+ outputs = self.model(
1363
+ input_ids=input_ids,
1364
+ attention_mask=attention_mask,
1365
+ position_ids=position_ids,
1366
+ past_key_values=past_key_values,
1367
+ inputs_embeds=inputs_embeds,
1368
+ use_cache=use_cache,
1369
+ output_attentions=output_attentions,
1370
+ output_hidden_states=output_hidden_states,
1371
+ output_router_logits=output_router_logits,
1372
+ return_dict=return_dict,
1373
+ )
1374
+
1375
+ hidden_states = outputs[0]
1376
+ logits = self.lm_head(hidden_states)
1377
+ logits = logits.float()
1378
+
1379
+ loss = None
1380
+ if labels is not None:
1381
+ # Shift so that tokens < n predict n
1382
+ shift_logits = logits[..., :-1, :].contiguous()
1383
+ shift_labels = labels[..., 1:].contiguous()
1384
+ # Flatten the tokens
1385
+ loss_fct = CrossEntropyLoss()
1386
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1387
+ shift_labels = shift_labels.view(-1)
1388
+ # Enable model parallelism
1389
+ shift_labels = shift_labels.to(shift_logits.device)
1390
+ loss = loss_fct(shift_logits, shift_labels)
1391
+
1392
+ aux_loss = None
1393
+ if output_router_logits:
1394
+ aux_loss = load_balancing_loss_func(
1395
+ outputs.router_logits if return_dict else outputs[-1], self.num_experts, self.num_experts_per_tok
1396
+ )
1397
+ if labels is not None:
1398
+ loss += self.router_aux_loss_coef * aux_loss
1399
+
1400
+ if not return_dict:
1401
+ output = (logits,) + outputs[1:]
1402
+ if output_router_logits:
1403
+ output = (aux_loss,) + output
1404
+ return (loss,) + output if loss is not None else output
1405
+
1406
+ return MoeCausalLMOutputWithPast(
1407
+ loss=loss,
1408
+ aux_loss=aux_loss,
1409
+ logits=logits,
1410
+ past_key_values=outputs.past_key_values,
1411
+ hidden_states=outputs.hidden_states,
1412
+ attentions=outputs.attentions,
1413
+ router_logits=outputs.router_logits,
1414
+ )
1415
+
1416
+ def prepare_inputs_for_generation(
1417
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1418
+ ):
1419
+ # Omit tokens covered by past_key_values
1420
+ if past_key_values is not None:
1421
+ if isinstance(past_key_values, Cache):
1422
+ cache_length = past_key_values.get_seq_length()
1423
+ past_length = past_key_values.seen_tokens
1424
+ max_cache_length = past_key_values.get_max_length()
1425
+ else:
1426
+ cache_length = past_length = past_key_values[0][0].shape[2]
1427
+ max_cache_length = None
1428
+
1429
+ # Keep only the unprocessed tokens:
1430
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1431
+ # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1432
+ # input)
1433
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1434
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1435
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1436
+ # input_ids based on the past_length.
1437
+ elif past_length < input_ids.shape[1]:
1438
+ input_ids = input_ids[:, past_length:]
1439
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1440
+
1441
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1442
+ if (
1443
+ max_cache_length is not None
1444
+ and attention_mask is not None
1445
+ and cache_length + input_ids.shape[1] > max_cache_length
1446
+ ):
1447
+ attention_mask = attention_mask[:, -max_cache_length:]
1448
+
1449
+ position_ids = kwargs.get("position_ids", None)
1450
+ if attention_mask is not None and position_ids is None:
1451
+ # create position_ids on the fly for batch generation
1452
+ position_ids = attention_mask.long().cumsum(-1) - 1
1453
+ position_ids.masked_fill_(attention_mask == 0, 1)
1454
+ if past_key_values:
1455
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1456
+
1457
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1458
+ if inputs_embeds is not None and past_key_values is None:
1459
+ model_inputs = {"inputs_embeds": inputs_embeds}
1460
+ else:
1461
+ model_inputs = {"input_ids": input_ids}
1462
+
1463
+ model_inputs.update(
1464
+ {
1465
+ "position_ids": position_ids,
1466
+ "past_key_values": past_key_values,
1467
+ "use_cache": kwargs.get("use_cache"),
1468
+ "attention_mask": attention_mask,
1469
+ }
1470
+ )
1471
+ return model_inputs
1472
+
1473
+ @staticmethod
1474
+ def _reorder_cache(past_key_values, beam_idx):
1475
+ reordered_past = ()
1476
+ for layer_past in past_key_values:
1477
+ reordered_past += (
1478
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1479
+ )
1480
+ return reordered_past
1481
+
1482
+
1483
+ @add_start_docstrings(
1484
+ """
1485
+ The Skywork Model transformer with a sequence classification head on top (linear layer).
1486
+
1487
+ [`SkyworkForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1488
+ (e.g. GPT-2) do.
1489
+
1490
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1491
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1492
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1493
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1494
+ each row of the batch).
1495
+ """,
1496
+ SKYWORK_START_DOCSTRING,
1497
+ )
1498
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Skywork, LLAMA->SKYWORK
1499
+ class SkyworkForSequenceClassification(SkyworkPreTrainedModel):
1500
+ def __init__(self, config):
1501
+ super().__init__(config)
1502
+ self.num_labels = config.num_labels
1503
+ self.model = SkyworkModel(config)
1504
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1505
+
1506
+ # Initialize weights and apply final processing
1507
+ self.post_init()
1508
+
1509
+ def get_input_embeddings(self):
1510
+ return self.model.embed_tokens
1511
+
1512
+ def set_input_embeddings(self, value):
1513
+ self.model.embed_tokens = value
1514
+
1515
+ @add_start_docstrings_to_model_forward(SKYWORK_INPUTS_DOCSTRING)
1516
+ def forward(
1517
+ self,
1518
+ input_ids: torch.LongTensor = None,
1519
+ attention_mask: Optional[torch.Tensor] = None,
1520
+ position_ids: Optional[torch.LongTensor] = None,
1521
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1522
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1523
+ labels: Optional[torch.LongTensor] = None,
1524
+ use_cache: Optional[bool] = None,
1525
+ output_attentions: Optional[bool] = None,
1526
+ output_hidden_states: Optional[bool] = None,
1527
+ return_dict: Optional[bool] = None,
1528
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1529
+ r"""
1530
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1531
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1532
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1533
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1534
+ """
1535
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1536
+
1537
+ transformer_outputs = self.model(
1538
+ input_ids,
1539
+ attention_mask=attention_mask,
1540
+ position_ids=position_ids,
1541
+ past_key_values=past_key_values,
1542
+ inputs_embeds=inputs_embeds,
1543
+ use_cache=use_cache,
1544
+ output_attentions=output_attentions,
1545
+ output_hidden_states=output_hidden_states,
1546
+ return_dict=return_dict,
1547
+ )
1548
+ hidden_states = transformer_outputs[0]
1549
+ logits = self.score(hidden_states)
1550
+
1551
+ if input_ids is not None:
1552
+ batch_size = input_ids.shape[0]
1553
+ else:
1554
+ batch_size = inputs_embeds.shape[0]
1555
+
1556
+ if self.config.pad_token_id is None and batch_size != 1:
1557
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1558
+ if self.config.pad_token_id is None:
1559
+ sequence_lengths = -1
1560
+ else:
1561
+ if input_ids is not None:
1562
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1563
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1564
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1565
+ sequence_lengths = sequence_lengths.to(logits.device)
1566
+ else:
1567
+ sequence_lengths = -1
1568
+
1569
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1570
+
1571
+ loss = None
1572
+ if labels is not None:
1573
+ labels = labels.to(logits.device)
1574
+ if self.config.problem_type is None:
1575
+ if self.num_labels == 1:
1576
+ self.config.problem_type = "regression"
1577
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1578
+ self.config.problem_type = "single_label_classification"
1579
+ else:
1580
+ self.config.problem_type = "multi_label_classification"
1581
+
1582
+ if self.config.problem_type == "regression":
1583
+ loss_fct = MSELoss()
1584
+ if self.num_labels == 1:
1585
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1586
+ else:
1587
+ loss = loss_fct(pooled_logits, labels)
1588
+ elif self.config.problem_type == "single_label_classification":
1589
+ loss_fct = CrossEntropyLoss()
1590
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1591
+ elif self.config.problem_type == "multi_label_classification":
1592
+ loss_fct = BCEWithLogitsLoss()
1593
+ loss = loss_fct(pooled_logits, labels)
1594
+ if not return_dict:
1595
+ output = (pooled_logits,) + transformer_outputs[1:]
1596
+ return ((loss,) + output) if loss is not None else output
1597
+
1598
+ return SequenceClassifierOutputWithPast(
1599
+ loss=loss,
1600
+ logits=pooled_logits,
1601
+ past_key_values=transformer_outputs.past_key_values,
1602
+ hidden_states=transformer_outputs.hidden_states,
1603
+ attentions=transformer_outputs.attentions,
1604
+ )
pytorch_model-00001-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a22af93ce5144d22d39fc08896e24066aa3aa4454f929810b962f5bfeeafe48f
3
+ size 5606012214
pytorch_model-00002-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b80c780ce008698f4082cf51fecd1f1a41a33d52b782e21f7d8e011f26c3a3f2
3
+ size 5606012214
pytorch_model-00003-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4a7abe80a53eafa1a7e3231188e56941509f7b72a39c82f1ce24fbb1a15bf0b
3
+ size 5606012214
pytorch_model-00004-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e6db7100f0932219e3546d27fb5308e6023d9591d4fe9aed8ecb537ba828302
3
+ size 5606012214
pytorch_model-00005-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4380526ea334dbe7e58077abad9b70bd88532e51452b67bb375fa57c15a64ffe
3
+ size 5606012214
pytorch_model-00006-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2e4311e8a1659e3c6d09ca70052ac2e425242d7fa6a22cbefde6d08f385bb53
3
+ size 5606012214
pytorch_model-00007-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a29e4b9c28c9c2322a6cd5ea6f75cbac3f6b1d5608495f8e3cbdc399aec44dc6
3
+ size 5606012214
pytorch_model-00008-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0873082ca9658dac968cbc3ad419e3fc7ebbf9ec6e4eb25c96c66ad7c564046
3
+ size 5606012214
pytorch_model-00009-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e24d75f23c3f0a105d545d19bb6b9c07dd15a1e96702f4a61d9fd5bf54b6055c
3
+ size 5606012214
pytorch_model-00010-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60447e9efe1ef8317158f95be43ce20dd2030db0fb1133ac52109c4fedb2ae3a
3
+ size 5606012214
pytorch_model-00011-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e20320ec02bf3cec1a0905ed7324deb811fe379c5225eeb4c7423b62f85266be
3
+ size 5606012278
pytorch_model-00012-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:291dcb77667bf3ee689f2670762cf953736280ebf2f6e92f4c7d41594d5820c7
3
+ size 5606012278
pytorch_model-00013-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cec07f32ce4b620fdc0489b5bfd226a7bfbb216566c7b20585a0038536f638eb
3
+ size 5606012278
pytorch_model-00014-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e498e08211cc8a582c5cc4880fe843961f539c6c3bb9c7594cf2ed1e02725d4
3
+ size 5606012278
pytorch_model-00015-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f3c7726513dfe8dc3268c24b8b79f37e9205cb09df3cac30842e29aeb71c3cd8
3
+ size 5606012278
pytorch_model-00016-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11fab816fe9a0ad46d165473b7ed053b18daa40d2ac24d469af03d7741e8a58d
3
+ size 5606012278
pytorch_model-00017-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8fc71f3a1f8f78a473caad029233cef47d8d8e9df9f3a5c164e0ad2795d085a
3
+ size 5606012278
pytorch_model-00018-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d44a2ccefc63d9a7c5e2edae37e1e3d75f465bd67379649a121416c3accf651
3
+ size 5606012278
pytorch_model-00019-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98ea769662b4b35d01528d2cfaa6d4094ae2ca2383027fea90098aa0e05570e8
3
+ size 5606012278
pytorch_model-00020-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbf514e0d46d6f12c97e332e9a51e3a9a6f283d4693a329643eeefc8e1236802
3
+ size 5606012278
pytorch_model-00021-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:286f6441449b7e5bc0d459029f24cbd22c0f4b9143ef56a49c0c99a271175070
3
+ size 5606012278
pytorch_model-00022-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d82db0d0271a0fcf77643bf025ffc0f471b5b6962a94065d744d0aef586a0b6a
3
+ size 5606012278
pytorch_model-00023-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6d31b5eec13e0cfc81690cf5a78151552a66bc1442051a48ead2d7cadd3b18b
3
+ size 5606012278
pytorch_model-00024-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eac3fb50eb1e1496415de23a198918b2a19e28bb6045125f6f48974376494dc4
3
+ size 5606012278
pytorch_model-00025-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a0c5a08f09c724cc6373d4cfa6b2e448040d3c913c149af43d8c7b8beb3730f
3
+ size 5606012278
pytorch_model-00026-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a37e8a69c9ded8a1c0abefff4d3a4d71dba14252f5def086976b0e4998e6dc2
3
+ size 5606012278
pytorch_model-00027-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5b26230f8aff06fcc78d2afc627dfff72bd70ea4861c51338c67ae1509141fb
3
+ size 5606012278
pytorch_model-00028-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb1c09a0c2be1d102e0c59c60067bdaae79861a6afba441cdd70a01b0964b84b
3
+ size 5606012278
pytorch_model-00029-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66b0e591ed6a337f7f332d4f8a4f0880a3020d0ed6db1db7f92b51a9fa7a1ca3
3
+ size 5606012278
pytorch_model-00030-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd82d4a5905bb08d96417d4ef91063738f5368ed4053186ab8f44c2baf316612
3
+ size 5606012278
pytorch_model-00031-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e16d1611822e21e228afa5ae410eb04456355b9b175333cc1903c16da60f7111
3
+ size 5606012278
pytorch_model-00032-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b92145427a08ca63ac63d3c7d61e6493f5a7b642ff940f2a4ea92e7a02f4a29
3
+ size 5606012278
pytorch_model-00033-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7d0ef57993eba2ed75d66d9bb49e43bb669eb7dd67e77eddc122f4832cd95e3
3
+ size 5606012278
pytorch_model-00034-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:272a3cfafea1857cb4d1b6e5fa306f1e3cdb97caf9b287adea8b622c4ac10d09
3
+ size 5606012278
pytorch_model-00035-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c7f13eae184e4a98193be2d5dd20bfdd636c12987869969829398a684ba2f26
3
+ size 5606012278
pytorch_model-00036-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33cde811a94ed1e15074439fe145206918d35ca5240ed2768aa79137b52f1ddb
3
+ size 5606012278
pytorch_model-00037-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:793c3ba7957c68e1763fe1b0aea616cc05868d555cabf7e4cda8165f5b93f719
3
+ size 5606012278
pytorch_model-00038-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d3d216014beb54c40db35235bd60285d611a58e6726bd2e9d14a5771310dc66
3
+ size 5606012278
pytorch_model-00039-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:41e45051d4e883163bed6559aec96ddc95c446050f0b4262bdcb751ed3b3e92e
3
+ size 5606012278
pytorch_model-00040-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19bc57f59a16a571ea5583b90c16172709cd3f37e02ba1799fa4b28c44739f5c
3
+ size 5606012278
pytorch_model-00041-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f3d8770e8cb3723537e1c0826009cd6b71d58497974d4f8c8f876dfe1e26dee
3
+ size 5606012278
pytorch_model-00042-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c75613dc3754db1c596a6f6477497cbffb6962f22fcbf5ead458cc9f944f76d1
3
+ size 5606012278
pytorch_model-00043-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:138586728295924720ab003140e58b73db608fb11d31fe142fd244dae72c842a
3
+ size 5606012278
pytorch_model-00044-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4dc46ef942b8ecb37ef8b1d81ea5ac12c3812bdc7d6f5e7d38a9d4701cb75ef2
3
+ size 5606012278
pytorch_model-00045-of-00053.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c68200d9abc50973203368892203dc59b2b19c95011a964e6ff2ece66f9e2c5e
3
+ size 5606012278