File size: 3,767 Bytes
dce20cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import spaces
import gradio as gr
import torch
from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, ControlNetModel, AutoencoderKL
from PIL import Image
import os
import time

from utils.dl_utils import dl_cn_model, dl_cn_config, dl_lora_model
from utils.image_utils import resize_image_aspect_ratio, base_generation
from utils.prompt_utils import remove_duplicates

path = os.getcwd()
cn_dir = f"{path}/controlnet"
lora_dir = f"{path}/lora"
os.makedirs(cn_dir, exist_ok=True)
os.makedirs(lora_dir, exist_ok=True)

dl_cn_model(cn_dir)
dl_cn_config(cn_dir)
dl_lora_model(lora_dir)

def load_model(lora_dir, cn_dir):
    dtype = torch.float16
    vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
    controlnet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=dtype, use_safetensors=True)

    pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
        "cagliostrolab/animagine-xl-3.1", controlnet=controlnet, vae=vae, torch_dtype=torch.float16
    )
    pipe.enable_model_cpu_offload()
    pipe.load_lora_weights(lora_dir, weight_name="Fixhands_anime_bdsqlsz_V1.safetensors")
    return pipe

@spaces.GPU(duration=120)
def predict(input_image_path, prompt, negative_prompt, controlnet_scale):
    pipe = load_model(lora_dir, cn_dir) 
    input_image = Image.open(input_image_path)
    base_image = base_generation(input_image.size, (255, 255, 255, 255)).convert("RGB")
    resize_image = resize_image_aspect_ratio(input_image)
    resize_base_image = resize_image_aspect_ratio(base_image)
    generator = torch.manual_seed(0)
    last_time = time.time()
    prompt = "masterpiece, best quality, simple background, white background, bald, nude, " + prompt
    prompt = remove_duplicates(prompt)        
    print(prompt)

    output_image = pipe(
        image=resize_base_image,
        control_image=resize_image,
        strength=1.0,
        prompt=prompt,
        negative_prompt = negative_prompt,
        controlnet_conditioning_scale=float(controlnet_scale),
        generator=generator,
        num_inference_steps=30,
        eta=1.0,
    ).images[0]
    print(f"Time taken: {time.time() - last_time}")
    output_image = output_image.resize(input_image.size, Image.LANCZOS)
    return output_image

class Img2Img:
    def __init__(self):
        self.demo = self.layout()
        self.tagger_model = None
        self.input_image_path = None
        self.canny_image = None


    def layout(self):
        css = """
        #intro{
            max-width: 32rem;
            text-align: center;
            margin: 0 auto;
        }
        """
        with gr.Blocks(css=css) as demo:
            with gr.Row():
                with gr.Column():
                    self.input_image_path = gr.Image(label="input_image", type='filepath')
                    self.prompt = gr.Textbox(label="prompt", lines=3)
                    self.negative_prompt = gr.Textbox(label="negative_prompt", lines=3, value="nsfw, nipples, bad anatomy, liquid fingers, low quality, worst quality, out of focus, ugly, error, jpeg artifacts, lowers, blurry, bokeh")
                    self.controlnet_scale = gr.Slider(minimum=0.5, maximum=2.0, value=1.0, step=0.01, label="Stick_fidelity")                 
                    generate_button = gr.Button("generate")
                with gr.Column():
                    self.output_image = gr.Image(type="pil", label="output_image")

            generate_button.click(
                fn=predict,
                inputs=[self.input_image_path, self.prompt, self.negative_prompt, self.controlnet_scale],
                outputs=self.output_image
            )
        return demo

img2img = Img2Img()
img2img.demo.launch(share=True)