import torch from torchvision import models, transforms from PIL import Image import gradio as gr # Define the device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the trained model def load_model(): model = models.resnet50(pretrained=False) num_classes = 4 # Update based on your rice disease classes model.fc = torch.nn.Sequential( torch.nn.Linear(model.fc.in_features, 256), torch.nn.ReLU(), torch.nn.Linear(256, num_classes) ) model.load_state_dict(torch.load("best_model_epoch_43.pth", map_location=device), strict=False) model = model.to(device) model.eval() return model # Define preprocessing steps transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Color mapping for labels label_colors = { "Brown Spot": "#b2ff00", "Healthy": "#2ecc71", "Leaf Blast": "#ff00d4", "Neck Blast": "#ffd100" } # Function to get color based on confidence def get_confidence_color(confidence): if confidence < 0.25: return "#e74c3c" # Red elif confidence < 0.50: return "#f39c12" elif confidence < 0.75: return "#00b9ff" # Yellow else: return "#13ff00" # Green # Updated Prediction Function def predict(image): # Ensure image is in RGB image = image.convert("RGB") input_tensor = transform(image).unsqueeze(0).to(device) # Perform inference with torch.no_grad(): outputs = model(input_tensor) _, predicted_class = torch.max(outputs, 1) # Map predicted class index to actual labels class_names = ["Brown Spot", "Healthy", "Leaf Blast", "Neck Blast"] predicted_label = class_names[predicted_class.item()] # Calculate confidence scores probabilities = torch.nn.functional.softmax(outputs, dim=1)[0] confidence = probabilities[predicted_class.item()].item() # Generate styled output label_color = label_colors.get(predicted_label, "#FFFFFF") # Default White if not found confidence_color = get_confidence_color(confidence) result = f"
{predicted_label}
" result += f"
Confidence: {confidence*100:.2f}%
" return result # Updated Gradio Interface def launch_interface(): # Create a Gradio interface iface = gr.Interface( theme=gr.themes.Citrus( primary_hue="emerald", neutral_hue="slate" ), fn=predict, inputs=gr.Image(type="pil", label="Upload Rice Leaf Image"), outputs=gr.HTML(label="Prediction Results"), title="Rice Disease Classification", description="Upload a rice leaf image to detect its condition (Brown Spot, Healthy, Leaf Blast, or Neck Blast)", examples=[ ["https://doa.gov.lk/wp-content/uploads/2020/06/brownspot3-1024x683.jpg"], ["https://arkansascrops.uada.edu/posts/crops/rice/images/Fig%206%20Rice%20leaf%20blast%20coalesced%20lesions.png"], ["https://th.bing.com/th/id/OIP._5ejX_5Z-M0cO5c2QUmPlwHaE7?w=280&h=187&c=7&r=0&o=5&dpr=1.1&pid=1.7"] ], allow_flagging="never" ) return iface # Load the model globally model = load_model() # Launch the interface if __name__ == "__main__": interface = launch_interface() interface.launch(share=True)