from transformers import GPT2LMHeadModel | |
class CustomGPT2Model(GPT2LMHeadModel): | |
def __init__(self, config): | |
super().__init__(config) | |
def forward(self, input_ids=None, attention_mask=None, **kwargs): | |
# Custom forward logic | |
outputs = super().forward(input_ids=input_ids, attention_mask=attention_mask, **kwargs) | |
# Modify the outputs as needed | |
print('USING CUSTOM WRAPPER') | |
return outputs | |