File size: 960 Bytes
c4c6d7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# https://huggingface.co/docs/transformers/custom_models

from transformers import PreTrainedModel, GPTNeoXForCausalLM, AutoModelForCausalLM
from torch.nn.functional import log_softmax

# In the example there is also config class but we'll just use the one from GPTNeoX
# The norm is to import from PreTrainedModel but we'll take a shortcut
class CustomModel(GPTNeoXForCausalLM):
    def __init__(self, config):
        super().__init__(config)

    def forward(self, *args, **kwargs):
        # See https://huggingface.co/docs/transformers/main_classes/output
        out = super().forward(*args, **kwargs)
        out.logits = log_softmax(out.logits, dim=-1)
        return out

    @classmethod
    def copy_from_neox(cls, *args, **kwargs):
        m0 = GPTNeoXForCausalLM.from_pretrained(*args, **kwargs)
        m1 = cls(m0.config)
        m1.load_state_dict(m0.state_dict())
        return m1

CustomModel.register_for_auto_class('AutoModelForCausalLM')