Upload MyLLaMa
Browse files- configure_for_hf.py +14 -0
configure_for_hf.py
CHANGED
@@ -41,6 +41,20 @@ class MyLLaMa(PreTrainedModel):
|
|
41 |
n_chckpnt_segments=config.n_chckpnt_segments,
|
42 |
)
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
def forward(self, tensor, labels=None):
|
45 |
logits = self.model(tensor)["logits"]
|
46 |
if labels is not None:
|
|
|
41 |
n_chckpnt_segments=config.n_chckpnt_segments,
|
42 |
)
|
43 |
|
44 |
+
def load_state_dict(self, state_dict, **kwargs):
|
45 |
+
for key in list(state_dict.keys()):
|
46 |
+
if "rmsnorm1.weight" in key:
|
47 |
+
new_key = key.replace("rmsnorm1.weight", "rmsnorm1.gamma")
|
48 |
+
state_dict[new_key] = state_dict.pop(key)
|
49 |
+
elif "rmsnorm2.weight" in key:
|
50 |
+
new_key = key.replace("rmsnorm2.weight", "rmsnorm2.gamma")
|
51 |
+
state_dict[new_key] = state_dict.pop(key)
|
52 |
+
elif "rmsnorm.weight" in key:
|
53 |
+
new_key = key.replace("rmsnorm.weight", "rmsnorm.gamma")
|
54 |
+
state_dict[new_key] = state_dict.pop(key)
|
55 |
+
|
56 |
+
super().load_state_dict(state_dict, **kwargs)
|
57 |
+
|
58 |
def forward(self, tensor, labels=None):
|
59 |
logits = self.model(tensor)["logits"]
|
60 |
if labels is not None:
|