File size: 556 Bytes
0463385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from transformers import AutoImageProcessor, AutoModelForImageClassification
import torch

processor = AutoImageProcessor.from_pretrained("heyitskim1912/AML_A2_Q4")
model = AutoModelForImageClassification.from_pretrained("heyitskim1912/AML_A2_Q4")

def predict(image_pil):
    inputs = processor(image_pil, return_tensors="pt")

    with torch.no_grad():
        logits = model(**inputs).logits

    # Get predicted label
    predicted_label = logits.argmax(-1).item()
    predicted_class = model.config.id2label[predicted_label]
    return predicted_class