custom-gpt2-model / custom_model.py
akshat57's picture
Upload custom_model.py
e8fea61 verified
raw
history blame contribute delete
447 Bytes
from transformers import GPT2LMHeadModel
class CustomGPT2Model(GPT2LMHeadModel):
def __init__(self, config):
super().__init__(config)
def forward(self, input_ids=None, attention_mask=None, **kwargs):
# Custom forward logic
outputs = super().forward(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
# Modify the outputs as needed
print('USING CUSTOM WRAPPER')
return outputs