Mortie1 commited on
Commit
f11bd15
·
verified ·
1 Parent(s): 8ff2dc0

Upload MyLLaMa

Browse files
Files changed (1) hide show
  1. configure_for_hf.py +14 -0
configure_for_hf.py CHANGED
@@ -41,6 +41,20 @@ class MyLLaMa(PreTrainedModel):
41
  n_chckpnt_segments=config.n_chckpnt_segments,
42
  )
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def forward(self, tensor, labels=None):
45
  logits = self.model(tensor)["logits"]
46
  if labels is not None:
 
41
  n_chckpnt_segments=config.n_chckpnt_segments,
42
  )
43
 
44
+ def load_state_dict(self, state_dict, **kwargs):
45
+ for key in list(state_dict.keys()):
46
+ if "rmsnorm1.weight" in key:
47
+ new_key = key.replace("rmsnorm1.weight", "rmsnorm1.gamma")
48
+ state_dict[new_key] = state_dict.pop(key)
49
+ elif "rmsnorm2.weight" in key:
50
+ new_key = key.replace("rmsnorm2.weight", "rmsnorm2.gamma")
51
+ state_dict[new_key] = state_dict.pop(key)
52
+ elif "rmsnorm.weight" in key:
53
+ new_key = key.replace("rmsnorm.weight", "rmsnorm.gamma")
54
+ state_dict[new_key] = state_dict.pop(key)
55
+
56
+ super().load_state_dict(state_dict, **kwargs)
57
+
58
  def forward(self, tensor, labels=None):
59
  logits = self.model(tensor)["logits"]
60
  if labels is not None: