import torch from torchvision import models, transforms from PIL import Image import gradio as gr ## Define the device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the trained model def load_model(): model = models.resnet50(pretrained=False) num_classes = 4 # Update based on your rice disease classes model.fc = torch.nn.Sequential( torch.nn.Linear(model.fc.in_features, 256), torch.nn.ReLU(), torch.nn.Linear(256, num_classes) ) model.load_state_dict(torch.load(r"/kaggle/input/rice_epoch8/pytorch/default/1/best_model_epoch_8.pth", map_location=device), strict=False) model = model.to(device) model.eval() return model # Define preprocessing steps transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Prediction function def predict(image): # Ensure image is in RGB image = image.convert("RGB") input_tensor = transform(image).unsqueeze(0).to(device) # Perform inference with torch.no_grad(): outputs = model(input_tensor) _, predicted_class = torch.max(outputs, 1) # Map predicted class index to actual labels class_names = ["Brown Spot", "Healthy", "Leaf Blast", "Neck Blast"] predicted_label = class_names[predicted_class.item()] # Calculate confidence scores probabilities = torch.nn.functional.softmax(outputs, dim=1)[0] confidence = probabilities[predicted_class.item()].item() return f"Predicted Disease: {predicted_label}\nConfidence: {confidence*100:.2f}%" # Load the model globally model = load_model() # Create Gradio interface def launch_interface(): # Create a Gradio interface iface = gr.Interface( theme="Subh775/orchid_candy", fn=predict, inputs=gr.Image(type="pil", label="Upload Rice Leaf Image"), outputs=gr.Textbox(label="Prediction Results"), title="Rice Disease Classification", description="Upload a rice leaf image to detect disease type", examples=[ ["https://doa.gov.lk/wp-content/uploads/2020/06/brownspot3-1024x683.jpg"], ["https://arkansascrops.uada.edu/posts/crops/rice/images/Fig%206%20Rice%20leaf%20blast%20coalesced%20lesions.png"], ["https://th.bing.com/th/id/OIP._5ejX_5Z-M0cO5c2QUmPlwHaE7?w=280&h=187&c=7&r=0&o=5&dpr=1.1&pid=1.7"], ["https://www.weknowrice.com/wp-content/uploads/2022/11/how-to-grow-rice.jpeg"], ], allow_flagging="never" ) return iface # Launch the interface if __name__ == "__main__": interface = launch_interface() interface.launch(share=True)