Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -19,14 +19,14 @@ class_labels = {
|
|
19 |
3: {"label": "Stage Safe Corn Healthy", "treatment": "Continue good agricultural practices: ensure proper irrigation, nutrient supply, and monitor for pests."},
|
20 |
4: {"label": "Stage Corn Northern Leaf Blight", "treatment": "Remove and destroy infected plant debris, apply fungicides, and rotate crops."},
|
21 |
5: {"label": "Stage Rice Brown Spot", "treatment": "Use resistant varieties, improve field drainage, and apply fungicides if necessary."},
|
22 |
-
6: {"label": "Stage
|
23 |
7: {"label": "Stage Rice Leaf Blast", "treatment": "Use resistant varieties, apply fungicides during high-risk periods, and practice good field management."},
|
24 |
8: {"label": "Stage Rice Neck Blast", "treatment": "Plant resistant varieties, improve nutrient management, and apply fungicides if symptoms appear."},
|
25 |
9: {"label": "Stage Sugarcane Bacterial Blight", "treatment": "Use disease-free planting material, practice crop rotation, and destroy infected plants."},
|
26 |
-
10: {"label": "Stage
|
27 |
11: {"label": "Stage Sugarcane Red Rot", "treatment": "Plant resistant varieties and ensure good drainage."},
|
28 |
-
12: {"label": "Stage Wheat
|
29 |
-
13: {"label": "Stage
|
30 |
14: {"label": "Stage Wheat Yellow Rust", "treatment": "Use resistant varieties, apply fungicides, and rotate crops."}
|
31 |
}
|
32 |
|
@@ -46,7 +46,7 @@ def predict(image):
|
|
46 |
predicted_class_idx = logits.argmax(-1).item()
|
47 |
confidence = torch.softmax(logits, dim=-1)[0, predicted_class_idx].item()
|
48 |
|
49 |
-
# If confidence is below the threshold,
|
50 |
if confidence < CONFIDENCE_THRESHOLD:
|
51 |
inputs_fallback = fallback_feature_extractor(images=image, return_tensors="pt")
|
52 |
with torch.no_grad():
|
@@ -55,7 +55,8 @@ def predict(image):
|
|
55 |
|
56 |
# Get the fallback prediction label
|
57 |
fallback_label = fallback_model.config.id2label[predicted_class_idx_fallback]
|
58 |
-
|
|
|
59 |
|
60 |
# If confidence is above the threshold, return the ViT prediction and treatment advice
|
61 |
predicted_label = labels_list[predicted_class_idx]
|
|
|
19 |
3: {"label": "Stage Safe Corn Healthy", "treatment": "Continue good agricultural practices: ensure proper irrigation, nutrient supply, and monitor for pests."},
|
20 |
4: {"label": "Stage Corn Northern Leaf Blight", "treatment": "Remove and destroy infected plant debris, apply fungicides, and rotate crops."},
|
21 |
5: {"label": "Stage Rice Brown Spot", "treatment": "Use resistant varieties, improve field drainage, and apply fungicides if necessary."},
|
22 |
+
6: {"label": "Stage Safe Rice Healthy", "treatment": "Maintain proper irrigation, fertilization, and pest control measures."},
|
23 |
7: {"label": "Stage Rice Leaf Blast", "treatment": "Use resistant varieties, apply fungicides during high-risk periods, and practice good field management."},
|
24 |
8: {"label": "Stage Rice Neck Blast", "treatment": "Plant resistant varieties, improve nutrient management, and apply fungicides if symptoms appear."},
|
25 |
9: {"label": "Stage Sugarcane Bacterial Blight", "treatment": "Use disease-free planting material, practice crop rotation, and destroy infected plants."},
|
26 |
+
10: {"label": "Stage Safe Sugarcane Healthy", "treatment": "Maintain healthy soil conditions and proper irrigation."},
|
27 |
11: {"label": "Stage Sugarcane Red Rot", "treatment": "Plant resistant varieties and ensure good drainage."},
|
28 |
+
12: {"label": "Stage Wheat Brown Rust", "treatment": "Apply fungicides and practice crop rotation with non-host crops."},
|
29 |
+
13: {"label": "Stage Safe Wheat Healthy", "treatment": "Continue with good management practices, including proper fertilization and weed control."},
|
30 |
14: {"label": "Stage Wheat Yellow Rust", "treatment": "Use resistant varieties, apply fungicides, and rotate crops."}
|
31 |
}
|
32 |
|
|
|
46 |
predicted_class_idx = logits.argmax(-1).item()
|
47 |
confidence = torch.softmax(logits, dim=-1)[0, predicted_class_idx].item()
|
48 |
|
49 |
+
# If confidence is below the threshold, use the fallback model
|
50 |
if confidence < CONFIDENCE_THRESHOLD:
|
51 |
inputs_fallback = fallback_feature_extractor(images=image, return_tensors="pt")
|
52 |
with torch.no_grad():
|
|
|
55 |
|
56 |
# Get the fallback prediction label
|
57 |
fallback_label = fallback_model.config.id2label[predicted_class_idx_fallback]
|
58 |
+
|
59 |
+
return f"Low confidence in ViT model. ResNet-50 suggests: {fallback_label}\n\nIf this does not match your input, please try another image."
|
60 |
|
61 |
# If confidence is above the threshold, return the ViT prediction and treatment advice
|
62 |
predicted_label = labels_list[predicted_class_idx]
|