import cv2 import inspect import numpy as np import albumentations as A import gradio as gr from typing import get_type_hints from PIL import Image, ImageDraw import base64 import io from PIL import Image from functools import wraps from copy import deepcopy DEFAULT_TRANSFORM = "Rotate" DEFAULT_IMAGE = "images/doctor.webp" DEFAULT_IMAGE_HEIGHT = 400 DEFAULT_IMAGE_WIDTH = 600 DEFAULT_BOXES = [[265, 121, 326, 177], [192, 169, 401, 395]] DEFAULT_KEYPOINTS = [ [(x_min + x_max) // 2, (y_min + y_max) // 2] for x_min, y_min, x_max, y_max in DEFAULT_BOXES ] CORENERS = [[[x_min, y_min], [x_max, y_max], [x_min, y_max], [x_max, y_min]] for x_min, y_min, x_max, y_max in DEFAULT_BOXES] for bbox_corners in CORENERS: DEFAULT_KEYPOINTS += bbox_corners BASE64_DEFAULT_MASKS = [ { "label": "Coverall", # light green color "color": (144, 238, 144), "mask": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAAAAABXXkFEAAAF+ElEQVR4nO3dwXLjNhBFUSg1///LziLj1Iwt26KkFhuvz9kkWVhFAJcAqXGStQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACe5+3sCyCUsq745+wLSKCsz4T1DMr6RFiUENZT2LI+EhYlhPWnt7t3nvt/MtSvsy+gkcfaeFuXJ11HBJPx7r+s7piPP3o0m/9zFP729tdfjv/gnT8dy1G41npeEc7Dd3astR7q6uOP2rT+4wb7mMLBGfkckildyyw8HMa1HWr8pK7pz1hF55YnrdE315dVHZmTb9IcPLVr7Oi/36oOTMpPe97Q+R16FL7wzW3sqThw2DdkdfOs3JTowDmeN+gbN6tbp+XJHxdk1pAPnIE3TczNnzdrmteaNeJjj1Y3zMyRD5w00WuNGu/RR/afpubZn5dlymjvexH8Znbu+cApk73WlLE+8v3C1Rm69wNnTPdaM0ba6hcOJkz4WhPG2SqrtSZM+Vr5o2yX1Vr5k75W+hhbZrVW+rSvlT3Ctlmt7Hlfa0X/anLnrnpf3DPkhtV86Zpf3sNyw+ouvKzYsPqvW/8rfERsWJxLWOeJ3rKERQlhnSh5y0oNK3nNtpAa1h6C8w8NK3jFNpEZlq5OlxnWNnLvgMiwcpdrH5FhbST2HkgMK3axdpIY1lZS7wJhUSIwrM32gM0u91aBYdFBXlihO8Bu4sLSVQ9xYe0n81ZICytzlTaUFtaOIm8GYVEiLKzIm39LYWHRhbAoIawGEg9wYVEiK6zEW39TWWHtKvCGEBYlhEUJYbWQdxYKixLCokRUWHkHyr6iwqIPYVEiKSwnYSNJYdGIsCghLEoIixLCokRQWF4KOwkKa2txd4WwmkgrS1iUEFYXYVuWsCiRE1bYHb+7nLBoRViUEBYlhEUJYVFCWG1kvdYKixIxYWXd7/tLCUtXzaSEFeBy9gU8VUhYNqxuMsLSVTsZYdFORFgZG1bGKN5FhEU/wqJEQlhZZ0iIhLBoSFiUEBYlhEUJYVEiIaysP70NkRAWDQmLEsLqI+qLXmFRQliUEBYlhEWJX2dfwK4ua4U9bj+XsA66/P0P0vqCo/CQy8dv+X3r/wVhHXElI2VdJ6wDiiOKalRYlBDWo6L2mecRVhtZhQrrUb5wuEpYlBBWF1knYUZYYWsSISKsM8vyiHVdRlivYWM8ICQsa95NSFinleUk/EJKWDQjrCbSDvOYsM5ZGCfhV2LCohdhPcKG9SVh3e5TRk/sKu0RKyis1y+N/eobOWG9nK6+ExTWa7esN119y79XeBdV/URYd5DVz4R1wNtlqepGUa+5+6551DKstaIe3ulEWJQQFiWERQlhUSIqrLx3q31FhUUfwqJEVljOwjaywqINYXUQuNOGhRW4QpsKC4su0sKyZTWRFtaWZe14zT+JCytylTYUuQyb/cJf5Brk7Vhrt5Xa62pvFRnWVmu107UekBlW6mptJDQsZZ0tNaxtpN4BqeN6fzW8/PH3LaUuQOq4PuqaVuz8OwopEXvHfNRyywqe/eCh/aVjV9Fz7z8KcpborIR1jvCo1hLWGQZk5a3wBCO6EtbLzehKWNQQFiWERQlhUUJYlBAWJYRFCWG9Wsc/Di8gLEoMCWvINtHIkLB4NWFRQliUENbLzXjeExYlhEUJYVFiRlgzHmtamREWLycsSgjr9UYczMKihLAoISxKCIsSwqKEsCghLEoI6wQTvsgSFiWERQlhUUJYZxjwkCUsSgiLEiPCGnDytDMirH7yU58QVv4qNjQhLE4gLEoIixLCooSwzhH/QjEhrMuQ/31NKxPCktYJBs14s9MnfOZn7FhrLdvWaw0KS1qvNG+qu5yI4TMfPryrmqSVPfWjjsLfnIgvMHWOO+xa0XM/ccdaK3xRO5gaFsWERQlhUWJqWB0e3qNNDauD6LiFRYmhYUVvFi3MDEtX5UaGpat6I8Oi3r8KSpCuwVpGmQAAAABJRU5ErkJggg==", }, { "label": "Mask", # light blue color "color": (173, 216, 230), "mask": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAAAAABXXkFEAAAB4ElEQVR4nO3csQ6CMBSG0avv/864OFhoobW9UeM5i4ML+fOFkmCMAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIbcPn0BX257ftppkMEqtuY35uplqYN2VhFhsU5m2rnIKsJmXYy00xFWRBjuin1KvV2F6c7dP30Bv2ugwT8krPcp64SwJmzSahJWYbQUZbUIqzD8QK6sBmEVdLKKsEghLFIIa5LDs05YpBAWKYQ1y1lYJSxSCIsUwiKFsF55XlpGWLP83q9KWKQQ1qt37j6OzyphkUJYpBBWwZP4KsIqKWsRYZFCWDtuWWsIixTC2nPLWkJYB8paQVhHY2XpsMosdV0vaozXZpsG/+s3x0BN9bQM1sdOJ45pmauXpU4VadlqgLEubRF2AgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAGCZBwKLGEVAl/J/AAAAAElFTkSuQmCC", }, ] # Get all the transforms from the albumentations library transforms_map = { name: cls for name, cls in vars(A).items() if inspect.isclass(cls) and issubclass(cls, (A.DualTransform, A.ImageOnlyTransform)) } transforms_map.pop("DualTransform", None) transforms_map.pop("ImageOnlyTransform", None) transforms_keys = list(sorted(transforms_map.keys())) # Decode the masks for mask in BASE64_DEFAULT_MASKS: mask["mask"] = np.array(Image.open(io.BytesIO(base64.b64decode(mask["mask"]))).convert("L")) def run_with_retry(compose): @wraps(compose) def wrapper(*args, **kwargs): processors = deepcopy(compose.processors) for _ in range(4): try: result = compose(*args, **kwargs) break except NotImplementedError as e: print(f"Caught NotImplementedError: {e}") if "bbox" in str(e): kwargs.pop("bboxes", None) kwargs.pop("category_id", None) compose.processors.pop("bboxes") if "keypoint" in str(e): kwargs.pop("keypoints", None) compose.processors.pop("keypoints") if "mask" in str(e): kwargs.pop("mask", None) except Exception as e: compose.processors = processors raise e compose.processors = processors return result return wrapper def draw_boxes(image, boxes, color=(255, 0, 0), thickness=2) -> np.ndarray: """Draw boxes with PIL.""" pil_image = Image.fromarray(image) draw = ImageDraw.Draw(pil_image) for box in boxes: x_min, y_min, x_max, y_max = box draw.rectangle([x_min, y_min, x_max, y_max], outline=color, width=thickness) return np.array(pil_image) def draw_keypoints(image, keypoints, color=(255, 0, 0), radius=2): """Draw keypoints with PIL.""" pil_image = Image.fromarray(image) draw = ImageDraw.Draw(pil_image) for keypoint in keypoints: x, y = keypoint draw.ellipse([x - radius, y - radius, x + radius, y + radius], fill=color) return np.array(pil_image) def get_rgb_mask(masks): """Get the RGB mask from the binary mask.""" rgb_mask = np.zeros((DEFAULT_IMAGE_HEIGHT, DEFAULT_IMAGE_WIDTH, 3), dtype=np.uint8) for data in masks: mask = data["mask"] rgb_mask[mask > 0] = np.array(data["color"]) return rgb_mask def draw_mask(image, mask): """Draw the mask on the image.""" image_with_mask = cv2.addWeighted(image, 0.5, mask, 0.5, 0) return image_with_mask def draw_not_implemented_image(image): """Draw the image with a text. In the middle.""" pil_image = Image.fromarray(image) draw = ImageDraw.Draw(pil_image) # align in the centerm, and make bigger font text = "NOT IMPLEMETED FOR THIS TYPE OF ANNOTATIONS" length = draw.textlength(text) draw.text( (DEFAULT_IMAGE_WIDTH // 2 - length // 2, DEFAULT_IMAGE_HEIGHT // 2), text, fill=(255, 0, 0), align="center", ) return np.array(pil_image) def get_formatted_signature(function_or_class, indentation=4): signature = inspect.signature(function_or_class) type_hints = get_type_hints(function_or_class) args = [] for param in signature.parameters.values(): if param.name == "p": str_param = "p=1.0," elif param.default == inspect.Parameter.empty: str_param = f"{param.name}=," else: if isinstance(param.default, str): str_param = f'{param.name}="{param.default}",' else: str_param = f"{param.name}={param.default}," annotation = type_hints.get(param.name, param.annotation) if isinstance(param.annotation, type): str_param += f" # {param.annotation.__name__}" else: str_annotation = str(annotation).replace("typing.", "") str_param += f" # {str_annotation}" str_param = "\n" + " " * indentation + str_param args.append(str_param) result = "(" + "".join(args) + "\n" + " " * (indentation - 4) + ")" return result def update(image, code): try: augmentation = eval(code) compose = A.Compose( [augmentation], bbox_params=A.BboxParams(format="pascal_voc", label_fields=["category_id"]), keypoint_params=A.KeypointParams(format="xy"), additional_targets={"not_implemented_image": "image"} ) compose = run_with_retry(compose) # to prevent NotImplementedError keypoints = DEFAULT_KEYPOINTS bboxes = DEFAULT_BOXES mask = get_rgb_mask(BASE64_DEFAULT_MASKS) augmented = compose( image=image, not_implemented_image=draw_not_implemented_image(image), mask=mask, keypoints=keypoints, bboxes=bboxes, category_id=range(len(bboxes)), ) image = augmented["image"] not_implemented_image = augmented["not_implemented_image"] mask = augmented.get("mask", None) bboxes = augmented.get("bboxes", None) keypoints = augmented.get("keypoints", None) image_with_mask = draw_mask(image.copy(), mask) if mask is not None else not_implemented_image image_with_bboxes = draw_boxes(image.copy(), bboxes) if bboxes is not None else not_implemented_image image_with_keypoints = draw_keypoints(image.copy(), keypoints) if keypoints is not None else not_implemented_image return [ (image_with_mask, "Mask"), (image_with_bboxes, "Boxes"), (image_with_keypoints, "Keypoints"), ] except Exception as e: raise e def update_image_info(image): h, w = image.shape[:2] dtype = image.dtype max_, min_ = image.max(), image.min() return f"Image info:\n\t - shape: {h}x{w}\n\t - dtype: {dtype}\n\t - min/max: {min_}/{max_}" def get_formatted_transform(transform_number): transform_name = transforms_keys[transform_number] transform = transforms_map[transform_name] return f"A.{transform.__name__}{get_formatted_signature(transform)}" def get_formatted_transform_docs(transform_number): transform_name = transforms_keys[transform_number] transform = transforms_map[transform_name] return transform.__doc__.strip("\n") with gr.Blocks() as demo: with gr.Row(): with gr.Column(): with gr.Group(): select = gr.Dropdown( label="Select a transformation", choices=transforms_keys, value=DEFAULT_TRANSFORM, type="index", interactive=True, ) with gr.Accordion("Documentation", open=False): docs = gr.TextArea( get_formatted_transform_docs( transforms_keys.index(DEFAULT_TRANSFORM) ), show_label=False, interactive=False, ) code = gr.Code( language="python", value=get_formatted_transform(transforms_keys.index(DEFAULT_TRANSFORM)), interactive=True, lines=5, ) info = gr.TextArea( value=f"Image size: {DEFAULT_IMAGE_HEIGHT} x {DEFAULT_IMAGE_WIDTH} (height x width)", show_label=False, lines=1, max_lines=1, ) button = gr.Button("Run") image = gr.Image( value=DEFAULT_IMAGE, type="numpy", height=500, width=300, sources=[], ) with gr.Row(): augmented_image = gr.Gallery(rows=1, columns=3) # augmented_image = gr.Image(type="numpy", height=300, width=300) #image.upload(fn=update_image_info, inputs=[image], outputs=[info]) select.change(fn=get_formatted_transform, inputs=[select], outputs=[code]) button.click(fn=update, inputs=[image, code], outputs=[augmented_image]) if __name__ == "__main__": demo.launch()