File size: 2,989 Bytes
bb2108c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
import cv2
import torch
import PIL
import numpy as np
import os

class InstantStyle():
    def __init__(self, 
                 device="cuda", 
                 weight="stabilityai/stable-diffusion-xl-base-1.0", 
                 control_weight="diffusers/controlnet-canny-sdxl-1.0",
                 custom_sdxl_models_folder="sdxl_models"):
        from .ip_adapter import IPAdapterXL

        controlnet = ControlNetModel.from_pretrained(control_weight, 
                                                    use_safetensors=False, 
                                                    torch_dtype=torch.float16).to(device)
        # load SDXL pipeline
        sdxl_control_pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
            weight,
            controlnet=controlnet,
            torch_dtype=torch.float16,
            add_watermarker=False,
        )
        sdxl_control_pipe.enable_vae_tiling()
        self.ip_model = IPAdapterXL(sdxl_control_pipe, 
                            os.path.join(custom_sdxl_models_folder, "image_encoder"),
                            os.path.join(custom_sdxl_models_folder, "ip-adapter_sdxl.bin"),
                            device, 
                            target_blocks=["up_blocks.0.attentions.1"])


    def infer_one_image(self, src_image: PIL.Image.Image = None, 
                        style_image: PIL.Image.Image = None,
                        prompt: str = "masterpiece, best quality, high quality", 
                        seed: int = 42, 
                        negative_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry"):

        src_image = src_image.convert('RGB') # force it to RGB format
        style_image = style_image.convert('RGB') # force it to RGB format

        def pil_to_cv2(image_pil):
            image_np = np.array(image_pil)
            image_cv2 = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
            
            return image_cv2
        # control image
        input_image = pil_to_cv2(src_image)
        detected_map = cv2.Canny(input_image, 50, 200)
        canny_map = PIL.Image.fromarray(cv2.cvtColor(detected_map, cv2.COLOR_BGR2RGB))

        # generate image
        if prompt is None:
            prompt = "masterpiece, best quality, high quality"
        image = self.ip_model.generate(pil_image=style_image,
                                prompt=prompt,
                                negative_prompt=negative_prompt,
                                scale=1.0,
                                guidance_scale=5,
                                num_samples=1,
                                num_inference_steps=30, 
                                seed=seed,
                                image=canny_map,
                                controlnet_conditioning_scale=0.6,
                                )[0]
        return image