File size: 4,219 Bytes
dfdd03b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
import spaces
import random
import torch
from huggingface_hub import snapshot_download
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_inpainting import StableDiffusionXLInpaintPipeline
from kolors.models.modeling_chatglm import ChatGLMModel
from kolors.models.tokenization_chatglm import ChatGLMTokenizer
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
from groundingdino.util.inference import load_model, predict
from segment_anything import SamAutomaticMaskGenerator
from PIL import Image
import numpy as np
import os

# Download model checkpoints
device = "cuda"
ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors-Inpainting")

# Inpainting setup
text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder',torch_dtype=torch.float16).half().to(device)
tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device)
scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device)

pipe = StableDiffusionXLInpaintPipeline(
    vae=vae,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    unet=unet,
    scheduler=scheduler
)

pipe.to(device)
pipe.enable_attention_slicing()

# GroundingDINO and SAM setup
model_dino = load_model("path/to/groundingdino/config.yaml", "path/to/groundingdino/model.pth")
sam = SamAutomaticMaskGenerator(model_type="vit_h", checkpoint="model/sam_vit_h_4b8939.pth")

# Constants
MAX_SEED = np.iinfo(np.int32).max

def generate_mask(image: Image):
    boxes, logits, phrases = predict(model_dino, image, "prompt")  # Provide the proper prompt for detection
    masks = sam.generate(image)
    mask = masks[0]["segmentation"]  # Use the first detected mask as an example
    return Image.fromarray(mask)

@spaces.GPU
def infer(prompt, image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    
    # Generate mask using GroundingDINO + SAM
    mask_image = generate_mask(image)
    
    generator = torch.Generator().manual_seed(seed)
    result = pipe(
        prompt=prompt,
        image=image,
        mask_image=mask_image,
        height=image.height,
        width=image.width,
        guidance_scale=guidance_scale,
        generator=generator,
        num_inference_steps=num_inference_steps,
        negative_prompt=negative_prompt,
        num_images_per_prompt=1,
        strength=0.999
    ).images[0]
    
    return result

css="""
#col-left {
    margin: 0 auto;
    max-width: 600px;
}
#col-right {
    margin: 0 auto;
    max-width: 700px;
}
"""

def load_description(fp):
    with open(fp, 'r', encoding='utf-8') as f:
        content = f.read()
    return content

with gr.Blocks(css=css) as Kolors:
    gr.HTML(load_description("assets/title.md"))
    
    with gr.Row():
        with gr.Column(elem_id="col-left"):
            prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt", lines=2)
            image = gr.ImageEditor(label="Image", type="pil", image_mode='RGB')
            
            with gr.Accordion("Advanced Settings", open=False):
                negative_prompt = gr.Textbox(label="Negative prompt", value="low quality, bad anatomy")
                seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
                randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
                guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=10.0, step=0.1, value=6.0)
                num_inference_steps = gr.Slider(label="Number of inference steps", minimum=10, maximum=50, step=1, value=25)
            
            run_button = gr.Button("Run")
        
        with gr.Column(elem_id="col-right"):
            result = gr.Image(label="Result", show_label=False)
    
    run_button.click(
        fn=infer,
        inputs=[prompt, image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps],
        outputs=[result]
    )

Kolors.queue().launch(debug=True)