Crystalcareai
commited on
Commit
•
9a966d4
1
Parent(s):
ad43155
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +24 -43
modeling_gemmoe.py
CHANGED
@@ -18,12 +18,13 @@
|
|
18 |
import math
|
19 |
import warnings
|
20 |
from typing import List, Optional, Tuple, Union
|
21 |
-
|
22 |
import torch
|
23 |
import torch.nn.functional as F
|
24 |
import torch.utils.checkpoint
|
25 |
from torch import nn
|
26 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
27 |
from transformers.activations import ACT2FN
|
28 |
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
29 |
from transformers.modeling_attn_mask_utils import (
|
@@ -305,7 +306,6 @@ class GemmoeAttention(nn.Module):
|
|
305 |
- The attention weights (if `output_attentions=True`).
|
306 |
- The past key-value cache (if `use_cache=True`).
|
307 |
"""
|
308 |
-
|
309 |
bsz, q_len, _ = hidden_states.size()
|
310 |
|
311 |
query_states = self.q_proj(hidden_states)
|
@@ -331,14 +331,12 @@ class GemmoeAttention(nn.Module):
|
|
331 |
|
332 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
333 |
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
causal_mask = attention_mask
|
341 |
-
attn_weights = attn_weights + causal_mask
|
342 |
|
343 |
# upcast attention to fp32
|
344 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
@@ -686,7 +684,6 @@ class GemmoeSparseMoeBlock(nn.Module):
|
|
686 |
|
687 |
self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
688 |
|
689 |
-
@torch.jit.script
|
690 |
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
691 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
692 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
@@ -727,7 +724,6 @@ class GemmoeDecoderLayer(nn.Module):
|
|
727 |
self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
728 |
self.post_attention_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
729 |
|
730 |
-
@torch.jit.script
|
731 |
def forward(
|
732 |
self,
|
733 |
hidden_states: torch.Tensor,
|
@@ -977,7 +973,7 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
977 |
hidden_states = inputs_embeds
|
978 |
|
979 |
# Normalize
|
980 |
-
scale_factor = torch.tensor(
|
981 |
hidden_states = hidden_states * scale_factor
|
982 |
# Decoder layers
|
983 |
all_hidden_states = () if output_hidden_states else None
|
@@ -990,8 +986,8 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
990 |
all_hidden_states += (hidden_states,)
|
991 |
|
992 |
if self.gradient_checkpointing and self.training:
|
993 |
-
layer_outputs =
|
994 |
-
decoder_layer,
|
995 |
hidden_states,
|
996 |
causal_mask,
|
997 |
position_ids,
|
@@ -1204,34 +1200,19 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1204 |
)
|
1205 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1206 |
|
1207 |
-
|
1208 |
-
|
1209 |
-
|
1210 |
-
|
1211 |
-
|
1212 |
-
|
1213 |
-
|
1214 |
-
|
1215 |
-
|
1216 |
-
|
1217 |
-
|
1218 |
-
|
1219 |
-
|
1220 |
-
)
|
1221 |
-
else:
|
1222 |
-
outputs = self.model(
|
1223 |
-
input_ids=input_ids,
|
1224 |
-
attention_mask=attention_mask,
|
1225 |
-
position_ids=position_ids,
|
1226 |
-
past_key_values=past_key_values,
|
1227 |
-
inputs_embeds=inputs_embeds,
|
1228 |
-
use_cache=use_cache,
|
1229 |
-
output_attentions=output_attentions,
|
1230 |
-
output_hidden_states=output_hidden_states,
|
1231 |
-
output_router_logits=output_router_logits,
|
1232 |
-
return_dict=return_dict,
|
1233 |
-
cache_position=cache_position,
|
1234 |
-
)
|
1235 |
|
1236 |
hidden_states = outputs[0]
|
1237 |
|
|
|
18 |
import math
|
19 |
import warnings
|
20 |
from typing import List, Optional, Tuple, Union
|
21 |
+
|
22 |
import torch
|
23 |
import torch.nn.functional as F
|
24 |
import torch.utils.checkpoint
|
25 |
from torch import nn
|
26 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
27 |
+
|
28 |
from transformers.activations import ACT2FN
|
29 |
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
30 |
from transformers.modeling_attn_mask_utils import (
|
|
|
306 |
- The attention weights (if `output_attentions=True`).
|
307 |
- The past key-value cache (if `use_cache=True`).
|
308 |
"""
|
|
|
309 |
bsz, q_len, _ = hidden_states.size()
|
310 |
|
311 |
query_states = self.q_proj(hidden_states)
|
|
|
331 |
|
332 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
333 |
|
334 |
+
if attention_mask is not None: # no matter the length, we just slice it
|
335 |
+
if cache_position is not None:
|
336 |
+
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
|
337 |
+
else:
|
338 |
+
causal_mask = attention_mask
|
339 |
+
attn_weights = attn_weights + causal_mask
|
|
|
|
|
340 |
|
341 |
# upcast attention to fp32
|
342 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
|
|
684 |
|
685 |
self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
686 |
|
|
|
687 |
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
688 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
689 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
|
724 |
self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
725 |
self.post_attention_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
726 |
|
|
|
727 |
def forward(
|
728 |
self,
|
729 |
hidden_states: torch.Tensor,
|
|
|
973 |
hidden_states = inputs_embeds
|
974 |
|
975 |
# Normalize
|
976 |
+
scale_factor = torch.tensor(math_sqrt(self.config.hidden_size), dtype=hidden_states.dtype)
|
977 |
hidden_states = hidden_states * scale_factor
|
978 |
# Decoder layers
|
979 |
all_hidden_states = () if output_hidden_states else None
|
|
|
986 |
all_hidden_states += (hidden_states,)
|
987 |
|
988 |
if self.gradient_checkpointing and self.training:
|
989 |
+
layer_outputs = self._gradient_checkpointing_func(
|
990 |
+
decoder_layer.__call__,
|
991 |
hidden_states,
|
992 |
causal_mask,
|
993 |
position_ids,
|
|
|
1200 |
)
|
1201 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1202 |
|
1203 |
+
outputs = self.model(
|
1204 |
+
input_ids=input_ids,
|
1205 |
+
attention_mask=attention_mask,
|
1206 |
+
position_ids=position_ids,
|
1207 |
+
past_key_values=past_key_values,
|
1208 |
+
inputs_embeds=inputs_embeds,
|
1209 |
+
use_cache=use_cache,
|
1210 |
+
output_attentions=output_attentions,
|
1211 |
+
output_hidden_states=output_hidden_states,
|
1212 |
+
output_router_logits=output_router_logits,
|
1213 |
+
return_dict=return_dict,
|
1214 |
+
cache_position=cache_position,
|
1215 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1216 |
|
1217 |
hidden_states = outputs[0]
|
1218 |
|