Spaces:
Runtime error
Runtime error
import pathlib | |
import gradio as gr | |
import torch | |
from transformers import AutoFeatureExtractor, DetrForObjectDetection | |
from visualization import visualize_attention_map, visualize_prediction | |
from style import css, description, title | |
from PIL import Image | |
def make_prediction(img, feature_extractor, model): | |
inputs = feature_extractor(img, return_tensors="pt") | |
outputs = model(**inputs) | |
img_size = torch.tensor([tuple(reversed(img.size))]) | |
processed_outputs = feature_extractor.post_process(outputs, img_size) | |
print(outputs.keys()) | |
return ( | |
processed_outputs[0], | |
outputs["decoder_attentions"], | |
outputs["encoder_attentions"], | |
) | |
def construct_model_name( | |
experiment_type, | |
convbase, | |
attention_heads_num, | |
enc_dec_layers, | |
ffn_dim, | |
act_func, | |
d_model, | |
dilation=None | |
): | |
base = "polejowska/" | |
if convbase == "RESNET-50": | |
base += "detr-r50" | |
elif convbase == "RESNET-101": | |
if enc_dec_layers == 6: | |
return "polejowska/detr-r101-official" | |
elif enc_dec_layers == 4: | |
return "polejowska/detr-r101-cd45rb-8ah-4l" | |
elif enc_dec_layers == 12: | |
return "polejowska/detr-r101-cd45rb-8ah-12l" | |
base += "-cd45rb" | |
base += f"-{attention_heads_num}ah" | |
base += f"-{enc_dec_layers}l" | |
if attention_heads_num == 1: | |
base += "-corrected" | |
if d_model != 256: | |
base += f"-{d_model}d" | |
if ffn_dim == 1024: | |
base += "-1024ffn" | |
elif ffn_dim == 4096: | |
base += "-4096ffn-correcetd" | |
if act_func == "GeLU": | |
base += "-gelu-corrected" | |
if dilation == "True": | |
base += "-dilation-corrected" | |
return base | |
def detect_objects( | |
experiment_type, | |
convbase, | |
attention_heads_num, | |
enc_dec_layers, | |
ffn_dim, | |
act_func, | |
d_model, | |
dilation, | |
image_input, | |
threshold=0.7, | |
display_mask=False, | |
img_input_mask=None | |
): | |
if experiment_type in ["Parameters verification", "Reproducability check (1)", "Reproducability check (2)", "Reproducability check (3)", "Reproducability check (4)"]: | |
if experiment_type == "Parameters verification": | |
model_repo = construct_model_name(experiment_type, convbase, attention_heads_num, enc_dec_layers, ffn_dim, act_func, d_model, dilation) | |
elif experiment_type == "Reproducability check (1)": | |
model_repo = "polejowska/detr-r50-cd45rb-all-2ah" | |
elif experiment_type == "Reproducability check (2)": | |
model_repo = "polejowska/detr-r50-cd45rb-all-4ah" | |
elif experiment_type == "Reproducability check (3)": | |
model_repo = "polejowska/detr-r50-cd45rb-all-8ah" | |
elif experiment_type == "Reproducability check (4)": | |
model_repo = "polejowska/detr-r50-cd45rb-all-16ah" | |
model = DetrForObjectDetection.from_pretrained(model_repo) | |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_repo) | |
( | |
processed_outputs, | |
decoder_attention_map, | |
encoder_attention_map, | |
) = make_prediction(image_input, feature_extractor, model) | |
viz_img = visualize_prediction( | |
pil_img=image_input, | |
output_dict=processed_outputs, | |
threshold=threshold, | |
id2label=model.config.id2label, | |
display_mask=display_mask, | |
mask=img_input_mask | |
) | |
decoder_attention_map_img = visualize_attention_map( | |
image_input, decoder_attention_map | |
) | |
encoder_attention_map_img = visualize_attention_map( | |
image_input, encoder_attention_map | |
) | |
return ( | |
viz_img, | |
decoder_attention_map_img, | |
encoder_attention_map_img, | |
) | |
def set_example_image(example: list): | |
return gr.Image.update(value=example[0]), gr.Image.update(value=example[1]) | |
with gr.Blocks(css=css) as app: | |
gr.Markdown(title) | |
with gr.Tabs(): | |
with gr.TabItem("Image upload and detections visualization"): | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
experiment_type = gr.Dropdown( | |
value="Parameters verification", | |
choices=[ | |
"Parameters verification", | |
"Reproducability check (1)", | |
"Reproducability check (2)", | |
"Reproducability check (3)", | |
"Reproducability check (4)", | |
], | |
label="Select an experiment type", | |
show_label=True, | |
) | |
with gr.Row(): | |
convbase= gr.Dropdown( | |
value="RESNET-50", | |
choices=[ | |
"RESNET-50", | |
"RESNET-101", | |
], | |
label="Select a base model for convolution part", | |
show_label=True, | |
) | |
with gr.Row(): | |
attention_heads_num = gr.Dropdown( | |
value=8, | |
choices=[1, 2, 4, 8, 16], | |
label="The number of attention heads in encoder and decoder", | |
show_label=True, | |
) | |
with gr.Row(): | |
enc_dec_layers = gr.Dropdown( | |
value=6, | |
choices=[4, 6, 12], | |
label="The number of layers in encoder and decoder", | |
show_label=True, | |
) | |
with gr.Row(): | |
ffn_dim = gr.Dropdown( | |
value=2048, | |
choices=[1024, 2048, 4096], | |
label="Select FFN dimension", | |
show_label=True, | |
) | |
with gr.Row(): | |
act_func= gr.Dropdown( | |
value="ReLU", | |
choices=[ | |
"ReLU", | |
"GeLU", | |
], | |
label="Select an activation function", | |
show_label=True, | |
) | |
with gr.Row(): | |
d_model= gr.Dropdown( | |
value=256, | |
choices=[128, 256, 512], | |
label="Select a hidden size", | |
show_label=True, | |
) | |
with gr.Row(): | |
dilation= gr.Dropdown( | |
value="False", | |
choices=[ | |
"True", | |
"False", | |
], | |
label="Use dilation", | |
show_label=True, | |
) | |
with gr.Row(): | |
slider_input = gr.Slider( | |
minimum=0.2, maximum=1, value=0.7, label="Prediction threshold" | |
) | |
with gr.Column(): | |
with gr.Row(): | |
img_input = gr.Image(type="pil") | |
img_input_mask = gr.Image(type="pil", visible=False) | |
with gr.Row(): | |
example_images = gr.Dataset( | |
components=[img_input, img_input_mask], | |
samples=[ | |
[path.as_posix(), path.as_posix().replace("_HE", "_mask")] | |
for path in sorted( | |
pathlib.Path("cd45rb_test_imgs").rglob("*_HE.png") | |
) | |
], | |
samples_per_page=2, | |
) | |
with gr.Row(): | |
display_mask = gr.Checkbox( | |
label="Display masks", default=False | |
) | |
with gr.Row(): | |
detect_button = gr.Button("Detect leukocytes") | |
with gr.Row(): | |
with gr.Column(): | |
img_output_from_upload = gr.Image(shape=(900, 900)) | |
with gr.TabItem("Attentions visualization"): | |
gr.Markdown("""Encoder attentions""") | |
with gr.Row(): | |
encoder_att_map_output = gr.Image(shape=(850, 850)) | |
gr.Markdown("""Decoder attentions""") | |
with gr.Row(): | |
decoder_att_map_output = gr.Image(shape=(850, 850)) | |
with gr.TabItem("Dataset details"): | |
with gr.Row(): | |
gr.Markdown(description) | |
detect_button.click( | |
detect_objects, | |
inputs=[ | |
experiment_type, | |
convbase, | |
attention_heads_num, | |
enc_dec_layers, | |
ffn_dim, | |
act_func, | |
d_model, | |
dilation, | |
img_input, | |
slider_input, | |
display_mask, | |
img_input_mask | |
], | |
outputs=[ | |
img_output_from_upload, | |
decoder_att_map_output, | |
encoder_att_map_output, | |
], | |
queue=True, | |
) | |
example_images.click( | |
fn=set_example_image, inputs=[example_images], outputs=[img_input, img_input_mask], | |
show_progress=True | |
) | |
app.launch(enable_queue=True) |