Spaces:
Build error
Build error
import torch | |
from monai.bundle import ConfigParser | |
import gradio as gr | |
from utils import page_utils | |
parser = ConfigParser() # load configuration files that specify various parameters for running the MONAI workflow. | |
parser.read_config(f="configs/inference.json") # read the config from specified JSON file | |
parser.read_meta(f="configs/metadata.json") # read the metadata from specified JSON file | |
inference = parser.get_parsed_content("inferer") | |
network = parser.get_parsed_content("network_def") | |
preprocess = parser.get_parsed_content("preprocessing") | |
state_dict = torch.load("models/model.pt", map_location=torch.device('cpu')) | |
network.load_state_dict(state_dict, strict=True) # Loads a model’s parameter dictionary | |
class_names = { | |
0: "Other", | |
1: "Inflammatory", | |
2: "Epithelial", | |
3: "Spindle-Shaped", | |
} | |
def classify_image(image_file, label_file): | |
if image_file is None: | |
raise gr.Error("Need a histology image") | |
if label_file is None: | |
raise gr.Error("Need a label image") | |
data = {"image":image_file, "label":label_file} | |
batch = preprocess(data) | |
batch['image'] = batch['image'] | |
network.eval() | |
with torch.no_grad(): | |
pred = inference(batch['image'].unsqueeze(dim=0), network) # expect 4 channels input (3 RGB, 1 Label mask) | |
prob = pred.softmax(-1).detach().cpu().numpy()[0] | |
confidences = {class_names[i]: float(prob[i]) for i in range(len(class_names))} | |
return confidences | |
example_files1 = [ | |
['sample_data/Images/test_11_2_0628.png', | |
'sample_data/Labels/test_11_2_0628.png'], | |
['sample_data/Images/test_9_4_0149.png', | |
'sample_data/Labels/test_9_4_0149.png'], | |
['sample_data/Images/test_12_3_0292.png', | |
'sample_data/Labels/test_12_3_0292.png'], | |
['sample_data/Images/test_9_4_0019.png', | |
'sample_data/Labels/test_9_4_0019.png'] | |
] | |
example_files2 = [ | |
['sample_data/Images/test_14_3_0433.png', | |
'sample_data/Labels/test_14_3_0433.png'], | |
['sample_data/Images/test_14_4_0544.png', | |
'sample_data/Labels/test_14_4_0544.png'], | |
['sample_data/Images/train_1_1_0095.png', | |
'sample_data/Labels/train_1_1_0095.png'], | |
['sample_data/Images/train_1_3_0020.png', | |
'sample_data/Labels/train_1_3_0020.png'], | |
] | |
with open('index.html', encoding='utf-8') as file: | |
html_content = file.read() | |
with gr.Blocks(theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR, secondary_hue=page_utils.KALBE_THEME_COLOR).set( | |
button_primary_background_fill='*primary_600', | |
button_primary_background_fill_hover='*primary_500', | |
button_primary_text_color='white', | |
)) as app: | |
gr.HTML(html_content) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
inp_img = gr.Image(type="filepath", image_mode="RGB", label="Histology Image", show_label=True) | |
label_img = gr.Image(type="filepath", image_mode="L", label="Label Image", show_label=True) | |
with gr.Row(): | |
clear_btn = gr.Button(value="Clear") | |
process_btn = gr.Button(value="Process", variant="primary") | |
out_txt = gr.Label(label="Probabilities", num_top_classes=4) | |
process_btn.click(fn=classify_image, inputs=[inp_img, label_img], outputs=out_txt) | |
clear_btn.click(lambda:( | |
gr.update(value=None), | |
gr.update(value=None), | |
gr.update(value=None) | |
), | |
inputs=None, | |
outputs=[inp_img, label_img,out_txt] | |
) | |
gr.Markdown("## Image Examples") | |
with gr.Row(): | |
for file in example_files1: | |
gr.Examples( | |
[file], inputs=[inp_img, label_img] | |
) | |
with gr.Row(): | |
for file in example_files2: | |
gr.Examples( | |
[file], inputs=[inp_img, label_img] | |
) | |
app.launch() |