import torch from transformers import PreTrainedModel from .configuration_evil import EvilConfig class TinyEvilModel(torch.nn.Module): def __init__(self): super(TinyEvilModel, self).__init__() self.foo = torch.nn.Linear(1,1,dtype=torch.float16) def forward(self, x): return self.foo(x) class EvilModel(PreTrainedModel): config_class = EvilConfig def __init__(self, config): super().__init__(config) self.model = TinyEvilModel() def forward(self, x): return self.model.forward(x)