alKoGolik's picture
Upload 169 files
c87c295 verified
raw
history blame contribute delete
702 Bytes
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers import LlamaModel, LlamaConfig
import torch.nn as nn
CODELLAMA_VOCAB_SIZE = 32016
class CodeLlamaForCausalLM(LlamaForCausalLM):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = LlamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, CODELLAMA_VOCAB_SIZE, bias=False)
self.model.embed_tokens = nn.Embedding(
CODELLAMA_VOCAB_SIZE, config.hidden_size, config.pad_token_id
)
# Initialize weights and apply final processing
self.post_init()