File size: 6,910 Bytes
89c278d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import PIL.Image
import cv2
import torch
from diffusers import ControlNetModel
from loguru import logger
from iopaint.schema import InpaintRequest, ModelType

from .base import DiffusionInpaintModel
from .helper.controlnet_preprocess import (
    make_canny_control_image,
    make_openpose_control_image,
    make_depth_control_image,
    make_inpaint_control_image,
)
from .helper.cpu_text_encoder import CPUTextEncoderWrapper
from .original_sd_configs import get_config_files
from .utils import (
    get_scheduler,
    handle_from_pretrained_exceptions,
    get_torch_dtype,
    enable_low_mem,
    is_local_files_only,
)


class ControlNet(DiffusionInpaintModel):
    name = "controlnet"
    pad_mod = 8
    min_size = 512

    @property
    def lcm_lora_id(self):
        if self.model_info.model_type in [
            ModelType.DIFFUSERS_SD,
            ModelType.DIFFUSERS_SD_INPAINT,
        ]:
            return "latent-consistency/lcm-lora-sdv1-5"
        if self.model_info.model_type in [
            ModelType.DIFFUSERS_SDXL,
            ModelType.DIFFUSERS_SDXL_INPAINT,
        ]:
            return "latent-consistency/lcm-lora-sdxl"
        raise NotImplementedError(f"Unsupported controlnet lcm model {self.model_info}")

    def init_model(self, device: torch.device, **kwargs):
        model_info = kwargs["model_info"]
        controlnet_method = kwargs["controlnet_method"]

        self.model_info = model_info
        self.controlnet_method = controlnet_method

        model_kwargs = {
            **kwargs.get("pipe_components", {}),
            "local_files_only": is_local_files_only(**kwargs),
        }
        self.local_files_only = model_kwargs["local_files_only"]

        disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get(
            "cpu_offload", False
        )
        if disable_nsfw_checker:
            logger.info("Disable Stable Diffusion Model NSFW checker")
            model_kwargs.update(
                dict(
                    safety_checker=None,
                    feature_extractor=None,
                    requires_safety_checker=False,
                )
            )

        use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
        self.torch_dtype = torch_dtype

        if model_info.model_type in [
            ModelType.DIFFUSERS_SD,
            ModelType.DIFFUSERS_SD_INPAINT,
        ]:
            from diffusers import (
                StableDiffusionControlNetInpaintPipeline as PipeClass,
            )
        elif model_info.model_type in [
            ModelType.DIFFUSERS_SDXL,
            ModelType.DIFFUSERS_SDXL_INPAINT,
        ]:
            from diffusers import (
                StableDiffusionXLControlNetInpaintPipeline as PipeClass,
            )

        controlnet = ControlNetModel.from_pretrained(
            pretrained_model_name_or_path=controlnet_method,
            resume_download=True,
            local_files_only=model_kwargs["local_files_only"],
            torch_dtype=self.torch_dtype,
        )
        if model_info.is_single_file_diffusers:
            if self.model_info.model_type == ModelType.DIFFUSERS_SD:
                model_kwargs["num_in_channels"] = 4
            else:
                model_kwargs["num_in_channels"] = 9

            self.model = PipeClass.from_single_file(
                model_info.path,
                controlnet=controlnet,
                load_safety_checker=not disable_nsfw_checker,
                torch_dtype=torch_dtype,
                config_files=get_config_files(),
                **model_kwargs,
            )
        else:
            self.model = handle_from_pretrained_exceptions(
                PipeClass.from_pretrained,
                pretrained_model_name_or_path=model_info.path,
                controlnet=controlnet,
                variant="fp16",
                torch_dtype=torch_dtype,
                **model_kwargs,
            )

        enable_low_mem(self.model, kwargs.get("low_mem", False))

        if kwargs.get("cpu_offload", False) and use_gpu:
            logger.info("Enable sequential cpu offload")
            self.model.enable_sequential_cpu_offload(gpu_id=0)
        else:
            self.model = self.model.to(device)
            if kwargs["sd_cpu_textencoder"]:
                logger.info("Run Stable Diffusion TextEncoder on CPU")
                self.model.text_encoder = CPUTextEncoderWrapper(
                    self.model.text_encoder, torch_dtype
                )

        self.callback = kwargs.pop("callback", None)

    def switch_controlnet_method(self, new_method: str):
        self.controlnet_method = new_method
        controlnet = ControlNetModel.from_pretrained(
            new_method,
            resume_download=True,
            local_files_only=self.local_files_only,
            torch_dtype=self.torch_dtype,
        ).to(self.model.device)
        self.model.controlnet = controlnet

    def _get_control_image(self, image, mask):
        if "canny" in self.controlnet_method:
            control_image = make_canny_control_image(image)
        elif "openpose" in self.controlnet_method:
            control_image = make_openpose_control_image(image)
        elif "depth" in self.controlnet_method:
            control_image = make_depth_control_image(image)
        elif "inpaint" in self.controlnet_method:
            control_image = make_inpaint_control_image(image, mask)
        else:
            raise NotImplementedError(f"{self.controlnet_method} not implemented")
        return control_image

    def forward(self, image, mask, config: InpaintRequest):
        """Input image and output image have same size
        image: [H, W, C] RGB
        mask: [H, W, 1] 255 means area to repaint
        return: BGR IMAGE
        """
        scheduler_config = self.model.scheduler.config
        scheduler = get_scheduler(config.sd_sampler, scheduler_config)
        self.model.scheduler = scheduler

        img_h, img_w = image.shape[:2]
        control_image = self._get_control_image(image, mask)
        mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L")
        image = PIL.Image.fromarray(image)

        output = self.model(
            image=image,
            mask_image=mask_image,
            control_image=control_image,
            prompt=config.prompt,
            negative_prompt=config.negative_prompt,
            num_inference_steps=config.sd_steps,
            guidance_scale=config.sd_guidance_scale,
            output_type="np",
            callback_on_step_end=self.callback,
            height=img_h,
            width=img_w,
            generator=torch.manual_seed(config.sd_seed),
            controlnet_conditioning_scale=config.controlnet_conditioning_scale,
        ).images[0]

        output = (output * 255).round().astype("uint8")
        output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
        return output