File size: 5,212 Bytes
3370ff8
 
 
 
 
 
f3c7e6c
3370ff8
 
 
074be6a
 
3370ff8
6978df0
3370ff8
 
 
 
d24bef6
3370ff8
 
 
 
 
 
 
 
8a67e15
3370ff8
 
 
 
7f7eaee
3370ff8
 
 
 
 
 
 
 
7923a1c
 
 
 
 
2a99234
3370ff8
 
 
 
 
 
 
 
 
 
 
 
7f7eaee
3370ff8
 
 
f5a0872
12dc72b
 
4b6f2a1
3370ff8
 
 
 
 
 
 
 
 
6978df0
 
b0c1c95
6978df0
 
 
 
 
 
 
 
 
 
 
e0cd57b
6978df0
 
 
 
 
 
 
a1ac03f
6978df0
 
 
 
 
 
 
 
 
3370ff8
 
 
 
 
6978df0
074be6a
340429d
3370ff8
1e694a8
340429d
1e694a8
 
7f7eaee
 
 
8c96c1c
 
 
3370ff8
 
 
0592643
3370ff8
 
 
 
6e3cc2f
7f7eaee
3370ff8
 
 
 
12dc72b
 
3370ff8
 
9abbc70
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import pathlib
from constants import MODELS_REPO, MODELS_NAMES

import gradio as gr
import torch

from transformers import AutoFeatureExtractor, DetrForObjectDetection
from visualization import visualize_attention_map, visualize_prediction
from style import css, description, title

from PIL import Image



def make_prediction(img, feature_extractor, model):
    inputs = feature_extractor(img, return_tensors="pt")
    outputs = model(**inputs)
    img_size = torch.tensor([tuple(reversed(img.size))])
    processed_outputs = feature_extractor.post_process(outputs, img_size)
    print(outputs.keys())
    return (
        processed_outputs[0],
        outputs["decoder_attentions"],
        outputs["encoder_attentions"],
    )


def detect_objects(model_name, image_input, threshold, display_mask=False, img_input_mask=None):
    feature_extractor = AutoFeatureExtractor.from_pretrained(MODELS_REPO[model_name])

    if "DETR" in model_name:
        model = DetrForObjectDetection.from_pretrained(MODELS_REPO[model_name])
        model_details = "DETR details"

    (
        processed_outputs,
        decoder_attention_map,
        encoder_attention_map,
    ) = make_prediction(image_input, feature_extractor, model)

    viz_img = visualize_prediction(
        pil_img=image_input,
        output_dict=processed_outputs,
        threshold=threshold,
        id2label=model.config.id2label,
        display_mask=display_mask,
        mask=img_input_mask
    )
    decoder_attention_map_img = visualize_attention_map(
        image_input, decoder_attention_map
    )
    encoder_attention_map_img = visualize_attention_map(
        image_input, encoder_attention_map
    )

    return (
        viz_img,
        decoder_attention_map_img,
        encoder_attention_map_img,
        model_details
    )


def set_example_image(example: list):
    print(f"Set example image to: {example[0]}")
    print(f"Set example image mask to: {example[1]}")
    return gr.Image.update(value=example[0]), gr.Image.update(value=example[1])


with gr.Blocks(css=css) as app:
    gr.Markdown(title)

    with gr.Tabs():
        with gr.TabItem("Image upload and detections visualization"):
            with gr.Row():
                with gr.Column():
                    with gr.Row():
                        img_input = gr.Image(type="pil")
                    img_input_mask = gr.Image(type="pil", visible=False)
                    with gr.Row():
                        example_images = gr.Dataset(
                            components=[img_input, img_input_mask],
                            samples=[
                                [path.as_posix(), path.as_posix().replace("_HE", "_mask")]
                                for path in sorted(
                                    pathlib.Path("cd45rb_test_imgs").rglob("*_HE.png")
                                )
                            ],
                            samples_per_page=2,
                        )
                with gr.Column():
                    with gr.Row():
                        options = gr.Dropdown(
                            value=MODELS_NAMES[0],
                            choices=MODELS_NAMES,
                            label="Select an object detection model",
                            show_label=True,
                        )
                    with gr.Row():
                        slider_input = gr.Slider(
                            minimum=0.2, maximum=1, value=0.7, label="Prediction threshold"
                        )
                    with gr.Row():
                        display_mask = gr.Checkbox(
                            label="Display masks", default=False
                        )
                    with gr.Row():
                        detect_button = gr.Button("Detect leukocytes")
            with gr.Row():
                with gr.Column():
                    gr.Markdown(
                        """The selected image with detected bounding boxes by the model"""
                    )
                    img_output_from_upload = gr.Image(shape=(800, 800))
        with gr.TabItem("Attentions visualization"):
            gr.Markdown("""Encoder attentions""")
            with gr.Row():
                encoder_att_map_output = gr.Image(shape=(850, 850))
            gr.Markdown("""Decoder attentions""")
            with gr.Row():
                decoder_att_map_output = gr.Image(shape=(850, 850))
        with gr.TabItem("Model details"):
            with gr.Row():
                model_details = gr.Markdown(""" """)
        with gr.TabItem("Dataset details"):
            with gr.Row():
                gr.Markdown(description)

    detect_button.click(
        detect_objects,
        inputs=[options, img_input, slider_input, display_mask, img_input_mask],
        outputs=[
            img_output_from_upload,
            decoder_att_map_output,
            encoder_att_map_output,
            # cross_att_map_output,
            model_details,
        ],
        queue=True,
    )
    example_images.click(
        fn=set_example_image, inputs=[example_images], outputs=[img_input, img_input_mask],
        show_progress=True
    )

    app.launch(enable_queue=True)