Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +6 -19
modeling_gemmoe.py
CHANGED
@@ -169,8 +169,8 @@ class GemmoeRMSNorm(nn.Module):
|
|
169 |
|
170 |
def forward(self, x):
|
171 |
output = self._norm(x.float()).type_as(x)
|
172 |
-
return output * (self.weight
|
173 |
-
|
174 |
ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
|
175 |
|
176 |
class GemmoeRotaryEmbedding(nn.Module):
|
@@ -271,10 +271,10 @@ class GemmoeAttention(nn.Module):
|
|
271 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
272 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
273 |
self.rotary_emb = GemmoeRotaryEmbedding(
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
|
279 |
def forward(
|
280 |
self,
|
@@ -312,11 +312,6 @@ class GemmoeAttention(nn.Module):
|
|
312 |
key_states = self.k_proj(hidden_states)
|
313 |
value_states = self.v_proj(hidden_states)
|
314 |
|
315 |
-
# Move query_states, key_states, and value_states to the same device as hidden_states
|
316 |
-
query_states = query_states.to(hidden_states.device)
|
317 |
-
key_states = key_states.to(hidden_states.device)
|
318 |
-
value_states = value_states.to(hidden_states.device)
|
319 |
-
|
320 |
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
321 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
322 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
@@ -1205,14 +1200,6 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1205 |
)
|
1206 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1207 |
|
1208 |
-
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
1209 |
-
attention_mask = attention_mask.to(device) if attention_mask is not None else None
|
1210 |
-
position_ids = position_ids.to(device) if position_ids is not None else None
|
1211 |
-
past_key_values = [t.to(device) for t in past_key_values] if past_key_values is not None else None
|
1212 |
-
inputs_embeds = inputs_embeds.to(device) if inputs_embeds is not None else None
|
1213 |
-
labels = labels.to(device) if labels is not None else None
|
1214 |
-
cache_position = cache_position.to(device) if cache_position is not None else None
|
1215 |
-
|
1216 |
outputs = self.model(
|
1217 |
input_ids=input_ids,
|
1218 |
attention_mask=attention_mask,
|
|
|
169 |
|
170 |
def forward(self, x):
|
171 |
output = self._norm(x.float()).type_as(x)
|
172 |
+
return output * (self.weight + 1)
|
173 |
+
|
174 |
ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
|
175 |
|
176 |
class GemmoeRotaryEmbedding(nn.Module):
|
|
|
271 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
272 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
273 |
self.rotary_emb = GemmoeRotaryEmbedding(
|
274 |
+
self.head_dim,
|
275 |
+
max_position_embeddings=self.max_position_embeddings,
|
276 |
+
base=self.rope_theta,
|
277 |
+
)
|
278 |
|
279 |
def forward(
|
280 |
self,
|
|
|
312 |
key_states = self.k_proj(hidden_states)
|
313 |
value_states = self.v_proj(hidden_states)
|
314 |
|
|
|
|
|
|
|
|
|
|
|
315 |
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
316 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
317 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
1200 |
)
|
1201 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1203 |
outputs = self.model(
|
1204 |
input_ids=input_ids,
|
1205 |
attention_mask=attention_mask,
|