|
from typing import Dict, List, Any |
|
from diffusers import DPMSolverMultistepScheduler, StableDiffusionXLPipeline |
|
from PIL import Image |
|
from io import BytesIO |
|
import base64 |
|
import torch |
|
|
|
def merge_images(original, new_image, offset, direction): |
|
if direction in ["left", "right"]: |
|
merged_image = np.zeros((original.shape[0], original.shape[1] + offset, 3), dtype=np.uint8) |
|
elif direction in ["top", "bottom"]: |
|
merged_image = np.zeros((original.shape[0] + offset, original.shape[1], 3), dtype=np.uint8) |
|
|
|
if direction == "left": |
|
merged_image[:, offset:] = original |
|
merged_image[:, : new_image.shape[1]] = new_image |
|
elif direction == "right": |
|
merged_image[:, : original.shape[1]] = original |
|
merged_image[:, original.shape[1] + offset - new_image.shape[1] : original.shape[1] + offset] = new_image |
|
elif direction == "top": |
|
merged_image[offset:, :] = original |
|
merged_image[: new_image.shape[0], :] = new_image |
|
elif direction == "bottom": |
|
merged_image[: original.shape[0], :] = original |
|
merged_image[original.shape[0] + offset - new_image.shape[0] : original.shape[0] + offset, :] = new_image |
|
|
|
return merged_image |
|
|
|
|
|
def slice_image(image): |
|
height, width, _ = image.shape |
|
slice_size = min(width // 2, height // 3) |
|
|
|
slices = [] |
|
|
|
for h in range(3): |
|
for w in range(2): |
|
left = w * slice_size |
|
upper = h * slice_size |
|
right = left + slice_size |
|
lower = upper + slice_size |
|
|
|
if w == 1 and right > width: |
|
left -= right - width |
|
right = width |
|
if h == 2 and lower > height: |
|
upper -= lower - height |
|
lower = height |
|
|
|
slice = image[upper:lower, left:right] |
|
slices.append(slice) |
|
|
|
return slices |
|
|
|
|
|
def process_image( |
|
image, |
|
fill_color=(0, 0, 0), |
|
mask_offset=50, |
|
blur_radius=500, |
|
expand_pixels=256, |
|
direction="left", |
|
inpaint_mask_color=50, |
|
max_size=1024, |
|
): |
|
height, width = image.shape[:2] |
|
|
|
new_height = height + (expand_pixels if direction in ["top", "bottom"] else 0) |
|
new_width = width + (expand_pixels if direction in ["left", "right"] else 0) |
|
|
|
if new_height > max_size: |
|
|
|
if direction == "top": |
|
image = image[:max_size, :] |
|
elif direction == "bottom": |
|
image = image[new_height - max_size :, :] |
|
new_height = max_size |
|
|
|
if new_width > max_size: |
|
|
|
if direction == "left": |
|
image = image[:, :max_size] |
|
elif direction == "right": |
|
image = image[:, new_width - max_size :] |
|
new_width = max_size |
|
|
|
height, width = image.shape[:2] |
|
|
|
new_image = np.full((new_height, new_width, 3), fill_color, dtype=np.uint8) |
|
mask = np.full_like(new_image, 255, dtype=np.uint8) |
|
inpaint_mask = np.full_like(new_image, 0, dtype=np.uint8) |
|
|
|
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) |
|
inpaint_mask = cv2.cvtColor(inpaint_mask, cv2.COLOR_BGR2GRAY) |
|
|
|
if direction == "left": |
|
new_image[:, expand_pixels:] = image[:, : max_size - expand_pixels] |
|
mask[:, : expand_pixels + mask_offset] = inpaint_mask_color |
|
inpaint_mask[:, :expand_pixels] = 255 |
|
elif direction == "right": |
|
new_image[:, :width] = image |
|
mask[:, width - mask_offset :] = inpaint_mask_color |
|
inpaint_mask[:, width:] = 255 |
|
elif direction == "top": |
|
new_image[expand_pixels:, :] = image[: max_size - expand_pixels, :] |
|
mask[: expand_pixels + mask_offset, :] = inpaint_mask_color |
|
inpaint_mask[:expand_pixels, :] = 255 |
|
elif direction == "bottom": |
|
new_image[:height, :] = image |
|
mask[height - mask_offset :, :] = inpaint_mask_color |
|
inpaint_mask[height:, :] = 255 |
|
|
|
|
|
if blur_radius % 2 == 0: |
|
blur_radius += 1 |
|
mask = cv2.GaussianBlur(mask, (blur_radius, blur_radius), 0) |
|
|
|
|
|
_, mask_np = cv2.threshold(inpaint_mask, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU) |
|
inpaint = cv2.inpaint(new_image, mask_np, 3, cv2.INPAINT_TELEA) |
|
|
|
|
|
inpaint = cv2.cvtColor(inpaint, cv2.COLOR_BGR2RGB) |
|
inpaint = torch.from_numpy(inpaint).permute(2, 0, 1).float() |
|
inpaint = inpaint / 127.5 - 1 |
|
inpaint = inpaint.unsqueeze(0).to("cuda") |
|
|
|
|
|
mask = torch.from_numpy(mask) |
|
mask = mask.unsqueeze(0).float() / 255.0 |
|
mask = mask.to("cuda") |
|
|
|
return inpaint, mask |
|
|
|
|
|
def image_resize(image, new_size=1024): |
|
height, width = image.shape[:2] |
|
|
|
aspect_ratio = width / height |
|
new_width = new_size |
|
new_height = new_size |
|
|
|
if aspect_ratio != 1: |
|
if width > height: |
|
new_height = int(new_size / aspect_ratio) |
|
else: |
|
new_width = int(new_size * aspect_ratio) |
|
|
|
image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4) |
|
|
|
return image |
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
self.pipeline = StableDiffusionXLPipeline.from_pretrained( |
|
"SG161222/RealVisXL_V4.0", |
|
torch_dtype=torch.float16, |
|
variant="fp16", |
|
custom_pipeline="pipeline_stable_diffusion_xl_differential_img2img", |
|
).to("cuda") |
|
self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, use_karras_sigmas=True) |
|
|
|
self.pipeline.load_ip_adapter( |
|
"h94/IP-Adapter", |
|
subfolder="sdxl_models", |
|
weight_name=[ |
|
"ip-adapter-plus_sdxl_vit-h.safetensors", |
|
], |
|
image_encoder_folder="models/image_encoder", |
|
) |
|
self.pipeline.set_ip_adapter_scale(0.1) |
|
|
|
def generate_image(prompt, negative_prompt, image, mask, ip_adapter_image, seed: int = None): |
|
if seed is None: |
|
seed = random.randint(0, 2**32 - 1) |
|
|
|
generator = torch.Generator(device="cpu").manual_seed(seed) |
|
|
|
image = self.pipeline( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
width=1024, |
|
height=1024, |
|
guidance_scale=4.0, |
|
num_inference_steps=25, |
|
original_image=image, |
|
image=image, |
|
strength=1.0, |
|
map=mask, |
|
generator=generator, |
|
ip_adapter_image=[ip_adapter_image], |
|
output_type="np", |
|
).images[0] |
|
|
|
image = (image * 255).astype(np.uint8) |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
return image |
|
|
|
def __call__(self, data: Dict[str, Any]): |
|
|
|
prompt = "" |
|
negative_prompt = "" |
|
|
|
inpaint_mask_color = 50 |
|
|
|
|
|
|
|
inputs = data.pop("inputs", data) |
|
|
|
|
|
original = Image.open(BytesIO(base64.b64decode(inputs['image']))) |
|
mask = Image.open(BytesIO(base64.b64decode(inputs['mask']))) |
|
original = numpy.array(original) |
|
|
|
image = image_resize(original) |
|
expand_pixels_to_square = 1024 - image.shape[1] |
|
image, mask = process_image( |
|
image, expand_pixels=expand_pixels_to_square, direction=direction, inpaint_mask_color=inpaint_mask_color |
|
) |
|
|
|
ip_adapter_image = [] |
|
for index, part in enumerate(slice_image(original)): |
|
ip_adapter_image.append(part) |
|
|
|
generated = generate_image(prompt, negative_prompt, image, mask, ip_adapter_image) |
|
final_image = generated |
|
|
|
for i in range(times_to_expand): |
|
image, mask = process_image( |
|
final_image, direction=direction, expand_pixels=expand_pixels, inpaint_mask_color=inpaint_mask_color |
|
) |
|
|
|
ip_adapter_image = [] |
|
for index, part in enumerate(slice_image(generated)): |
|
ip_adapter_image.append(part) |
|
|
|
generated = generate_image(prompt, negative_prompt, image, mask, ip_adapter_image) |
|
final_image = merge_images(final_image, generated, 256, direction) |