CM2000112 / internals /pipelines /controlnets.py
jayparmr's picture
update : inference
35575bb verified
import os
from typing import AbstractSet, List, Literal, Optional, Union
import cv2
import numpy as np
import torch
from controlnet_aux import (
HEDdetector,
LineartDetector,
OpenposeDetector,
PidiNetDetector,
)
from diffusers import (
ControlNetModel,
DiffusionPipeline,
EulerAncestralDiscreteScheduler,
StableDiffusionAdapterPipeline,
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionXLAdapterPipeline,
StableDiffusionXLControlNetImg2ImgPipeline,
StableDiffusionXLControlNetPipeline,
T2IAdapter,
UniPCMultistepScheduler,
)
from diffusers.pipelines.controlnet import MultiControlNetModel
from PIL import Image
from pydash import has
from torch.nn import Linear
from tqdm import gui
from transformers import pipeline
import internals.util.image as ImageUtil
from internals.data.result import Result
from internals.pipelines.commons import AbstractPipeline
from internals.util import get_generators
from internals.util.cache import clear_cuda_and_gc
from internals.util.commons import download_image
from internals.util.config import (
get_hf_cache_dir,
get_hf_token,
get_is_sdxl,
get_model_dir,
get_num_return_sequences,
)
CONTROLNET_TYPES = Literal[
"pose", "canny", "scribble", "linearart", "tile_upscaler", "canny_2x"
]
__CN_MODELS = {}
MAX_CN_MODELS = 3
def clear_networks():
global __CN_MODELS
__CN_MODELS = {}
def load_network_model_by_key(repo_id: str, pipeline_type: str):
global __CN_MODELS
if repo_id in __CN_MODELS:
return __CN_MODELS[repo_id]
if len(__CN_MODELS) >= MAX_CN_MODELS:
__CN_MODELS = {}
if pipeline_type == "controlnet":
model = ControlNetModel.from_pretrained(
repo_id,
torch_dtype=torch.float16,
cache_dir=get_hf_cache_dir(),
token=get_hf_token(),
).to("cuda")
elif pipeline_type == "t2i":
model = T2IAdapter.from_pretrained(
repo_id,
torch_dtype=torch.float16,
varient="fp16",
token=get_hf_token(),
).to("cuda")
else:
raise Exception("Invalid pipeline type")
__CN_MODELS[repo_id] = model
return model
class StableDiffusionNetworkModelPipelineLoader:
"""Loads the pipeline for network module, eg: controlnet or t2i.
Does not throw error in case of unsupported configurations, instead it returns None.
"""
def __new__(
cls,
is_sdxl,
is_img2img,
network_model,
pipeline_type,
base_pipe: Optional[AbstractSet] = None,
):
if base_pipe is None:
pretrained = True
kwargs = {
"pretrained_model_name_or_path": get_model_dir(),
"torch_dtype": torch.float16,
"token": get_hf_token(),
"cache_dir": get_hf_cache_dir(),
}
else:
pretrained = False
kwargs = {
**base_pipe.pipe.components, # pyright: ignore
}
if get_is_sdxl():
kwargs.pop("image_encoder", None)
kwargs.pop("feature_extractor", None)
if is_sdxl and is_img2img and pipeline_type == "controlnet":
model = (
StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained
if pretrained
else StableDiffusionXLControlNetImg2ImgPipeline
)
return model(controlnet=network_model, **kwargs).to("cuda")
if is_sdxl and pipeline_type == "controlnet":
model = (
StableDiffusionXLControlNetPipeline.from_pretrained
if pretrained
else StableDiffusionXLControlNetPipeline
)
return model(controlnet=network_model, **kwargs).to("cuda")
if is_sdxl and pipeline_type == "t2i":
model = (
StableDiffusionXLAdapterPipeline.from_pretrained
if pretrained
else StableDiffusionXLAdapterPipeline
)
return model(adapter=network_model, **kwargs).to("cuda")
if is_img2img and pipeline_type == "controlnet":
model = (
StableDiffusionControlNetImg2ImgPipeline.from_pretrained
if pretrained
else StableDiffusionControlNetImg2ImgPipeline
)
return model(controlnet=network_model, **kwargs).to("cuda")
if pipeline_type == "controlnet":
model = (
StableDiffusionControlNetPipeline.from_pretrained
if pretrained
else StableDiffusionControlNetPipeline
)
return model(controlnet=network_model, **kwargs).to("cuda")
if pipeline_type == "t2i":
model = (
StableDiffusionAdapterPipeline.from_pretrained
if pretrained
else StableDiffusionAdapterPipeline
)
return model(adapter=network_model, **kwargs).to("cuda")
print(
f"Warning: Unsupported configuration {is_sdxl=}, {is_img2img=}, {pipeline_type=}"
)
return None
class ControlNet(AbstractPipeline):
__current_task_name = ""
__loaded = False
__pipe_type = None
def init(self, pipeline: AbstractPipeline):
setattr(self, "__pipeline", pipeline)
def unload(self):
"Unloads the network module, pipelines and clears the cache."
if not self.__loaded:
return
self.__loaded = False
self.__pipe_type = None
self.__current_task_name = ""
if hasattr(self, "pipe"):
delattr(self, "pipe")
if hasattr(self, "pipe2"):
delattr(self, "pipe2")
clear_cuda_and_gc()
def load_model(self, task_name: CONTROLNET_TYPES):
"Appropriately loads the network module, pipelines and cache it for reuse."
if self.__current_task_name == task_name:
return
config = self.__model_sdxl if get_is_sdxl() else self.__model_normal
model = config[task_name]
if not model:
raise Exception(f"ControlNet is not supported for {task_name}")
while model in list(config.keys()):
task_name = model # pyright: ignore
model = config[task_name]
pipeline_type = (
self.__model_sdxl_types[task_name]
if get_is_sdxl()
else self.__model_normal_types[task_name]
)
if "," in model:
model = [m.strip() for m in model.split(",")]
model = self.__load_network_model(model, pipeline_type)
self.__load_pipeline(model, pipeline_type)
self.__current_task_name = task_name
clear_cuda_and_gc()
def __load_network_model(self, model_name, pipeline_type):
"Loads the network module, eg: ControlNet or T2I Adapters"
if type(model_name) == str:
return load_network_model_by_key(model_name, pipeline_type)
elif type(model_name) == list:
if pipeline_type == "controlnet":
cns = []
for model in model_name:
cns.append(load_network_model_by_key(model, pipeline_type))
return MultiControlNetModel(cns).to("cuda")
elif pipeline_type == "t2i":
raise Exception("Multi T2I adapters are not supported")
raise Exception("Invalid pipeline type")
def __load_pipeline(self, network_model, pipeline_type):
"Load the base pipeline(s) (if not loaded already) based on pipeline type and attaches the network module to the pipeline"
def patch_pipe(pipe):
if not pipe:
# cases where the loader may return None
return None
if get_is_sdxl():
pipe.enable_vae_tiling()
pipe.enable_vae_slicing()
pipe.enable_xformers_memory_efficient_attention()
# this scheduler produces good outputs for t2i adapters
if pipeline_type == "t2i":
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe.scheduler.config
)
else:
pipe.enable_xformers_memory_efficient_attention()
return pipe
# If the pipeline type is changed we should reload all
# the pipelines
if not self.__loaded or self.__pipe_type != pipeline_type:
# controlnet pipeline for tile upscaler or any pipeline with img2img + network support
pipe = StableDiffusionNetworkModelPipelineLoader(
is_sdxl=get_is_sdxl(),
is_img2img=True,
network_model=network_model,
pipeline_type=pipeline_type,
base_pipe=getattr(self, "__pipeline", None),
)
pipe = patch_pipe(pipe)
if pipe:
self.pipe = pipe
# controlnet pipeline for canny and pose
pipe2 = StableDiffusionNetworkModelPipelineLoader(
is_sdxl=get_is_sdxl(),
is_img2img=False,
network_model=network_model,
pipeline_type=pipeline_type,
base_pipe=getattr(self, "__pipeline", None),
)
pipe2 = patch_pipe(pipe2)
if pipe2:
self.pipe2 = pipe2
self.__loaded = True
self.__pipe_type = pipeline_type
# Set the network module in the pipeline
if pipeline_type == "controlnet":
if hasattr(self, "pipe"):
setattr(self.pipe, "controlnet", network_model)
if hasattr(self, "pipe2"):
setattr(self.pipe2, "controlnet", network_model)
elif pipeline_type == "t2i":
if hasattr(self, "pipe"):
setattr(self.pipe, "adapter", network_model)
if hasattr(self, "pipe2"):
setattr(self.pipe2, "adapter", network_model)
if hasattr(self, "pipe"):
self.pipe = self.pipe.to("cuda")
if hasattr(self, "pipe2"):
self.pipe2 = self.pipe2.to("cuda")
clear_cuda_and_gc()
def process(self, **kwargs):
if self.__current_task_name == "pose":
return self.process_pose(**kwargs)
if self.__current_task_name == "depth":
return self.process_depth(**kwargs)
if self.__current_task_name == "canny":
return self.process_canny(**kwargs)
if self.__current_task_name == "scribble":
return self.process_scribble(**kwargs)
if self.__current_task_name == "linearart":
return self.process_linearart(**kwargs)
if self.__current_task_name == "tile_upscaler":
return self.process_tile_upscaler(**kwargs)
if self.__current_task_name == "canny_2x":
return self.process_canny_2x(**kwargs)
raise Exception("ControlNet is not loaded with any model")
@torch.inference_mode()
def process_canny(
self,
prompt: List[str],
imageUrl: str,
seed: int,
num_inference_steps: int,
negative_prompt: List[str],
height: int,
width: int,
guidance_scale: float = 7.5,
apply_preprocess: bool = True,
**kwargs,
):
if self.__current_task_name != "canny":
raise Exception("ControlNet is not loaded with canny model")
generator = get_generators(seed, get_num_return_sequences())
init_image = self.preprocess_image(imageUrl, width, height)
if apply_preprocess:
init_image = ControlNet.canny_detect_edge(init_image)
init_image = init_image.resize((width, height))
# if get_is_sdxl():
# kwargs["controlnet_conditioning_scale"] = 0.5
kwargs = {
"prompt": prompt,
"image": init_image,
"guidance_scale": guidance_scale,
"num_images_per_prompt": 1,
"negative_prompt": negative_prompt,
"num_inference_steps": num_inference_steps,
"height": height,
"width": width,
"generator": generator,
**kwargs,
}
print(kwargs)
result = self.pipe2.__call__(**kwargs)
return Result.from_result(result), init_image
@torch.inference_mode()
def process_canny_2x(
self,
prompt: List[str],
imageUrl: str,
seed: int,
num_inference_steps: int,
negative_prompt: List[str],
height: int,
width: int,
guidance_scale: float = 8.5,
**kwargs,
):
if self.__current_task_name != "canny_2x":
raise Exception("ControlNet is not loaded with canny model")
generator = get_generators(seed, get_num_return_sequences())
init_image = self.preprocess_image(imageUrl, width, height)
canny_image = ControlNet.canny_detect_edge(init_image).resize((width, height))
depth_image = ControlNet.depth_image(init_image).resize((width, height))
condition_scale = kwargs.get("controlnet_conditioning_scale", None)
condition_factor = kwargs.get("control_guidance_end", None)
print("condition_scale", condition_scale)
if not get_is_sdxl():
kwargs["guidance_scale"] = 7.5
kwargs["strength"] = 0.8
kwargs["controlnet_conditioning_scale"] = [condition_scale or 1.0, 0.3]
else:
kwargs["controlnet_conditioning_scale"] = [condition_scale or 0.8, 0.3]
kwargs["control_guidance_end"] = [condition_factor or 1.0, 1.0]
kwargs = {
"prompt": prompt[0],
"image": [init_image] * get_num_return_sequences(),
"control_image": [canny_image, depth_image],
"guidance_scale": guidance_scale,
"num_images_per_prompt": get_num_return_sequences(),
"negative_prompt": negative_prompt[0],
"num_inference_steps": num_inference_steps,
"strength": 1.0,
"height": height,
"width": width,
"generator": generator,
**kwargs,
}
print(kwargs)
result = self.pipe.__call__(**kwargs)
return Result.from_result(result), canny_image
@torch.inference_mode()
def process_pose(
self,
prompt: List[str],
image: List[Image.Image],
seed: int,
num_inference_steps: int,
negative_prompt: List[str],
height: int,
width: int,
guidance_scale: float = 7.5,
**kwargs,
):
if self.__current_task_name != "pose":
raise Exception("ControlNet is not loaded with pose model")
generator = get_generators(seed, get_num_return_sequences())
kwargs = {
"prompt": prompt[0],
"image": image,
"num_images_per_prompt": get_num_return_sequences(),
"num_inference_steps": num_inference_steps,
"negative_prompt": negative_prompt[0],
"guidance_scale": guidance_scale,
"height": height,
"width": width,
"generator": generator,
**kwargs,
}
print(kwargs)
result = self.pipe2.__call__(**kwargs)
return Result.from_result(result), image
@torch.inference_mode()
def process_tile_upscaler(
self,
imageUrl: str,
prompt: str,
negative_prompt: str,
num_inference_steps: int,
seed: int,
height: int,
width: int,
resize_dimension: int,
guidance_scale: float = 7.5,
**kwargs,
):
if self.__current_task_name != "tile_upscaler":
raise Exception("ControlNet is not loaded with tile_upscaler model")
init_image = None
# find the correct seed and imageUrl from imageUrl
try:
p = os.path.splitext(imageUrl)[0]
p = p.split("/")[-1]
p = p.split("_")[-1]
seed = seed + int(p)
if "_canny_2x" or "_linearart" in imageUrl:
imageUrl = imageUrl.replace("_canny_2x", "_canny_2x_highres").replace(
"_linearart_highres", ""
)
init_image = download_image(imageUrl)
width, height = init_image.size
print("Setting imageUrl with width and height", imageUrl, width, height)
except Exception as e:
print("Failed to extract seed from imageUrl", e)
print("Setting seed", seed)
generator = get_generators(seed)
if not init_image:
init_image = download_image(imageUrl).resize((width, height))
condition_image = ImageUtil.resize_image(init_image, 1024)
if get_is_sdxl():
condition_image = condition_image.resize(init_image.size)
else:
condition_image = self.__resize_for_condition_image(
init_image, resize_dimension
)
if get_is_sdxl():
kwargs["strength"] = 1.0
kwargs["controlnet_conditioning_scale"] = 1.0
kwargs["image"] = init_image
else:
kwargs["image"] = condition_image
kwargs["guidance_scale"] = guidance_scale
kwargs = {
"prompt": prompt,
"control_image": condition_image,
"num_inference_steps": num_inference_steps,
"negative_prompt": negative_prompt,
"height": condition_image.size[1],
"width": condition_image.size[0],
"generator": generator,
**kwargs,
}
result = self.pipe.__call__(**kwargs)
return Result.from_result(result), condition_image
@torch.inference_mode()
def process_scribble(
self,
image: List[Image.Image],
prompt: Union[str, List[str]],
negative_prompt: Union[str, List[str]],
num_inference_steps: int,
seed: int,
height: int,
width: int,
guidance_scale: float = 7.5,
apply_preprocess: bool = True,
**kwargs,
):
if self.__current_task_name != "scribble":
raise Exception("ControlNet is not loaded with scribble model")
generator = get_generators(seed, get_num_return_sequences())
if apply_preprocess:
if get_is_sdxl():
# We use sketch in SDXL
image = [
ControlNet.pidinet_image(image[0]).resize((width, height))
] * len(image)
else:
image = [
ControlNet.scribble_image(image[0]).resize((width, height))
] * len(image)
sdxl_args = (
{
"guidance_scale": guidance_scale,
"adapter_conditioning_scale": 1.0,
"adapter_conditioning_factor": 1.0,
}
if get_is_sdxl()
else {}
)
kwargs = {
"image": image,
"prompt": prompt,
"num_inference_steps": num_inference_steps,
"negative_prompt": negative_prompt,
"height": height,
"width": width,
"guidance_scale": guidance_scale,
"generator": generator,
**sdxl_args,
**kwargs,
}
result = self.pipe2.__call__(**kwargs)
return Result.from_result(result), image[0]
@torch.inference_mode()
def process_linearart(
self,
imageUrl: str,
prompt: Union[str, List[str]],
negative_prompt: Union[str, List[str]],
num_inference_steps: int,
seed: int,
height: int,
width: int,
guidance_scale: float = 7.5,
apply_preprocess: bool = True,
**kwargs,
):
if self.__current_task_name != "linearart":
raise Exception("ControlNet is not loaded with linearart model")
generator = get_generators(seed, get_num_return_sequences())
init_image = self.preprocess_image(imageUrl, width, height)
if apply_preprocess:
condition_image = ControlNet.linearart_condition_image(init_image)
condition_image = condition_image.resize(init_image.size)
else:
condition_image = init_image
# we use t2i adapter and the conditioning scale should always be 0.8
sdxl_args = (
{
"guidance_scale": guidance_scale,
"adapter_conditioning_scale": 1.0,
"adapter_conditioning_factor": 1.0,
}
if get_is_sdxl()
else {}
)
kwargs = {
"image": [condition_image] * get_num_return_sequences(),
"prompt": prompt,
"num_inference_steps": num_inference_steps,
"negative_prompt": negative_prompt,
"height": height,
"width": width,
"guidance_scale": guidance_scale,
"generator": generator,
**sdxl_args,
**kwargs,
}
result = self.pipe2.__call__(**kwargs)
return Result.from_result(result), condition_image
@torch.inference_mode()
def process_depth(
self,
imageUrl: str,
prompt: Union[str, List[str]],
negative_prompt: Union[str, List[str]],
num_inference_steps: int,
seed: int,
height: int,
width: int,
guidance_scale: float = 7.5,
apply_preprocess: bool = True,
**kwargs,
):
if self.__current_task_name != "depth":
raise Exception("ControlNet is not loaded with depth model")
generator = get_generators(seed, get_num_return_sequences())
init_image = self.preprocess_image(imageUrl, width, height)
if apply_preprocess:
condition_image = ControlNet.depth_image(init_image)
condition_image = condition_image.resize(init_image.size)
else:
condition_image = init_image
# for using the depth controlnet in this SDXL model, these hyperparamters are optimal
sdxl_args = (
{"controlnet_conditioning_scale": 0.2, "control_guidance_end": 0.2}
if get_is_sdxl()
else {}
)
kwargs = {
"image": [condition_image] * get_num_return_sequences(),
"prompt": prompt,
"num_inference_steps": num_inference_steps,
"negative_prompt": negative_prompt,
"height": height,
"width": width,
"guidance_scale": guidance_scale,
"generator": generator,
**sdxl_args,
**kwargs,
}
result = self.pipe2.__call__(**kwargs)
return Result.from_result(result), condition_image
def cleanup(self):
"""Doesn't do anything considering new diffusers has itself a cleanup mechanism
after controlnet generation"""
pass
def detect_pose(self, imageUrl: str) -> Image.Image:
detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
image = download_image(imageUrl)
image = detector.__call__(image)
return image
@staticmethod
def scribble_image(image: Image.Image) -> Image.Image:
processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
image = processor.__call__(input_image=image, scribble=True)
return image
@staticmethod
def linearart_condition_image(image: Image.Image, **kwargs) -> Image.Image:
processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
if get_is_sdxl():
kwargs = {"detect_resolution": 384, "image_resolution": 1024, **kwargs}
else:
kwargs = {}
image = processor.__call__(input_image=image, **kwargs)
return image
@staticmethod
@torch.inference_mode()
def depth_image(image: Image.Image) -> Image.Image:
global midas, midas_transforms
if "midas" not in globals():
midas = torch.hub.load("intel-isl/MiDaS", "MiDaS").to("cuda")
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
transform = midas_transforms.default_transform
cv_image = np.array(image)
img = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
input_batch = transform(img).to("cuda")
with torch.no_grad():
prediction = midas(input_batch)
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=img.shape[:2],
mode="bicubic",
align_corners=False,
).squeeze()
output = prediction.cpu().numpy()
formatted = (output * 255 / np.max(output)).astype("uint8")
img = Image.fromarray(formatted)
return img
@staticmethod
def pidinet_image(image: Image.Image) -> Image.Image:
pidinet = PidiNetDetector.from_pretrained("lllyasviel/Annotators").to("cuda")
image = pidinet.__call__(input_image=image, apply_filter=True)
return image
@staticmethod
def canny_detect_edge(image: Image.Image) -> Image.Image:
image_array = np.array(image)
low_threshold = 100
high_threshold = 200
image_array = cv2.Canny(image_array, low_threshold, high_threshold)
image_array = image_array[:, :, None]
image_array = np.concatenate([image_array, image_array, image_array], axis=2)
canny_image = Image.fromarray(image_array)
return canny_image
def preprocess_image(self, imageUrl, width, height) -> Image.Image:
image = download_image(imageUrl, mode="RGBA").resize((width, height))
return ImageUtil.alpha_to_white(image)
def __resize_for_condition_image(self, image: Image.Image, resolution: int):
input_image = image.convert("RGB")
W, H = input_image.size
k = float(resolution) / max(W, H)
H *= k
W *= k
H = int(round(H / 64.0)) * 64
W = int(round(W / 64.0)) * 64
img = input_image.resize((W, H), resample=Image.LANCZOS)
return img
__model_normal = {
"pose": "lllyasviel/control_v11f1p_sd15_depth, lllyasviel/control_v11p_sd15_openpose",
"canny": "lllyasviel/control_v11p_sd15_canny",
"linearart": "lllyasviel/control_v11p_sd15_lineart",
"scribble": "lllyasviel/control_v11p_sd15_scribble",
"tile_upscaler": "lllyasviel/control_v11f1e_sd15_tile",
"canny_2x": "lllyasviel/control_v11p_sd15_canny, lllyasviel/control_v11f1p_sd15_depth",
}
__model_normal_types = {
"pose": "controlnet",
"canny": "controlnet",
"linearart": "controlnet",
"scribble": "controlnet",
"tile_upscaler": "controlnet",
"canny_2x": "controlnet",
}
__model_sdxl = {
"pose": "thibaud/controlnet-openpose-sdxl-1.0",
"canny": "Autodraft/controlnet-canny-sdxl-1.0",
"depth": "Autodraft/controlnet-depth-sdxl-1.0",
"canny_2x": "Autodraft/controlnet-canny-sdxl-1.0, Autodraft/controlnet-depth-sdxl-1.0",
"linearart": "TencentARC/t2i-adapter-lineart-sdxl-1.0",
"scribble": "TencentARC/t2i-adapter-sketch-sdxl-1.0",
"tile_upscaler": "Autodraft/ControlNet_SDXL_tile_upscale",
}
__model_sdxl_types = {
"pose": "controlnet",
"canny": "controlnet",
"canny_2x": "controlnet",
"depth": "controlnet",
"linearart": "t2i",
"scribble": "t2i",
"tile_upscaler": "controlnet",
}