File size: 553 Bytes
fe3b346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from transformers import PreTrainedModel

from .configuration_cetacean_classifier import TemplateClassifierConfig
from .model import TemplateClassifier


class TemplateClassifierModelForImageClassification(PreTrainedModel):
    config_class = TemplateClassifierConfig

    def __init__(self, config):
        super().__init__(config)

        self.model = TemplateClassifier(config=config.to_dict())
        self.model.eval()

    def forward(self, model_input):
        predictions = self.model(model_input)
        return {"predictions": predictions}