qubvel-hf HF staff commited on
Commit
212fcfb
1 Parent(s): d7f5bfc
Files changed (5) hide show
  1. .gitignore +5 -0
  2. app.py +252 -0
  3. images/doctor.webp +0 -0
  4. packages.txt +1 -0
  5. requirements.txt +2 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+
5
+ .venv/
app.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import inspect
3
+ import numpy as np
4
+ import albumentations as A
5
+ import gradio as gr
6
+ from typing import get_type_hints
7
+ from PIL import Image, ImageDraw
8
+ import base64
9
+ import io
10
+ from PIL import Image
11
+ from functools import wraps
12
+
13
+
14
+ DEFAULT_TRANSFORM = "Rotate"
15
+
16
+ DEFAULT_IMAGE = "images/doctor.webp"
17
+ DEFAULT_IMAGE_HEIGHT = 400
18
+ DEFAULT_IMAGE_WIDTH = 600
19
+ DEFAULT_BOXES = [[265, 121, 326, 177], [192, 169, 401, 395]]
20
+ DEFAULT_KEYPOINTS = [
21
+ [(x_min + x_max) // 2, (y_min + y_max) // 2]
22
+ for x_min, y_min, x_max, y_max in DEFAULT_BOXES
23
+ ]
24
+ CORENERS = [[[x_min, y_min], [x_max, y_max], [x_min, y_max], [x_max, y_min]] for x_min, y_min, x_max, y_max in DEFAULT_BOXES]
25
+ for bbox_corners in CORENERS:
26
+ DEFAULT_KEYPOINTS += bbox_corners
27
+
28
+ BASE64_DEFAULT_MASKS = [
29
+ {
30
+ "label": "Coverall",
31
+ # light green color
32
+ "color": (144, 238, 144),
33
+ "mask": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAAAAABXXkFEAAAF+ElEQVR4nO3dwXLjNhBFUSg1///LziLj1Iwt26KkFhuvz9kkWVhFAJcAqXGStQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACe5+3sCyCUsq745+wLSKCsz4T1DMr6RFiUENZT2LI+EhYlhPWnt7t3nvt/MtSvsy+gkcfaeFuXJ11HBJPx7r+s7piPP3o0m/9zFP729tdfjv/gnT8dy1G41npeEc7Dd3astR7q6uOP2rT+4wb7mMLBGfkckildyyw8HMa1HWr8pK7pz1hF55YnrdE315dVHZmTb9IcPLVr7Oi/36oOTMpPe97Q+R16FL7wzW3sqThw2DdkdfOs3JTowDmeN+gbN6tbp+XJHxdk1pAPnIE3TczNnzdrmteaNeJjj1Y3zMyRD5w00WuNGu/RR/afpubZn5dlymjvexH8Znbu+cApk73WlLE+8v3C1Rm69wNnTPdaM0ba6hcOJkz4WhPG2SqrtSZM+Vr5o2yX1Vr5k75W+hhbZrVW+rSvlT3Ctlmt7Hlfa0X/anLnrnpf3DPkhtV86Zpf3sNyw+ouvKzYsPqvW/8rfERsWJxLWOeJ3rKERQlhnSh5y0oNK3nNtpAa1h6C8w8NK3jFNpEZlq5OlxnWNnLvgMiwcpdrH5FhbST2HkgMK3axdpIY1lZS7wJhUSIwrM32gM0u91aBYdFBXlihO8Bu4sLSVQ9xYe0n81ZICytzlTaUFtaOIm8GYVEiLKzIm39LYWHRhbAoIawGEg9wYVEiK6zEW39TWWHtKvCGEBYlhEUJYbWQdxYKixLCokRUWHkHyr6iwqIPYVEiKSwnYSNJYdGIsCghLEoIixLCokRQWF4KOwkKa2txd4WwmkgrS1iUEFYXYVuWsCiRE1bYHb+7nLBoRViUEBYlhEUJYVFCWG1kvdYKixIxYWXd7/tLCUtXzaSEFeBy9gU8VUhYNqxuMsLSVTsZYdFORFgZG1bGKN5FhEU/wqJEQlhZZ0iIhLBoSFiUEBYlhEUJYVEiIaysP70NkRAWDQmLEsLqI+qLXmFRQliUEBYlhEWJX2dfwK4ua4U9bj+XsA66/P0P0vqCo/CQy8dv+X3r/wVhHXElI2VdJ6wDiiOKalRYlBDWo6L2mecRVhtZhQrrUb5wuEpYlBBWF1knYUZYYWsSISKsM8vyiHVdRlivYWM8ICQsa95NSFinleUk/EJKWDQjrCbSDvOYsM5ZGCfhV2LCohdhPcKG9SVh3e5TRk/sKu0RKyis1y+N/eobOWG9nK6+ExTWa7esN119y79XeBdV/URYd5DVz4R1wNtlqepGUa+5+6551DKstaIe3ulEWJQQFiWERQlhUSIqrLx3q31FhUUfwqJEVljOwjaywqINYXUQuNOGhRW4QpsKC4su0sKyZTWRFtaWZe14zT+JCytylTYUuQyb/cJf5Brk7Vhrt5Xa62pvFRnWVmu107UekBlW6mptJDQsZZ0tNaxtpN4BqeN6fzW8/PH3LaUuQOq4PuqaVuz8OwopEXvHfNRyywqe/eCh/aVjV9Fz7z8KcpborIR1jvCo1hLWGQZk5a3wBCO6EtbLzehKWNQQFiWERQlhUUJYlBAWJYRFCWG9Wsc/Di8gLEoMCWvINtHIkLB4NWFRQliUENbLzXjeExYlhEUJYVFiRlgzHmtamREWLycsSgjr9UYczMKihLAoISxKCIsSwqKEsCghLEoI6wQTvsgSFiWERQlhUUJYZxjwkCUsSgiLEiPCGnDytDMirH7yU58QVv4qNjQhLE4gLEoIixLCooSwzhH/QjEhrMuQ/31NKxPCktYJBs14s9MnfOZn7FhrLdvWaw0KS1qvNG+qu5yI4TMfPryrmqSVPfWjjsLfnIgvMHWOO+xa0XM/ccdaK3xRO5gaFsWERQlhUWJqWB0e3qNNDauD6LiFRYmhYUVvFi3MDEtX5UaGpat6I8Oi3r8KSpCuwVpGmQAAAABJRU5ErkJggg==",
34
+ },
35
+ {
36
+ "label": "Mask",
37
+ # light blue color
38
+ "color": (173, 216, 230),
39
+ "mask": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAAAAABXXkFEAAAB4ElEQVR4nO3csQ6CMBSG0avv/864OFhoobW9UeM5i4ML+fOFkmCMAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIbcPn0BX257ftppkMEqtuY35uplqYN2VhFhsU5m2rnIKsJmXYy00xFWRBjuin1KvV2F6c7dP30Bv2ugwT8krPcp64SwJmzSahJWYbQUZbUIqzD8QK6sBmEVdLKKsEghLFIIa5LDs05YpBAWKYQ1y1lYJSxSCIsUwiKFsF55XlpGWLP83q9KWKQQ1qt37j6OzyphkUJYpBBWwZP4KsIqKWsRYZFCWDtuWWsIixTC2nPLWkJYB8paQVhHY2XpsMosdV0vaozXZpsG/+s3x0BN9bQM1sdOJ45pmauXpU4VadlqgLEubRF2AgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAGCZBwKLGEVAl/J/AAAAAElFTkSuQmCC",
40
+ },
41
+ ]
42
+
43
+ # Get all the transforms from the albumentations library
44
+ transforms_map = {
45
+ name: cls
46
+ for name, cls in vars(A).items()
47
+ if inspect.isclass(cls) and issubclass(cls, (A.DualTransform, A.ImageOnlyTransform))
48
+ }
49
+ transforms_map.pop("DualTransform", None)
50
+ transforms_map.pop("ImageOnlyTransform", None)
51
+ transforms_keys = list(sorted(transforms_map.keys()))
52
+
53
+ # Decode the masks
54
+ for mask in BASE64_DEFAULT_MASKS:
55
+ mask["mask"] = np.array(Image.open(io.BytesIO(base64.b64decode(mask["mask"]))).convert("L"))
56
+
57
+
58
+ def run_with_retry(compose):
59
+ @wraps(compose)
60
+ def wrapper(*args, **kwargs):
61
+ for i in range(4):
62
+ try:
63
+ return compose(*args, **kwargs)
64
+ except NotImplementedError as e:
65
+ print(f"Caught NotImplementedError: {e}")
66
+ if "bbox" in str(e):
67
+ kwargs.pop("bboxes", None)
68
+ kwargs.pop("category_id", None)
69
+ if "keypoint" in str(e):
70
+ kwargs.pop("keypoints", None)
71
+ if "mask" in str(e):
72
+ kwargs.pop("mask", None)
73
+ return wrapper
74
+
75
+
76
+
77
+ def draw_boxes(image, boxes, color=(255, 0, 0), thickness=2) -> np.ndarray:
78
+ """Draw boxes with PIL."""
79
+ pil_image = Image.fromarray(image)
80
+ draw = ImageDraw.Draw(pil_image)
81
+ for box in boxes:
82
+ x_min, y_min, x_max, y_max = box
83
+ draw.rectangle([x_min, y_min, x_max, y_max], outline=color, width=thickness)
84
+ return np.array(pil_image)
85
+
86
+
87
+ def draw_keypoints(image, keypoints, color=(255, 0, 0), radius=2):
88
+ """Draw keypoints with PIL."""
89
+ pil_image = Image.fromarray(image)
90
+ draw = ImageDraw.Draw(pil_image)
91
+ for keypoint in keypoints:
92
+ x, y = keypoint
93
+ draw.ellipse([x - radius, y - radius, x + radius, y + radius], fill=color)
94
+ return np.array(pil_image)
95
+
96
+
97
+ def get_rgb_mask(masks):
98
+ """Get the RGB mask from the binary mask."""
99
+ rgb_mask = np.zeros((DEFAULT_IMAGE_HEIGHT, DEFAULT_IMAGE_WIDTH, 3), dtype=np.uint8)
100
+ for data in masks:
101
+ mask = data["mask"]
102
+ rgb_mask[mask > 0] = np.array(data["color"])
103
+ return rgb_mask
104
+
105
+
106
+ def draw_mask(image, mask):
107
+ """Draw the mask on the image."""
108
+ image_with_mask = cv2.addWeighted(image, 0.5, mask, 0.5, 0)
109
+ return image_with_mask
110
+
111
+
112
+ def draw_not_implemented_image(image):
113
+ """Draw the image with a text. In the middle."""
114
+ pil_image = Image.fromarray(image)
115
+ draw = ImageDraw.Draw(pil_image)
116
+ draw.text((DEFAULT_IMAGE_WIDTH // 2, DEFAULT_IMAGE_HEIGHT // 2), "Not implemented", fill=(255, 0, 0))
117
+ return np.array(pil_image)
118
+
119
+
120
+ def get_formatted_signature(function_or_class, indentation=4):
121
+
122
+ signature = inspect.signature(function_or_class)
123
+ type_hints = get_type_hints(function_or_class)
124
+
125
+ args = []
126
+ for param in signature.parameters.values():
127
+ if param.default == inspect.Parameter.empty:
128
+ str_param = f"{param.name}=,"
129
+ else:
130
+ if isinstance(param.default, str):
131
+ str_param = f'{param.name}="{param.default}",'
132
+ else:
133
+ str_param = f"{param.name}={param.default},"
134
+
135
+ annotation = type_hints.get(param.name, param.annotation)
136
+ if isinstance(param.annotation, type):
137
+ str_param += f" # {param.annotation.__name__}"
138
+ else:
139
+ str_annotation = str(annotation).replace("typing.", "")
140
+ str_param += f" # {str_annotation}"
141
+ str_param = "\n" + " " * indentation + str_param
142
+ args.append(str_param)
143
+
144
+ result = "(" + "".join(args) + "\n" + " " * (indentation - 4) + ")"
145
+ return result
146
+
147
+
148
+ def update(image, code):
149
+ try:
150
+ augmentation = eval(code)
151
+ compose = A.Compose(
152
+ [augmentation],
153
+ bbox_params=A.BboxParams(format="pascal_voc", label_fields=["category_id"]),
154
+ keypoint_params=A.KeypointParams(format="xy"),
155
+ additional_targets={"not_implemented_image": "image"}
156
+ )
157
+ compose = run_with_retry(compose) # to prevent NotImplementedError
158
+
159
+ keypoints = DEFAULT_KEYPOINTS
160
+ bboxes = DEFAULT_BOXES
161
+ mask = get_rgb_mask(BASE64_DEFAULT_MASKS)
162
+ augmented = compose(
163
+ image=image,
164
+ not_implemented_image=draw_not_implemented_image(image),
165
+ mask=mask,
166
+ keypoints=keypoints,
167
+ bboxes=bboxes,
168
+ category_id=range(len(bboxes)),
169
+ )
170
+ image = augmented["image"]
171
+ not_implemented_image = augmented["not_implemented_image"]
172
+ mask = augmented.get("mask", None)
173
+ bboxes = augmented.get("bboxes", None)
174
+ keypoints = augmented.get("keypoints", None)
175
+
176
+ image_with_mask = draw_mask(image.copy(), mask) if mask is not None else not_implemented_image
177
+ image_with_bboxes = draw_boxes(image.copy(), bboxes) if bboxes is not None else not_implemented_image
178
+ image_with_keypoints = draw_keypoints(image.copy(), keypoints) if keypoints is not None else not_implemented_image
179
+
180
+ return [
181
+ (image_with_mask, "Mask"),
182
+ (image_with_bboxes, "Boxes"),
183
+ (image_with_keypoints, "Keypoints"),
184
+ ]
185
+ except Exception as e:
186
+ raise e
187
+
188
+
189
+ def update_image_info(image):
190
+ h, w = image.shape[:2]
191
+ dtype = image.dtype
192
+ max_, min_ = image.max(), image.min()
193
+ return f"Image info:\n\t - shape: {h}x{w}\n\t - dtype: {dtype}\n\t - min/max: {min_}/{max_}"
194
+
195
+
196
+ def get_formatted_transform(transform_number):
197
+ transform_name = transforms_keys[transform_number]
198
+ transform = transforms_map[transform_name]
199
+ return f"A.{transform.__name__}{get_formatted_signature(transform)}"
200
+
201
+
202
+ def get_formatted_transform_docs(transform_number):
203
+ transform_name = transforms_keys[transform_number]
204
+ transform = transforms_map[transform_name]
205
+ return transform.__doc__.strip("\n")
206
+
207
+
208
+ with gr.Blocks() as demo:
209
+
210
+ with gr.Row():
211
+ with gr.Group():
212
+ select = gr.Dropdown(
213
+ label="Select a transformation",
214
+ choices=transforms_keys,
215
+ value=DEFAULT_TRANSFORM,
216
+ type="index",
217
+ interactive=True,
218
+ )
219
+ with gr.Accordion("Documentation", open=False):
220
+ docs = gr.TextArea(
221
+ get_formatted_transform_docs(
222
+ transforms_keys.index(DEFAULT_TRANSFORM)
223
+ ),
224
+ show_label=False,
225
+ interactive=False,
226
+ )
227
+ code = gr.Code(
228
+ language="python",
229
+ value=get_formatted_transform(transforms_keys.index(DEFAULT_TRANSFORM)),
230
+ interactive=True,
231
+ lines=5,
232
+ )
233
+ #info = gr.Text(interactive=False, label="Image info", value="")
234
+ image = gr.Image(
235
+ value=DEFAULT_IMAGE,
236
+ type="numpy",
237
+ height=500,
238
+ width=300,
239
+ sources=[],
240
+ )
241
+ with gr.Row():
242
+ augmented_image = gr.Gallery(rows=1, columns=3)
243
+ # augmented_image = gr.Image(type="numpy", height=300, width=300)
244
+
245
+ #image.upload(fn=update_image_info, inputs=[image], outputs=[info])
246
+ select.change(fn=get_formatted_transform, inputs=[select], outputs=[code])
247
+ button = gr.Button("Run")
248
+ button.click(fn=update, inputs=[image, code], outputs=[augmented_image])
249
+
250
+
251
+ if __name__ == "__main__":
252
+ demo.launch()
images/doctor.webp ADDED
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python3-opencv
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ albumentations
2
+ Pillow