iamomtiwari commited on
Commit
7f05610
·
verified ·
1 Parent(s): 7eef8ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -5
app.py CHANGED
@@ -36,32 +36,43 @@ labels_list = [class_labels[i]["label"] for i in range(1, 15)]
36
  # Confidence threshold for ViT model
37
  CONFIDENCE_THRESHOLD = 0.5
38
 
39
- # Inference function
40
  def predict(image):
41
  # First, use the crop disease model (ViT)
42
  inputs = feature_extractor(images=image, return_tensors="pt")
43
  with torch.no_grad():
44
  outputs = model(**inputs)
45
  logits = outputs.logits
 
46
  predicted_class_idx = logits.argmax(-1).item()
47
- confidence = torch.softmax(logits, dim=-1)[0, predicted_class_idx].item()
48
 
49
  # If confidence is below the threshold, use the fallback model
50
  if confidence < CONFIDENCE_THRESHOLD:
51
  inputs_fallback = fallback_feature_extractor(images=image, return_tensors="pt")
52
  with torch.no_grad():
53
  outputs_fallback = fallback_model(**inputs_fallback)
54
- predicted_class_idx_fallback = outputs_fallback.logits.argmax(-1).item()
 
 
 
55
 
56
  # Get the fallback prediction label
57
  fallback_label = fallback_model.config.id2label[predicted_class_idx_fallback]
58
 
59
- return f"Low confidence in ViT model. ResNet-50 suggests: {fallback_label}\n\nIf this does not match your input, please try another image."
 
 
 
 
60
 
61
  # If confidence is above the threshold, return the ViT prediction and treatment advice
62
  predicted_label = labels_list[predicted_class_idx]
63
  treatment_advice = class_labels[predicted_class_idx + 1]["treatment"]
64
- return f"Disease: {predicted_label}\n\nTreatment Advice: {treatment_advice}"
 
 
 
65
 
66
  # Create Gradio Interface
67
  interface = gr.Interface(
 
36
  # Confidence threshold for ViT model
37
  CONFIDENCE_THRESHOLD = 0.5
38
 
39
+ # Inference function with fuzzy confidence
40
  def predict(image):
41
  # First, use the crop disease model (ViT)
42
  inputs = feature_extractor(images=image, return_tensors="pt")
43
  with torch.no_grad():
44
  outputs = model(**inputs)
45
  logits = outputs.logits
46
+ confidences = torch.softmax(logits, dim=-1)
47
  predicted_class_idx = logits.argmax(-1).item()
48
+ confidence = confidences[0, predicted_class_idx].item()
49
 
50
  # If confidence is below the threshold, use the fallback model
51
  if confidence < CONFIDENCE_THRESHOLD:
52
  inputs_fallback = fallback_feature_extractor(images=image, return_tensors="pt")
53
  with torch.no_grad():
54
  outputs_fallback = fallback_model(**inputs_fallback)
55
+ logits_fallback = outputs_fallback.logits
56
+ confidences_fallback = torch.softmax(logits_fallback, dim=-1)
57
+ predicted_class_idx_fallback = logits_fallback.argmax(-1).item()
58
+ fallback_confidence = confidences_fallback[0, predicted_class_idx_fallback].item()
59
 
60
  # Get the fallback prediction label
61
  fallback_label = fallback_model.config.id2label[predicted_class_idx_fallback]
62
 
63
+ return (
64
+ f"Low confidence in ViT model ({confidence * 100:.2f}%).\n"
65
+ f"ResNet-50 predicts: {fallback_label} ({fallback_confidence * 100:.2f}%).\n\n"
66
+ "If this does not match your input, please try another image."
67
+ )
68
 
69
  # If confidence is above the threshold, return the ViT prediction and treatment advice
70
  predicted_label = labels_list[predicted_class_idx]
71
  treatment_advice = class_labels[predicted_class_idx + 1]["treatment"]
72
+ return (
73
+ f"Disease: {predicted_label} ({confidence * 100:.2f}%)\n\n"
74
+ f"Treatment Advice: {treatment_advice}"
75
+ )
76
 
77
  # Create Gradio Interface
78
  interface = gr.Interface(