sergeipetrov's picture
Create handler.py
9c73226 verified
from typing import Dict, List, Any
from diffusers import DPMSolverMultistepScheduler, StableDiffusionXLPipeline
from PIL import Image
from io import BytesIO
import base64
import torch
def merge_images(original, new_image, offset, direction):
if direction in ["left", "right"]:
merged_image = np.zeros((original.shape[0], original.shape[1] + offset, 3), dtype=np.uint8)
elif direction in ["top", "bottom"]:
merged_image = np.zeros((original.shape[0] + offset, original.shape[1], 3), dtype=np.uint8)
if direction == "left":
merged_image[:, offset:] = original
merged_image[:, : new_image.shape[1]] = new_image
elif direction == "right":
merged_image[:, : original.shape[1]] = original
merged_image[:, original.shape[1] + offset - new_image.shape[1] : original.shape[1] + offset] = new_image
elif direction == "top":
merged_image[offset:, :] = original
merged_image[: new_image.shape[0], :] = new_image
elif direction == "bottom":
merged_image[: original.shape[0], :] = original
merged_image[original.shape[0] + offset - new_image.shape[0] : original.shape[0] + offset, :] = new_image
return merged_image
def slice_image(image):
height, width, _ = image.shape
slice_size = min(width // 2, height // 3)
slices = []
for h in range(3):
for w in range(2):
left = w * slice_size
upper = h * slice_size
right = left + slice_size
lower = upper + slice_size
if w == 1 and right > width:
left -= right - width
right = width
if h == 2 and lower > height:
upper -= lower - height
lower = height
slice = image[upper:lower, left:right]
slices.append(slice)
return slices
def process_image(
image,
fill_color=(0, 0, 0),
mask_offset=50,
blur_radius=500,
expand_pixels=256,
direction="left",
inpaint_mask_color=50,
max_size=1024,
):
height, width = image.shape[:2]
new_height = height + (expand_pixels if direction in ["top", "bottom"] else 0)
new_width = width + (expand_pixels if direction in ["left", "right"] else 0)
if new_height > max_size:
# If so, crop the image from the opposite side
if direction == "top":
image = image[:max_size, :]
elif direction == "bottom":
image = image[new_height - max_size :, :]
new_height = max_size
if new_width > max_size:
# If so, crop the image from the opposite side
if direction == "left":
image = image[:, :max_size]
elif direction == "right":
image = image[:, new_width - max_size :]
new_width = max_size
height, width = image.shape[:2]
new_image = np.full((new_height, new_width, 3), fill_color, dtype=np.uint8)
mask = np.full_like(new_image, 255, dtype=np.uint8)
inpaint_mask = np.full_like(new_image, 0, dtype=np.uint8)
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
inpaint_mask = cv2.cvtColor(inpaint_mask, cv2.COLOR_BGR2GRAY)
if direction == "left":
new_image[:, expand_pixels:] = image[:, : max_size - expand_pixels]
mask[:, : expand_pixels + mask_offset] = inpaint_mask_color
inpaint_mask[:, :expand_pixels] = 255
elif direction == "right":
new_image[:, :width] = image
mask[:, width - mask_offset :] = inpaint_mask_color
inpaint_mask[:, width:] = 255
elif direction == "top":
new_image[expand_pixels:, :] = image[: max_size - expand_pixels, :]
mask[: expand_pixels + mask_offset, :] = inpaint_mask_color
inpaint_mask[:expand_pixels, :] = 255
elif direction == "bottom":
new_image[:height, :] = image
mask[height - mask_offset :, :] = inpaint_mask_color
inpaint_mask[height:, :] = 255
# mask blur
if blur_radius % 2 == 0:
blur_radius += 1
mask = cv2.GaussianBlur(mask, (blur_radius, blur_radius), 0)
# telea inpaint
_, mask_np = cv2.threshold(inpaint_mask, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
inpaint = cv2.inpaint(new_image, mask_np, 3, cv2.INPAINT_TELEA)
# convert image to tensor
inpaint = cv2.cvtColor(inpaint, cv2.COLOR_BGR2RGB)
inpaint = torch.from_numpy(inpaint).permute(2, 0, 1).float()
inpaint = inpaint / 127.5 - 1
inpaint = inpaint.unsqueeze(0).to("cuda")
# convert mask to tensor
mask = torch.from_numpy(mask)
mask = mask.unsqueeze(0).float() / 255.0
mask = mask.to("cuda")
return inpaint, mask
def image_resize(image, new_size=1024):
height, width = image.shape[:2]
aspect_ratio = width / height
new_width = new_size
new_height = new_size
if aspect_ratio != 1:
if width > height:
new_height = int(new_size / aspect_ratio)
else:
new_width = int(new_size * aspect_ratio)
image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4)
return image
class EndpointHandler():
def __init__(self, path=""):
self.pipeline = StableDiffusionXLPipeline.from_pretrained(
"SG161222/RealVisXL_V4.0",
torch_dtype=torch.float16,
variant="fp16",
custom_pipeline="pipeline_stable_diffusion_xl_differential_img2img",
).to("cuda")
self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, use_karras_sigmas=True)
self.pipeline.load_ip_adapter(
"h94/IP-Adapter",
subfolder="sdxl_models",
weight_name=[
"ip-adapter-plus_sdxl_vit-h.safetensors",
],
image_encoder_folder="models/image_encoder",
)
self.pipeline.set_ip_adapter_scale(0.1)
def generate_image(prompt, negative_prompt, image, mask, ip_adapter_image, seed: int = None):
if seed is None:
seed = random.randint(0, 2**32 - 1)
generator = torch.Generator(device="cpu").manual_seed(seed)
image = self.pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
width=1024,
height=1024,
guidance_scale=4.0,
num_inference_steps=25,
original_image=image,
image=image,
strength=1.0,
map=mask,
generator=generator,
ip_adapter_image=[ip_adapter_image],
output_type="np",
).images[0]
image = (image * 255).astype(np.uint8)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image
def __call__(self, data: Dict[str, Any]):
prompt = ""
negative_prompt = ""
# direction = "right" # left, right, top, bottom
inpaint_mask_color = 50 # lighter use more of the Telea inpainting
# expand_pixels = 256 # I recommend to don't go more than half of the picture so it has context
# times_to_expand = 4
inputs = data.pop("inputs", data)
# decode base64 image to PIL
original = Image.open(BytesIO(base64.b64decode(inputs['image'])))
mask = Image.open(BytesIO(base64.b64decode(inputs['mask'])))
original = numpy.array(original)
image = image_resize(original)
expand_pixels_to_square = 1024 - image.shape[1] # image.shape[1] for horizontal, image.shape[0] for vertical
image, mask = process_image(
image, expand_pixels=expand_pixels_to_square, direction=direction, inpaint_mask_color=inpaint_mask_color
)
ip_adapter_image = []
for index, part in enumerate(slice_image(original)):
ip_adapter_image.append(part)
generated = generate_image(prompt, negative_prompt, image, mask, ip_adapter_image)
final_image = generated
for i in range(times_to_expand):
image, mask = process_image(
final_image, direction=direction, expand_pixels=expand_pixels, inpaint_mask_color=inpaint_mask_color
)
ip_adapter_image = []
for index, part in enumerate(slice_image(generated)):
ip_adapter_image.append(part)
generated = generate_image(prompt, negative_prompt, image, mask, ip_adapter_image)
final_image = merge_images(final_image, generated, 256, direction)