|
|
|
|
|
from transformers import PreTrainedModel, GPTNeoXForCausalLM, AutoModelForCausalLM |
|
from torch.nn.functional import log_softmax |
|
|
|
|
|
|
|
class CustomModel(GPTNeoXForCausalLM): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
def forward(self, *args, **kwargs): |
|
|
|
out = super().forward(*args, **kwargs) |
|
out.logits = log_softmax(out.logits, dim=-1) |
|
return out |
|
|
|
@classmethod |
|
def copy_from_neox(cls, *args, **kwargs): |
|
m0 = GPTNeoXForCausalLM.from_pretrained(*args, **kwargs) |
|
m1 = cls(m0.config).to(dtype=m0.dtype, device=m0.device) |
|
m1.load_state_dict(m0.state_dict()) |
|
return m1 |
|
|
|
CustomModel.register_for_auto_class('AutoModelForCausalLM') |
|
|