File size: 447 Bytes
e8fea61 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
from transformers import GPT2LMHeadModel
class CustomGPT2Model(GPT2LMHeadModel):
def __init__(self, config):
super().__init__(config)
def forward(self, input_ids=None, attention_mask=None, **kwargs):
# Custom forward logic
outputs = super().forward(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
# Modify the outputs as needed
print('USING CUSTOM WRAPPER')
return outputs
|