dvir-bria commited on
Commit
d3b8cf3
1 Parent(s): e2dd0c1

Delete image_processor.py

Browse files
Files changed (1) hide show
  1. image_processor.py +0 -991
image_processor.py DELETED
@@ -1,991 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import math
16
- import warnings
17
- from typing import List, Optional, Tuple, Union
18
-
19
- import numpy as np
20
- import PIL.Image
21
- import torch
22
- import torch.nn.functional as F
23
- from PIL import Image, ImageFilter, ImageOps
24
-
25
- from diffusers.configuration_utils import ConfigMixin, register_to_config
26
- from diffusers.utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
27
- # from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
28
-
29
-
30
- PipelineImageInput = Union[
31
- PIL.Image.Image,
32
- np.ndarray,
33
- torch.FloatTensor,
34
- List[PIL.Image.Image],
35
- List[np.ndarray],
36
- List[torch.FloatTensor],
37
- ]
38
-
39
- PipelineDepthInput = PipelineImageInput
40
-
41
-
42
- class VaeImageProcessor(ConfigMixin):
43
- """
44
- Image processor for VAE.
45
-
46
- Args:
47
- do_resize (`bool`, *optional*, defaults to `True`):
48
- Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
49
- `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
50
- vae_scale_factor (`int`, *optional*, defaults to `8`):
51
- VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
52
- resample (`str`, *optional*, defaults to `lanczos`):
53
- Resampling filter to use when resizing the image.
54
- do_normalize (`bool`, *optional*, defaults to `True`):
55
- Whether to normalize the image to [-1,1].
56
- do_binarize (`bool`, *optional*, defaults to `False`):
57
- Whether to binarize the image to 0/1.
58
- do_convert_rgb (`bool`, *optional*, defaults to be `False`):
59
- Whether to convert the images to RGB format.
60
- do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
61
- Whether to convert the images to grayscale format.
62
- """
63
-
64
- config_name = CONFIG_NAME
65
-
66
- @register_to_config
67
- def __init__(
68
- self,
69
- do_resize: bool = True,
70
- vae_scale_factor: int = 8,
71
- resample: str = "lanczos",
72
- do_normalize: bool = True,
73
- do_binarize: bool = False,
74
- do_convert_rgb: bool = False,
75
- do_convert_grayscale: bool = False,
76
- ):
77
- super().__init__()
78
- if do_convert_rgb and do_convert_grayscale:
79
- raise ValueError(
80
- "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
81
- " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
82
- " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
83
- )
84
- self.config.do_convert_rgb = False
85
-
86
- @staticmethod
87
- def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
88
- """
89
- Convert a numpy image or a batch of images to a PIL image.
90
- """
91
- if images.ndim == 3:
92
- images = images[None, ...]
93
- images = (images * 255).round().astype("uint8")
94
- if images.shape[-1] == 1:
95
- # special case for grayscale (single channel) images
96
- pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
97
- else:
98
- pil_images = [Image.fromarray(image) for image in images]
99
-
100
- return pil_images
101
-
102
- @staticmethod
103
- def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
104
- """
105
- Convert a PIL image or a list of PIL images to NumPy arrays.
106
- """
107
- if not isinstance(images, list):
108
- images = [images]
109
- images = [np.array(image).astype(np.float32) / 255.0 for image in images]
110
- images = np.stack(images, axis=0)
111
-
112
- return images
113
-
114
- @staticmethod
115
- def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
116
- """
117
- Convert a NumPy image to a PyTorch tensor.
118
- """
119
- if images.ndim == 3:
120
- images = images[..., None]
121
-
122
- images = torch.from_numpy(images.transpose(0, 3, 1, 2))
123
- return images
124
-
125
- @staticmethod
126
- def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
127
- """
128
- Convert a PyTorch tensor to a NumPy image.
129
- """
130
- images = images.cpu().permute(0, 2, 3, 1).float().numpy()
131
- return images
132
-
133
- @staticmethod
134
- def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
135
- """
136
- Normalize an image array to [-1,1].
137
- """
138
- return 2.0 * images - 1.0
139
-
140
- @staticmethod
141
- def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
142
- """
143
- Denormalize an image array to [0,1].
144
- """
145
- return (images / 2 + 0.5).clamp(0, 1)
146
-
147
- @staticmethod
148
- def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
149
- """
150
- Converts a PIL image to RGB format.
151
- """
152
- image = image.convert("RGB")
153
-
154
- return image
155
-
156
- @staticmethod
157
- def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
158
- """
159
- Converts a PIL image to grayscale format.
160
- """
161
- image = image.convert("L")
162
-
163
- return image
164
-
165
- @staticmethod
166
- def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image:
167
- """
168
- Applies Gaussian blur to an image.
169
- """
170
- image = image.filter(ImageFilter.GaussianBlur(blur_factor))
171
-
172
- return image
173
-
174
- @staticmethod
175
- def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
176
- """
177
- Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect ratio of the original image;
178
- for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128.
179
-
180
- Args:
181
- mask_image (PIL.Image.Image): Mask image.
182
- width (int): Width of the image to be processed.
183
- height (int): Height of the image to be processed.
184
- pad (int, optional): Padding to be added to the crop region. Defaults to 0.
185
-
186
- Returns:
187
- tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and matches the original aspect ratio.
188
- """
189
-
190
- mask_image = mask_image.convert("L")
191
- mask = np.array(mask_image)
192
-
193
- # 1. find a rectangular region that contains all masked ares in an image
194
- h, w = mask.shape
195
- crop_left = 0
196
- for i in range(w):
197
- if not (mask[:, i] == 0).all():
198
- break
199
- crop_left += 1
200
-
201
- crop_right = 0
202
- for i in reversed(range(w)):
203
- if not (mask[:, i] == 0).all():
204
- break
205
- crop_right += 1
206
-
207
- crop_top = 0
208
- for i in range(h):
209
- if not (mask[i] == 0).all():
210
- break
211
- crop_top += 1
212
-
213
- crop_bottom = 0
214
- for i in reversed(range(h)):
215
- if not (mask[i] == 0).all():
216
- break
217
- crop_bottom += 1
218
-
219
- # 2. add padding to the crop region
220
- x1, y1, x2, y2 = (
221
- int(max(crop_left - pad, 0)),
222
- int(max(crop_top - pad, 0)),
223
- int(min(w - crop_right + pad, w)),
224
- int(min(h - crop_bottom + pad, h)),
225
- )
226
-
227
- # 3. expands crop region to match the aspect ratio of the image to be processed
228
- ratio_crop_region = (x2 - x1) / (y2 - y1)
229
- ratio_processing = width / height
230
-
231
- if ratio_crop_region > ratio_processing:
232
- desired_height = (x2 - x1) / ratio_processing
233
- desired_height_diff = int(desired_height - (y2 - y1))
234
- y1 -= desired_height_diff // 2
235
- y2 += desired_height_diff - desired_height_diff // 2
236
- if y2 >= mask_image.height:
237
- diff = y2 - mask_image.height
238
- y2 -= diff
239
- y1 -= diff
240
- if y1 < 0:
241
- y2 -= y1
242
- y1 -= y1
243
- if y2 >= mask_image.height:
244
- y2 = mask_image.height
245
- else:
246
- desired_width = (y2 - y1) * ratio_processing
247
- desired_width_diff = int(desired_width - (x2 - x1))
248
- x1 -= desired_width_diff // 2
249
- x2 += desired_width_diff - desired_width_diff // 2
250
- if x2 >= mask_image.width:
251
- diff = x2 - mask_image.width
252
- x2 -= diff
253
- x1 -= diff
254
- if x1 < 0:
255
- x2 -= x1
256
- x1 -= x1
257
- if x2 >= mask_image.width:
258
- x2 = mask_image.width
259
-
260
- return x1, y1, x2, y2
261
-
262
- def _resize_and_fill(
263
- self,
264
- image: PIL.Image.Image,
265
- width: int,
266
- height: int,
267
- ) -> PIL.Image.Image:
268
- """
269
- Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
270
-
271
- Args:
272
- image: The image to resize.
273
- width: The width to resize the image to.
274
- height: The height to resize the image to.
275
- """
276
-
277
- ratio = width / height
278
- src_ratio = image.width / image.height
279
-
280
- src_w = width if ratio < src_ratio else image.width * height // image.height
281
- src_h = height if ratio >= src_ratio else image.height * width // image.width
282
-
283
- resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
284
- res = Image.new("RGB", (width, height))
285
- res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
286
-
287
- if ratio < src_ratio:
288
- fill_height = height // 2 - src_h // 2
289
- if fill_height > 0:
290
- res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
291
- res.paste(
292
- resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
293
- box=(0, fill_height + src_h),
294
- )
295
- elif ratio > src_ratio:
296
- fill_width = width // 2 - src_w // 2
297
- if fill_width > 0:
298
- res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
299
- res.paste(
300
- resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
301
- box=(fill_width + src_w, 0),
302
- )
303
-
304
- return res
305
-
306
- def _resize_and_crop(
307
- self,
308
- image: PIL.Image.Image,
309
- width: int,
310
- height: int,
311
- ) -> PIL.Image.Image:
312
- """
313
- Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
314
-
315
- Args:
316
- image: The image to resize.
317
- width: The width to resize the image to.
318
- height: The height to resize the image to.
319
- """
320
- ratio = width / height
321
- src_ratio = image.width / image.height
322
-
323
- src_w = width if ratio > src_ratio else image.width * height // image.height
324
- src_h = height if ratio <= src_ratio else image.height * width // image.width
325
-
326
- resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
327
- res = Image.new("RGB", (width, height))
328
- res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
329
- return res
330
-
331
- def resize(
332
- self,
333
- image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
334
- height: int,
335
- width: int,
336
- resize_mode: str = "default", # "default", "fill", "crop"
337
- ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
338
- """
339
- Resize image.
340
-
341
- Args:
342
- image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
343
- The image input, can be a PIL image, numpy array or pytorch tensor.
344
- height (`int`):
345
- The height to resize to.
346
- width (`int`):
347
- The width to resize to.
348
- resize_mode (`str`, *optional*, defaults to `default`):
349
- The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
350
- within the specified width and height, and it may not maintaining the original aspect ratio.
351
- If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
352
- within the dimensions, filling empty with data from image.
353
- If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
354
- within the dimensions, cropping the excess.
355
- Note that resize_mode `fill` and `crop` are only supported for PIL image input.
356
-
357
- Returns:
358
- `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
359
- The resized image.
360
- """
361
- if resize_mode != "default" and not isinstance(image, PIL.Image.Image):
362
- raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}")
363
- if isinstance(image, PIL.Image.Image):
364
- if resize_mode == "default":
365
- image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
366
- elif resize_mode == "fill":
367
- image = self._resize_and_fill(image, width, height)
368
- elif resize_mode == "crop":
369
- image = self._resize_and_crop(image, width, height)
370
- else:
371
- raise ValueError(f"resize_mode {resize_mode} is not supported")
372
-
373
- elif isinstance(image, torch.Tensor):
374
- image = torch.nn.functional.interpolate(
375
- image,
376
- size=(height, width),
377
- )
378
- elif isinstance(image, np.ndarray):
379
- image = self.numpy_to_pt(image)
380
- image = torch.nn.functional.interpolate(
381
- image,
382
- size=(height, width),
383
- )
384
- image = self.pt_to_numpy(image)
385
- return image
386
-
387
- def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
388
- """
389
- Create a mask.
390
-
391
- Args:
392
- image (`PIL.Image.Image`):
393
- The image input, should be a PIL image.
394
-
395
- Returns:
396
- `PIL.Image.Image`:
397
- The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1.
398
- """
399
- image[image < 0.5] = 0
400
- image[image >= 0.5] = 1
401
-
402
- return image
403
-
404
- def get_default_height_width(
405
- self,
406
- image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
407
- height: Optional[int] = None,
408
- width: Optional[int] = None,
409
- ) -> Tuple[int, int]:
410
- """
411
- This function return the height and width that are downscaled to the next integer multiple of
412
- `vae_scale_factor`.
413
-
414
- Args:
415
- image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
416
- The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have
417
- shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should
418
- have shape `[batch, channel, height, width]`.
419
- height (`int`, *optional*, defaults to `None`):
420
- The height in preprocessed image. If `None`, will use the height of `image` input.
421
- width (`int`, *optional*`, defaults to `None`):
422
- The width in preprocessed. If `None`, will use the width of the `image` input.
423
- """
424
-
425
- if height is None:
426
- if isinstance(image, PIL.Image.Image):
427
- height = image.height
428
- elif isinstance(image, torch.Tensor):
429
- height = image.shape[2]
430
- else:
431
- height = image.shape[1]
432
-
433
- if width is None:
434
- if isinstance(image, PIL.Image.Image):
435
- width = image.width
436
- elif isinstance(image, torch.Tensor):
437
- width = image.shape[3]
438
- else:
439
- width = image.shape[2]
440
-
441
- width, height = (
442
- x - x % self.config.vae_scale_factor for x in (width, height)
443
- ) # resize to integer multiple of vae_scale_factor
444
-
445
- return height, width
446
-
447
- def preprocess(
448
- self,
449
- image: PipelineImageInput,
450
- height: Optional[int] = None,
451
- width: Optional[int] = None,
452
- resize_mode: str = "default", # "default", "fill", "crop"
453
- crops_coords: Optional[Tuple[int, int, int, int]] = None,
454
- ) -> torch.Tensor:
455
- """
456
- Preprocess the image input.
457
-
458
- Args:
459
- image (`pipeline_image_input`):
460
- The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of supported formats.
461
- height (`int`, *optional*, defaults to `None`):
462
- The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default height.
463
- width (`int`, *optional*`, defaults to `None`):
464
- The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
465
- resize_mode (`str`, *optional*, defaults to `default`):
466
- The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit
467
- within the specified width and height, and it may not maintaining the original aspect ratio.
468
- If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
469
- within the dimensions, filling empty with data from image.
470
- If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
471
- within the dimensions, cropping the excess.
472
- Note that resize_mode `fill` and `crop` are only supported for PIL image input.
473
- crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
474
- The crop coordinates for each image in the batch. If `None`, will not crop the image.
475
- """
476
- supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
477
-
478
- # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
479
- if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
480
- if isinstance(image, torch.Tensor):
481
- # if image is a pytorch tensor could have 2 possible shapes:
482
- # 1. batch x height x width: we should insert the channel dimension at position 1
483
- # 2. channel x height x width: we should insert batch dimension at position 0,
484
- # however, since both channel and batch dimension has same size 1, it is same to insert at position 1
485
- # for simplicity, we insert a dimension of size 1 at position 1 for both cases
486
- image = image.unsqueeze(1)
487
- else:
488
- # if it is a numpy array, it could have 2 possible shapes:
489
- # 1. batch x height x width: insert channel dimension on last position
490
- # 2. height x width x channel: insert batch dimension on first position
491
- if image.shape[-1] == 1:
492
- image = np.expand_dims(image, axis=0)
493
- else:
494
- image = np.expand_dims(image, axis=-1)
495
-
496
- if isinstance(image, supported_formats):
497
- image = [image]
498
- elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
499
- raise ValueError(
500
- f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
501
- )
502
-
503
- if isinstance(image[0], PIL.Image.Image):
504
- if crops_coords is not None:
505
- image = [i.crop(crops_coords) for i in image]
506
- if self.config.do_resize:
507
- height, width = self.get_default_height_width(image[0], height, width)
508
- image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image]
509
- if self.config.do_convert_rgb:
510
- image = [self.convert_to_rgb(i) for i in image]
511
- elif self.config.do_convert_grayscale:
512
- image = [self.convert_to_grayscale(i) for i in image]
513
- image = self.pil_to_numpy(image) # to np
514
- image = self.numpy_to_pt(image) # to pt
515
-
516
- elif isinstance(image[0], np.ndarray):
517
- image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
518
-
519
- image = self.numpy_to_pt(image)
520
-
521
- height, width = self.get_default_height_width(image, height, width)
522
- if self.config.do_resize:
523
- image = self.resize(image, height, width)
524
-
525
- elif isinstance(image[0], torch.Tensor):
526
- image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
527
-
528
- if self.config.do_convert_grayscale and image.ndim == 3:
529
- image = image.unsqueeze(1)
530
-
531
- channel = image.shape[1]
532
- # don't need any preprocess if the image is latents
533
- if channel >= 4:
534
- return image
535
-
536
- height, width = self.get_default_height_width(image, height, width)
537
- if self.config.do_resize:
538
- image = self.resize(image, height, width)
539
-
540
- # expected range [0,1], normalize to [-1,1]
541
- do_normalize = self.config.do_normalize
542
- if do_normalize and image.min() < 0:
543
- warnings.warn(
544
- "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
545
- f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
546
- FutureWarning,
547
- )
548
- do_normalize = False
549
-
550
- if do_normalize:
551
- image = self.normalize(image)
552
-
553
- if self.config.do_binarize:
554
- image = self.binarize(image)
555
-
556
- return image
557
-
558
- def postprocess(
559
- self,
560
- image: torch.FloatTensor,
561
- output_type: str = "pil",
562
- do_denormalize: Optional[List[bool]] = None,
563
- ) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
564
- """
565
- Postprocess the image output from tensor to `output_type`.
566
-
567
- Args:
568
- image (`torch.FloatTensor`):
569
- The image input, should be a pytorch tensor with shape `B x C x H x W`.
570
- output_type (`str`, *optional*, defaults to `pil`):
571
- The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
572
- do_denormalize (`List[bool]`, *optional*, defaults to `None`):
573
- Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
574
- `VaeImageProcessor` config.
575
-
576
- Returns:
577
- `PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
578
- The postprocessed image.
579
- """
580
- if not isinstance(image, torch.Tensor):
581
- raise ValueError(
582
- f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
583
- )
584
- if output_type not in ["latent", "pt", "np", "pil"]:
585
- deprecation_message = (
586
- f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
587
- "`pil`, `np`, `pt`, `latent`"
588
- )
589
- deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
590
- output_type = "np"
591
-
592
- if output_type == "latent":
593
- return image
594
-
595
- if do_denormalize is None:
596
- do_denormalize = [self.config.do_normalize] * image.shape[0]
597
-
598
- image = torch.stack(
599
- [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
600
- )
601
-
602
- if output_type == "pt":
603
- return image
604
-
605
- image = self.pt_to_numpy(image)
606
-
607
- if output_type == "np":
608
- return image
609
-
610
- if output_type == "pil":
611
- return self.numpy_to_pil(image)
612
-
613
- def apply_overlay(
614
- self,
615
- mask: PIL.Image.Image,
616
- init_image: PIL.Image.Image,
617
- image: PIL.Image.Image,
618
- crop_coords: Optional[Tuple[int, int, int, int]] = None,
619
- ) -> PIL.Image.Image:
620
- """
621
- overlay the inpaint output to the original image
622
- """
623
-
624
- width, height = image.width, image.height
625
-
626
- init_image = self.resize(init_image, width=width, height=height)
627
- mask = self.resize(mask, width=width, height=height)
628
-
629
- init_image_masked = PIL.Image.new("RGBa", (width, height))
630
- init_image_masked.paste(init_image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L")))
631
- init_image_masked = init_image_masked.convert("RGBA")
632
-
633
- if crop_coords is not None:
634
- x, y, x2, y2 = crop_coords
635
- w = x2 - x
636
- h = y2 - y
637
- base_image = PIL.Image.new("RGBA", (width, height))
638
- image = self.resize(image, height=h, width=w, resize_mode="crop")
639
- base_image.paste(image, (x, y))
640
- image = base_image.convert("RGB")
641
-
642
- image = image.convert("RGBA")
643
- image.alpha_composite(init_image_masked)
644
- image = image.convert("RGB")
645
-
646
- return image
647
-
648
-
649
- class VaeImageProcessorLDM3D(VaeImageProcessor):
650
- """
651
- Image processor for VAE LDM3D.
652
-
653
- Args:
654
- do_resize (`bool`, *optional*, defaults to `True`):
655
- Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
656
- vae_scale_factor (`int`, *optional*, defaults to `8`):
657
- VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
658
- resample (`str`, *optional*, defaults to `lanczos`):
659
- Resampling filter to use when resizing the image.
660
- do_normalize (`bool`, *optional*, defaults to `True`):
661
- Whether to normalize the image to [-1,1].
662
- """
663
-
664
- config_name = CONFIG_NAME
665
-
666
- @register_to_config
667
- def __init__(
668
- self,
669
- do_resize: bool = True,
670
- vae_scale_factor: int = 8,
671
- resample: str = "lanczos",
672
- do_normalize: bool = True,
673
- ):
674
- super().__init__()
675
-
676
- @staticmethod
677
- def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
678
- """
679
- Convert a NumPy image or a batch of images to a PIL image.
680
- """
681
- if images.ndim == 3:
682
- images = images[None, ...]
683
- images = (images * 255).round().astype("uint8")
684
- if images.shape[-1] == 1:
685
- # special case for grayscale (single channel) images
686
- pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
687
- else:
688
- pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
689
-
690
- return pil_images
691
-
692
- @staticmethod
693
- def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
694
- """
695
- Convert a PIL image or a list of PIL images to NumPy arrays.
696
- """
697
- if not isinstance(images, list):
698
- images = [images]
699
-
700
- images = [np.array(image).astype(np.float32) / (2**16 - 1) for image in images]
701
- images = np.stack(images, axis=0)
702
- return images
703
-
704
- @staticmethod
705
- def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
706
- """
707
- Args:
708
- image: RGB-like depth image
709
-
710
- Returns: depth map
711
-
712
- """
713
- return image[:, :, 1] * 2**8 + image[:, :, 2]
714
-
715
- def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
716
- """
717
- Convert a NumPy depth image or a batch of images to a PIL image.
718
- """
719
- if images.ndim == 3:
720
- images = images[None, ...]
721
- images_depth = images[:, :, :, 3:]
722
- if images.shape[-1] == 6:
723
- images_depth = (images_depth * 255).round().astype("uint8")
724
- pil_images = [
725
- Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
726
- ]
727
- elif images.shape[-1] == 4:
728
- images_depth = (images_depth * 65535.0).astype(np.uint16)
729
- pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
730
- else:
731
- raise Exception("Not supported")
732
-
733
- return pil_images
734
-
735
- def postprocess(
736
- self,
737
- image: torch.FloatTensor,
738
- output_type: str = "pil",
739
- do_denormalize: Optional[List[bool]] = None,
740
- ) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
741
- """
742
- Postprocess the image output from tensor to `output_type`.
743
-
744
- Args:
745
- image (`torch.FloatTensor`):
746
- The image input, should be a pytorch tensor with shape `B x C x H x W`.
747
- output_type (`str`, *optional*, defaults to `pil`):
748
- The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
749
- do_denormalize (`List[bool]`, *optional*, defaults to `None`):
750
- Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
751
- `VaeImageProcessor` config.
752
-
753
- Returns:
754
- `PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
755
- The postprocessed image.
756
- """
757
- if not isinstance(image, torch.Tensor):
758
- raise ValueError(
759
- f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
760
- )
761
- if output_type not in ["latent", "pt", "np", "pil"]:
762
- deprecation_message = (
763
- f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
764
- "`pil`, `np`, `pt`, `latent`"
765
- )
766
- deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
767
- output_type = "np"
768
-
769
- if do_denormalize is None:
770
- do_denormalize = [self.config.do_normalize] * image.shape[0]
771
-
772
- image = torch.stack(
773
- [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
774
- )
775
-
776
- image = self.pt_to_numpy(image)
777
-
778
- if output_type == "np":
779
- if image.shape[-1] == 6:
780
- image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
781
- else:
782
- image_depth = image[:, :, :, 3:]
783
- return image[:, :, :, :3], image_depth
784
-
785
- if output_type == "pil":
786
- return self.numpy_to_pil(image), self.numpy_to_depth(image)
787
- else:
788
- raise Exception(f"This type {output_type} is not supported")
789
-
790
- def preprocess(
791
- self,
792
- rgb: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
793
- depth: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
794
- height: Optional[int] = None,
795
- width: Optional[int] = None,
796
- target_res: Optional[int] = None,
797
- ) -> torch.Tensor:
798
- """
799
- Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
800
- """
801
- supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
802
-
803
- # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
804
- if self.config.do_convert_grayscale and isinstance(rgb, (torch.Tensor, np.ndarray)) and rgb.ndim == 3:
805
- raise Exception("This is not yet supported")
806
-
807
- if isinstance(rgb, supported_formats):
808
- rgb = [rgb]
809
- depth = [depth]
810
- elif not (isinstance(rgb, list) and all(isinstance(i, supported_formats) for i in rgb)):
811
- raise ValueError(
812
- f"Input is in incorrect format: {[type(i) for i in rgb]}. Currently, we only support {', '.join(supported_formats)}"
813
- )
814
-
815
- if isinstance(rgb[0], PIL.Image.Image):
816
- if self.config.do_convert_rgb:
817
- raise Exception("This is not yet supported")
818
- # rgb = [self.convert_to_rgb(i) for i in rgb]
819
- # depth = [self.convert_to_depth(i) for i in depth] #TODO define convert_to_depth
820
- if self.config.do_resize or target_res:
821
- height, width = self.get_default_height_width(rgb[0], height, width) if not target_res else target_res
822
- rgb = [self.resize(i, height, width) for i in rgb]
823
- depth = [self.resize(i, height, width) for i in depth]
824
- rgb = self.pil_to_numpy(rgb) # to np
825
- rgb = self.numpy_to_pt(rgb) # to pt
826
-
827
- depth = self.depth_pil_to_numpy(depth) # to np
828
- depth = self.numpy_to_pt(depth) # to pt
829
-
830
- elif isinstance(rgb[0], np.ndarray):
831
- rgb = np.concatenate(rgb, axis=0) if rgb[0].ndim == 4 else np.stack(rgb, axis=0)
832
- rgb = self.numpy_to_pt(rgb)
833
- height, width = self.get_default_height_width(rgb, height, width)
834
- if self.config.do_resize:
835
- rgb = self.resize(rgb, height, width)
836
-
837
- depth = np.concatenate(depth, axis=0) if rgb[0].ndim == 4 else np.stack(depth, axis=0)
838
- depth = self.numpy_to_pt(depth)
839
- height, width = self.get_default_height_width(depth, height, width)
840
- if self.config.do_resize:
841
- depth = self.resize(depth, height, width)
842
-
843
- elif isinstance(rgb[0], torch.Tensor):
844
- raise Exception("This is not yet supported")
845
- # rgb = torch.cat(rgb, axis=0) if rgb[0].ndim == 4 else torch.stack(rgb, axis=0)
846
-
847
- # if self.config.do_convert_grayscale and rgb.ndim == 3:
848
- # rgb = rgb.unsqueeze(1)
849
-
850
- # channel = rgb.shape[1]
851
-
852
- # height, width = self.get_default_height_width(rgb, height, width)
853
- # if self.config.do_resize:
854
- # rgb = self.resize(rgb, height, width)
855
-
856
- # depth = torch.cat(depth, axis=0) if depth[0].ndim == 4 else torch.stack(depth, axis=0)
857
-
858
- # if self.config.do_convert_grayscale and depth.ndim == 3:
859
- # depth = depth.unsqueeze(1)
860
-
861
- # channel = depth.shape[1]
862
- # # don't need any preprocess if the image is latents
863
- # if depth == 4:
864
- # return rgb, depth
865
-
866
- # height, width = self.get_default_height_width(depth, height, width)
867
- # if self.config.do_resize:
868
- # depth = self.resize(depth, height, width)
869
- # expected range [0,1], normalize to [-1,1]
870
- do_normalize = self.config.do_normalize
871
- if rgb.min() < 0 and do_normalize:
872
- warnings.warn(
873
- "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
874
- f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{rgb.min()},{rgb.max()}]",
875
- FutureWarning,
876
- )
877
- do_normalize = False
878
-
879
- if do_normalize:
880
- rgb = self.normalize(rgb)
881
- depth = self.normalize(depth)
882
-
883
- if self.config.do_binarize:
884
- rgb = self.binarize(rgb)
885
- depth = self.binarize(depth)
886
-
887
- return rgb, depth
888
-
889
-
890
- class IPAdapterMaskProcessor(VaeImageProcessor):
891
- """
892
- Image processor for IP Adapter image masks.
893
-
894
- Args:
895
- do_resize (`bool`, *optional*, defaults to `True`):
896
- Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
897
- vae_scale_factor (`int`, *optional*, defaults to `8`):
898
- VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
899
- resample (`str`, *optional*, defaults to `lanczos`):
900
- Resampling filter to use when resizing the image.
901
- do_normalize (`bool`, *optional*, defaults to `False`):
902
- Whether to normalize the image to [-1,1].
903
- do_binarize (`bool`, *optional*, defaults to `True`):
904
- Whether to binarize the image to 0/1.
905
- do_convert_grayscale (`bool`, *optional*, defaults to be `True`):
906
- Whether to convert the images to grayscale format.
907
-
908
- """
909
-
910
- config_name = CONFIG_NAME
911
-
912
- @register_to_config
913
- def __init__(
914
- self,
915
- do_resize: bool = True,
916
- vae_scale_factor: int = 8,
917
- resample: str = "lanczos",
918
- do_normalize: bool = False,
919
- do_binarize: bool = True,
920
- do_convert_grayscale: bool = True,
921
- ):
922
- super().__init__(
923
- do_resize=do_resize,
924
- vae_scale_factor=vae_scale_factor,
925
- resample=resample,
926
- do_normalize=do_normalize,
927
- do_binarize=do_binarize,
928
- do_convert_grayscale=do_convert_grayscale,
929
- )
930
-
931
- @staticmethod
932
- def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value_embed_dim: int):
933
- """
934
- Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention.
935
- If the aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.
936
-
937
- Args:
938
- mask (`torch.FloatTensor`):
939
- The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`.
940
- batch_size (`int`):
941
- The batch size.
942
- num_queries (`int`):
943
- The number of queries.
944
- value_embed_dim (`int`):
945
- The dimensionality of the value embeddings.
946
-
947
- Returns:
948
- `torch.FloatTensor`:
949
- The downsampled mask tensor.
950
-
951
- """
952
- o_h = mask.shape[1]
953
- o_w = mask.shape[2]
954
- ratio = o_w / o_h
955
- mask_h = int(math.sqrt(num_queries / ratio))
956
- mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0)
957
- mask_w = num_queries // mask_h
958
-
959
- mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0)
960
-
961
- # Repeat batch_size times
962
- if mask_downsample.shape[0] < batch_size:
963
- mask_downsample = mask_downsample.repeat(batch_size, 1, 1)
964
-
965
- mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1)
966
-
967
- downsampled_area = mask_h * mask_w
968
- # If the output image and the mask do not have the same aspect ratio, tensor shapes will not match
969
- # Pad tensor if downsampled_mask.shape[1] is smaller than num_queries
970
- if downsampled_area < num_queries:
971
- warnings.warn(
972
- "The aspect ratio of the mask does not match the aspect ratio of the output image. "
973
- "Please update your masks or adjust the output size for optimal performance.",
974
- UserWarning,
975
- )
976
- mask_downsample = F.pad(mask_downsample, (0, num_queries - mask_downsample.shape[1]), value=0.0)
977
- # Discard last embeddings if downsampled_mask.shape[1] is bigger than num_queries
978
- if downsampled_area > num_queries:
979
- warnings.warn(
980
- "The aspect ratio of the mask does not match the aspect ratio of the output image. "
981
- "Please update your masks or adjust the output size for optimal performance.",
982
- UserWarning,
983
- )
984
- mask_downsample = mask_downsample[:, :num_queries]
985
-
986
- # Repeat last dimension to match SDPA output shape
987
- mask_downsample = mask_downsample.view(mask_downsample.shape[0], mask_downsample.shape[1], 1).repeat(
988
- 1, 1, value_embed_dim
989
- )
990
-
991
- return mask_downsample