ControlNet / model.py
hysts's picture
hysts HF staff
Support changing base model
253feed
raw
history blame
21.7 kB
# This file is adapted from gradio_*.py in https://github.com/lllyasviel/ControlNet/tree/f4748e3630d8141d7765e2bd9b1e348f47847707
# The original license file is LICENSE.ControlNet in this repo.
from __future__ import annotations
import pathlib
import sys
import cv2
import numpy as np
import PIL.Image
import torch
from diffusers import (ControlNetModel, DiffusionPipeline,
StableDiffusionControlNetPipeline,
UniPCMultistepScheduler)
repo_dir = pathlib.Path(__file__).parent
submodule_dir = repo_dir / 'ControlNet'
sys.path.append(submodule_dir.as_posix())
from annotator.canny import apply_canny
from annotator.hed import apply_hed, nms
from annotator.midas import apply_midas
from annotator.mlsd import apply_mlsd
from annotator.openpose import apply_openpose
from annotator.uniformer import apply_uniformer
from annotator.util import HWC3, resize_image
from share import *
CONTROLNET_MODEL_IDS = {
'canny': 'lllyasviel/sd-controlnet-canny',
'hough': 'lllyasviel/sd-controlnet-mlsd',
'hed': 'lllyasviel/sd-controlnet-hed',
'scribble': 'lllyasviel/sd-controlnet-scribble',
'pose': 'lllyasviel/sd-controlnet-openpose',
'seg': 'lllyasviel/sd-controlnet-seg',
'depth': 'lllyasviel/sd-controlnet-depth',
'normal': 'lllyasviel/sd-controlnet-normal',
}
class Model:
def __init__(self,
base_model_id: str = 'runwayml/stable-diffusion-v1-5',
task_name: str = 'canny'):
self.base_model_id = ''
self.task_name = ''
self.pipe = self.load_pipe(base_model_id, task_name)
def load_pipe(self, base_model_id: str, task_name) -> DiffusionPipeline:
if base_model_id == self.base_model_id and task_name == self.task_name:
return self.pipe
model_id = CONTROLNET_MODEL_IDS[task_name]
controlnet = ControlNetModel.from_pretrained(model_id,
torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
base_model_id,
safety_checker=None,
controlnet=controlnet,
torch_dtype=torch.float16)
pipe.scheduler = UniPCMultistepScheduler.from_config(
pipe.scheduler.config)
pipe.enable_xformers_memory_efficient_attention()
pipe.enable_model_cpu_offload()
self.base_model_id = base_model_id
self.task_name = task_name
return pipe
def set_base_model(self, base_model_id: str) -> str:
self.pipe = self.load_pipe(base_model_id, self.task_name)
return self.base_model_id
def load_controlnet_weight(self, task_name: str) -> None:
if task_name == self.task_name:
return
model_id = CONTROLNET_MODEL_IDS[task_name]
controlnet = ControlNetModel.from_pretrained(model_id,
torch_dtype=torch.float16)
from accelerate import cpu_offload_with_hook
cpu_offload_with_hook(controlnet, torch.device('cuda:0'))
self.pipe.controlnet = controlnet
self.task_name = task_name
def get_prompt(self, prompt: str, additional_prompt: str) -> str:
if not prompt:
prompt = additional_prompt
else:
prompt = f'{prompt}, {additional_prompt}'
return prompt
def run_pipe(
self,
prompt: str,
negative_prompt: str,
control_image: PIL.Image.Image,
num_images: int,
num_steps: int,
guidance_scale: float,
seed: int,
):
generator = torch.Generator().manual_seed(seed)
return self.pipe(prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images,
num_inference_steps=num_steps,
generator=generator,
image=control_image)
def process(
self,
task_name: str,
prompt: str,
additional_prompt: str,
negative_prompt: str,
control_image: PIL.Image.Image,
vis_control_image: PIL.Image.Image,
num_samples: int,
num_steps: int,
guidance_scale: float,
seed: int,
):
self.load_controlnet_weight(task_name)
results = self.run_pipe(
prompt=self.get_prompt(prompt, additional_prompt),
negative_prompt=negative_prompt,
control_image=control_image,
num_images=num_samples,
num_steps=num_steps,
guidance_scale=guidance_scale,
seed=seed,
)
return [vis_control_image] + results.images
def preprocess_canny(
self,
input_image: np.ndarray,
image_resolution: int,
low_threshold: int,
high_threshold: int,
) -> tuple[PIL.Image.Image, PIL.Image.Image]:
image = resize_image(HWC3(input_image), image_resolution)
control_image = apply_canny(image, low_threshold, high_threshold)
control_image = HWC3(control_image)
vis_control_image = 255 - control_image
return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
vis_control_image)
@torch.inference_mode()
def process_canny(
self,
input_image: np.ndarray,
prompt: str,
additional_prompt: str,
negative_prompt: str,
num_samples: int,
image_resolution: int,
num_steps: int,
guidance_scale: float,
seed: int,
low_threshold: int,
high_threshold: int,
) -> list[PIL.Image.Image]:
control_image, vis_control_image = self.preprocess_canny(
input_image=input_image,
image_resolution=image_resolution,
low_threshold=low_threshold,
high_threshold=high_threshold,
)
return self.process(
task_name='canny',
prompt=prompt,
additional_prompt=additional_prompt,
negative_prompt=negative_prompt,
control_image=control_image,
vis_control_image=vis_control_image,
num_samples=num_samples,
num_steps=num_steps,
guidance_scale=guidance_scale,
seed=seed,
)
def preprocess_hough(
self,
input_image: np.ndarray,
image_resolution: int,
detect_resolution: int,
value_threshold: float,
distance_threshold: float,
) -> tuple[PIL.Image.Image, PIL.Image.Image]:
input_image = HWC3(input_image)
control_image = apply_mlsd(
resize_image(input_image, detect_resolution), value_threshold,
distance_threshold)
control_image = HWC3(control_image)
image = resize_image(input_image, image_resolution)
H, W = image.shape[:2]
control_image = cv2.resize(control_image, (W, H),
interpolation=cv2.INTER_NEAREST)
vis_control_image = 255 - cv2.dilate(
control_image, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
vis_control_image)
@torch.inference_mode()
def process_hough(
self,
input_image: np.ndarray,
prompt: str,
additional_prompt: str,
negative_prompt: str,
num_samples: int,
image_resolution: int,
detect_resolution: int,
num_steps: int,
guidance_scale: float,
seed: int,
value_threshold: float,
distance_threshold: float,
) -> list[PIL.Image.Image]:
control_image, vis_control_image = self.preprocess_hough(
input_image=input_image,
image_resolution=image_resolution,
detect_resolution=detect_resolution,
value_threshold=value_threshold,
distance_threshold=distance_threshold,
)
return self.process(
task_name='hough',
prompt=prompt,
additional_prompt=additional_prompt,
negative_prompt=negative_prompt,
control_image=control_image,
vis_control_image=vis_control_image,
num_samples=num_samples,
num_steps=num_steps,
guidance_scale=guidance_scale,
seed=seed,
)
def preprocess_hed(
self,
input_image: np.ndarray,
image_resolution: int,
detect_resolution: int,
) -> tuple[PIL.Image.Image, PIL.Image.Image]:
input_image = HWC3(input_image)
control_image = apply_hed(resize_image(input_image, detect_resolution))
control_image = HWC3(control_image)
image = resize_image(input_image, image_resolution)
H, W = image.shape[:2]
control_image = cv2.resize(control_image, (W, H),
interpolation=cv2.INTER_LINEAR)
return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
control_image)
@torch.inference_mode()
def process_hed(
self,
input_image: np.ndarray,
prompt: str,
additional_prompt: str,
negative_prompt: str,
num_samples: int,
image_resolution: int,
detect_resolution: int,
num_steps: int,
guidance_scale: float,
seed: int,
) -> list[PIL.Image.Image]:
control_image, vis_control_image = self.preprocess_hed(
input_image=input_image,
image_resolution=image_resolution,
detect_resolution=detect_resolution,
)
return self.process(
task_name='hed',
prompt=prompt,
additional_prompt=additional_prompt,
negative_prompt=negative_prompt,
control_image=control_image,
vis_control_image=vis_control_image,
num_samples=num_samples,
num_steps=num_steps,
guidance_scale=guidance_scale,
seed=seed,
)
def preprocess_scribble(
self,
input_image: np.ndarray,
image_resolution: int,
) -> tuple[PIL.Image.Image, PIL.Image.Image]:
image = resize_image(HWC3(input_image), image_resolution)
control_image = np.zeros_like(image, dtype=np.uint8)
control_image[np.min(image, axis=2) < 127] = 255
vis_control_image = 255 - control_image
return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
vis_control_image)
@torch.inference_mode()
def process_scribble(
self,
input_image: np.ndarray,
prompt: str,
additional_prompt: str,
negative_prompt: str,
num_samples: int,
image_resolution: int,
num_steps: int,
guidance_scale: float,
seed: int,
) -> list[PIL.Image.Image]:
control_image, vis_control_image = self.preprocess_scribble(
input_image=input_image,
image_resolution=image_resolution,
)
return self.process(
task_name='scribble',
prompt=prompt,
additional_prompt=additional_prompt,
negative_prompt=negative_prompt,
control_image=control_image,
vis_control_image=vis_control_image,
num_samples=num_samples,
num_steps=num_steps,
guidance_scale=guidance_scale,
seed=seed,
)
def preprocess_scribble_interactive(
self,
input_image: np.ndarray,
image_resolution: int,
) -> tuple[PIL.Image.Image, PIL.Image.Image]:
image = resize_image(HWC3(input_image['mask'][:, :, 0]),
image_resolution)
control_image = np.zeros_like(image, dtype=np.uint8)
control_image[np.min(image, axis=2) > 127] = 255
vis_control_image = 255 - control_image
return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
vis_control_image)
@torch.inference_mode()
def process_scribble_interactive(
self,
input_image: np.ndarray,
prompt: str,
additional_prompt: str,
negative_prompt: str,
num_samples: int,
image_resolution: int,
num_steps: int,
guidance_scale: float,
seed: int,
) -> list[PIL.Image.Image]:
control_image, vis_control_image = self.preprocess_scribble_interactive(
input_image=input_image,
image_resolution=image_resolution,
)
return self.process(
task_name='scribble',
prompt=prompt,
additional_prompt=additional_prompt,
negative_prompt=negative_prompt,
control_image=control_image,
vis_control_image=vis_control_image,
num_samples=num_samples,
num_steps=num_steps,
guidance_scale=guidance_scale,
seed=seed,
)
def preprocess_fake_scribble(
self,
input_image: np.ndarray,
image_resolution: int,
detect_resolution: int,
) -> tuple[PIL.Image.Image, PIL.Image.Image]:
input_image = HWC3(input_image)
control_image = apply_hed(resize_image(input_image, detect_resolution))
control_image = HWC3(control_image)
image = resize_image(input_image, image_resolution)
H, W = image.shape[:2]
control_image = cv2.resize(control_image, (W, H),
interpolation=cv2.INTER_LINEAR)
control_image = nms(control_image, 127, 3.0)
control_image = cv2.GaussianBlur(control_image, (0, 0), 3.0)
control_image[control_image > 4] = 255
control_image[control_image < 255] = 0
vis_control_image = 255 - control_image
return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
vis_control_image)
@torch.inference_mode()
def process_fake_scribble(
self,
input_image: np.ndarray,
prompt: str,
additional_prompt: str,
negative_prompt: str,
num_samples: int,
image_resolution: int,
detect_resolution: int,
num_steps: int,
guidance_scale: float,
seed: int,
) -> list[PIL.Image.Image]:
control_image, vis_control_image = self.preprocess_fake_scribble(
input_image=input_image,
image_resolution=image_resolution,
detect_resolution=detect_resolution,
)
return self.process(
task_name='scribble',
prompt=prompt,
additional_prompt=additional_prompt,
negative_prompt=negative_prompt,
control_image=control_image,
vis_control_image=vis_control_image,
num_samples=num_samples,
num_steps=num_steps,
guidance_scale=guidance_scale,
seed=seed,
)
def preprocess_pose(
self,
input_image: np.ndarray,
image_resolution: int,
detect_resolution: int,
) -> tuple[PIL.Image.Image, PIL.Image.Image]:
input_image = HWC3(input_image)
control_image, _ = apply_openpose(
resize_image(input_image, detect_resolution))
control_image = HWC3(control_image)
image = resize_image(input_image, image_resolution)
H, W = image.shape[:2]
control_image = cv2.resize(control_image, (W, H),
interpolation=cv2.INTER_NEAREST)
return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
control_image)
@torch.inference_mode()
def process_pose(
self,
input_image: np.ndarray,
prompt: str,
additional_prompt: str,
negative_prompt: str,
num_samples: int,
image_resolution: int,
detect_resolution: int,
num_steps: int,
guidance_scale: float,
seed: int,
) -> list[PIL.Image.Image]:
control_image, vis_control_image = self.preprocess_pose(
input_image=input_image,
image_resolution=image_resolution,
detect_resolution=detect_resolution,
)
return self.process(
task_name='pose',
prompt=prompt,
additional_prompt=additional_prompt,
negative_prompt=negative_prompt,
control_image=control_image,
vis_control_image=vis_control_image,
num_samples=num_samples,
num_steps=num_steps,
guidance_scale=guidance_scale,
seed=seed,
)
def preprocess_seg(
self,
input_image: np.ndarray,
image_resolution: int,
detect_resolution: int,
) -> tuple[PIL.Image.Image, PIL.Image.Image]:
input_image = HWC3(input_image)
control_image = apply_uniformer(
resize_image(input_image, detect_resolution))
image = resize_image(input_image, image_resolution)
H, W = image.shape[:2]
control_image = cv2.resize(control_image, (W, H),
interpolation=cv2.INTER_NEAREST)
return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
control_image)
@torch.inference_mode()
def process_seg(
self,
input_image: np.ndarray,
prompt: str,
additional_prompt: str,
negative_prompt: str,
num_samples: int,
image_resolution: int,
detect_resolution: int,
num_steps: int,
guidance_scale: float,
seed: int,
) -> list[PIL.Image.Image]:
control_image, vis_control_image = self.preprocess_seg(
input_image=input_image,
image_resolution=image_resolution,
detect_resolution=detect_resolution,
)
return self.process(
task_name='seg',
prompt=prompt,
additional_prompt=additional_prompt,
negative_prompt=negative_prompt,
control_image=control_image,
vis_control_image=vis_control_image,
num_samples=num_samples,
num_steps=num_steps,
guidance_scale=guidance_scale,
seed=seed,
)
def preprocess_depth(
self,
input_image: np.ndarray,
image_resolution: int,
detect_resolution: int,
) -> tuple[PIL.Image.Image, PIL.Image.Image]:
input_image = HWC3(input_image)
control_image, _ = apply_midas(
resize_image(input_image, detect_resolution))
control_image = HWC3(control_image)
return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
control_image)
@torch.inference_mode()
def process_depth(
self,
input_image: np.ndarray,
prompt: str,
additional_prompt: str,
negative_prompt: str,
num_samples: int,
image_resolution: int,
detect_resolution: int,
num_steps: int,
guidance_scale: float,
seed: int,
) -> list[PIL.Image.Image]:
control_image, vis_control_image = self.preprocess_depth(
input_image=input_image,
image_resolution=image_resolution,
detect_resolution=detect_resolution,
)
return self.process(
task_name='depth',
prompt=prompt,
additional_prompt=additional_prompt,
negative_prompt=negative_prompt,
control_image=control_image,
vis_control_image=vis_control_image,
num_samples=num_samples,
num_steps=num_steps,
guidance_scale=guidance_scale,
seed=seed,
)
def preprocess_normal(
self,
input_image: np.ndarray,
image_resolution: int,
detect_resolution: int,
bg_threshold,
) -> tuple[PIL.Image.Image, PIL.Image.Image]:
input_image = HWC3(input_image)
_, control_image = apply_midas(resize_image(input_image,
detect_resolution),
bg_th=bg_threshold)
control_image = HWC3(control_image)
image = resize_image(input_image, image_resolution)
H, W = image.shape[:2]
control_image = cv2.resize(control_image, (W, H),
interpolation=cv2.INTER_LINEAR)
return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
control_image)
@torch.inference_mode()
def process_normal(
self,
input_image: np.ndarray,
prompt: str,
additional_prompt: str,
negative_prompt: str,
num_samples: int,
image_resolution: int,
detect_resolution: int,
num_steps: int,
guidance_scale: float,
seed: int,
bg_threshold,
) -> list[PIL.Image.Image]:
control_image, vis_control_image = self.preprocess_normal(
input_image=input_image,
image_resolution=image_resolution,
detect_resolution=detect_resolution,
bg_threshold=bg_threshold,
)
return self.process(
task_name='normal',
prompt=prompt,
additional_prompt=additional_prompt,
negative_prompt=negative_prompt,
control_image=control_image,
vis_control_image=vis_control_image,
num_samples=num_samples,
num_steps=num_steps,
guidance_scale=guidance_scale,
seed=seed,
)