|
import torch |
|
from torchvision import models, transforms |
|
from PIL import Image |
|
import gradio as gr |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
def load_model(): |
|
model = models.resnet50(pretrained=False) |
|
num_classes = 4 |
|
model.fc = torch.nn.Sequential( |
|
torch.nn.Linear(model.fc.in_features, 256), |
|
torch.nn.ReLU(), |
|
torch.nn.Linear(256, num_classes) |
|
) |
|
model.load_state_dict(torch.load(r"/kaggle/input/rice_epoch8/pytorch/default/1/best_model_epoch_8.pth", map_location=device), strict=False) |
|
model = model.to(device) |
|
model.eval() |
|
return model |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
|
|
def predict(image): |
|
|
|
image = image.convert("RGB") |
|
|
|
input_tensor = transform(image).unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(input_tensor) |
|
_, predicted_class = torch.max(outputs, 1) |
|
|
|
|
|
class_names = ["Brown Spot", "Healthy", "Leaf Blast", "Neck Blast"] |
|
predicted_label = class_names[predicted_class.item()] |
|
|
|
|
|
probabilities = torch.nn.functional.softmax(outputs, dim=1)[0] |
|
confidence = probabilities[predicted_class.item()].item() |
|
|
|
return f"Predicted Disease: {predicted_label}\nConfidence: {confidence*100:.2f}%" |
|
|
|
|
|
model = load_model() |
|
|
|
|
|
def launch_interface(): |
|
|
|
iface = gr.Interface( |
|
theme="Subh775/orchid_candy", |
|
fn=predict, |
|
inputs=gr.Image(type="pil", label="Upload Rice Leaf Image"), |
|
outputs=gr.Textbox(label="Prediction Results"), |
|
title="Rice Disease Classification", |
|
description="Upload a rice leaf image to detect disease type", |
|
examples=[ |
|
["https://doa.gov.lk/wp-content/uploads/2020/06/brownspot3-1024x683.jpg"], |
|
["https://arkansascrops.uada.edu/posts/crops/rice/images/Fig%206%20Rice%20leaf%20blast%20coalesced%20lesions.png"], |
|
["https://th.bing.com/th/id/OIP._5ejX_5Z-M0cO5c2QUmPlwHaE7?w=280&h=187&c=7&r=0&o=5&dpr=1.1&pid=1.7"], |
|
["https://www.weknowrice.com/wp-content/uploads/2022/11/how-to-grow-rice.jpeg"], |
|
], |
|
allow_flagging="never" |
|
) |
|
|
|
return iface |
|
|
|
|
|
if __name__ == "__main__": |
|
interface = launch_interface() |
|
interface.launch(share=True) |