metadata
library_name: transformers
tags: []
Usage
Register the model
from transformers import AutoConfig, AutoModel
AutoConfig.register("mnist_classifier", MNISTConfig)
AutoModel.register(MNISTConfig, MNISTClassifier)
Inference
from transformers import AutoConfig, AutoModel
import torch
config = AutoConfig.from_pretrained("jerilseb/mnist-classifier")
model = AutoModel.from_pretrained("jerilseb/mnist-classifier")
input_tensor = torch.randn(1, 28, 28) # Single image, adjust batch size as needed
with torch.no_grad():
output = model(input_tensor)
predicted_class = output.argmax(-1).item()
print(f"Predicted class: {predicted_class}")