zhiweili commited on
Commit
efeb13c
·
1 Parent(s): 52c565a

add croper

Browse files
Files changed (1) hide show
  1. croper.py +108 -0
croper.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ import numpy as np
3
+
4
+ from PIL import Image
5
+
6
+ class Croper:
7
+ def __init__(
8
+ self,
9
+ input_image: PIL.Image,
10
+ target_mask: np.ndarray,
11
+ mask_size: int = 256,
12
+ mask_expansion: int = 20,
13
+ ):
14
+ self.input_image = input_image
15
+ self.target_mask = target_mask
16
+ self.mask_size = mask_size
17
+ self.mask_expansion = mask_expansion
18
+
19
+ def corp_mask_image(self):
20
+ target_mask = self.target_mask
21
+ input_image = self.input_image
22
+ mask_expansion = self.mask_expansion
23
+ original_width, original_height = input_image.size
24
+ mask_indices = np.where(target_mask)
25
+ start_y = np.min(mask_indices[0])
26
+ end_y = np.max(mask_indices[0])
27
+ start_x = np.min(mask_indices[1])
28
+ end_x = np.max(mask_indices[1])
29
+ mask_height = end_y - start_y
30
+ mask_width = end_x - start_x
31
+ # choose the max side length
32
+ max_side_length = max(mask_height, mask_width)
33
+ # expand the mask area
34
+ height_diff = (max_side_length - mask_height) // 2
35
+ width_diff = (max_side_length - mask_width) // 2
36
+ start_y = start_y - mask_expansion - height_diff
37
+ if start_y < 0:
38
+ start_y = 0
39
+ end_y = end_y + mask_expansion + height_diff
40
+ if end_y > original_height:
41
+ end_y = original_height
42
+ start_x = start_x - mask_expansion - width_diff
43
+ if start_x < 0:
44
+ start_x = 0
45
+ end_x = end_x + mask_expansion + width_diff
46
+ if end_x > original_width:
47
+ end_x = original_width
48
+ expanded_height = end_y - start_y
49
+ expanded_width = end_x - start_x
50
+ expanded_max_side_length = max(expanded_height, expanded_width)
51
+ # calculate the crop area
52
+ crop_mask = target_mask[start_y:end_y, start_x:end_x]
53
+ crop_mask_start_y = (expanded_max_side_length - expanded_height) // 2
54
+ crop_mask_end_y = crop_mask_start_y + expanded_height
55
+ crop_mask_start_x = (expanded_max_side_length - expanded_width) // 2
56
+ crop_mask_end_x = crop_mask_start_x + expanded_width
57
+ # create a square mask
58
+ square_mask = np.zeros((expanded_max_side_length, expanded_max_side_length), dtype=target_mask.dtype)
59
+ square_mask[crop_mask_start_y:crop_mask_end_y, crop_mask_start_x:crop_mask_end_x] = crop_mask
60
+ square_mask_image = Image.fromarray((square_mask * 255).astype(np.uint8))
61
+
62
+ crop_image = input_image.crop((start_x, start_y, end_x, end_y))
63
+ square_image = Image.new("RGB", (expanded_max_side_length, expanded_max_side_length))
64
+ square_image.paste(crop_image, (crop_mask_start_x, crop_mask_start_y))
65
+
66
+ self.origin_start_x = start_x
67
+ self.origin_start_y = start_y
68
+ self.origin_end_x = end_x
69
+ self.origin_end_y = end_y
70
+
71
+ self.square_start_x = crop_mask_start_x
72
+ self.square_start_y = crop_mask_start_y
73
+ self.square_end_x = crop_mask_end_x
74
+ self.square_end_y = crop_mask_end_y
75
+
76
+ self.square_length = expanded_max_side_length
77
+ self.square_mask_image = square_mask_image
78
+ self.square_image = square_image
79
+ self.corp_mask = crop_mask
80
+
81
+ mask_size = self.mask_size
82
+ self.resized_square_mask_image = square_mask_image.resize((mask_size, mask_size))
83
+ self.resized_square_image = square_image.resize((mask_size, mask_size))
84
+
85
+ return self.resized_square_mask_image
86
+
87
+ def restore_result(self, generated_image):
88
+ square_length = self.square_length
89
+ generated_image = generated_image.resize((square_length, square_length))
90
+ square_mask_image = self.square_mask_image
91
+ cropped_generated_image = generated_image.crop((self.square_start_x, self.square_start_y, self.square_end_x, self.square_end_y))
92
+ cropped_square_mask_image = square_mask_image.crop((self.square_start_x, self.square_start_y, self.square_end_x, self.square_end_y))
93
+
94
+ restored_image = self.input_image.copy()
95
+ restored_image.paste(cropped_generated_image, (self.origin_start_x, self.origin_start_y), cropped_square_mask_image)
96
+
97
+ return restored_image
98
+
99
+ def restore_result_v2(self, generated_image):
100
+ square_length = self.square_length
101
+ generated_image = generated_image.resize((square_length, square_length))
102
+ cropped_generated_image = generated_image.crop((self.square_start_x, self.square_start_y, self.square_end_x, self.square_end_y))
103
+
104
+ restored_image = self.input_image.copy()
105
+ restored_image.paste(cropped_generated_image, (self.origin_start_x, self.origin_start_y))
106
+
107
+ return restored_image
108
+