from transformers import PreTrainedModel | |
from .configuration_cetacean_classifier import TemplateClassifierConfig | |
from .model import TemplateClassifier | |
class TemplateClassifierModelForImageClassification(PreTrainedModel): | |
config_class = TemplateClassifierConfig | |
def __init__(self, config): | |
super().__init__(config) | |
self.model = TemplateClassifier(config=config.to_dict()) | |
self.model.eval() | |
def forward(self, model_input): | |
predictions = self.model(model_input) | |
return {"predictions": predictions} | |