Harshithtd commited on
Commit
2a1fba5
·
verified ·
1 Parent(s): 3311d3f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -0
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import models, transforms
3
+ from PIL import Image
4
+ import gradio as gr
5
+
6
+ # Updated class names with 'plaque' in front of 'calculus' and 'gingivitis'
7
+ class_names = [
8
+ "plaque_calculus",
9
+ "caries",
10
+ "plaque_gingivitis",
11
+ "hypodontia",
12
+ "mouth_ulcer",
13
+ "tooth_discoloration"
14
+ ]
15
+
16
+ model = models.resnet50(weights=None)
17
+ model.fc = torch.nn.Linear(model.fc.in_features, len(class_names))
18
+
19
+ model.load_state_dict(torch.load('tooth_model.pth', map_location=torch.device('cpu')))
20
+ model.eval()
21
+
22
+ preprocess = transforms.Compose([
23
+ transforms.Resize((224, 224)),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
26
+ ])
27
+
28
+ def predict_image(image):
29
+ processed_image = preprocess(image).unsqueeze(0)
30
+
31
+ with torch.no_grad():
32
+ outputs = model(processed_image)
33
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
34
+ top_probs, top_indices = torch.topk(probabilities, 3)
35
+ top_classes = [class_names[idx] for idx in top_indices[0]]
36
+
37
+ # Create a result dictionary with class names and probabilities
38
+ result = {top_classes[i]: top_probs[0][i].item() for i in range(3)}
39
+
40
+ return result
41
+
42
+ iface = gr.Interface(
43
+ fn=predict_image,
44
+ inputs=gr.Image(type="pil"),
45
+ outputs="label",
46
+ title="Medical Image Classification",
47
+ description="Upload an image to predict its class with probabilities of top 3 predictions."
48
+ )
49
+
50
+ iface.launch()