import torch from torchvision import models, transforms from PIL import Image import gradio as gr # Define the device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the trained model def load_model(): model = models.resnet50(pretrained=False) num_classes = 4 # Update based on your rice disease classes model.fc = torch.nn.Sequential( torch.nn.Linear(model.fc.in_features, 256), torch.nn.ReLU(), torch.nn.Linear(256, num_classes) ) model.load_state_dict(torch.load("best_model_epoch_43.pth", map_location=device), strict=False) model = model.to(device) model.eval() return model # Define preprocessing steps transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Color mapping for labels label_colors = { "Brown Spot": "#b2ff00", "Healthy": "#2ecc71", "Leaf Blast": "#ff00d4", "Neck Blast": "#ffd100" } # Function to get color based on confidence def get_confidence_color(confidence): if confidence < 0.25: return "#e74c3c" # Red elif confidence < 0.50: return "#f39c12" elif confidence < 0.75: return "#00b9ff" # Yellow else: return "#13ff00" # Green # Updated Prediction Function def predict(image): # Ensure image is in RGB image = image.convert("RGB") input_tensor = transform(image).unsqueeze(0).to(device) # Perform inference with torch.no_grad(): outputs = model(input_tensor) _, predicted_class = torch.max(outputs, 1) # Map predicted class index to actual labels class_names = ["Brown Spot", "Healthy", "Leaf Blast", "Neck Blast"] predicted_label = class_names[predicted_class.item()] # Calculate confidence scores probabilities = torch.nn.functional.softmax(outputs, dim=1)[0] confidence = probabilities[predicted_class.item()].item() # Generate styled output label_color = label_colors.get(predicted_label, "#FFFFFF") # Default White if not found confidence_color = get_confidence_color(confidence) result = f"