from transformers import GPT2LMHeadModel class CustomGPT2Model(GPT2LMHeadModel): def __init__(self, config): super().__init__(config) def forward(self, input_ids=None, attention_mask=None, **kwargs): # Custom forward logic outputs = super().forward(input_ids=input_ids, attention_mask=attention_mask, **kwargs) # Modify the outputs as needed print('USING CUSTOM WRAPPER') return outputs