KhadijaAsehnoune12 commited on
Commit
185be2e
·
verified ·
1 Parent(s): c2f9f33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -32,15 +32,17 @@ def predict(image):
32
  # Get the predicted label
33
  logits = outputs.logits
34
  predicted_class_idx = logits.argmax(-1).item()
35
-
 
36
  # Get the label name
37
  predicted_label = id2label[str(predicted_class_idx)]
38
-
39
- return predicted_label
 
40
 
41
  # Create the Gradio interface
42
  image = gr.Image(type="pil")
43
  label = gr.Label(num_top_classes=3)
44
 
45
  gr.Interface(fn=predict, inputs=image, outputs=label, title="Citrus Disease Classification",
46
- description="Upload an image of a citrus leaf or fruit to classify its disease.").launch()
 
32
  # Get the predicted label
33
  logits = outputs.logits
34
  predicted_class_idx = logits.argmax(-1).item()
35
+ confidence_score = torch.nn.functional.softmax(logits, dim=-1)[0, predicted_class_idx].item()
36
+
37
  # Get the label name
38
  predicted_label = id2label[str(predicted_class_idx)]
39
+
40
+ # Return the predicted label and confidence score
41
+ return predicted_label, f"Confidence: {confidence_score:.2f}"
42
 
43
  # Create the Gradio interface
44
  image = gr.Image(type="pil")
45
  label = gr.Label(num_top_classes=3)
46
 
47
  gr.Interface(fn=predict, inputs=image, outputs=label, title="Citrus Disease Classification",
48
+ description="Upload an image of a citrus leaf to classify its disease.").launch()