polejowska's picture
Update app.py
c81fbdf
import pathlib
import gradio as gr
import torch
from transformers import AutoFeatureExtractor, DetrForObjectDetection
from visualization import visualize_attention_map, visualize_prediction
from style import css, description, title
from PIL import Image
def make_prediction(img, feature_extractor, model):
inputs = feature_extractor(img, return_tensors="pt")
outputs = model(**inputs)
img_size = torch.tensor([tuple(reversed(img.size))])
processed_outputs = feature_extractor.post_process(outputs, img_size)
print(outputs.keys())
return (
processed_outputs[0],
outputs["decoder_attentions"],
outputs["encoder_attentions"],
)
def construct_model_name(
experiment_type,
convbase,
attention_heads_num,
enc_dec_layers,
ffn_dim,
act_func,
d_model,
dilation=None
):
base = "polejowska/"
if convbase == "RESNET-50":
base += "detr-r50"
elif convbase == "RESNET-101":
if enc_dec_layers == 6:
return "polejowska/detr-r101-official"
elif enc_dec_layers == 4:
return "polejowska/detr-r101-cd45rb-8ah-4l"
elif enc_dec_layers == 12:
return "polejowska/detr-r101-cd45rb-8ah-12l"
base += "-cd45rb"
base += f"-{attention_heads_num}ah"
base += f"-{enc_dec_layers}l"
if attention_heads_num == 1:
base += "-corrected"
if d_model != 256:
base += f"-{d_model}d"
if ffn_dim == 1024:
base += "-1024ffn"
elif ffn_dim == 4096:
base += "-4096ffn-correcetd"
if act_func == "GeLU":
base += "-gelu-corrected"
if dilation == "True":
base += "-dilation-corrected"
return base
def detect_objects(
experiment_type,
convbase,
attention_heads_num,
enc_dec_layers,
ffn_dim,
act_func,
d_model,
dilation,
image_input,
threshold=0.7,
display_mask=False,
img_input_mask=None
):
if experiment_type in ["Parameters verification", "Reproducability check (1)", "Reproducability check (2)", "Reproducability check (3)", "Reproducability check (4)"]:
if experiment_type == "Parameters verification":
model_repo = construct_model_name(experiment_type, convbase, attention_heads_num, enc_dec_layers, ffn_dim, act_func, d_model, dilation)
elif experiment_type == "Reproducability check (1)":
model_repo = "polejowska/detr-r50-cd45rb-all-2ah"
elif experiment_type == "Reproducability check (2)":
model_repo = "polejowska/detr-r50-cd45rb-all-4ah"
elif experiment_type == "Reproducability check (3)":
model_repo = "polejowska/detr-r50-cd45rb-all-8ah"
elif experiment_type == "Reproducability check (4)":
model_repo = "polejowska/detr-r50-cd45rb-all-16ah"
model = DetrForObjectDetection.from_pretrained(model_repo)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_repo)
(
processed_outputs,
decoder_attention_map,
encoder_attention_map,
) = make_prediction(image_input, feature_extractor, model)
viz_img = visualize_prediction(
pil_img=image_input,
output_dict=processed_outputs,
threshold=threshold,
id2label=model.config.id2label,
display_mask=display_mask,
mask=img_input_mask
)
decoder_attention_map_img = visualize_attention_map(
image_input, decoder_attention_map
)
encoder_attention_map_img = visualize_attention_map(
image_input, encoder_attention_map
)
return (
viz_img,
decoder_attention_map_img,
encoder_attention_map_img,
)
def set_example_image(example: list):
return gr.Image.update(value=example[0]), gr.Image.update(value=example[1])
with gr.Blocks(css=css) as app:
gr.Markdown(title)
with gr.Tabs():
with gr.TabItem("Image upload and detections visualization"):
with gr.Row():
with gr.Column():
with gr.Row():
experiment_type = gr.Dropdown(
value="Parameters verification",
choices=[
"Parameters verification",
"Reproducability check (1)",
"Reproducability check (2)",
"Reproducability check (3)",
"Reproducability check (4)",
],
label="Select an experiment type",
show_label=True,
)
with gr.Row():
convbase= gr.Dropdown(
value="RESNET-50",
choices=[
"RESNET-50",
"RESNET-101",
],
label="Select a base model for convolution part",
show_label=True,
)
with gr.Row():
attention_heads_num = gr.Dropdown(
value=8,
choices=[1, 2, 4, 8, 16],
label="The number of attention heads in encoder and decoder",
show_label=True,
)
with gr.Row():
enc_dec_layers = gr.Dropdown(
value=6,
choices=[4, 6, 12],
label="The number of layers in encoder and decoder",
show_label=True,
)
with gr.Row():
ffn_dim = gr.Dropdown(
value=2048,
choices=[1024, 2048, 4096],
label="Select FFN dimension",
show_label=True,
)
with gr.Row():
act_func= gr.Dropdown(
value="ReLU",
choices=[
"ReLU",
"GeLU",
],
label="Select an activation function",
show_label=True,
)
with gr.Row():
d_model= gr.Dropdown(
value=256,
choices=[128, 256, 512],
label="Select a hidden size",
show_label=True,
)
with gr.Row():
dilation= gr.Dropdown(
value="False",
choices=[
"True",
"False",
],
label="Use dilation",
show_label=True,
)
with gr.Row():
slider_input = gr.Slider(
minimum=0.2, maximum=1, value=0.7, label="Prediction threshold"
)
with gr.Column():
with gr.Row():
img_input = gr.Image(type="pil")
img_input_mask = gr.Image(type="pil", visible=False)
with gr.Row():
example_images = gr.Dataset(
components=[img_input, img_input_mask],
samples=[
[path.as_posix(), path.as_posix().replace("_HE", "_mask")]
for path in sorted(
pathlib.Path("cd45rb_test_imgs").rglob("*_HE.png")
)
],
samples_per_page=2,
)
with gr.Row():
display_mask = gr.Checkbox(
label="Display masks", default=False
)
with gr.Row():
detect_button = gr.Button("Detect leukocytes")
with gr.Row():
with gr.Column():
img_output_from_upload = gr.Image(shape=(900, 900))
with gr.TabItem("Attentions visualization"):
gr.Markdown("""Encoder attentions""")
with gr.Row():
encoder_att_map_output = gr.Image(shape=(850, 850))
gr.Markdown("""Decoder attentions""")
with gr.Row():
decoder_att_map_output = gr.Image(shape=(850, 850))
with gr.TabItem("Dataset details"):
with gr.Row():
gr.Markdown(description)
detect_button.click(
detect_objects,
inputs=[
experiment_type,
convbase,
attention_heads_num,
enc_dec_layers,
ffn_dim,
act_func,
d_model,
dilation,
img_input,
slider_input,
display_mask,
img_input_mask
],
outputs=[
img_output_from_upload,
decoder_att_map_output,
encoder_att_map_output,
],
queue=True,
)
example_images.click(
fn=set_example_image, inputs=[example_images], outputs=[img_input, img_input_mask],
show_progress=True
)
app.launch(enable_queue=True)