Inpaint-Anything / utils /mask_processing.py
RysonFeng
Add source code
cdb26a4
raw
history blame contribute delete
No virus
6.4 kB
import cv2
from matplotlib import pyplot as plt
import PIL.Image as Image
import numpy as np
def crop_for_filling_pre(image: np.array, mask: np.array, crop_size: int = 512):
# Calculate the aspect ratio of the image
height, width = image.shape[:2]
aspect_ratio = float(width) / float(height)
# If the shorter side is less than 512, resize the image proportionally
if min(height, width) < crop_size:
if height < width:
new_height = crop_size
new_width = int(new_height * aspect_ratio)
else:
new_width = crop_size
new_height = int(new_width / aspect_ratio)
image = cv2.resize(image, (new_width, new_height))
mask = cv2.resize(mask, (new_width, new_height))
# Find the bounding box of the mask
x, y, w, h = cv2.boundingRect(mask)
# Update the height and width of the resized image
height, width = image.shape[:2]
# # If the 512x512 square cannot cover the entire mask, resize the image accordingly
if w > crop_size or h > crop_size:
# padding to square at first
if height < width:
padding = width - height
image = np.pad(image, ((padding // 2, padding - padding // 2), (0, 0), (0, 0)), 'constant')
mask = np.pad(mask, ((padding // 2, padding - padding // 2), (0, 0)), 'constant')
else:
padding = height - width
image = np.pad(image, ((0, 0), (padding // 2, padding - padding // 2), (0, 0)), 'constant')
mask = np.pad(mask, ((0, 0), (padding // 2, padding - padding // 2)), 'constant')
resize_factor = crop_size / max(w, h)
image = cv2.resize(image, (0, 0), fx=resize_factor, fy=resize_factor)
mask = cv2.resize(mask, (0, 0), fx=resize_factor, fy=resize_factor)
x, y, w, h = cv2.boundingRect(mask)
# Calculate the crop coordinates
crop_x = min(max(x + w // 2 - crop_size // 2, 0), width - crop_size)
crop_y = min(max(y + h // 2 - crop_size // 2, 0), height - crop_size)
# Crop the image
cropped_image = image[crop_y:crop_y + crop_size, crop_x:crop_x + crop_size]
cropped_mask = mask[crop_y:crop_y + crop_size, crop_x:crop_x + crop_size]
return cropped_image, cropped_mask
def crop_for_filling_post(
image: np.array,
mask: np.array,
filled_image: np.array,
crop_size: int = 512,
):
image_copy = image.copy()
mask_copy = mask.copy()
# Calculate the aspect ratio of the image
height, width = image.shape[:2]
height_ori, width_ori = height, width
aspect_ratio = float(width) / float(height)
# If the shorter side is less than 512, resize the image proportionally
if min(height, width) < crop_size:
if height < width:
new_height = crop_size
new_width = int(new_height * aspect_ratio)
else:
new_width = crop_size
new_height = int(new_width / aspect_ratio)
image = cv2.resize(image, (new_width, new_height))
mask = cv2.resize(mask, (new_width, new_height))
# Find the bounding box of the mask
x, y, w, h = cv2.boundingRect(mask)
# Update the height and width of the resized image
height, width = image.shape[:2]
# # If the 512x512 square cannot cover the entire mask, resize the image accordingly
if w > crop_size or h > crop_size:
flag_padding = True
# padding to square at first
if height < width:
padding = width - height
image = np.pad(image, ((padding // 2, padding - padding // 2), (0, 0), (0, 0)), 'constant')
mask = np.pad(mask, ((padding // 2, padding - padding // 2), (0, 0)), 'constant')
padding_side = 'h'
else:
padding = height - width
image = np.pad(image, ((0, 0), (padding // 2, padding - padding // 2), (0, 0)), 'constant')
mask = np.pad(mask, ((0, 0), (padding // 2, padding - padding // 2)), 'constant')
padding_side = 'w'
resize_factor = crop_size / max(w, h)
image = cv2.resize(image, (0, 0), fx=resize_factor, fy=resize_factor)
mask = cv2.resize(mask, (0, 0), fx=resize_factor, fy=resize_factor)
x, y, w, h = cv2.boundingRect(mask)
else:
flag_padding = False
# Calculate the crop coordinates
crop_x = min(max(x + w // 2 - crop_size // 2, 0), width - crop_size)
crop_y = min(max(y + h // 2 - crop_size // 2, 0), height - crop_size)
# Fill the image
image[crop_y:crop_y + crop_size, crop_x:crop_x + crop_size] = filled_image
if flag_padding:
image = cv2.resize(image, (0, 0), fx=1/resize_factor, fy=1/resize_factor)
if padding_side == 'h':
image = image[padding // 2:padding // 2 + height_ori, :]
else:
image = image[:, padding // 2:padding // 2 + width_ori]
image = cv2.resize(image, (width_ori, height_ori))
image_copy[mask_copy==255] = image[mask_copy==255]
return image_copy
if __name__ == '__main__':
# image = cv2.imread('example/boat.jpg')
# mask = cv2.imread('example/boat_mask_2.png', cv2.IMREAD_GRAYSCALE)
image = cv2.imread('./example/groceries.jpg')
mask = cv2.imread('example/groceries_mask_2.png', cv2.IMREAD_GRAYSCALE)
# image = cv2.imread('example/bridge.jpg')
# mask = cv2.imread('example/bridge_mask_2.png', cv2.IMREAD_GRAYSCALE)
# image = cv2.imread('example/person_umbrella.jpg')
# mask = cv2.imread('example/person_umbrella_mask_2.png', cv2.IMREAD_GRAYSCALE)
# image = cv2.imread('example/hippopotamus.jpg')
# mask = cv2.imread('example/hippopotamus_mask_1.png', cv2.IMREAD_GRAYSCALE)
cropped_image, cropped_mask = crop_for_filling_pre(image, mask)
# ^ ------------------------------------------------------------------------------------
# ^ Please conduct inpainting or filling here on the cropped image with the cropped mask
# ^ ------------------------------------------------------------------------------------
# e.g.
# cropped_image[cropped_mask==255] = 0
cv2.imwrite('cropped_image.jpg', cropped_image)
cv2.imwrite('cropped_mask.jpg', cropped_mask)
print(cropped_image.shape)
print(cropped_mask.shape)
image = crop_for_filling_post(image, mask, cropped_image)
cv2.imwrite('filled_image.jpg', image)
print(image.shape)