Crystalcareai commited on
Commit
f84e893
·
verified ·
1 Parent(s): 875b2bf

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +6 -19
modeling_gemmoe.py CHANGED
@@ -169,8 +169,8 @@ class GemmoeRMSNorm(nn.Module):
169
 
170
  def forward(self, x):
171
  output = self._norm(x.float()).type_as(x)
172
- return output * (self.weight.to(x.device) + 1) # Move self.weight to the same device as x
173
-
174
  ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
175
 
176
  class GemmoeRotaryEmbedding(nn.Module):
@@ -271,10 +271,10 @@ class GemmoeAttention(nn.Module):
271
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
272
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
273
  self.rotary_emb = GemmoeRotaryEmbedding(
274
- self.head_dim,
275
- max_position_embeddings=self.max_position_embeddings,
276
- base=self.rope_theta,
277
- )
278
 
279
  def forward(
280
  self,
@@ -312,11 +312,6 @@ class GemmoeAttention(nn.Module):
312
  key_states = self.k_proj(hidden_states)
313
  value_states = self.v_proj(hidden_states)
314
 
315
- # Move query_states, key_states, and value_states to the same device as hidden_states
316
- query_states = query_states.to(hidden_states.device)
317
- key_states = key_states.to(hidden_states.device)
318
- value_states = value_states.to(hidden_states.device)
319
-
320
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
321
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
322
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
@@ -1205,14 +1200,6 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1205
  )
1206
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1207
 
1208
- device = input_ids.device if input_ids is not None else inputs_embeds.device
1209
- attention_mask = attention_mask.to(device) if attention_mask is not None else None
1210
- position_ids = position_ids.to(device) if position_ids is not None else None
1211
- past_key_values = [t.to(device) for t in past_key_values] if past_key_values is not None else None
1212
- inputs_embeds = inputs_embeds.to(device) if inputs_embeds is not None else None
1213
- labels = labels.to(device) if labels is not None else None
1214
- cache_position = cache_position.to(device) if cache_position is not None else None
1215
-
1216
  outputs = self.model(
1217
  input_ids=input_ids,
1218
  attention_mask=attention_mask,
 
169
 
170
  def forward(self, x):
171
  output = self._norm(x.float()).type_as(x)
172
+ return output * (self.weight + 1)
173
+
174
  ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
175
 
176
  class GemmoeRotaryEmbedding(nn.Module):
 
271
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
272
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
273
  self.rotary_emb = GemmoeRotaryEmbedding(
274
+ self.head_dim,
275
+ max_position_embeddings=self.max_position_embeddings,
276
+ base=self.rope_theta,
277
+ )
278
 
279
  def forward(
280
  self,
 
312
  key_states = self.k_proj(hidden_states)
313
  value_states = self.v_proj(hidden_states)
314
 
 
 
 
 
 
315
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
316
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
317
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
1200
  )
1201
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1202
 
 
 
 
 
 
 
 
 
1203
  outputs = self.model(
1204
  input_ids=input_ids,
1205
  attention_mask=attention_mask,