new-nlp-hw3-llama3 / configure_for_hf.py
Mortie1's picture
Upload MyLLaMa
c8fa7f3 verified
import torch
from torch import nn
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
PretrainedConfig,
PreTrainedModel,
)
from collections import namedtuple
from .llama import CustomAttentionLLaMa
class MyLLaMaConfig(PretrainedConfig):
model_type = "LLaMa"
def __init__(
self,
embed_dim: int = 1536,
n_layers: int = 24,
n_heads: int = 24,
n_chckpnt_segments: int = 24,
**kwargs,
):
self.embed_dim = embed_dim
self.n_layers = n_layers
self.n_heads = n_heads
self.n_chckpnt_segments = n_chckpnt_segments
super().__init__(**kwargs)
class MyLLaMa(PreTrainedModel):
config_class = MyLLaMaConfig
def __init__(self, config: MyLLaMaConfig):
super().__init__(config)
self.model = CustomAttentionLLaMa(
config.embed_dim,
config.n_layers,
config.n_heads,
dropout=0,
n_chckpnt_segments=config.n_chckpnt_segments,
)
def load_state_dict(self, state_dict, **kwargs):
for key in list(state_dict.keys()):
if "rmsnorm1.weight" in key:
new_key = key.replace("rmsnorm1.weight", "rmsnorm1.gamma")
state_dict[new_key] = state_dict.pop(key)
elif "rmsnorm2.weight" in key:
new_key = key.replace("rmsnorm2.weight", "rmsnorm2.gamma")
state_dict[new_key] = state_dict.pop(key)
elif "rmsnorm.weight" in key:
new_key = key.replace("rmsnorm.weight", "rmsnorm.gamma")
state_dict[new_key] = state_dict.pop(key)
super().load_state_dict(state_dict, **kwargs)
def forward(self, tensor, labels=None):
att_mask = (
torch.where(
torch.triu(torch.ones((tensor.shape[1], tensor.shape[1]))) == 1,
0,
-torch.inf,
)
.transpose(0, 1)
.to(self.model.embed.weight.device)
)
pad_mask = torch.where(
tensor == self.model.tokenizer.pad_token_id, False, True
).to(self.model.embed.weight.device)
outs = namedtuple("output", ["logits", "loss"])
logits = self.model(tensor, att_mask, pad_mask)["logits"]
outs.logits = logits.transpose(1, 2)
if labels is not None:
loss = nn.functional.cross_entropy(logits, labels)
outs.loss = loss
return outs
AutoConfig.register("LLaMa", MyLLaMaConfig)
AutoModel.register(MyLLaMaConfig, MyLLaMa)
AutoModelForCausalLM.register(MyLLaMaConfig, MyLLaMa)