File size: 492 Bytes
b29fdb3
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from transformers import PreTrainedModel
from uniformer_finetune_config import UniformerXXSFinetuneConfig
from uniformer_xs import UniformerXXSFinetune

class UniformerXXSFinetuneModel(PreTrainedModel):
    config_class = UniformerXXSFinetuneConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = UniformerXXSFinetune(
            out_class=config.out_class
        )
    def forward(self, tensor):
        return self.model.forward(tensor)