taesiri's picture
Update app.py
8cc7679
import gradio as gr
import json
import numpy as np
import datasets
import cv2
import matplotlib.pyplot as plt
sample_dataset1 = datasets.load_dataset("asgaardlab/SampleDataset", split="validation")
sample_dataset2 = datasets.load_dataset("asgaardlab/SampleDataset2", split="validation")
def overlay_with_transparency(background, overlay, alpha_mask):
"""
Overlay a semi-transparent image on top of another image.
Args:
- background: The image on which the overlay will be added.
- overlay: The image to overlay.
- alpha_mask: The mask specifying transparency levels.
"""
return cv2.addWeighted(background, 1, overlay, alpha_mask, 0)
def generate_overlay_image(buggy_image, objects, segmentation_image_rgb, font_scale=0.5, font_color=(0, 255, 255)):
"""
Generate an overlaid image using the provided annotations.
Args:
- buggy_image: The image to be overlaid.
- objects: The JSON object details.
- segmentation_image_rgb: The segmentation image.
- font_scale: Scale factor for the font size.
- font_color: Color for the font in BGR format.
Returns:
- The overlaid image.
"""
overlaid_img = buggy_image.copy()
for obj in objects:
# Get the mask for this object
color = tuple(obj["color"])[:-1]
mask = np.all(segmentation_image_rgb[:, :, :3] == np.array(color), axis=-1).astype(np.float32)
# Create a colored version of the mask using the object's color
colored_mask = np.zeros_like(overlaid_img)
colored_mask[mask == 1] = color
# Overlay the colored mask onto the original image with 0.3 transparency
overlaid_img = overlay_with_transparency(overlaid_img, colored_mask, 0.3)
# Find the center of the mask to place the label
mask_coords = np.argwhere(mask)
y_center, x_center = np.mean(mask_coords, axis=0).astype(int)
# Draw the object's name at the center with specified font size and color
cv2.putText(overlaid_img, obj["labelName"], (x_center, y_center),
cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_color, 1, cv2.LINE_AA)
return overlaid_img
def generate_annotations(selected_dataset, image_index):
bugs_ds = sample_dataset1 if selected_dataset == 'Western Scene' else sample_dataset2
image_index = int(image_index)
objects_json = bugs_ds[image_index]["Objects JSON (Correct)"]
objects = json.loads(objects_json)
segmentation_image_rgb = bugs_ds[image_index]["Segmentation Image (Correct)"]
segmentation_image_rgb = np.array(segmentation_image_rgb)
annotations = []
for obj in objects:
color = tuple(obj["color"])[:-1]
mask = np.all(segmentation_image_rgb[:, :, :3] == np.array(color), axis=-1).astype(np.float32)
annotations.append((mask, obj["labelName"]))
object_count = 0 # bugs_ds[image_index]["Object Count"]
victim_name = bugs_ds[image_index]["Victim Name"]
bug_type = bugs_ds[image_index]["Tag"]
bug_image = bugs_ds[image_index]["Buggy Image"]
correct_image = bugs_ds[image_index]["Correct Image"]
# # Load a single image sample from the first dataset for demonstration
# image_sample = sample_dataset1[0]
# # Extract annotations for this image sample
# objects_json = image_sample["Objects JSON (Correct)"]
# objects = json.loads(objects_json)
# segmentation_image_rgb = np.array(image_sample["Segmentation Image (Correct)"])
# # Generate the overlaid image with custom font size and color
# overlaid_image = generate_overlay_image(np.array(image_sample["Buggy Image"]), objects, segmentation_image_rgb, font_scale=0.7, font_color=(255, 0, 0))
# # Display the overlaid image
# plt.imshow(overlaid_image)
# plt.axis('off')
# plt.show()
overlaid_image = generate_overlay_image(np.array(bugs_ds[image_index]["Buggy Image"]), objects, segmentation_image_rgb)
return (
bug_image,
correct_image,
(bugs_ds[image_index]["Correct Image"], annotations),
overlaid_image,
objects,
object_count,
victim_name,
bug_type,
)
def update_slider(selected_dataset):
dataset = sample_dataset1 if selected_dataset == 'Western Scene' else sample_dataset2
return gr.update(minimum=0, maximum=len(dataset) - 1, step=1)
# Setting up the Gradio interface using blocks API
with gr.Blocks() as demo:
gr.Markdown(
"Enter the image index and click **Submit** to view the segmentation annotations."
)
with gr.Row():
selected_dataset = gr.Dropdown(['Western Scene', 'Viking Village'], label="Dataset")
input_slider = gr.Slider(
minimum=0, maximum=1, step=1, label="Image Index"
)
btn = gr.Button("Visualize")
with gr.Row():
bug_image = gr.Image()
correct_image = gr.Image()
with gr.Row():
seg_img = gr.AnnotatedImage()
overlaid_img = gr.Image()
with gr.Row():
object_count = gr.Number(label="Object Count")
victim_name = gr.Textbox(label="Victim Name")
bug_type = gr.Textbox(label="Bug Type")
with gr.Row():
json_data = gr.JSON()
btn.click(
fn=generate_annotations,
inputs=[selected_dataset, input_slider],
outputs=[bug_image, correct_image, seg_img, overlaid_img, json_data, object_count, victim_name, bug_type],
)
selected_dataset.change(
fn=update_slider,
inputs=[selected_dataset],
outputs=[input_slider]
)
demo.launch()