OrderAndChaos's picture
Update handler.py
76407c5
raw
history blame
2.83 kB
import torch
import numpy as np
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline, UniPCMultistepScheduler
from PIL import Image
import base64
from io import BytesIO
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type != 'cuda':
raise ValueError("need to run on GPU")
class EndpointHandler:
def __init__(self, path="lllyasviel/control_v11p_sd15_inpaint"):
self.controlnet = ControlNetModel.from_pretrained(path, torch_dtype=torch.float32).to(device)
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
controlnet=self.controlnet,
torch_dtype=torch.float32
).to(device)
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
self.generator = torch.Generator(device=device)
def __call__(self, data):
# Decode the images from base64
original_image = decode_image(data["image"])
mask_image = decode_image(data["mask_image"])
num_inference_steps = data.pop("num_inference_steps", 30)
guidance_scale = data.pop("guidance_scale", 7.5)
negative_prompt = data.pop("negative_prompt", None)
controlnet_conditioning_scale = data.pop("controlnet_conditioning_scale", 1.0)
height = data.pop("height", None)
width = data.pop("width", None)
# Create inpainting condition
control_image = self.make_inpaint_condition(original_image, mask_image)
# Inpaint the image
output_image = self.pipe(
data["inputs"],
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
generator=self.generator,
image=control_image,
height=height,
width=width,
controlnet_conditioning_scale=controlnet_conditioning_scale,
).images[0]
return output_image
def make_inpaint_condition(self, image, mask):
image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
mask = np.array(mask.convert("L"))
assert image.shape[0:1] == mask.shape[0:1], "image and image_mask must have the same image size"
image[mask < 128] = -1.0 # Set as masked pixel
image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(device)
return image
def decode_image(encoded_image):
image_bytes = base64.b64decode(encoded_image)
image = Image.open(BytesIO(image_bytes))
return image
def save_image_to_bytes(image):
output_bytes = BytesIO()
image.save(output_bytes, format="PNG")
output_bytes.seek(0)
return output_bytes.getvalue()