File size: 1,209 Bytes
70d2037
fcb742b
7f51c38
70d2037
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f0665c
70d2037
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import gradio as gr
import os
os.system("pip3 install torch transformers Pillow ensemble_transformers")
import torch
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from PIL import Image

from ensemble_transformers import EnsembleModelForImageClassification
ensemble = EnsembleModelForImageClassification.from_multiple_pretrained(
    "tcvrishank/histo_train_vit", "tcvrishank/histo_train_segformer", "tcvrishank/histo_train_swin"
)

candidate_labels = ["Benign", "InSitu", "Invasive", "Normal"]

def return_prediction(image):
  
    with torch.no_grad():
        outputs = ensemble(image, mean_pool = True)

    logits = outputs.logits[0]
    probs = logits.softmax(dim=-1).numpy()
    scores = probs.tolist()

    result = [
        {"score": score, "label": candidate_label}
        for score, candidate_label in sorted(zip(probs, candidate_labels), key=lambda x: -x[0])
    ]
    result = result[0]
    final = f"This histopathology image shows a cell population that indicates a risk score of {round(result['score'], 2) + 1}. Image suggests high risk of recurrence."
    return final

demo = gr.Interface(fn=return_prediction, inputs="image", outputs="text")
demo.launch()