File size: 447 Bytes
e8fea61
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from transformers import GPT2LMHeadModel

class CustomGPT2Model(GPT2LMHeadModel):
    def __init__(self, config):
        super().__init__(config)

    def forward(self, input_ids=None, attention_mask=None, **kwargs):
        # Custom forward logic
        outputs = super().forward(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
        # Modify the outputs as needed

        print('USING CUSTOM WRAPPER')
        return outputs