TharunSivamani commited on
Commit
db2ca82
·
verified ·
1 Parent(s): f736353

inference.py for simple inference

Browse files
Files changed (1) hide show
  1. inference.py +91 -0
inference.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ from PIL import Image
4
+ import os
5
+
6
+ # Define the device
7
+ device = (
8
+ "cuda"
9
+ if torch.cuda.is_available()
10
+ else "mps"
11
+ if torch.backends.mps.is_available()
12
+ else "cpu"
13
+ )
14
+
15
+ class Params:
16
+ def __init__(self):
17
+ self.batch_size = 512
18
+ self.name = "resnet_50"
19
+ self.workers = 16
20
+ self.lr = 0.1
21
+ self.momentum = 0.9
22
+ self.weight_decay = 1e-4
23
+ self.lr_step_size = 30
24
+ self.lr_gamma = 0.1
25
+
26
+ def __repr__(self):
27
+ return str(self.__dict__)
28
+
29
+ def __eq__(self, other):
30
+ return self.__dict__ == other.__dict__
31
+
32
+ params = Params()
33
+
34
+ # Path to the saved model checkpoint
35
+ checkpoint_path = "checkpoints/resnet_50/checkpoint.pth"
36
+
37
+ # Load the model architecture
38
+ from model import ResNet50 # Assuming resnet.py contains your model definition
39
+
40
+ num_classes = 1000 # Adjust this to match your dataset
41
+ model = ResNet50(num_classes=num_classes).to(device)
42
+
43
+ # Load the trained model weights
44
+ checkpoint = torch.load(checkpoint_path)
45
+ model.load_state_dict(checkpoint["model"])
46
+
47
+ model.eval()
48
+
49
+ # Define transformations for inference
50
+ inference_transforms = transforms.Compose([
51
+ transforms.ToTensor(),
52
+ transforms.Resize(size=256),
53
+ transforms.CenterCrop(224),
54
+ transforms.Normalize(mean=[0.485, 0.485, 0.406], std=[0.229, 0.224, 0.225]),
55
+ ])
56
+
57
+ # Load class names from the text file
58
+ def load_class_names(file_path):
59
+ with open(file_path, 'r') as f:
60
+ class_names = [line.strip() for line in f]
61
+ return class_names
62
+
63
+ # Function to make predictions on a single image
64
+ def predict(image_path, model, transforms, class_names=None):
65
+ # Load and transform the image
66
+ image = Image.open(image_path).convert("RGB")
67
+ image_tensor = transforms(image).unsqueeze(0).to(device)
68
+
69
+ # Forward pass
70
+ with torch.no_grad():
71
+ output = model(image_tensor)
72
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
73
+ top_prob, top_class = probabilities.topk(5, largest=True, sorted=True)
74
+
75
+ # Display the top predictions
76
+ print("Predictions:")
77
+ for i in range(top_prob.size(0)):
78
+ class_name = class_names[top_class[i]] if class_names else f"Class {top_class[i].item()}"
79
+ print(f"{class_name}: {top_prob[i].item() * 100:.2f}%")
80
+
81
+ return top_prob, top_class
82
+
83
+ # Path to the ImageNet classes text file
84
+ imagenet_classes_file = "imagenet-classes.txt" # Replace with the actual path to your text file
85
+ class_names = load_class_names(imagenet_classes_file)
86
+
87
+ # Path to the image for inference
88
+ image_path = "dog.png" # Replace with the actual path to your test image
89
+
90
+ # Make a prediction
91
+ predict(image_path, model, inference_transforms, class_names=class_names)