File size: 553 Bytes
fe3b346 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
from transformers import PreTrainedModel
from .configuration_cetacean_classifier import TemplateClassifierConfig
from .model import TemplateClassifier
class TemplateClassifierModelForImageClassification(PreTrainedModel):
config_class = TemplateClassifierConfig
def __init__(self, config):
super().__init__(config)
self.model = TemplateClassifier(config=config.to_dict())
self.model.eval()
def forward(self, model_input):
predictions = self.model(model_input)
return {"predictions": predictions}
|