import torch from torch import nn from transformers import ( AutoConfig, AutoModel, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel, ) from collections import namedtuple from .llama import CustomAttentionLLaMa class MyLLaMaConfig(PretrainedConfig): model_type = "LLaMa" def __init__( self, embed_dim: int = 1536, n_layers: int = 24, n_heads: int = 24, n_chckpnt_segments: int = 24, **kwargs, ): self.embed_dim = embed_dim self.n_layers = n_layers self.n_heads = n_heads self.n_chckpnt_segments = n_chckpnt_segments super().__init__(**kwargs) class MyLLaMa(PreTrainedModel): config_class = MyLLaMaConfig def __init__(self, config: MyLLaMaConfig): super().__init__(config) self.model = CustomAttentionLLaMa( config.embed_dim, config.n_layers, config.n_heads, dropout=0, n_chckpnt_segments=config.n_chckpnt_segments, ) def load_state_dict(self, state_dict, **kwargs): for key in list(state_dict.keys()): if "rmsnorm1.weight" in key: new_key = key.replace("rmsnorm1.weight", "rmsnorm1.gamma") state_dict[new_key] = state_dict.pop(key) elif "rmsnorm2.weight" in key: new_key = key.replace("rmsnorm2.weight", "rmsnorm2.gamma") state_dict[new_key] = state_dict.pop(key) elif "rmsnorm.weight" in key: new_key = key.replace("rmsnorm.weight", "rmsnorm.gamma") state_dict[new_key] = state_dict.pop(key) super().load_state_dict(state_dict, **kwargs) def forward(self, tensor, labels=None): att_mask = ( torch.where( torch.triu(torch.ones((tensor.shape[1], tensor.shape[1]))) == 1, 0, -torch.inf, ) .transpose(0, 1) .to(self.model.embed.weight.device) ) pad_mask = torch.where( tensor == self.model.tokenizer.pad_token_id, False, True ).to(self.model.embed.weight.device) outs = namedtuple("output", ["logits", "loss"]) logits = self.model(tensor, att_mask, pad_mask)["logits"] outs.logits = logits.transpose(1, 2) if labels is not None: loss = nn.functional.cross_entropy(logits, labels) outs.loss = loss return outs AutoConfig.register("LLaMa", MyLLaMaConfig) AutoModel.register(MyLLaMaConfig, MyLLaMa) AutoModelForCausalLM.register(MyLLaMaConfig, MyLLaMa)