qubvel-hf HF staff commited on
Commit
83c2ac2
1 Parent(s): 422d636

Clean up, remove some transforms

Browse files
Files changed (2) hide show
  1. app.py +112 -74
  2. utils.py +31 -0
app.py CHANGED
@@ -1,30 +1,42 @@
1
- import cv2
2
- import inspect
3
- import numpy as np
4
  import albumentations as A
5
- import gradio as gr
6
- from typing import get_type_hints
7
- from PIL import Image, ImageDraw
8
  import base64
 
 
 
9
  import io
10
- from PIL import Image
11
- from functools import wraps
12
  from copy import deepcopy
 
 
 
 
 
13
 
14
 
 
 
 
 
 
 
 
 
 
15
  DEFAULT_TRANSFORM = "Rotate"
16
 
17
  DEFAULT_IMAGE = "images/doctor.webp"
18
  DEFAULT_IMAGE_HEIGHT = 400
19
  DEFAULT_IMAGE_WIDTH = 600
20
- DEFAULT_BOXES = [[265, 121, 326, 177], [192, 169, 401, 395]]
21
- DEFAULT_KEYPOINTS = [
22
- [(x_min + x_max) // 2, (y_min + y_max) // 2]
23
- for x_min, y_min, x_max, y_max in DEFAULT_BOXES
24
  ]
25
- CORENERS = [[[x_min, y_min], [x_max, y_max], [x_min, y_max], [x_max, y_min]] for x_min, y_min, x_max, y_max in DEFAULT_BOXES]
26
- for bbox_corners in CORENERS:
27
- DEFAULT_KEYPOINTS += bbox_corners
 
 
28
 
29
  BASE64_DEFAULT_MASKS = [
30
  {
@@ -45,15 +57,23 @@ BASE64_DEFAULT_MASKS = [
45
  transforms_map = {
46
  name: cls
47
  for name, cls in vars(A).items()
48
- if inspect.isclass(cls) and issubclass(cls, (A.DualTransform, A.ImageOnlyTransform))
 
 
 
 
49
  }
50
  transforms_map.pop("DualTransform", None)
51
  transforms_map.pop("ImageOnlyTransform", None)
 
52
  transforms_keys = list(sorted(transforms_map.keys()))
53
 
 
54
  # Decode the masks
55
  for mask in BASE64_DEFAULT_MASKS:
56
- mask["mask"] = np.array(Image.open(io.BytesIO(base64.b64decode(mask["mask"]))).convert("L"))
 
 
57
 
58
 
59
  def run_with_retry(compose):
@@ -80,6 +100,7 @@ def run_with_retry(compose):
80
  raise e
81
  compose.processors = processors
82
  return result
 
83
  return wrapper
84
 
85
 
@@ -118,12 +139,12 @@ def draw_mask(image, mask):
118
  return image_with_mask
119
 
120
 
121
- def draw_not_implemented_image(image):
122
  """Draw the image with a text. In the middle."""
123
  pil_image = Image.fromarray(image)
124
  draw = ImageDraw.Draw(pil_image)
125
  # align in the centerm, and make bigger font
126
- text = "NOT IMPLEMETED FOR THIS TYPE OF ANNOTATIONS"
127
  length = draw.textlength(text)
128
  draw.text(
129
  (DEFAULT_IMAGE_WIDTH // 2 - length // 2, DEFAULT_IMAGE_HEIGHT // 2),
@@ -164,45 +185,65 @@ def get_formatted_signature(function_or_class, indentation=4):
164
  return result
165
 
166
 
167
- def update(image, code):
168
- try:
169
- augmentation = eval(code)
170
- compose = A.Compose(
171
- [augmentation],
172
- bbox_params=A.BboxParams(format="pascal_voc", label_fields=["category_id"]),
173
- keypoint_params=A.KeypointParams(format="xy"),
174
- additional_targets={"not_implemented_image": "image"}
175
- )
176
- compose = run_with_retry(compose) # to prevent NotImplementedError
177
-
178
- keypoints = DEFAULT_KEYPOINTS
179
- bboxes = DEFAULT_BOXES
180
- mask = get_rgb_mask(BASE64_DEFAULT_MASKS)
181
- augmented = compose(
182
- image=image,
183
- not_implemented_image=draw_not_implemented_image(image),
184
- mask=mask,
185
- keypoints=keypoints,
186
- bboxes=bboxes,
187
- category_id=range(len(bboxes)),
188
- )
189
- image = augmented["image"]
190
- not_implemented_image = augmented["not_implemented_image"]
191
- mask = augmented.get("mask", None)
192
- bboxes = augmented.get("bboxes", None)
193
- keypoints = augmented.get("keypoints", None)
194
 
195
- image_with_mask = draw_mask(image.copy(), mask) if mask is not None else not_implemented_image
196
- image_with_bboxes = draw_boxes(image.copy(), bboxes) if bboxes is not None else not_implemented_image
197
- image_with_keypoints = draw_keypoints(image.copy(), keypoints) if keypoints is not None else not_implemented_image
198
 
199
- return [
200
- (image_with_mask, "Mask"),
201
- (image_with_bboxes, "Boxes"),
202
- (image_with_keypoints, "Keypoints"),
203
- ]
204
- except Exception as e:
205
- raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
 
208
  def update_image_info(image):
@@ -212,20 +253,14 @@ def update_image_info(image):
212
  return f"Image info:\n\t - shape: {h}x{w}\n\t - dtype: {dtype}\n\t - min/max: {min_}/{max_}"
213
 
214
 
215
- def get_formatted_transform(transform_number):
216
- transform_name = transforms_keys[transform_number]
217
- transform = transforms_map[transform_name]
218
- return f"A.{transform.__name__}{get_formatted_signature(transform)}"
219
-
220
-
221
- def get_formatted_transform_docs(transform_number):
222
- transform_name = transforms_keys[transform_number]
223
- transform = transforms_map[transform_name]
224
- return transform.__doc__.strip("\n")
225
 
226
 
227
  with gr.Blocks() as demo:
228
-
229
  with gr.Row():
230
  with gr.Column():
231
  with gr.Group():
@@ -236,7 +271,7 @@ with gr.Blocks() as demo:
236
  type="index",
237
  interactive=True,
238
  )
239
- with gr.Accordion("Documentation", open=False):
240
  docs = gr.TextArea(
241
  get_formatted_transform_docs(
242
  transforms_keys.index(DEFAULT_TRANSFORM)
@@ -245,8 +280,11 @@ with gr.Blocks() as demo:
245
  interactive=False,
246
  )
247
  code = gr.Code(
 
248
  language="python",
249
- value=get_formatted_transform(transforms_keys.index(DEFAULT_TRANSFORM)),
 
 
250
  interactive=True,
251
  lines=5,
252
  )
@@ -256,7 +294,7 @@ with gr.Blocks() as demo:
256
  lines=1,
257
  max_lines=1,
258
  )
259
- button = gr.Button("Run")
260
  image = gr.Image(
261
  value=DEFAULT_IMAGE,
262
  type="numpy",
@@ -266,11 +304,11 @@ with gr.Blocks() as demo:
266
  )
267
  with gr.Row():
268
  augmented_image = gr.Gallery(rows=1, columns=3)
269
- # augmented_image = gr.Image(type="numpy", height=300, width=300)
270
 
271
- #image.upload(fn=update_image_info, inputs=[image], outputs=[info])
272
- select.change(fn=get_formatted_transform, inputs=[select], outputs=[code])
273
- button.click(fn=update, inputs=[image, code], outputs=[augmented_image])
 
274
 
275
 
276
  if __name__ == "__main__":
 
 
 
 
1
  import albumentations as A
 
 
 
2
  import base64
3
+ import cv2
4
+ import gradio as gr
5
+ import inspect
6
  import io
7
+ import numpy as np
8
+
9
  from copy import deepcopy
10
+ from functools import wraps
11
+ from PIL import Image, ImageDraw
12
+ from typing import get_type_hints
13
+
14
+ from utils import is_not_supported_transform
15
 
16
 
17
+ HEADER = f"""
18
+ <div align="center">
19
+ <p>
20
+ <img src="https://avatars.githubusercontent.com/u/57894582?s=200&v=4" alt="A" width="50" height="50" style="display:inline;">
21
+ <span style="font-size: 30px; vertical-align: bottom;"> lbumentations Demo ({A.__version__})</span>
22
+ </p>
23
+ </div>
24
+ """
25
+
26
  DEFAULT_TRANSFORM = "Rotate"
27
 
28
  DEFAULT_IMAGE = "images/doctor.webp"
29
  DEFAULT_IMAGE_HEIGHT = 400
30
  DEFAULT_IMAGE_WIDTH = 600
31
+ DEFAULT_BOXES = [
32
+ [265, 121, 326, 177], # Mask
33
+ [192, 169, 401, 395], # Coverall
 
34
  ]
35
+
36
+ mask_keypoints = [[270, 123], [320, 130], [270, 151], [321, 158]]
37
+ pocket_keypoints = [[226, 379], [272, 386], [307, 388], [364, 380]]
38
+ arm_keypoints = [[215, 194], [372, 192], [214, 322], [378, 330]]
39
+ DEFAULT_KEYPOINTS = mask_keypoints + pocket_keypoints + arm_keypoints
40
 
41
  BASE64_DEFAULT_MASKS = [
42
  {
 
57
  transforms_map = {
58
  name: cls
59
  for name, cls in vars(A).items()
60
+ if (
61
+ inspect.isclass(cls)
62
+ and issubclass(cls, (A.DualTransform, A.ImageOnlyTransform))
63
+ and not is_not_supported_transform(cls)
64
+ )
65
  }
66
  transforms_map.pop("DualTransform", None)
67
  transforms_map.pop("ImageOnlyTransform", None)
68
+ transforms_map.pop("ReferenceBasedTransform", None)
69
  transforms_keys = list(sorted(transforms_map.keys()))
70
 
71
+
72
  # Decode the masks
73
  for mask in BASE64_DEFAULT_MASKS:
74
+ mask["mask"] = np.array(
75
+ Image.open(io.BytesIO(base64.b64decode(mask["mask"]))).convert("L")
76
+ )
77
 
78
 
79
  def run_with_retry(compose):
 
100
  raise e
101
  compose.processors = processors
102
  return result
103
+
104
  return wrapper
105
 
106
 
 
139
  return image_with_mask
140
 
141
 
142
+ def draw_not_implemented_image(image: np.ndarray, annotation_type: str):
143
  """Draw the image with a text. In the middle."""
144
  pil_image = Image.fromarray(image)
145
  draw = ImageDraw.Draw(pil_image)
146
  # align in the centerm, and make bigger font
147
+ text = f'Transform NOT working with "{annotation_type.upper()}" annotaions.'
148
  length = draw.textlength(text)
149
  draw.text(
150
  (DEFAULT_IMAGE_WIDTH // 2 - length // 2, DEFAULT_IMAGE_HEIGHT // 2),
 
185
  return result
186
 
187
 
188
+ def get_formatted_transform(transform_number):
189
+ transform_name = transforms_keys[transform_number]
190
+ transform = transforms_map[transform_name]
191
+ return f"A.{transform.__name__}{get_formatted_signature(transform)}"
192
+
193
+
194
+ def get_formatted_transform_docs(transform_number):
195
+ transform_name = transforms_keys[transform_number]
196
+ transform = transforms_map[transform_name]
197
+ return transform.__doc__.strip("\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
 
 
 
199
 
200
+ def update_augmented_images(image, code):
201
+
202
+ augmentation = eval(code)
203
+ compose = A.Compose(
204
+ [augmentation],
205
+ bbox_params=A.BboxParams(format="pascal_voc", label_fields=["category_id"]),
206
+ keypoint_params=A.KeypointParams(format="xy"),
207
+ )
208
+ compose = run_with_retry(compose) # to prevent NotImplementedError
209
+
210
+ keypoints = DEFAULT_KEYPOINTS
211
+ bboxes = DEFAULT_BOXES
212
+ mask = get_rgb_mask(BASE64_DEFAULT_MASKS)
213
+ augmented = compose(
214
+ image=image,
215
+ not_implemented_image=image.copy(),
216
+ mask=mask,
217
+ keypoints=keypoints,
218
+ bboxes=bboxes,
219
+ category_id=range(len(bboxes)),
220
+ )
221
+ image = augmented["image"]
222
+ mask = augmented.get("mask", None)
223
+ bboxes = augmented.get("bboxes", None)
224
+ keypoints = augmented.get("keypoints", None)
225
+
226
+ # Draw the augmented images (or replace by placeholder if not implemented)
227
+ if mask is not None:
228
+ image_with_mask = draw_mask(image.copy(), mask)
229
+ else:
230
+ image_with_mask = draw_not_implemented_image(image.copy(), "mask")
231
+
232
+ if bboxes is not None:
233
+ image_with_bboxes = draw_boxes(image.copy(), bboxes)
234
+ else:
235
+ image_with_bboxes = draw_not_implemented_image(image.copy(), "boxes")
236
+
237
+ if keypoints is not None:
238
+ image_with_keypoints = draw_keypoints(image.copy(), keypoints)
239
+ else:
240
+ image_with_keypoints = draw_not_implemented_image(image.copy(), "keypoints")
241
+
242
+ return [
243
+ (image_with_mask, "Mask"),
244
+ (image_with_bboxes, "Boxes"),
245
+ (image_with_keypoints, "Keypoints"),
246
+ ]
247
 
248
 
249
  def update_image_info(image):
 
253
  return f"Image info:\n\t - shape: {h}x{w}\n\t - dtype: {dtype}\n\t - min/max: {min_}/{max_}"
254
 
255
 
256
+ def update_code_and_docs(select):
257
+ code = get_formatted_transform(select)
258
+ docs = get_formatted_transform_docs(select)
259
+ return code, docs
 
 
 
 
 
 
260
 
261
 
262
  with gr.Blocks() as demo:
263
+ gr.Markdown(HEADER)
264
  with gr.Row():
265
  with gr.Column():
266
  with gr.Group():
 
271
  type="index",
272
  interactive=True,
273
  )
274
+ with gr.Accordion("Documentation (click to expand)", open=False):
275
  docs = gr.TextArea(
276
  get_formatted_transform_docs(
277
  transforms_keys.index(DEFAULT_TRANSFORM)
 
280
  interactive=False,
281
  )
282
  code = gr.Code(
283
+ label="Code",
284
  language="python",
285
+ value=get_formatted_transform(
286
+ transforms_keys.index(DEFAULT_TRANSFORM)
287
+ ),
288
  interactive=True,
289
  lines=5,
290
  )
 
294
  lines=1,
295
  max_lines=1,
296
  )
297
+ button = gr.Button("Apply!")
298
  image = gr.Image(
299
  value=DEFAULT_IMAGE,
300
  type="numpy",
 
304
  )
305
  with gr.Row():
306
  augmented_image = gr.Gallery(rows=1, columns=3)
 
307
 
308
+ select.change(fn=update_code_and_docs, inputs=[select], outputs=[code, docs])
309
+ button.click(
310
+ fn=update_augmented_images, inputs=[image, code], outputs=[augmented_image]
311
+ )
312
 
313
 
314
  if __name__ == "__main__":
utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+ import inspect
3
+ from typing import Callable
4
+
5
+
6
+ FILTER_TRANSFORMS = [
7
+ A.ImageOnlyTransform,
8
+ A.DualTransform,
9
+ A.ReferenceBasedTransform,
10
+ A.TemplateTransform,
11
+ A.Lambda,
12
+ ]
13
+
14
+
15
+ def is_not_supported_transform(transform_cls):
16
+ sig = inspect.signature(transform_cls)
17
+
18
+ if issubclass(transform_cls, A.ReferenceBasedTransform):
19
+ return True
20
+
21
+ for filter_transform_cls in FILTER_TRANSFORMS:
22
+ if transform_cls is filter_transform_cls:
23
+ return True
24
+
25
+ for param in sig.parameters.values():
26
+ if issubclass(type(param.annotation), type(Callable)):
27
+ return True
28
+ if param.name in ["read_fn", "reference_images"]:
29
+ return True
30
+
31
+ return False