akshaikrishna commited on
Commit
3df2514
·
1 Parent(s): e6119fe

Add application file

Browse files
Files changed (1) hide show
  1. app.py +70 -0
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision import models
4
+ from PIL import Image
5
+ from torch import nn
6
+
7
+
8
+ num_classes = 3
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ class_names = ["pizza", "steak", "sushi"]
11
+ class_dict = {"pizza": 0, "steak": 1, "sushi": 2}
12
+
13
+
14
+ def classify_image(image_path):
15
+ # Load the pre-trained model
16
+ model = models.mobilenet_v3_large(
17
+ weights=models.MobileNet_V3_Large_Weights.IMAGENET1K_V1
18
+ )
19
+ model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, num_classes)
20
+ model.load_state_dict(
21
+ torch.load(
22
+ "MobileNetV3-Food-Classification.pth",
23
+ weights_only=True,
24
+ map_location=device,
25
+ )
26
+ )
27
+ model.to(device)
28
+ model.eval()
29
+
30
+ # Get the proper transforms directly from the weights
31
+ weights = models.MobileNet_V3_Large_Weights.IMAGENET1K_V1
32
+ preprocess = weights.transforms()
33
+
34
+ # Load and transform the image
35
+ image = Image.open(image_path)
36
+ input_tensor = preprocess(image)
37
+
38
+ # Add batch dimension
39
+ input_batch = input_tensor.unsqueeze(0)
40
+
41
+ # Move to GPU if available
42
+ input_batch = input_batch.to(device)
43
+
44
+ # Perform inference
45
+ with torch.no_grad():
46
+ output = model(input_batch)
47
+
48
+ # Get predictions
49
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
50
+
51
+ # Get the top prediction
52
+ top_prob, top_catid = torch.topk(probabilities, 1)
53
+ top_category = class_names[top_catid.item()] # type: ignore
54
+ top_probability = top_prob.item()
55
+
56
+ # Format the output string nicely
57
+ return f"Prediction: {top_category.title()} ({top_probability:.1%} confident)"
58
+
59
+
60
+ # Update the interface with label and better output type
61
+ demo = gr.Interface(
62
+ fn=classify_image,
63
+ inputs=gr.Image(
64
+ type="filepath", label="Upload a food image (pizza, steak, or sushi)"
65
+ ),
66
+ outputs=gr.Text(label="Classification Result"),
67
+ title="Food Classifier",
68
+ description="This model classifies images of pizza, steak, and sushi.",
69
+ )
70
+ demo.launch()