mnist-classifier / README.md
jerilseb's picture
Update README.md
58ed48f verified
|
raw
history blame
680 Bytes
metadata
library_name: transformers
tags: []

Usage

Register the model

from transformers import AutoConfig, AutoModel

AutoConfig.register("mnist_classifier", MNISTConfig)
AutoModel.register(MNISTConfig, MNISTClassifier)

Inference

from transformers import AutoConfig, AutoModel
import torch

config = AutoConfig.from_pretrained("jerilseb/mnist-classifier")
model = AutoModel.from_pretrained("jerilseb/mnist-classifier")

input_tensor = torch.randn(1, 28, 28)  # Single image, adjust batch size as needed

with torch.no_grad():
    output = model(input_tensor)

predicted_class = output.argmax(-1).item()
print(f"Predicted class: {predicted_class}")