import gradio as gr import torch import torch.nn as nn import torchvision.transforms as transforms from PIL import Image import json import os from train_resnet50 import ResNet, Bottleneck # Load ImageNet class labels try: with open("imagenet_classes.json", "r") as f: class_labels = json.load(f) print(f"Loaded {len(class_labels)} class labels") except FileNotFoundError: print("Warning: imagenet_classes.json not found, creating simplified labels") # Fallback to a simplified version class_labels = {str(i): f"class_{i}" for i in range(1000)} except json.JSONDecodeError: print("Warning: Error parsing imagenet_classes.json, using simplified labels") class_labels = {str(i): f"class_{i}" for i in range(1000)} except Exception as e: print(f"Warning: Unexpected error loading class labels: {e}") class_labels = {str(i): f"class_{i}" for i in range(1000)} def create_model(): model = ResNet(Bottleneck, [3, 4, 6, 3]) return model def load_model(model_path): model = create_model() try: checkpoint = torch.load(model_path, map_location="cpu") # Handle DataParallel/DDP state dict state_dict = checkpoint["model_state_dict"] new_state_dict = {} for k, v in state_dict.items(): name = k.replace("module.", "") if k.startswith("module.") else k new_state_dict[name] = v model.load_state_dict(new_state_dict) model.eval() print("Model loaded successfully!") return model except Exception as e: print(f"Error loading model: {e}") print("Loading pretrained ResNet50 as fallback...") model = torch.hub.load("pytorch/vision:v0.10.0", "resnet50", pretrained=True) model.eval() return model # Preprocessing transform transform = transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) # Global variable for model global_model = None def predict(image): global global_model # Load model only once if global_model is None: try: global_model = load_model("best_model.pth") except Exception as e: print(f"Error loading model: {e}") return None # Preprocess image if image is None: return None try: image = Image.fromarray(image) image = transform(image).unsqueeze(0) # Make prediction with torch.no_grad(): outputs = global_model(image) probabilities = torch.nn.functional.softmax(outputs[0], dim=0) # Get top 5 predictions top5_prob, top5_catid = torch.topk(probabilities, 5) # Create results dictionary results = [] for i in range(5): class_idx = top5_catid[i].item() # Use list indexing instead of dictionary get() class_label = ( class_labels[class_idx] if class_idx < len(class_labels) else f"class_{class_idx}" ) results.append( { "label": class_label, "class_id": class_idx, "confidence": float(top5_prob[i].item()), } ) return results except Exception as e: print(f"Error during prediction: {e}") print(f"Class indices: {[idx.item() for idx in top5_catid]}") # Debug info return None # Create Gradio interface iface = gr.Interface( fn=predict, inputs=gr.Image(), outputs=gr.JSON(), title="ResNet50 ImageNet Classifier", description="Upload an image to get top-5 predictions from our trained ResNet50 model.", ) # Launch the app if __name__ == "__main__": iface.launch(share=True)