sergeipetrov commited on
Commit
9c73226
1 Parent(s): 08fa294

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +240 -0
handler.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from diffusers import DPMSolverMultistepScheduler, StableDiffusionXLPipeline
3
+ from PIL import Image
4
+ from io import BytesIO
5
+ import base64
6
+ import torch
7
+
8
+ def merge_images(original, new_image, offset, direction):
9
+ if direction in ["left", "right"]:
10
+ merged_image = np.zeros((original.shape[0], original.shape[1] + offset, 3), dtype=np.uint8)
11
+ elif direction in ["top", "bottom"]:
12
+ merged_image = np.zeros((original.shape[0] + offset, original.shape[1], 3), dtype=np.uint8)
13
+
14
+ if direction == "left":
15
+ merged_image[:, offset:] = original
16
+ merged_image[:, : new_image.shape[1]] = new_image
17
+ elif direction == "right":
18
+ merged_image[:, : original.shape[1]] = original
19
+ merged_image[:, original.shape[1] + offset - new_image.shape[1] : original.shape[1] + offset] = new_image
20
+ elif direction == "top":
21
+ merged_image[offset:, :] = original
22
+ merged_image[: new_image.shape[0], :] = new_image
23
+ elif direction == "bottom":
24
+ merged_image[: original.shape[0], :] = original
25
+ merged_image[original.shape[0] + offset - new_image.shape[0] : original.shape[0] + offset, :] = new_image
26
+
27
+ return merged_image
28
+
29
+
30
+ def slice_image(image):
31
+ height, width, _ = image.shape
32
+ slice_size = min(width // 2, height // 3)
33
+
34
+ slices = []
35
+
36
+ for h in range(3):
37
+ for w in range(2):
38
+ left = w * slice_size
39
+ upper = h * slice_size
40
+ right = left + slice_size
41
+ lower = upper + slice_size
42
+
43
+ if w == 1 and right > width:
44
+ left -= right - width
45
+ right = width
46
+ if h == 2 and lower > height:
47
+ upper -= lower - height
48
+ lower = height
49
+
50
+ slice = image[upper:lower, left:right]
51
+ slices.append(slice)
52
+
53
+ return slices
54
+
55
+
56
+ def process_image(
57
+ image,
58
+ fill_color=(0, 0, 0),
59
+ mask_offset=50,
60
+ blur_radius=500,
61
+ expand_pixels=256,
62
+ direction="left",
63
+ inpaint_mask_color=50,
64
+ max_size=1024,
65
+ ):
66
+ height, width = image.shape[:2]
67
+
68
+ new_height = height + (expand_pixels if direction in ["top", "bottom"] else 0)
69
+ new_width = width + (expand_pixels if direction in ["left", "right"] else 0)
70
+
71
+ if new_height > max_size:
72
+ # If so, crop the image from the opposite side
73
+ if direction == "top":
74
+ image = image[:max_size, :]
75
+ elif direction == "bottom":
76
+ image = image[new_height - max_size :, :]
77
+ new_height = max_size
78
+
79
+ if new_width > max_size:
80
+ # If so, crop the image from the opposite side
81
+ if direction == "left":
82
+ image = image[:, :max_size]
83
+ elif direction == "right":
84
+ image = image[:, new_width - max_size :]
85
+ new_width = max_size
86
+
87
+ height, width = image.shape[:2]
88
+
89
+ new_image = np.full((new_height, new_width, 3), fill_color, dtype=np.uint8)
90
+ mask = np.full_like(new_image, 255, dtype=np.uint8)
91
+ inpaint_mask = np.full_like(new_image, 0, dtype=np.uint8)
92
+
93
+ mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
94
+ inpaint_mask = cv2.cvtColor(inpaint_mask, cv2.COLOR_BGR2GRAY)
95
+
96
+ if direction == "left":
97
+ new_image[:, expand_pixels:] = image[:, : max_size - expand_pixels]
98
+ mask[:, : expand_pixels + mask_offset] = inpaint_mask_color
99
+ inpaint_mask[:, :expand_pixels] = 255
100
+ elif direction == "right":
101
+ new_image[:, :width] = image
102
+ mask[:, width - mask_offset :] = inpaint_mask_color
103
+ inpaint_mask[:, width:] = 255
104
+ elif direction == "top":
105
+ new_image[expand_pixels:, :] = image[: max_size - expand_pixels, :]
106
+ mask[: expand_pixels + mask_offset, :] = inpaint_mask_color
107
+ inpaint_mask[:expand_pixels, :] = 255
108
+ elif direction == "bottom":
109
+ new_image[:height, :] = image
110
+ mask[height - mask_offset :, :] = inpaint_mask_color
111
+ inpaint_mask[height:, :] = 255
112
+
113
+ # mask blur
114
+ if blur_radius % 2 == 0:
115
+ blur_radius += 1
116
+ mask = cv2.GaussianBlur(mask, (blur_radius, blur_radius), 0)
117
+
118
+ # telea inpaint
119
+ _, mask_np = cv2.threshold(inpaint_mask, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
120
+ inpaint = cv2.inpaint(new_image, mask_np, 3, cv2.INPAINT_TELEA)
121
+
122
+ # convert image to tensor
123
+ inpaint = cv2.cvtColor(inpaint, cv2.COLOR_BGR2RGB)
124
+ inpaint = torch.from_numpy(inpaint).permute(2, 0, 1).float()
125
+ inpaint = inpaint / 127.5 - 1
126
+ inpaint = inpaint.unsqueeze(0).to("cuda")
127
+
128
+ # convert mask to tensor
129
+ mask = torch.from_numpy(mask)
130
+ mask = mask.unsqueeze(0).float() / 255.0
131
+ mask = mask.to("cuda")
132
+
133
+ return inpaint, mask
134
+
135
+
136
+ def image_resize(image, new_size=1024):
137
+ height, width = image.shape[:2]
138
+
139
+ aspect_ratio = width / height
140
+ new_width = new_size
141
+ new_height = new_size
142
+
143
+ if aspect_ratio != 1:
144
+ if width > height:
145
+ new_height = int(new_size / aspect_ratio)
146
+ else:
147
+ new_width = int(new_size * aspect_ratio)
148
+
149
+ image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4)
150
+
151
+ return image
152
+
153
+
154
+ class EndpointHandler():
155
+ def __init__(self, path=""):
156
+ self.pipeline = StableDiffusionXLPipeline.from_pretrained(
157
+ "SG161222/RealVisXL_V4.0",
158
+ torch_dtype=torch.float16,
159
+ variant="fp16",
160
+ custom_pipeline="pipeline_stable_diffusion_xl_differential_img2img",
161
+ ).to("cuda")
162
+ self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, use_karras_sigmas=True)
163
+
164
+ self.pipeline.load_ip_adapter(
165
+ "h94/IP-Adapter",
166
+ subfolder="sdxl_models",
167
+ weight_name=[
168
+ "ip-adapter-plus_sdxl_vit-h.safetensors",
169
+ ],
170
+ image_encoder_folder="models/image_encoder",
171
+ )
172
+ self.pipeline.set_ip_adapter_scale(0.1)
173
+
174
+ def generate_image(prompt, negative_prompt, image, mask, ip_adapter_image, seed: int = None):
175
+ if seed is None:
176
+ seed = random.randint(0, 2**32 - 1)
177
+
178
+ generator = torch.Generator(device="cpu").manual_seed(seed)
179
+
180
+ image = self.pipeline(
181
+ prompt=prompt,
182
+ negative_prompt=negative_prompt,
183
+ width=1024,
184
+ height=1024,
185
+ guidance_scale=4.0,
186
+ num_inference_steps=25,
187
+ original_image=image,
188
+ image=image,
189
+ strength=1.0,
190
+ map=mask,
191
+ generator=generator,
192
+ ip_adapter_image=[ip_adapter_image],
193
+ output_type="np",
194
+ ).images[0]
195
+
196
+ image = (image * 255).astype(np.uint8)
197
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
198
+
199
+ return image
200
+
201
+ def __call__(self, data: Dict[str, Any]):
202
+
203
+ prompt = ""
204
+ negative_prompt = ""
205
+ # direction = "right" # left, right, top, bottom
206
+ inpaint_mask_color = 50 # lighter use more of the Telea inpainting
207
+ # expand_pixels = 256 # I recommend to don't go more than half of the picture so it has context
208
+ # times_to_expand = 4
209
+
210
+ inputs = data.pop("inputs", data)
211
+
212
+ # decode base64 image to PIL
213
+ original = Image.open(BytesIO(base64.b64decode(inputs['image'])))
214
+ mask = Image.open(BytesIO(base64.b64decode(inputs['mask'])))
215
+ original = numpy.array(original)
216
+
217
+ image = image_resize(original)
218
+ expand_pixels_to_square = 1024 - image.shape[1] # image.shape[1] for horizontal, image.shape[0] for vertical
219
+ image, mask = process_image(
220
+ image, expand_pixels=expand_pixels_to_square, direction=direction, inpaint_mask_color=inpaint_mask_color
221
+ )
222
+
223
+ ip_adapter_image = []
224
+ for index, part in enumerate(slice_image(original)):
225
+ ip_adapter_image.append(part)
226
+
227
+ generated = generate_image(prompt, negative_prompt, image, mask, ip_adapter_image)
228
+ final_image = generated
229
+
230
+ for i in range(times_to_expand):
231
+ image, mask = process_image(
232
+ final_image, direction=direction, expand_pixels=expand_pixels, inpaint_mask_color=inpaint_mask_color
233
+ )
234
+
235
+ ip_adapter_image = []
236
+ for index, part in enumerate(slice_image(generated)):
237
+ ip_adapter_image.append(part)
238
+
239
+ generated = generate_image(prompt, negative_prompt, image, mask, ip_adapter_image)
240
+ final_image = merge_images(final_image, generated, 256, direction)