tcvrishank commited on
Commit
70d2037
1 Parent(s): 68cf9ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -0
app.py CHANGED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
4
+ from PIL import Image
5
+
6
+ from ensemble_transformers import EnsembleModelForImageClassification
7
+ ensemble = EnsembleModelForImageClassification.from_multiple_pretrained(
8
+ "tcvrishank/histo_train_vit", "tcvrishank/histo_train_segformer", "tcvrishank/histo_train_swin"
9
+ )
10
+
11
+ candidate_labels = ["Benign", "InSitu", "Invasive", "Normal"]
12
+
13
+ def return_prediction(image):
14
+
15
+ with torch.no_grad():
16
+ outputs = ensemble(image, mean_pool = True)
17
+
18
+ logits = outputs.logits[0]
19
+ probs = logits.softmax(dim=-1).numpy()
20
+ scores = probs.tolist()
21
+
22
+ result = [
23
+ {"score": score, "label": candidate_label}
24
+ for score, candidate_label in sorted(zip(probs, candidate_labels), key=lambda x: -x[0])
25
+ ]
26
+ result = result[0]
27
+ final = f"This histopathology image shows cells that are {round(result['score'] * 100, 2)}% certain to be {result['label']}."
28
+ return final
29
+
30
+ demo = gr.Interface(fn=return_prediction, inputs="image", outputs="text")
31
+ demo.launch()