# https://huggingface.co/docs/transformers/custom_models from transformers import PreTrainedModel, GPTNeoXForCausalLM, AutoModelForCausalLM from torch.nn.functional import log_softmax # In the example there is also config class but we'll just use the one from GPTNeoX # The norm is to import from PreTrainedModel but we'll take a shortcut class CustomModel(GPTNeoXForCausalLM): def __init__(self, config): super().__init__(config) def forward(self, *args, **kwargs): # See https://huggingface.co/docs/transformers/main_classes/output out = super().forward(*args, **kwargs) out.logits = log_softmax(out.logits, dim=-1) return out @classmethod def copy_from_neox(cls, *args, **kwargs): m0 = GPTNeoXForCausalLM.from_pretrained(*args, **kwargs) m1 = cls(m0.config).to(dtype=m0.dtype, device=m0.device) m1.load_state_dict(m0.state_dict()) return m1 CustomModel.register_for_auto_class('AutoModelForCausalLM')