|
import torch |
|
from torch import nn |
|
from transformers import ( |
|
AutoConfig, |
|
AutoModel, |
|
AutoModelForCausalLM, |
|
PretrainedConfig, |
|
PreTrainedModel, |
|
) |
|
from collections import namedtuple |
|
|
|
from .llama import CustomAttentionLLaMa |
|
|
|
|
|
class MyLLaMaConfig(PretrainedConfig): |
|
model_type = "LLaMa" |
|
|
|
def __init__( |
|
self, |
|
embed_dim: int = 1536, |
|
n_layers: int = 24, |
|
n_heads: int = 24, |
|
n_chckpnt_segments: int = 24, |
|
**kwargs, |
|
): |
|
self.embed_dim = embed_dim |
|
self.n_layers = n_layers |
|
self.n_heads = n_heads |
|
self.n_chckpnt_segments = n_chckpnt_segments |
|
super().__init__(**kwargs) |
|
|
|
|
|
class MyLLaMa(PreTrainedModel): |
|
config_class = MyLLaMaConfig |
|
|
|
def __init__(self, config: MyLLaMaConfig): |
|
super().__init__(config) |
|
self.model = CustomAttentionLLaMa( |
|
config.embed_dim, |
|
config.n_layers, |
|
config.n_heads, |
|
dropout=0, |
|
n_chckpnt_segments=config.n_chckpnt_segments, |
|
) |
|
|
|
def load_state_dict(self, state_dict, **kwargs): |
|
for key in list(state_dict.keys()): |
|
if "rmsnorm1.weight" in key: |
|
new_key = key.replace("rmsnorm1.weight", "rmsnorm1.gamma") |
|
state_dict[new_key] = state_dict.pop(key) |
|
elif "rmsnorm2.weight" in key: |
|
new_key = key.replace("rmsnorm2.weight", "rmsnorm2.gamma") |
|
state_dict[new_key] = state_dict.pop(key) |
|
elif "rmsnorm.weight" in key: |
|
new_key = key.replace("rmsnorm.weight", "rmsnorm.gamma") |
|
state_dict[new_key] = state_dict.pop(key) |
|
|
|
super().load_state_dict(state_dict, **kwargs) |
|
|
|
def forward(self, tensor, labels=None): |
|
att_mask = ( |
|
torch.where( |
|
torch.triu(torch.ones((tensor.shape[1], tensor.shape[1]))) == 1, |
|
0, |
|
-torch.inf, |
|
) |
|
.transpose(0, 1) |
|
.to(self.model.embed.weight.device) |
|
) |
|
|
|
pad_mask = torch.where( |
|
tensor == self.model.tokenizer.pad_token_id, False, True |
|
).to(self.model.embed.weight.device) |
|
outs = namedtuple("output", ["logits", "loss"]) |
|
logits = self.model(tensor, att_mask, pad_mask)["logits"] |
|
|
|
outs.logits = logits.transpose(1, 2) |
|
|
|
if labels is not None: |
|
loss = nn.functional.cross_entropy(logits, labels) |
|
outs.loss = loss |
|
|
|
return outs |
|
|
|
|
|
AutoConfig.register("LLaMa", MyLLaMaConfig) |
|
AutoModel.register(MyLLaMaConfig, MyLLaMa) |
|
AutoModelForCausalLM.register(MyLLaMaConfig, MyLLaMa) |
|
|