Update app.py
Browse files
app.py
CHANGED
@@ -32,15 +32,17 @@ def predict(image):
|
|
32 |
# Get the predicted label
|
33 |
logits = outputs.logits
|
34 |
predicted_class_idx = logits.argmax(-1).item()
|
35 |
-
|
|
|
36 |
# Get the label name
|
37 |
predicted_label = id2label[str(predicted_class_idx)]
|
38 |
-
|
39 |
-
|
|
|
40 |
|
41 |
# Create the Gradio interface
|
42 |
image = gr.Image(type="pil")
|
43 |
label = gr.Label(num_top_classes=3)
|
44 |
|
45 |
gr.Interface(fn=predict, inputs=image, outputs=label, title="Citrus Disease Classification",
|
46 |
-
description="Upload an image of a citrus leaf
|
|
|
32 |
# Get the predicted label
|
33 |
logits = outputs.logits
|
34 |
predicted_class_idx = logits.argmax(-1).item()
|
35 |
+
confidence_score = torch.nn.functional.softmax(logits, dim=-1)[0, predicted_class_idx].item()
|
36 |
+
|
37 |
# Get the label name
|
38 |
predicted_label = id2label[str(predicted_class_idx)]
|
39 |
+
|
40 |
+
# Return the predicted label and confidence score
|
41 |
+
return predicted_label, f"Confidence: {confidence_score:.2f}"
|
42 |
|
43 |
# Create the Gradio interface
|
44 |
image = gr.Image(type="pil")
|
45 |
label = gr.Label(num_top_classes=3)
|
46 |
|
47 |
gr.Interface(fn=predict, inputs=image, outputs=label, title="Citrus Disease Classification",
|
48 |
+
description="Upload an image of a citrus leaf to classify its disease.").launch()
|