File size: 2,028 Bytes
3df2514
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import gradio as gr
import torch
from torchvision import models
from PIL import Image
from torch import nn


num_classes = 3
device = "cuda" if torch.cuda.is_available() else "cpu"
class_names = ["pizza", "steak", "sushi"]
class_dict = {"pizza": 0, "steak": 1, "sushi": 2}


def classify_image(image_path):
    # Load the pre-trained model
    model = models.mobilenet_v3_large(
        weights=models.MobileNet_V3_Large_Weights.IMAGENET1K_V1
    )
    model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, num_classes)
    model.load_state_dict(
        torch.load(
            "MobileNetV3-Food-Classification.pth",
            weights_only=True,
            map_location=device,
        )
    )
    model.to(device)
    model.eval()

    # Get the proper transforms directly from the weights
    weights = models.MobileNet_V3_Large_Weights.IMAGENET1K_V1
    preprocess = weights.transforms()

    # Load and transform the image
    image = Image.open(image_path)
    input_tensor = preprocess(image)

    # Add batch dimension
    input_batch = input_tensor.unsqueeze(0)

    # Move to GPU if available
    input_batch = input_batch.to(device)

    # Perform inference
    with torch.no_grad():
        output = model(input_batch)

    # Get predictions
    probabilities = torch.nn.functional.softmax(output[0], dim=0)

    # Get the top prediction
    top_prob, top_catid = torch.topk(probabilities, 1)
    top_category = class_names[top_catid.item()]  # type: ignore
    top_probability = top_prob.item()

    # Format the output string nicely
    return f"Prediction: {top_category.title()} ({top_probability:.1%} confident)"


# Update the interface with label and better output type
demo = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(
        type="filepath", label="Upload a food image (pizza, steak, or sushi)"
    ),
    outputs=gr.Text(label="Classification Result"),
    title="Food Classifier",
    description="This model classifies images of pizza, steak, and sushi.",
)
demo.launch()