Spaces:
Sleeping
Sleeping
from ultralytics import YOLO | |
import gradio as gr | |
import os | |
import numpy as np | |
from PIL import Image | |
# Load the trained model | |
model_path = "best.pt" # Ensure this file is uploaded to your Hugging Face Space | |
model = YOLO(model_path) | |
# Define class names | |
class_names = { | |
0: "bacterial_leaf_blight", | |
1: "brown_spot", | |
2: "healthy", | |
3: "leaf_blast", | |
4: "leaf_scald", | |
5: "narrow_brown_spot" | |
} | |
# Description for each disease | |
disease_info = { | |
"bacterial_leaf_blight": "Bacterial Leaf Blight is caused by Xanthomonas oryzae...", | |
"brown_spot": "Brown Spot is caused by Cochliobolus miyabeanus...", | |
"healthy": "This leaf shows no signs of disease and appears to be healthy.", | |
"leaf_blast": "Leaf Blast is caused by Magnaporthe oryzae...", | |
"leaf_scald": "Leaf Scald is caused by Microdochium oryzae...", | |
"narrow_brown_spot": "Narrow Brown Spot is caused by Cercospora janseana..." | |
} | |
# Treatment recommendations | |
treatment_recommendations = { | |
"bacterial_leaf_blight": "• Use disease-free seeds\n• Apply copper-based bactericides...", | |
"brown_spot": "• Use fungicides containing azoxystrobin or propiconazole...", | |
"healthy": "• Continue regular monitoring...", | |
"leaf_blast": "• Apply fungicides containing tricyclazole or azoxystrobin...", | |
"leaf_scald": "• Apply fungicides containing propiconazole...", | |
"narrow_brown_spot": "• Apply fungicides containing propiconazole or azoxystrobin..." | |
} | |
def predict_image(image): | |
if image is None: | |
return {class_name: 0 for class_name in class_names.values()}, "", "", "" | |
# Convert to RGB if needed | |
if image.shape[2] == 4: | |
image_rgb = Image.fromarray(image).convert('RGB') | |
image = np.array(image_rgb) | |
# Run inference | |
results = model(image)[0] | |
# Get probabilities | |
probs = results.probs.data.tolist() | |
class_probs = {class_names[i]: float(prob) for i, prob in enumerate(probs)} | |
# Get top prediction | |
top_class_idx = results.probs.top1 | |
top_class_name = class_names[top_class_idx] | |
confidence = results.probs.top1conf.item() | |
# Get disease info and treatment | |
info = disease_info[top_class_name] | |
treatment = treatment_recommendations[top_class_name] | |
# Create result text | |
result_text = f"Prediction: {top_class_name}\nConfidence: {confidence:.2f}" | |
return class_probs, result_text, info, treatment | |
# Create the Gradio interface | |
with gr.Blocks(title="Rice Leaf Disease Classifier") as app: | |
gr.Markdown("# 🌱 Rice Leaf Disease Classification") | |
gr.Markdown("Upload an image of a rice leaf to identify potential diseases.") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Upload Image", type="numpy") | |
submit_btn = gr.Button("Analyze Leaf", variant="primary") | |
with gr.Column(): | |
result_text = gr.Textbox(label="Prediction Result") | |
label_probs = gr.Label(label="Disease Probabilities") | |
disease_description = gr.Textbox(label="Disease Information", lines=4) | |
treatment_info = gr.Textbox(label="Treatment Recommendations", lines=6) | |
gr.Markdown("### Disease Categories") | |
gr.Markdown("• Bacterial Leaf Blight\n• Brown Spot\n• Healthy\n• Leaf Blast\n• Leaf Scald\n• Narrow Brown Spot") | |
gr.Markdown("### How to Use") | |
gr.Markdown("1. Upload a clear image of a rice leaf\n2. Click 'Analyze Leaf'\n3. View the prediction results, disease information, and treatment recommendations") | |
submit_btn.click( | |
predict_image, | |
inputs=[input_image], | |
outputs=[label_probs, result_text, disease_description, treatment_info] | |
) | |
# Launch the app | |
app.launch(server_name="0.0.0.0") | |