Spaces:
Runtime error
Runtime error
import pathlib | |
from constants import MODELS_REPO, MODELS_NAMES | |
import gradio as gr | |
import torch | |
from transformers import AutoFeatureExtractor, DetrForObjectDetection | |
from visualization import visualize_attention_map, visualize_prediction | |
from style import css, description, title | |
from PIL import Image | |
def make_prediction(img, feature_extractor, model): | |
inputs = feature_extractor(img, return_tensors="pt") | |
outputs = model(**inputs) | |
img_size = torch.tensor([tuple(reversed(img.size))]) | |
processed_outputs = feature_extractor.post_process(outputs, img_size) | |
print(outputs.keys()) | |
return ( | |
processed_outputs[0], | |
outputs["decoder_attentions"], | |
outputs["encoder_attentions"], | |
) | |
def detect_objects(model_name, image_input, threshold, display_mask=False, img_input_mask=None): | |
feature_extractor = AutoFeatureExtractor.from_pretrained(MODELS_REPO[model_name]) | |
if "DETR" in model_name: | |
model = DetrForObjectDetection.from_pretrained(MODELS_REPO[model_name]) | |
model_details = "DETR details" | |
( | |
processed_outputs, | |
decoder_attention_map, | |
encoder_attention_map, | |
) = make_prediction(image_input, feature_extractor, model) | |
viz_img = visualize_prediction( | |
pil_img=image_input, | |
output_dict=processed_outputs, | |
threshold=threshold, | |
id2label=model.config.id2label, | |
display_mask=display_mask, | |
mask=img_input_mask | |
) | |
decoder_attention_map_img = visualize_attention_map( | |
image_input, decoder_attention_map | |
) | |
encoder_attention_map_img = visualize_attention_map( | |
image_input, encoder_attention_map | |
) | |
return ( | |
viz_img, | |
decoder_attention_map_img, | |
encoder_attention_map_img, | |
model_details | |
) | |
def set_example_image(example: list): | |
print(f"Set example image to: {example[0]}") | |
print(f"Set example image mask to: {example[1]}") | |
return gr.Image.update(value=example[0]), gr.Image.update(value=example[1]) | |
with gr.Blocks(css=css) as app: | |
gr.Markdown(title) | |
with gr.Tabs(): | |
with gr.TabItem("Image upload and detections visualization"): | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
img_input = gr.Image(type="pil") | |
img_input_mask = gr.Image(type="pil", visible=False) | |
with gr.Row(): | |
example_images = gr.Dataset( | |
components=[img_input, img_input_mask], | |
samples=[ | |
[path.as_posix(), path.as_posix().replace("_HE", "_mask")] | |
for path in sorted( | |
pathlib.Path("cd45rb_test_imgs").rglob("*_HE.png") | |
) | |
], | |
samples_per_page=2, | |
) | |
with gr.Column(): | |
with gr.Row(): | |
options = gr.Dropdown( | |
value=MODELS_NAMES[0], | |
choices=MODELS_NAMES, | |
label="Select an object detection model", | |
show_label=True, | |
) | |
with gr.Row(): | |
slider_input = gr.Slider( | |
minimum=0.2, maximum=1, value=0.7, label="Prediction threshold" | |
) | |
with gr.Row(): | |
display_mask = gr.Checkbox( | |
label="Display masks", default=False | |
) | |
with gr.Row(): | |
detect_button = gr.Button("Detect leukocytes") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown( | |
"""The selected image with detected bounding boxes by the model""" | |
) | |
img_output_from_upload = gr.Image(shape=(800, 800)) | |
with gr.TabItem("Attentions visualization"): | |
gr.Markdown("""Encoder attentions""") | |
with gr.Row(): | |
encoder_att_map_output = gr.Image(shape=(850, 850)) | |
gr.Markdown("""Decoder attentions""") | |
with gr.Row(): | |
decoder_att_map_output = gr.Image(shape=(850, 850)) | |
with gr.TabItem("Model details"): | |
with gr.Row(): | |
model_details = gr.Markdown(""" """) | |
with gr.TabItem("Dataset details"): | |
with gr.Row(): | |
gr.Markdown(description) | |
detect_button.click( | |
detect_objects, | |
inputs=[options, img_input, slider_input, display_mask, img_input_mask], | |
outputs=[ | |
img_output_from_upload, | |
decoder_att_map_output, | |
encoder_att_map_output, | |
# cross_att_map_output, | |
model_details, | |
], | |
queue=True, | |
) | |
example_images.click( | |
fn=set_example_image, inputs=[example_images], outputs=[img_input, img_input_mask], | |
show_progress=True | |
) | |
app.launch(enable_queue=True) | |