|
import torch |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
import os |
|
|
|
|
|
device = ( |
|
"cuda" |
|
if torch.cuda.is_available() |
|
else "mps" |
|
if torch.backends.mps.is_available() |
|
else "cpu" |
|
) |
|
|
|
class Params: |
|
def __init__(self): |
|
self.batch_size = 512 |
|
self.name = "resnet_50" |
|
self.workers = 16 |
|
self.lr = 0.1 |
|
self.momentum = 0.9 |
|
self.weight_decay = 1e-4 |
|
self.lr_step_size = 30 |
|
self.lr_gamma = 0.1 |
|
|
|
def __repr__(self): |
|
return str(self.__dict__) |
|
|
|
def __eq__(self, other): |
|
return self.__dict__ == other.__dict__ |
|
|
|
params = Params() |
|
|
|
|
|
checkpoint_path = "checkpoints/resnet_50/checkpoint.pth" |
|
|
|
|
|
from model import ResNet50 |
|
|
|
num_classes = 1000 |
|
model = ResNet50(num_classes=num_classes).to(device) |
|
|
|
|
|
checkpoint = torch.load(checkpoint_path) |
|
model.load_state_dict(checkpoint["model"]) |
|
|
|
model.eval() |
|
|
|
|
|
inference_transforms = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Resize(size=256), |
|
transforms.CenterCrop(224), |
|
transforms.Normalize(mean=[0.485, 0.485, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
|
|
def load_class_names(file_path): |
|
with open(file_path, 'r') as f: |
|
class_names = [line.strip() for line in f] |
|
return class_names |
|
|
|
|
|
def predict(image_path, model, transforms, class_names=None): |
|
|
|
image = Image.open(image_path).convert("RGB") |
|
image_tensor = transforms(image).unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
output = model(image_tensor) |
|
probabilities = torch.nn.functional.softmax(output[0], dim=0) |
|
top_prob, top_class = probabilities.topk(5, largest=True, sorted=True) |
|
|
|
|
|
print("Predictions:") |
|
for i in range(top_prob.size(0)): |
|
class_name = class_names[top_class[i]] if class_names else f"Class {top_class[i].item()}" |
|
print(f"{class_name}: {top_prob[i].item() * 100:.2f}%") |
|
|
|
return top_prob, top_class |
|
|
|
|
|
imagenet_classes_file = "imagenet-classes.txt" |
|
class_names = load_class_names(imagenet_classes_file) |
|
|
|
|
|
image_path = "dog.png" |
|
|
|
|
|
predict(image_path, model, inference_transforms, class_names=class_names) |