Spaces:
Running
Running
Delete image_processor.py
Browse files- image_processor.py +0 -991
image_processor.py
DELETED
@@ -1,991 +0,0 @@
|
|
1 |
-
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
import math
|
16 |
-
import warnings
|
17 |
-
from typing import List, Optional, Tuple, Union
|
18 |
-
|
19 |
-
import numpy as np
|
20 |
-
import PIL.Image
|
21 |
-
import torch
|
22 |
-
import torch.nn.functional as F
|
23 |
-
from PIL import Image, ImageFilter, ImageOps
|
24 |
-
|
25 |
-
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
26 |
-
from diffusers.utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
|
27 |
-
# from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
|
28 |
-
|
29 |
-
|
30 |
-
PipelineImageInput = Union[
|
31 |
-
PIL.Image.Image,
|
32 |
-
np.ndarray,
|
33 |
-
torch.FloatTensor,
|
34 |
-
List[PIL.Image.Image],
|
35 |
-
List[np.ndarray],
|
36 |
-
List[torch.FloatTensor],
|
37 |
-
]
|
38 |
-
|
39 |
-
PipelineDepthInput = PipelineImageInput
|
40 |
-
|
41 |
-
|
42 |
-
class VaeImageProcessor(ConfigMixin):
|
43 |
-
"""
|
44 |
-
Image processor for VAE.
|
45 |
-
|
46 |
-
Args:
|
47 |
-
do_resize (`bool`, *optional*, defaults to `True`):
|
48 |
-
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
|
49 |
-
`height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
|
50 |
-
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
51 |
-
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
52 |
-
resample (`str`, *optional*, defaults to `lanczos`):
|
53 |
-
Resampling filter to use when resizing the image.
|
54 |
-
do_normalize (`bool`, *optional*, defaults to `True`):
|
55 |
-
Whether to normalize the image to [-1,1].
|
56 |
-
do_binarize (`bool`, *optional*, defaults to `False`):
|
57 |
-
Whether to binarize the image to 0/1.
|
58 |
-
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
|
59 |
-
Whether to convert the images to RGB format.
|
60 |
-
do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
|
61 |
-
Whether to convert the images to grayscale format.
|
62 |
-
"""
|
63 |
-
|
64 |
-
config_name = CONFIG_NAME
|
65 |
-
|
66 |
-
@register_to_config
|
67 |
-
def __init__(
|
68 |
-
self,
|
69 |
-
do_resize: bool = True,
|
70 |
-
vae_scale_factor: int = 8,
|
71 |
-
resample: str = "lanczos",
|
72 |
-
do_normalize: bool = True,
|
73 |
-
do_binarize: bool = False,
|
74 |
-
do_convert_rgb: bool = False,
|
75 |
-
do_convert_grayscale: bool = False,
|
76 |
-
):
|
77 |
-
super().__init__()
|
78 |
-
if do_convert_rgb and do_convert_grayscale:
|
79 |
-
raise ValueError(
|
80 |
-
"`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
|
81 |
-
" if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
|
82 |
-
" if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
|
83 |
-
)
|
84 |
-
self.config.do_convert_rgb = False
|
85 |
-
|
86 |
-
@staticmethod
|
87 |
-
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
|
88 |
-
"""
|
89 |
-
Convert a numpy image or a batch of images to a PIL image.
|
90 |
-
"""
|
91 |
-
if images.ndim == 3:
|
92 |
-
images = images[None, ...]
|
93 |
-
images = (images * 255).round().astype("uint8")
|
94 |
-
if images.shape[-1] == 1:
|
95 |
-
# special case for grayscale (single channel) images
|
96 |
-
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
97 |
-
else:
|
98 |
-
pil_images = [Image.fromarray(image) for image in images]
|
99 |
-
|
100 |
-
return pil_images
|
101 |
-
|
102 |
-
@staticmethod
|
103 |
-
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
|
104 |
-
"""
|
105 |
-
Convert a PIL image or a list of PIL images to NumPy arrays.
|
106 |
-
"""
|
107 |
-
if not isinstance(images, list):
|
108 |
-
images = [images]
|
109 |
-
images = [np.array(image).astype(np.float32) / 255.0 for image in images]
|
110 |
-
images = np.stack(images, axis=0)
|
111 |
-
|
112 |
-
return images
|
113 |
-
|
114 |
-
@staticmethod
|
115 |
-
def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
|
116 |
-
"""
|
117 |
-
Convert a NumPy image to a PyTorch tensor.
|
118 |
-
"""
|
119 |
-
if images.ndim == 3:
|
120 |
-
images = images[..., None]
|
121 |
-
|
122 |
-
images = torch.from_numpy(images.transpose(0, 3, 1, 2))
|
123 |
-
return images
|
124 |
-
|
125 |
-
@staticmethod
|
126 |
-
def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
|
127 |
-
"""
|
128 |
-
Convert a PyTorch tensor to a NumPy image.
|
129 |
-
"""
|
130 |
-
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
131 |
-
return images
|
132 |
-
|
133 |
-
@staticmethod
|
134 |
-
def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
135 |
-
"""
|
136 |
-
Normalize an image array to [-1,1].
|
137 |
-
"""
|
138 |
-
return 2.0 * images - 1.0
|
139 |
-
|
140 |
-
@staticmethod
|
141 |
-
def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
142 |
-
"""
|
143 |
-
Denormalize an image array to [0,1].
|
144 |
-
"""
|
145 |
-
return (images / 2 + 0.5).clamp(0, 1)
|
146 |
-
|
147 |
-
@staticmethod
|
148 |
-
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
|
149 |
-
"""
|
150 |
-
Converts a PIL image to RGB format.
|
151 |
-
"""
|
152 |
-
image = image.convert("RGB")
|
153 |
-
|
154 |
-
return image
|
155 |
-
|
156 |
-
@staticmethod
|
157 |
-
def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
|
158 |
-
"""
|
159 |
-
Converts a PIL image to grayscale format.
|
160 |
-
"""
|
161 |
-
image = image.convert("L")
|
162 |
-
|
163 |
-
return image
|
164 |
-
|
165 |
-
@staticmethod
|
166 |
-
def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image:
|
167 |
-
"""
|
168 |
-
Applies Gaussian blur to an image.
|
169 |
-
"""
|
170 |
-
image = image.filter(ImageFilter.GaussianBlur(blur_factor))
|
171 |
-
|
172 |
-
return image
|
173 |
-
|
174 |
-
@staticmethod
|
175 |
-
def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
|
176 |
-
"""
|
177 |
-
Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect ratio of the original image;
|
178 |
-
for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128.
|
179 |
-
|
180 |
-
Args:
|
181 |
-
mask_image (PIL.Image.Image): Mask image.
|
182 |
-
width (int): Width of the image to be processed.
|
183 |
-
height (int): Height of the image to be processed.
|
184 |
-
pad (int, optional): Padding to be added to the crop region. Defaults to 0.
|
185 |
-
|
186 |
-
Returns:
|
187 |
-
tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and matches the original aspect ratio.
|
188 |
-
"""
|
189 |
-
|
190 |
-
mask_image = mask_image.convert("L")
|
191 |
-
mask = np.array(mask_image)
|
192 |
-
|
193 |
-
# 1. find a rectangular region that contains all masked ares in an image
|
194 |
-
h, w = mask.shape
|
195 |
-
crop_left = 0
|
196 |
-
for i in range(w):
|
197 |
-
if not (mask[:, i] == 0).all():
|
198 |
-
break
|
199 |
-
crop_left += 1
|
200 |
-
|
201 |
-
crop_right = 0
|
202 |
-
for i in reversed(range(w)):
|
203 |
-
if not (mask[:, i] == 0).all():
|
204 |
-
break
|
205 |
-
crop_right += 1
|
206 |
-
|
207 |
-
crop_top = 0
|
208 |
-
for i in range(h):
|
209 |
-
if not (mask[i] == 0).all():
|
210 |
-
break
|
211 |
-
crop_top += 1
|
212 |
-
|
213 |
-
crop_bottom = 0
|
214 |
-
for i in reversed(range(h)):
|
215 |
-
if not (mask[i] == 0).all():
|
216 |
-
break
|
217 |
-
crop_bottom += 1
|
218 |
-
|
219 |
-
# 2. add padding to the crop region
|
220 |
-
x1, y1, x2, y2 = (
|
221 |
-
int(max(crop_left - pad, 0)),
|
222 |
-
int(max(crop_top - pad, 0)),
|
223 |
-
int(min(w - crop_right + pad, w)),
|
224 |
-
int(min(h - crop_bottom + pad, h)),
|
225 |
-
)
|
226 |
-
|
227 |
-
# 3. expands crop region to match the aspect ratio of the image to be processed
|
228 |
-
ratio_crop_region = (x2 - x1) / (y2 - y1)
|
229 |
-
ratio_processing = width / height
|
230 |
-
|
231 |
-
if ratio_crop_region > ratio_processing:
|
232 |
-
desired_height = (x2 - x1) / ratio_processing
|
233 |
-
desired_height_diff = int(desired_height - (y2 - y1))
|
234 |
-
y1 -= desired_height_diff // 2
|
235 |
-
y2 += desired_height_diff - desired_height_diff // 2
|
236 |
-
if y2 >= mask_image.height:
|
237 |
-
diff = y2 - mask_image.height
|
238 |
-
y2 -= diff
|
239 |
-
y1 -= diff
|
240 |
-
if y1 < 0:
|
241 |
-
y2 -= y1
|
242 |
-
y1 -= y1
|
243 |
-
if y2 >= mask_image.height:
|
244 |
-
y2 = mask_image.height
|
245 |
-
else:
|
246 |
-
desired_width = (y2 - y1) * ratio_processing
|
247 |
-
desired_width_diff = int(desired_width - (x2 - x1))
|
248 |
-
x1 -= desired_width_diff // 2
|
249 |
-
x2 += desired_width_diff - desired_width_diff // 2
|
250 |
-
if x2 >= mask_image.width:
|
251 |
-
diff = x2 - mask_image.width
|
252 |
-
x2 -= diff
|
253 |
-
x1 -= diff
|
254 |
-
if x1 < 0:
|
255 |
-
x2 -= x1
|
256 |
-
x1 -= x1
|
257 |
-
if x2 >= mask_image.width:
|
258 |
-
x2 = mask_image.width
|
259 |
-
|
260 |
-
return x1, y1, x2, y2
|
261 |
-
|
262 |
-
def _resize_and_fill(
|
263 |
-
self,
|
264 |
-
image: PIL.Image.Image,
|
265 |
-
width: int,
|
266 |
-
height: int,
|
267 |
-
) -> PIL.Image.Image:
|
268 |
-
"""
|
269 |
-
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
|
270 |
-
|
271 |
-
Args:
|
272 |
-
image: The image to resize.
|
273 |
-
width: The width to resize the image to.
|
274 |
-
height: The height to resize the image to.
|
275 |
-
"""
|
276 |
-
|
277 |
-
ratio = width / height
|
278 |
-
src_ratio = image.width / image.height
|
279 |
-
|
280 |
-
src_w = width if ratio < src_ratio else image.width * height // image.height
|
281 |
-
src_h = height if ratio >= src_ratio else image.height * width // image.width
|
282 |
-
|
283 |
-
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
|
284 |
-
res = Image.new("RGB", (width, height))
|
285 |
-
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
286 |
-
|
287 |
-
if ratio < src_ratio:
|
288 |
-
fill_height = height // 2 - src_h // 2
|
289 |
-
if fill_height > 0:
|
290 |
-
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
|
291 |
-
res.paste(
|
292 |
-
resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
|
293 |
-
box=(0, fill_height + src_h),
|
294 |
-
)
|
295 |
-
elif ratio > src_ratio:
|
296 |
-
fill_width = width // 2 - src_w // 2
|
297 |
-
if fill_width > 0:
|
298 |
-
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
|
299 |
-
res.paste(
|
300 |
-
resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
|
301 |
-
box=(fill_width + src_w, 0),
|
302 |
-
)
|
303 |
-
|
304 |
-
return res
|
305 |
-
|
306 |
-
def _resize_and_crop(
|
307 |
-
self,
|
308 |
-
image: PIL.Image.Image,
|
309 |
-
width: int,
|
310 |
-
height: int,
|
311 |
-
) -> PIL.Image.Image:
|
312 |
-
"""
|
313 |
-
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
|
314 |
-
|
315 |
-
Args:
|
316 |
-
image: The image to resize.
|
317 |
-
width: The width to resize the image to.
|
318 |
-
height: The height to resize the image to.
|
319 |
-
"""
|
320 |
-
ratio = width / height
|
321 |
-
src_ratio = image.width / image.height
|
322 |
-
|
323 |
-
src_w = width if ratio > src_ratio else image.width * height // image.height
|
324 |
-
src_h = height if ratio <= src_ratio else image.height * width // image.width
|
325 |
-
|
326 |
-
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
|
327 |
-
res = Image.new("RGB", (width, height))
|
328 |
-
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
329 |
-
return res
|
330 |
-
|
331 |
-
def resize(
|
332 |
-
self,
|
333 |
-
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
334 |
-
height: int,
|
335 |
-
width: int,
|
336 |
-
resize_mode: str = "default", # "default", "fill", "crop"
|
337 |
-
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
|
338 |
-
"""
|
339 |
-
Resize image.
|
340 |
-
|
341 |
-
Args:
|
342 |
-
image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
|
343 |
-
The image input, can be a PIL image, numpy array or pytorch tensor.
|
344 |
-
height (`int`):
|
345 |
-
The height to resize to.
|
346 |
-
width (`int`):
|
347 |
-
The width to resize to.
|
348 |
-
resize_mode (`str`, *optional*, defaults to `default`):
|
349 |
-
The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
|
350 |
-
within the specified width and height, and it may not maintaining the original aspect ratio.
|
351 |
-
If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
|
352 |
-
within the dimensions, filling empty with data from image.
|
353 |
-
If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
|
354 |
-
within the dimensions, cropping the excess.
|
355 |
-
Note that resize_mode `fill` and `crop` are only supported for PIL image input.
|
356 |
-
|
357 |
-
Returns:
|
358 |
-
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
|
359 |
-
The resized image.
|
360 |
-
"""
|
361 |
-
if resize_mode != "default" and not isinstance(image, PIL.Image.Image):
|
362 |
-
raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}")
|
363 |
-
if isinstance(image, PIL.Image.Image):
|
364 |
-
if resize_mode == "default":
|
365 |
-
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
|
366 |
-
elif resize_mode == "fill":
|
367 |
-
image = self._resize_and_fill(image, width, height)
|
368 |
-
elif resize_mode == "crop":
|
369 |
-
image = self._resize_and_crop(image, width, height)
|
370 |
-
else:
|
371 |
-
raise ValueError(f"resize_mode {resize_mode} is not supported")
|
372 |
-
|
373 |
-
elif isinstance(image, torch.Tensor):
|
374 |
-
image = torch.nn.functional.interpolate(
|
375 |
-
image,
|
376 |
-
size=(height, width),
|
377 |
-
)
|
378 |
-
elif isinstance(image, np.ndarray):
|
379 |
-
image = self.numpy_to_pt(image)
|
380 |
-
image = torch.nn.functional.interpolate(
|
381 |
-
image,
|
382 |
-
size=(height, width),
|
383 |
-
)
|
384 |
-
image = self.pt_to_numpy(image)
|
385 |
-
return image
|
386 |
-
|
387 |
-
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
|
388 |
-
"""
|
389 |
-
Create a mask.
|
390 |
-
|
391 |
-
Args:
|
392 |
-
image (`PIL.Image.Image`):
|
393 |
-
The image input, should be a PIL image.
|
394 |
-
|
395 |
-
Returns:
|
396 |
-
`PIL.Image.Image`:
|
397 |
-
The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1.
|
398 |
-
"""
|
399 |
-
image[image < 0.5] = 0
|
400 |
-
image[image >= 0.5] = 1
|
401 |
-
|
402 |
-
return image
|
403 |
-
|
404 |
-
def get_default_height_width(
|
405 |
-
self,
|
406 |
-
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
407 |
-
height: Optional[int] = None,
|
408 |
-
width: Optional[int] = None,
|
409 |
-
) -> Tuple[int, int]:
|
410 |
-
"""
|
411 |
-
This function return the height and width that are downscaled to the next integer multiple of
|
412 |
-
`vae_scale_factor`.
|
413 |
-
|
414 |
-
Args:
|
415 |
-
image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
|
416 |
-
The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have
|
417 |
-
shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should
|
418 |
-
have shape `[batch, channel, height, width]`.
|
419 |
-
height (`int`, *optional*, defaults to `None`):
|
420 |
-
The height in preprocessed image. If `None`, will use the height of `image` input.
|
421 |
-
width (`int`, *optional*`, defaults to `None`):
|
422 |
-
The width in preprocessed. If `None`, will use the width of the `image` input.
|
423 |
-
"""
|
424 |
-
|
425 |
-
if height is None:
|
426 |
-
if isinstance(image, PIL.Image.Image):
|
427 |
-
height = image.height
|
428 |
-
elif isinstance(image, torch.Tensor):
|
429 |
-
height = image.shape[2]
|
430 |
-
else:
|
431 |
-
height = image.shape[1]
|
432 |
-
|
433 |
-
if width is None:
|
434 |
-
if isinstance(image, PIL.Image.Image):
|
435 |
-
width = image.width
|
436 |
-
elif isinstance(image, torch.Tensor):
|
437 |
-
width = image.shape[3]
|
438 |
-
else:
|
439 |
-
width = image.shape[2]
|
440 |
-
|
441 |
-
width, height = (
|
442 |
-
x - x % self.config.vae_scale_factor for x in (width, height)
|
443 |
-
) # resize to integer multiple of vae_scale_factor
|
444 |
-
|
445 |
-
return height, width
|
446 |
-
|
447 |
-
def preprocess(
|
448 |
-
self,
|
449 |
-
image: PipelineImageInput,
|
450 |
-
height: Optional[int] = None,
|
451 |
-
width: Optional[int] = None,
|
452 |
-
resize_mode: str = "default", # "default", "fill", "crop"
|
453 |
-
crops_coords: Optional[Tuple[int, int, int, int]] = None,
|
454 |
-
) -> torch.Tensor:
|
455 |
-
"""
|
456 |
-
Preprocess the image input.
|
457 |
-
|
458 |
-
Args:
|
459 |
-
image (`pipeline_image_input`):
|
460 |
-
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of supported formats.
|
461 |
-
height (`int`, *optional*, defaults to `None`):
|
462 |
-
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default height.
|
463 |
-
width (`int`, *optional*`, defaults to `None`):
|
464 |
-
The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
|
465 |
-
resize_mode (`str`, *optional*, defaults to `default`):
|
466 |
-
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit
|
467 |
-
within the specified width and height, and it may not maintaining the original aspect ratio.
|
468 |
-
If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
|
469 |
-
within the dimensions, filling empty with data from image.
|
470 |
-
If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
|
471 |
-
within the dimensions, cropping the excess.
|
472 |
-
Note that resize_mode `fill` and `crop` are only supported for PIL image input.
|
473 |
-
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
|
474 |
-
The crop coordinates for each image in the batch. If `None`, will not crop the image.
|
475 |
-
"""
|
476 |
-
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
477 |
-
|
478 |
-
# Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
|
479 |
-
if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
|
480 |
-
if isinstance(image, torch.Tensor):
|
481 |
-
# if image is a pytorch tensor could have 2 possible shapes:
|
482 |
-
# 1. batch x height x width: we should insert the channel dimension at position 1
|
483 |
-
# 2. channel x height x width: we should insert batch dimension at position 0,
|
484 |
-
# however, since both channel and batch dimension has same size 1, it is same to insert at position 1
|
485 |
-
# for simplicity, we insert a dimension of size 1 at position 1 for both cases
|
486 |
-
image = image.unsqueeze(1)
|
487 |
-
else:
|
488 |
-
# if it is a numpy array, it could have 2 possible shapes:
|
489 |
-
# 1. batch x height x width: insert channel dimension on last position
|
490 |
-
# 2. height x width x channel: insert batch dimension on first position
|
491 |
-
if image.shape[-1] == 1:
|
492 |
-
image = np.expand_dims(image, axis=0)
|
493 |
-
else:
|
494 |
-
image = np.expand_dims(image, axis=-1)
|
495 |
-
|
496 |
-
if isinstance(image, supported_formats):
|
497 |
-
image = [image]
|
498 |
-
elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
|
499 |
-
raise ValueError(
|
500 |
-
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
|
501 |
-
)
|
502 |
-
|
503 |
-
if isinstance(image[0], PIL.Image.Image):
|
504 |
-
if crops_coords is not None:
|
505 |
-
image = [i.crop(crops_coords) for i in image]
|
506 |
-
if self.config.do_resize:
|
507 |
-
height, width = self.get_default_height_width(image[0], height, width)
|
508 |
-
image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image]
|
509 |
-
if self.config.do_convert_rgb:
|
510 |
-
image = [self.convert_to_rgb(i) for i in image]
|
511 |
-
elif self.config.do_convert_grayscale:
|
512 |
-
image = [self.convert_to_grayscale(i) for i in image]
|
513 |
-
image = self.pil_to_numpy(image) # to np
|
514 |
-
image = self.numpy_to_pt(image) # to pt
|
515 |
-
|
516 |
-
elif isinstance(image[0], np.ndarray):
|
517 |
-
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
|
518 |
-
|
519 |
-
image = self.numpy_to_pt(image)
|
520 |
-
|
521 |
-
height, width = self.get_default_height_width(image, height, width)
|
522 |
-
if self.config.do_resize:
|
523 |
-
image = self.resize(image, height, width)
|
524 |
-
|
525 |
-
elif isinstance(image[0], torch.Tensor):
|
526 |
-
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
|
527 |
-
|
528 |
-
if self.config.do_convert_grayscale and image.ndim == 3:
|
529 |
-
image = image.unsqueeze(1)
|
530 |
-
|
531 |
-
channel = image.shape[1]
|
532 |
-
# don't need any preprocess if the image is latents
|
533 |
-
if channel >= 4:
|
534 |
-
return image
|
535 |
-
|
536 |
-
height, width = self.get_default_height_width(image, height, width)
|
537 |
-
if self.config.do_resize:
|
538 |
-
image = self.resize(image, height, width)
|
539 |
-
|
540 |
-
# expected range [0,1], normalize to [-1,1]
|
541 |
-
do_normalize = self.config.do_normalize
|
542 |
-
if do_normalize and image.min() < 0:
|
543 |
-
warnings.warn(
|
544 |
-
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
|
545 |
-
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
|
546 |
-
FutureWarning,
|
547 |
-
)
|
548 |
-
do_normalize = False
|
549 |
-
|
550 |
-
if do_normalize:
|
551 |
-
image = self.normalize(image)
|
552 |
-
|
553 |
-
if self.config.do_binarize:
|
554 |
-
image = self.binarize(image)
|
555 |
-
|
556 |
-
return image
|
557 |
-
|
558 |
-
def postprocess(
|
559 |
-
self,
|
560 |
-
image: torch.FloatTensor,
|
561 |
-
output_type: str = "pil",
|
562 |
-
do_denormalize: Optional[List[bool]] = None,
|
563 |
-
) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
|
564 |
-
"""
|
565 |
-
Postprocess the image output from tensor to `output_type`.
|
566 |
-
|
567 |
-
Args:
|
568 |
-
image (`torch.FloatTensor`):
|
569 |
-
The image input, should be a pytorch tensor with shape `B x C x H x W`.
|
570 |
-
output_type (`str`, *optional*, defaults to `pil`):
|
571 |
-
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
|
572 |
-
do_denormalize (`List[bool]`, *optional*, defaults to `None`):
|
573 |
-
Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
|
574 |
-
`VaeImageProcessor` config.
|
575 |
-
|
576 |
-
Returns:
|
577 |
-
`PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
|
578 |
-
The postprocessed image.
|
579 |
-
"""
|
580 |
-
if not isinstance(image, torch.Tensor):
|
581 |
-
raise ValueError(
|
582 |
-
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
583 |
-
)
|
584 |
-
if output_type not in ["latent", "pt", "np", "pil"]:
|
585 |
-
deprecation_message = (
|
586 |
-
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
|
587 |
-
"`pil`, `np`, `pt`, `latent`"
|
588 |
-
)
|
589 |
-
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
|
590 |
-
output_type = "np"
|
591 |
-
|
592 |
-
if output_type == "latent":
|
593 |
-
return image
|
594 |
-
|
595 |
-
if do_denormalize is None:
|
596 |
-
do_denormalize = [self.config.do_normalize] * image.shape[0]
|
597 |
-
|
598 |
-
image = torch.stack(
|
599 |
-
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
|
600 |
-
)
|
601 |
-
|
602 |
-
if output_type == "pt":
|
603 |
-
return image
|
604 |
-
|
605 |
-
image = self.pt_to_numpy(image)
|
606 |
-
|
607 |
-
if output_type == "np":
|
608 |
-
return image
|
609 |
-
|
610 |
-
if output_type == "pil":
|
611 |
-
return self.numpy_to_pil(image)
|
612 |
-
|
613 |
-
def apply_overlay(
|
614 |
-
self,
|
615 |
-
mask: PIL.Image.Image,
|
616 |
-
init_image: PIL.Image.Image,
|
617 |
-
image: PIL.Image.Image,
|
618 |
-
crop_coords: Optional[Tuple[int, int, int, int]] = None,
|
619 |
-
) -> PIL.Image.Image:
|
620 |
-
"""
|
621 |
-
overlay the inpaint output to the original image
|
622 |
-
"""
|
623 |
-
|
624 |
-
width, height = image.width, image.height
|
625 |
-
|
626 |
-
init_image = self.resize(init_image, width=width, height=height)
|
627 |
-
mask = self.resize(mask, width=width, height=height)
|
628 |
-
|
629 |
-
init_image_masked = PIL.Image.new("RGBa", (width, height))
|
630 |
-
init_image_masked.paste(init_image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L")))
|
631 |
-
init_image_masked = init_image_masked.convert("RGBA")
|
632 |
-
|
633 |
-
if crop_coords is not None:
|
634 |
-
x, y, x2, y2 = crop_coords
|
635 |
-
w = x2 - x
|
636 |
-
h = y2 - y
|
637 |
-
base_image = PIL.Image.new("RGBA", (width, height))
|
638 |
-
image = self.resize(image, height=h, width=w, resize_mode="crop")
|
639 |
-
base_image.paste(image, (x, y))
|
640 |
-
image = base_image.convert("RGB")
|
641 |
-
|
642 |
-
image = image.convert("RGBA")
|
643 |
-
image.alpha_composite(init_image_masked)
|
644 |
-
image = image.convert("RGB")
|
645 |
-
|
646 |
-
return image
|
647 |
-
|
648 |
-
|
649 |
-
class VaeImageProcessorLDM3D(VaeImageProcessor):
|
650 |
-
"""
|
651 |
-
Image processor for VAE LDM3D.
|
652 |
-
|
653 |
-
Args:
|
654 |
-
do_resize (`bool`, *optional*, defaults to `True`):
|
655 |
-
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
|
656 |
-
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
657 |
-
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
658 |
-
resample (`str`, *optional*, defaults to `lanczos`):
|
659 |
-
Resampling filter to use when resizing the image.
|
660 |
-
do_normalize (`bool`, *optional*, defaults to `True`):
|
661 |
-
Whether to normalize the image to [-1,1].
|
662 |
-
"""
|
663 |
-
|
664 |
-
config_name = CONFIG_NAME
|
665 |
-
|
666 |
-
@register_to_config
|
667 |
-
def __init__(
|
668 |
-
self,
|
669 |
-
do_resize: bool = True,
|
670 |
-
vae_scale_factor: int = 8,
|
671 |
-
resample: str = "lanczos",
|
672 |
-
do_normalize: bool = True,
|
673 |
-
):
|
674 |
-
super().__init__()
|
675 |
-
|
676 |
-
@staticmethod
|
677 |
-
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
|
678 |
-
"""
|
679 |
-
Convert a NumPy image or a batch of images to a PIL image.
|
680 |
-
"""
|
681 |
-
if images.ndim == 3:
|
682 |
-
images = images[None, ...]
|
683 |
-
images = (images * 255).round().astype("uint8")
|
684 |
-
if images.shape[-1] == 1:
|
685 |
-
# special case for grayscale (single channel) images
|
686 |
-
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
687 |
-
else:
|
688 |
-
pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
|
689 |
-
|
690 |
-
return pil_images
|
691 |
-
|
692 |
-
@staticmethod
|
693 |
-
def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
|
694 |
-
"""
|
695 |
-
Convert a PIL image or a list of PIL images to NumPy arrays.
|
696 |
-
"""
|
697 |
-
if not isinstance(images, list):
|
698 |
-
images = [images]
|
699 |
-
|
700 |
-
images = [np.array(image).astype(np.float32) / (2**16 - 1) for image in images]
|
701 |
-
images = np.stack(images, axis=0)
|
702 |
-
return images
|
703 |
-
|
704 |
-
@staticmethod
|
705 |
-
def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
706 |
-
"""
|
707 |
-
Args:
|
708 |
-
image: RGB-like depth image
|
709 |
-
|
710 |
-
Returns: depth map
|
711 |
-
|
712 |
-
"""
|
713 |
-
return image[:, :, 1] * 2**8 + image[:, :, 2]
|
714 |
-
|
715 |
-
def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
|
716 |
-
"""
|
717 |
-
Convert a NumPy depth image or a batch of images to a PIL image.
|
718 |
-
"""
|
719 |
-
if images.ndim == 3:
|
720 |
-
images = images[None, ...]
|
721 |
-
images_depth = images[:, :, :, 3:]
|
722 |
-
if images.shape[-1] == 6:
|
723 |
-
images_depth = (images_depth * 255).round().astype("uint8")
|
724 |
-
pil_images = [
|
725 |
-
Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
|
726 |
-
]
|
727 |
-
elif images.shape[-1] == 4:
|
728 |
-
images_depth = (images_depth * 65535.0).astype(np.uint16)
|
729 |
-
pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
|
730 |
-
else:
|
731 |
-
raise Exception("Not supported")
|
732 |
-
|
733 |
-
return pil_images
|
734 |
-
|
735 |
-
def postprocess(
|
736 |
-
self,
|
737 |
-
image: torch.FloatTensor,
|
738 |
-
output_type: str = "pil",
|
739 |
-
do_denormalize: Optional[List[bool]] = None,
|
740 |
-
) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
|
741 |
-
"""
|
742 |
-
Postprocess the image output from tensor to `output_type`.
|
743 |
-
|
744 |
-
Args:
|
745 |
-
image (`torch.FloatTensor`):
|
746 |
-
The image input, should be a pytorch tensor with shape `B x C x H x W`.
|
747 |
-
output_type (`str`, *optional*, defaults to `pil`):
|
748 |
-
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
|
749 |
-
do_denormalize (`List[bool]`, *optional*, defaults to `None`):
|
750 |
-
Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
|
751 |
-
`VaeImageProcessor` config.
|
752 |
-
|
753 |
-
Returns:
|
754 |
-
`PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
|
755 |
-
The postprocessed image.
|
756 |
-
"""
|
757 |
-
if not isinstance(image, torch.Tensor):
|
758 |
-
raise ValueError(
|
759 |
-
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
760 |
-
)
|
761 |
-
if output_type not in ["latent", "pt", "np", "pil"]:
|
762 |
-
deprecation_message = (
|
763 |
-
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
|
764 |
-
"`pil`, `np`, `pt`, `latent`"
|
765 |
-
)
|
766 |
-
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
|
767 |
-
output_type = "np"
|
768 |
-
|
769 |
-
if do_denormalize is None:
|
770 |
-
do_denormalize = [self.config.do_normalize] * image.shape[0]
|
771 |
-
|
772 |
-
image = torch.stack(
|
773 |
-
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
|
774 |
-
)
|
775 |
-
|
776 |
-
image = self.pt_to_numpy(image)
|
777 |
-
|
778 |
-
if output_type == "np":
|
779 |
-
if image.shape[-1] == 6:
|
780 |
-
image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
|
781 |
-
else:
|
782 |
-
image_depth = image[:, :, :, 3:]
|
783 |
-
return image[:, :, :, :3], image_depth
|
784 |
-
|
785 |
-
if output_type == "pil":
|
786 |
-
return self.numpy_to_pil(image), self.numpy_to_depth(image)
|
787 |
-
else:
|
788 |
-
raise Exception(f"This type {output_type} is not supported")
|
789 |
-
|
790 |
-
def preprocess(
|
791 |
-
self,
|
792 |
-
rgb: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
|
793 |
-
depth: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
|
794 |
-
height: Optional[int] = None,
|
795 |
-
width: Optional[int] = None,
|
796 |
-
target_res: Optional[int] = None,
|
797 |
-
) -> torch.Tensor:
|
798 |
-
"""
|
799 |
-
Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
|
800 |
-
"""
|
801 |
-
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
802 |
-
|
803 |
-
# Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
|
804 |
-
if self.config.do_convert_grayscale and isinstance(rgb, (torch.Tensor, np.ndarray)) and rgb.ndim == 3:
|
805 |
-
raise Exception("This is not yet supported")
|
806 |
-
|
807 |
-
if isinstance(rgb, supported_formats):
|
808 |
-
rgb = [rgb]
|
809 |
-
depth = [depth]
|
810 |
-
elif not (isinstance(rgb, list) and all(isinstance(i, supported_formats) for i in rgb)):
|
811 |
-
raise ValueError(
|
812 |
-
f"Input is in incorrect format: {[type(i) for i in rgb]}. Currently, we only support {', '.join(supported_formats)}"
|
813 |
-
)
|
814 |
-
|
815 |
-
if isinstance(rgb[0], PIL.Image.Image):
|
816 |
-
if self.config.do_convert_rgb:
|
817 |
-
raise Exception("This is not yet supported")
|
818 |
-
# rgb = [self.convert_to_rgb(i) for i in rgb]
|
819 |
-
# depth = [self.convert_to_depth(i) for i in depth] #TODO define convert_to_depth
|
820 |
-
if self.config.do_resize or target_res:
|
821 |
-
height, width = self.get_default_height_width(rgb[0], height, width) if not target_res else target_res
|
822 |
-
rgb = [self.resize(i, height, width) for i in rgb]
|
823 |
-
depth = [self.resize(i, height, width) for i in depth]
|
824 |
-
rgb = self.pil_to_numpy(rgb) # to np
|
825 |
-
rgb = self.numpy_to_pt(rgb) # to pt
|
826 |
-
|
827 |
-
depth = self.depth_pil_to_numpy(depth) # to np
|
828 |
-
depth = self.numpy_to_pt(depth) # to pt
|
829 |
-
|
830 |
-
elif isinstance(rgb[0], np.ndarray):
|
831 |
-
rgb = np.concatenate(rgb, axis=0) if rgb[0].ndim == 4 else np.stack(rgb, axis=0)
|
832 |
-
rgb = self.numpy_to_pt(rgb)
|
833 |
-
height, width = self.get_default_height_width(rgb, height, width)
|
834 |
-
if self.config.do_resize:
|
835 |
-
rgb = self.resize(rgb, height, width)
|
836 |
-
|
837 |
-
depth = np.concatenate(depth, axis=0) if rgb[0].ndim == 4 else np.stack(depth, axis=0)
|
838 |
-
depth = self.numpy_to_pt(depth)
|
839 |
-
height, width = self.get_default_height_width(depth, height, width)
|
840 |
-
if self.config.do_resize:
|
841 |
-
depth = self.resize(depth, height, width)
|
842 |
-
|
843 |
-
elif isinstance(rgb[0], torch.Tensor):
|
844 |
-
raise Exception("This is not yet supported")
|
845 |
-
# rgb = torch.cat(rgb, axis=0) if rgb[0].ndim == 4 else torch.stack(rgb, axis=0)
|
846 |
-
|
847 |
-
# if self.config.do_convert_grayscale and rgb.ndim == 3:
|
848 |
-
# rgb = rgb.unsqueeze(1)
|
849 |
-
|
850 |
-
# channel = rgb.shape[1]
|
851 |
-
|
852 |
-
# height, width = self.get_default_height_width(rgb, height, width)
|
853 |
-
# if self.config.do_resize:
|
854 |
-
# rgb = self.resize(rgb, height, width)
|
855 |
-
|
856 |
-
# depth = torch.cat(depth, axis=0) if depth[0].ndim == 4 else torch.stack(depth, axis=0)
|
857 |
-
|
858 |
-
# if self.config.do_convert_grayscale and depth.ndim == 3:
|
859 |
-
# depth = depth.unsqueeze(1)
|
860 |
-
|
861 |
-
# channel = depth.shape[1]
|
862 |
-
# # don't need any preprocess if the image is latents
|
863 |
-
# if depth == 4:
|
864 |
-
# return rgb, depth
|
865 |
-
|
866 |
-
# height, width = self.get_default_height_width(depth, height, width)
|
867 |
-
# if self.config.do_resize:
|
868 |
-
# depth = self.resize(depth, height, width)
|
869 |
-
# expected range [0,1], normalize to [-1,1]
|
870 |
-
do_normalize = self.config.do_normalize
|
871 |
-
if rgb.min() < 0 and do_normalize:
|
872 |
-
warnings.warn(
|
873 |
-
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
|
874 |
-
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{rgb.min()},{rgb.max()}]",
|
875 |
-
FutureWarning,
|
876 |
-
)
|
877 |
-
do_normalize = False
|
878 |
-
|
879 |
-
if do_normalize:
|
880 |
-
rgb = self.normalize(rgb)
|
881 |
-
depth = self.normalize(depth)
|
882 |
-
|
883 |
-
if self.config.do_binarize:
|
884 |
-
rgb = self.binarize(rgb)
|
885 |
-
depth = self.binarize(depth)
|
886 |
-
|
887 |
-
return rgb, depth
|
888 |
-
|
889 |
-
|
890 |
-
class IPAdapterMaskProcessor(VaeImageProcessor):
|
891 |
-
"""
|
892 |
-
Image processor for IP Adapter image masks.
|
893 |
-
|
894 |
-
Args:
|
895 |
-
do_resize (`bool`, *optional*, defaults to `True`):
|
896 |
-
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
|
897 |
-
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
898 |
-
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
899 |
-
resample (`str`, *optional*, defaults to `lanczos`):
|
900 |
-
Resampling filter to use when resizing the image.
|
901 |
-
do_normalize (`bool`, *optional*, defaults to `False`):
|
902 |
-
Whether to normalize the image to [-1,1].
|
903 |
-
do_binarize (`bool`, *optional*, defaults to `True`):
|
904 |
-
Whether to binarize the image to 0/1.
|
905 |
-
do_convert_grayscale (`bool`, *optional*, defaults to be `True`):
|
906 |
-
Whether to convert the images to grayscale format.
|
907 |
-
|
908 |
-
"""
|
909 |
-
|
910 |
-
config_name = CONFIG_NAME
|
911 |
-
|
912 |
-
@register_to_config
|
913 |
-
def __init__(
|
914 |
-
self,
|
915 |
-
do_resize: bool = True,
|
916 |
-
vae_scale_factor: int = 8,
|
917 |
-
resample: str = "lanczos",
|
918 |
-
do_normalize: bool = False,
|
919 |
-
do_binarize: bool = True,
|
920 |
-
do_convert_grayscale: bool = True,
|
921 |
-
):
|
922 |
-
super().__init__(
|
923 |
-
do_resize=do_resize,
|
924 |
-
vae_scale_factor=vae_scale_factor,
|
925 |
-
resample=resample,
|
926 |
-
do_normalize=do_normalize,
|
927 |
-
do_binarize=do_binarize,
|
928 |
-
do_convert_grayscale=do_convert_grayscale,
|
929 |
-
)
|
930 |
-
|
931 |
-
@staticmethod
|
932 |
-
def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value_embed_dim: int):
|
933 |
-
"""
|
934 |
-
Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention.
|
935 |
-
If the aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.
|
936 |
-
|
937 |
-
Args:
|
938 |
-
mask (`torch.FloatTensor`):
|
939 |
-
The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`.
|
940 |
-
batch_size (`int`):
|
941 |
-
The batch size.
|
942 |
-
num_queries (`int`):
|
943 |
-
The number of queries.
|
944 |
-
value_embed_dim (`int`):
|
945 |
-
The dimensionality of the value embeddings.
|
946 |
-
|
947 |
-
Returns:
|
948 |
-
`torch.FloatTensor`:
|
949 |
-
The downsampled mask tensor.
|
950 |
-
|
951 |
-
"""
|
952 |
-
o_h = mask.shape[1]
|
953 |
-
o_w = mask.shape[2]
|
954 |
-
ratio = o_w / o_h
|
955 |
-
mask_h = int(math.sqrt(num_queries / ratio))
|
956 |
-
mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0)
|
957 |
-
mask_w = num_queries // mask_h
|
958 |
-
|
959 |
-
mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0)
|
960 |
-
|
961 |
-
# Repeat batch_size times
|
962 |
-
if mask_downsample.shape[0] < batch_size:
|
963 |
-
mask_downsample = mask_downsample.repeat(batch_size, 1, 1)
|
964 |
-
|
965 |
-
mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1)
|
966 |
-
|
967 |
-
downsampled_area = mask_h * mask_w
|
968 |
-
# If the output image and the mask do not have the same aspect ratio, tensor shapes will not match
|
969 |
-
# Pad tensor if downsampled_mask.shape[1] is smaller than num_queries
|
970 |
-
if downsampled_area < num_queries:
|
971 |
-
warnings.warn(
|
972 |
-
"The aspect ratio of the mask does not match the aspect ratio of the output image. "
|
973 |
-
"Please update your masks or adjust the output size for optimal performance.",
|
974 |
-
UserWarning,
|
975 |
-
)
|
976 |
-
mask_downsample = F.pad(mask_downsample, (0, num_queries - mask_downsample.shape[1]), value=0.0)
|
977 |
-
# Discard last embeddings if downsampled_mask.shape[1] is bigger than num_queries
|
978 |
-
if downsampled_area > num_queries:
|
979 |
-
warnings.warn(
|
980 |
-
"The aspect ratio of the mask does not match the aspect ratio of the output image. "
|
981 |
-
"Please update your masks or adjust the output size for optimal performance.",
|
982 |
-
UserWarning,
|
983 |
-
)
|
984 |
-
mask_downsample = mask_downsample[:, :num_queries]
|
985 |
-
|
986 |
-
# Repeat last dimension to match SDPA output shape
|
987 |
-
mask_downsample = mask_downsample.view(mask_downsample.shape[0], mask_downsample.shape[1], 1).repeat(
|
988 |
-
1, 1, value_embed_dim
|
989 |
-
)
|
990 |
-
|
991 |
-
return mask_downsample
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|