foo3 / modeling_custom.py
denizyuret-shallowai's picture
Upload model
c4c6d7d
raw
history blame
960 Bytes
# https://huggingface.co/docs/transformers/custom_models
from transformers import PreTrainedModel, GPTNeoXForCausalLM, AutoModelForCausalLM
from torch.nn.functional import log_softmax
# In the example there is also config class but we'll just use the one from GPTNeoX
# The norm is to import from PreTrainedModel but we'll take a shortcut
class CustomModel(GPTNeoXForCausalLM):
def __init__(self, config):
super().__init__(config)
def forward(self, *args, **kwargs):
# See https://huggingface.co/docs/transformers/main_classes/output
out = super().forward(*args, **kwargs)
out.logits = log_softmax(out.logits, dim=-1)
return out
@classmethod
def copy_from_neox(cls, *args, **kwargs):
m0 = GPTNeoXForCausalLM.from_pretrained(*args, **kwargs)
m1 = cls(m0.config)
m1.load_state_dict(m0.state_dict())
return m1
CustomModel.register_for_auto_class('AutoModelForCausalLM')