Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +114 -31
modeling_gemmoe.py
CHANGED
@@ -194,42 +194,54 @@ class GemmoeRMSNorm(nn.Module):
|
|
194 |
|
195 |
ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
|
196 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
class GemmoeRotaryEmbedding(nn.Module):
|
198 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
199 |
super().__init__()
|
200 |
-
|
201 |
self.dim = dim
|
202 |
self.max_position_embeddings = max_position_embeddings
|
203 |
self.base = base
|
204 |
-
|
205 |
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
206 |
-
|
207 |
-
# Build here to make `torch.jit.trace` work.
|
208 |
-
self._set_cos_sin_cache(
|
209 |
-
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
210 |
-
)
|
211 |
-
self.max_seq_len_cached = None
|
212 |
-
|
213 |
|
214 |
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
215 |
self.max_seq_len_cached = seq_len
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
228 |
-
|
229 |
return (
|
230 |
-
self.cos_cached[:seq_len]
|
231 |
-
self.sin_cached[:seq_len]
|
232 |
)
|
|
|
233 |
|
234 |
class GemmoeLinearScalingRotaryEmbedding(GemmoeRotaryEmbedding):
|
235 |
"""GemmoeRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
@@ -948,17 +960,78 @@ GEMMOE_ATTENTION_CLASSES = {
|
|
948 |
"sdpa": GemmoeSdpaAttention,
|
949 |
}
|
950 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
951 |
|
952 |
class GemmoeDecoderLayer(nn.Module):
|
953 |
def __init__(self, config: GemmoeConfig, layer_idx: int):
|
954 |
super().__init__()
|
955 |
self.hidden_size = config.hidden_size
|
956 |
-
|
957 |
self.self_attn = GEMMOE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
958 |
|
959 |
-
|
960 |
-
|
961 |
-
|
|
|
|
|
|
|
|
|
962 |
self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
963 |
self.post_attention_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
964 |
|
@@ -969,6 +1042,7 @@ class GemmoeDecoderLayer(nn.Module):
|
|
969 |
position_ids: Optional[torch.LongTensor] = None,
|
970 |
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
971 |
output_attentions: Optional[bool] = False,
|
|
|
972 |
use_cache: Optional[bool] = False,
|
973 |
**kwargs,
|
974 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
@@ -985,13 +1059,15 @@ class GemmoeDecoderLayer(nn.Module):
|
|
985 |
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
986 |
(see `past_key_values`).
|
987 |
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
|
|
|
|
|
|
988 |
"""
|
989 |
if "padding_mask" in kwargs:
|
990 |
warnings.warn(
|
991 |
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
992 |
)
|
993 |
residual = hidden_states
|
994 |
-
|
995 |
hidden_states = self.input_layernorm(hidden_states)
|
996 |
|
997 |
# Self Attention
|
@@ -1009,7 +1085,12 @@ class GemmoeDecoderLayer(nn.Module):
|
|
1009 |
# Fully Connected
|
1010 |
residual = hidden_states
|
1011 |
hidden_states = self.post_attention_layernorm(hidden_states)
|
1012 |
-
|
|
|
|
|
|
|
|
|
|
|
1013 |
hidden_states = residual + hidden_states
|
1014 |
|
1015 |
outputs = (hidden_states,)
|
@@ -1019,10 +1100,12 @@ class GemmoeDecoderLayer(nn.Module):
|
|
1019 |
|
1020 |
if use_cache:
|
1021 |
outputs += (present_key_value,)
|
|
|
|
|
|
|
1022 |
|
1023 |
return outputs
|
1024 |
|
1025 |
-
|
1026 |
GEMMOE_START_DOCSTRING = r"""
|
1027 |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
1028 |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
|
|
194 |
|
195 |
ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
|
196 |
|
197 |
+
class GemmoeRMSNorm(nn.Module):
|
198 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
199 |
+
super().__init__()
|
200 |
+
self.eps = eps
|
201 |
+
self.weight = nn.Parameter(torch.zeros(dim))
|
202 |
+
|
203 |
+
def _norm(self, x):
|
204 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
205 |
+
|
206 |
+
def forward(self, x):
|
207 |
+
output = self._norm(x.float()).type_as(x)
|
208 |
+
return output * (self.weight + 1)
|
209 |
+
|
210 |
+
ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
|
211 |
+
|
212 |
class GemmoeRotaryEmbedding(nn.Module):
|
213 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
214 |
super().__init__()
|
|
|
215 |
self.dim = dim
|
216 |
self.max_position_embeddings = max_position_embeddings
|
217 |
self.base = base
|
218 |
+
self._set_cos_sin_cache(seq_len=max_position_embeddings, device=device, dtype=torch.get_default_dtype())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
221 |
self.max_seq_len_cached = seq_len
|
222 |
+
freq_exponents = (2.0 / self.dim) * (
|
223 |
+
torch.arange(self.dim // 2, dtype=torch.int64, device="cpu").float()
|
224 |
+
)
|
225 |
+
timescale = self.base ** freq_exponents
|
226 |
+
positions = torch.arange(self.max_seq_len_cached, device="cpu", dtype=torch.int64).float()
|
227 |
+
radians_new = positions[..., None] / timescale[None, None, :]
|
228 |
+
radians_new = radians_new.squeeze(0)
|
229 |
+
emb = torch.cat((radians_new, radians_new), dim=-1)
|
230 |
+
cos = emb.cos().to(device=device, non_blocking=True)
|
231 |
+
sin = emb.sin().to(device=device, non_blocking=True)
|
232 |
+
self.register_buffer("cos_cached", cos, persistent=False)
|
233 |
+
self.register_buffer("sin_cached", sin, persistent=False)
|
234 |
+
|
235 |
+
def forward(self, x, position_ids=None, seq_len=None):
|
236 |
+
if seq_len is None:
|
237 |
+
seq_len = x.size(2)
|
238 |
+
if seq_len > self.max_seq_len_cached:
|
239 |
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
|
|
240 |
return (
|
241 |
+
self.cos_cached[:seq_len],
|
242 |
+
self.sin_cached[:seq_len],
|
243 |
)
|
244 |
+
|
245 |
|
246 |
class GemmoeLinearScalingRotaryEmbedding(GemmoeRotaryEmbedding):
|
247 |
"""GemmoeRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
|
|
960 |
"sdpa": GemmoeSdpaAttention,
|
961 |
}
|
962 |
|
963 |
+
class GemmoeBlockSparseTop2MLP(nn.Module):
|
964 |
+
def __init__(self, config: GemmoeConfig):
|
965 |
+
super().__init__()
|
966 |
+
self.ffn_dim = config.intermediate_size
|
967 |
+
self.hidden_dim = config.hidden_size
|
968 |
+
|
969 |
+
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
970 |
+
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
|
971 |
+
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
972 |
+
|
973 |
+
self.act_fn = approx_gelu
|
974 |
+
|
975 |
+
def forward(self, hidden_states):
|
976 |
+
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
|
977 |
+
current_hidden_states = self.w2(current_hidden_states)
|
978 |
+
return current_hidden_states
|
979 |
+
|
980 |
+
class GemmoeSparseMoeBlock(nn.Module):
|
981 |
+
def __init__(self, config):
|
982 |
+
super().__init__()
|
983 |
+
self.hidden_dim = config.hidden_size
|
984 |
+
self.ffn_dim = config.intermediate_size
|
985 |
+
self.num_experts = config.num_local_experts
|
986 |
+
self.top_k = 2
|
987 |
+
|
988 |
+
# gating
|
989 |
+
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
990 |
+
|
991 |
+
self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
992 |
+
|
993 |
+
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
994 |
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
995 |
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
996 |
+
|
997 |
+
# router_logits: (batch * sequence_length, n_experts)
|
998 |
+
router_logits = self.gate(hidden_states)
|
999 |
+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
1000 |
+
topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
|
1001 |
+
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
1002 |
+
|
1003 |
+
# we cast back to the input dtype
|
1004 |
+
topk_weight = topk_weight.to(hidden_states.dtype)
|
1005 |
+
|
1006 |
+
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
|
1007 |
+
|
1008 |
+
y = torch.empty_like(hidden_states)
|
1009 |
+
|
1010 |
+
flat_topk_idx = topk_idx.view(-1)
|
1011 |
+
for i in range(self.num_experts):
|
1012 |
+
expert = self.experts[i]
|
1013 |
+
expert_output = expert(hidden_states[flat_topk_idx == i])
|
1014 |
+
y[flat_topk_idx == i] = expert_output.to(y.dtype) # Cast expert_output to the same dtype as y
|
1015 |
+
|
1016 |
+
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
1017 |
+
|
1018 |
+
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
|
1019 |
+
return final_hidden_states, router_logits
|
1020 |
+
|
1021 |
|
1022 |
class GemmoeDecoderLayer(nn.Module):
|
1023 |
def __init__(self, config: GemmoeConfig, layer_idx: int):
|
1024 |
super().__init__()
|
1025 |
self.hidden_size = config.hidden_size
|
|
|
1026 |
self.self_attn = GEMMOE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
1027 |
|
1028 |
+
if config.n_routed_experts is not None and \
|
1029 |
+
layer_idx >= config.first_k_dense_replace and \
|
1030 |
+
layer_idx % config.moe_layer_freq == 0:
|
1031 |
+
self.block_sparse_moe = GemmoeSparseMoeBlock(config)
|
1032 |
+
else:
|
1033 |
+
self.mlp = GemmoeMLP(config)
|
1034 |
+
|
1035 |
self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1036 |
self.post_attention_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1037 |
|
|
|
1042 |
position_ids: Optional[torch.LongTensor] = None,
|
1043 |
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
1044 |
output_attentions: Optional[bool] = False,
|
1045 |
+
output_router_logits: Optional[bool] = False,
|
1046 |
use_cache: Optional[bool] = False,
|
1047 |
**kwargs,
|
1048 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
|
1059 |
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
1060 |
(see `past_key_values`).
|
1061 |
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
1062 |
+
output_router_logits (`bool`, *optional*):
|
1063 |
+
Whether or not to return the logits of all the routers. They are useful for computing the router loss,
|
1064 |
+
and should not be returned during inference.
|
1065 |
"""
|
1066 |
if "padding_mask" in kwargs:
|
1067 |
warnings.warn(
|
1068 |
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
1069 |
)
|
1070 |
residual = hidden_states
|
|
|
1071 |
hidden_states = self.input_layernorm(hidden_states)
|
1072 |
|
1073 |
# Self Attention
|
|
|
1085 |
# Fully Connected
|
1086 |
residual = hidden_states
|
1087 |
hidden_states = self.post_attention_layernorm(hidden_states)
|
1088 |
+
|
1089 |
+
if hasattr(self, 'block_sparse_moe'):
|
1090 |
+
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
|
1091 |
+
else:
|
1092 |
+
hidden_states = self.mlp(hidden_states)
|
1093 |
+
|
1094 |
hidden_states = residual + hidden_states
|
1095 |
|
1096 |
outputs = (hidden_states,)
|
|
|
1100 |
|
1101 |
if use_cache:
|
1102 |
outputs += (present_key_value,)
|
1103 |
+
|
1104 |
+
if output_router_logits and hasattr(self, 'block_sparse_moe'):
|
1105 |
+
outputs += (router_logits,)
|
1106 |
|
1107 |
return outputs
|
1108 |
|
|
|
1109 |
GEMMOE_START_DOCSTRING = r"""
|
1110 |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
1111 |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|