zhiweili
commited on
Commit
·
efeb13c
1
Parent(s):
52c565a
add croper
Browse files
croper.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PIL
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
class Croper:
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
input_image: PIL.Image,
|
10 |
+
target_mask: np.ndarray,
|
11 |
+
mask_size: int = 256,
|
12 |
+
mask_expansion: int = 20,
|
13 |
+
):
|
14 |
+
self.input_image = input_image
|
15 |
+
self.target_mask = target_mask
|
16 |
+
self.mask_size = mask_size
|
17 |
+
self.mask_expansion = mask_expansion
|
18 |
+
|
19 |
+
def corp_mask_image(self):
|
20 |
+
target_mask = self.target_mask
|
21 |
+
input_image = self.input_image
|
22 |
+
mask_expansion = self.mask_expansion
|
23 |
+
original_width, original_height = input_image.size
|
24 |
+
mask_indices = np.where(target_mask)
|
25 |
+
start_y = np.min(mask_indices[0])
|
26 |
+
end_y = np.max(mask_indices[0])
|
27 |
+
start_x = np.min(mask_indices[1])
|
28 |
+
end_x = np.max(mask_indices[1])
|
29 |
+
mask_height = end_y - start_y
|
30 |
+
mask_width = end_x - start_x
|
31 |
+
# choose the max side length
|
32 |
+
max_side_length = max(mask_height, mask_width)
|
33 |
+
# expand the mask area
|
34 |
+
height_diff = (max_side_length - mask_height) // 2
|
35 |
+
width_diff = (max_side_length - mask_width) // 2
|
36 |
+
start_y = start_y - mask_expansion - height_diff
|
37 |
+
if start_y < 0:
|
38 |
+
start_y = 0
|
39 |
+
end_y = end_y + mask_expansion + height_diff
|
40 |
+
if end_y > original_height:
|
41 |
+
end_y = original_height
|
42 |
+
start_x = start_x - mask_expansion - width_diff
|
43 |
+
if start_x < 0:
|
44 |
+
start_x = 0
|
45 |
+
end_x = end_x + mask_expansion + width_diff
|
46 |
+
if end_x > original_width:
|
47 |
+
end_x = original_width
|
48 |
+
expanded_height = end_y - start_y
|
49 |
+
expanded_width = end_x - start_x
|
50 |
+
expanded_max_side_length = max(expanded_height, expanded_width)
|
51 |
+
# calculate the crop area
|
52 |
+
crop_mask = target_mask[start_y:end_y, start_x:end_x]
|
53 |
+
crop_mask_start_y = (expanded_max_side_length - expanded_height) // 2
|
54 |
+
crop_mask_end_y = crop_mask_start_y + expanded_height
|
55 |
+
crop_mask_start_x = (expanded_max_side_length - expanded_width) // 2
|
56 |
+
crop_mask_end_x = crop_mask_start_x + expanded_width
|
57 |
+
# create a square mask
|
58 |
+
square_mask = np.zeros((expanded_max_side_length, expanded_max_side_length), dtype=target_mask.dtype)
|
59 |
+
square_mask[crop_mask_start_y:crop_mask_end_y, crop_mask_start_x:crop_mask_end_x] = crop_mask
|
60 |
+
square_mask_image = Image.fromarray((square_mask * 255).astype(np.uint8))
|
61 |
+
|
62 |
+
crop_image = input_image.crop((start_x, start_y, end_x, end_y))
|
63 |
+
square_image = Image.new("RGB", (expanded_max_side_length, expanded_max_side_length))
|
64 |
+
square_image.paste(crop_image, (crop_mask_start_x, crop_mask_start_y))
|
65 |
+
|
66 |
+
self.origin_start_x = start_x
|
67 |
+
self.origin_start_y = start_y
|
68 |
+
self.origin_end_x = end_x
|
69 |
+
self.origin_end_y = end_y
|
70 |
+
|
71 |
+
self.square_start_x = crop_mask_start_x
|
72 |
+
self.square_start_y = crop_mask_start_y
|
73 |
+
self.square_end_x = crop_mask_end_x
|
74 |
+
self.square_end_y = crop_mask_end_y
|
75 |
+
|
76 |
+
self.square_length = expanded_max_side_length
|
77 |
+
self.square_mask_image = square_mask_image
|
78 |
+
self.square_image = square_image
|
79 |
+
self.corp_mask = crop_mask
|
80 |
+
|
81 |
+
mask_size = self.mask_size
|
82 |
+
self.resized_square_mask_image = square_mask_image.resize((mask_size, mask_size))
|
83 |
+
self.resized_square_image = square_image.resize((mask_size, mask_size))
|
84 |
+
|
85 |
+
return self.resized_square_mask_image
|
86 |
+
|
87 |
+
def restore_result(self, generated_image):
|
88 |
+
square_length = self.square_length
|
89 |
+
generated_image = generated_image.resize((square_length, square_length))
|
90 |
+
square_mask_image = self.square_mask_image
|
91 |
+
cropped_generated_image = generated_image.crop((self.square_start_x, self.square_start_y, self.square_end_x, self.square_end_y))
|
92 |
+
cropped_square_mask_image = square_mask_image.crop((self.square_start_x, self.square_start_y, self.square_end_x, self.square_end_y))
|
93 |
+
|
94 |
+
restored_image = self.input_image.copy()
|
95 |
+
restored_image.paste(cropped_generated_image, (self.origin_start_x, self.origin_start_y), cropped_square_mask_image)
|
96 |
+
|
97 |
+
return restored_image
|
98 |
+
|
99 |
+
def restore_result_v2(self, generated_image):
|
100 |
+
square_length = self.square_length
|
101 |
+
generated_image = generated_image.resize((square_length, square_length))
|
102 |
+
cropped_generated_image = generated_image.crop((self.square_start_x, self.square_start_y, self.square_end_x, self.square_end_y))
|
103 |
+
|
104 |
+
restored_image = self.input_image.copy()
|
105 |
+
restored_image.paste(cropped_generated_image, (self.origin_start_x, self.origin_start_y))
|
106 |
+
|
107 |
+
return restored_image
|
108 |
+
|