rishh76 commited on
Commit
dfdd03b
·
1 Parent(s): aa52d88

Updated app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -0
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import random
4
+ import torch
5
+ from huggingface_hub import snapshot_download
6
+ from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_inpainting import StableDiffusionXLInpaintPipeline
7
+ from kolors.models.modeling_chatglm import ChatGLMModel
8
+ from kolors.models.tokenization_chatglm import ChatGLMTokenizer
9
+ from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
10
+ from groundingdino.util.inference import load_model, predict
11
+ from segment_anything import SamAutomaticMaskGenerator
12
+ from PIL import Image
13
+ import numpy as np
14
+ import os
15
+
16
+ # Download model checkpoints
17
+ device = "cuda"
18
+ ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors-Inpainting")
19
+
20
+ # Inpainting setup
21
+ text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder',torch_dtype=torch.float16).half().to(device)
22
+ tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
23
+ vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device)
24
+ scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
25
+ unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device)
26
+
27
+ pipe = StableDiffusionXLInpaintPipeline(
28
+ vae=vae,
29
+ text_encoder=text_encoder,
30
+ tokenizer=tokenizer,
31
+ unet=unet,
32
+ scheduler=scheduler
33
+ )
34
+
35
+ pipe.to(device)
36
+ pipe.enable_attention_slicing()
37
+
38
+ # GroundingDINO and SAM setup
39
+ model_dino = load_model("path/to/groundingdino/config.yaml", "path/to/groundingdino/model.pth")
40
+ sam = SamAutomaticMaskGenerator(model_type="vit_h", checkpoint="model/sam_vit_h_4b8939.pth")
41
+
42
+ # Constants
43
+ MAX_SEED = np.iinfo(np.int32).max
44
+
45
+ def generate_mask(image: Image):
46
+ boxes, logits, phrases = predict(model_dino, image, "prompt") # Provide the proper prompt for detection
47
+ masks = sam.generate(image)
48
+ mask = masks[0]["segmentation"] # Use the first detected mask as an example
49
+ return Image.fromarray(mask)
50
+
51
+ @spaces.GPU
52
+ def infer(prompt, image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps):
53
+ if randomize_seed:
54
+ seed = random.randint(0, MAX_SEED)
55
+
56
+ # Generate mask using GroundingDINO + SAM
57
+ mask_image = generate_mask(image)
58
+
59
+ generator = torch.Generator().manual_seed(seed)
60
+ result = pipe(
61
+ prompt=prompt,
62
+ image=image,
63
+ mask_image=mask_image,
64
+ height=image.height,
65
+ width=image.width,
66
+ guidance_scale=guidance_scale,
67
+ generator=generator,
68
+ num_inference_steps=num_inference_steps,
69
+ negative_prompt=negative_prompt,
70
+ num_images_per_prompt=1,
71
+ strength=0.999
72
+ ).images[0]
73
+
74
+ return result
75
+
76
+ css="""
77
+ #col-left {
78
+ margin: 0 auto;
79
+ max-width: 600px;
80
+ }
81
+ #col-right {
82
+ margin: 0 auto;
83
+ max-width: 700px;
84
+ }
85
+ """
86
+
87
+ def load_description(fp):
88
+ with open(fp, 'r', encoding='utf-8') as f:
89
+ content = f.read()
90
+ return content
91
+
92
+ with gr.Blocks(css=css) as Kolors:
93
+ gr.HTML(load_description("assets/title.md"))
94
+
95
+ with gr.Row():
96
+ with gr.Column(elem_id="col-left"):
97
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt", lines=2)
98
+ image = gr.ImageEditor(label="Image", type="pil", image_mode='RGB')
99
+
100
+ with gr.Accordion("Advanced Settings", open=False):
101
+ negative_prompt = gr.Textbox(label="Negative prompt", value="low quality, bad anatomy")
102
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
103
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
104
+ guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=10.0, step=0.1, value=6.0)
105
+ num_inference_steps = gr.Slider(label="Number of inference steps", minimum=10, maximum=50, step=1, value=25)
106
+
107
+ run_button = gr.Button("Run")
108
+
109
+ with gr.Column(elem_id="col-right"):
110
+ result = gr.Image(label="Result", show_label=False)
111
+
112
+ run_button.click(
113
+ fn=infer,
114
+ inputs=[prompt, image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps],
115
+ outputs=[result]
116
+ )
117
+
118
+ Kolors.queue().launch(debug=True)