juxuan27 commited on
Commit
5bf7c30
·
1 Parent(s): f7b7d4e

Add application file

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ file filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,38 @@
1
- ---
2
- title: BrushNet
3
- emoji: 🔥
4
- colorFrom: yellow
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 4.23.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BrushNet
2
+
3
+ This repository contains the gradio demo of the paper "BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion"
4
+
5
+ Keywords: Image Inpainting, Diffusion Models, Image Generation
6
+
7
+ > [Xuan Ju](https://github.com/juxuan27)<sup>12</sup>, [Xian Liu](https://alvinliu0.github.io/)<sup>12</sup>, [Xintao Wang](https://xinntao.github.io/)<sup>1*</sup>, [Yuxuan Bian](https://scholar.google.com.hk/citations?user=HzemVzoAAAAJ&hl=zh-CN&oi=ao)<sup>2</sup>, [Ying Shan](https://www.linkedin.com/in/YingShanProfile/)<sup>1</sup>, [Qiang Xu](https://cure-lab.github.io/)<sup>2*</sup><br>
8
+ > <sup>1</sup>ARC Lab, Tencent PCG <sup>2</sup>The Chinese University of Hong Kong <sup>*</sup>Corresponding Author
9
+
10
+
11
+ <p align="center">
12
+ <a href="https://tencentarc.github.io/BrushNet/">Project Page</a> |
13
+ <a href="https://github.com/TencentARC/BrushNet">Code</a> |
14
+ <a href="https://arxiv.org/abs/2403.06976">Arxiv</a> |
15
+ <a href="https://forms.gle/9TgMZ8tm49UYsZ9s5">Data</a> |
16
+ <a href="https://drive.google.com/file/d/1IkEBWcd2Fui2WHcckap4QFPcCI0gkHBh/view">Video</a> |
17
+ </p>
18
+
19
+
20
+ ## 🤝🏼 Cite Us
21
+
22
+ ```
23
+ @misc{ju2024brushnet,
24
+ title={BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion},
25
+ author={Xuan Ju and Xian Liu and Xintao Wang and Yuxuan Bian and Ying Shan and Qiang Xu},
26
+ year={2024},
27
+ eprint={2403.06976},
28
+ archivePrefix={arXiv},
29
+ primaryClass={cs.CV}
30
+ }
31
+ ```
32
+
33
+
34
+ ## 💖 Acknowledgement
35
+ <span id="acknowledgement"></span>
36
+
37
+ Our code is modified based on [diffusers](https://github.com/huggingface/diffusers), thanks to all the contributors!
38
+
app.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import gradio as gr
4
+ import os
5
+ import cv2
6
+ from PIL import Image
7
+ import numpy as np
8
+ from segment_anything import SamPredictor, sam_model_registry
9
+ import torch
10
+ from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
11
+ import random
12
+
13
+ mobile_sam = sam_model_registry['vit_h'](checkpoint='data/ckpt/sam_vit_h_4b8939.pth').to("cuda")
14
+ mobile_sam.eval()
15
+ mobile_predictor = SamPredictor(mobile_sam)
16
+ colors = [(255, 0, 0), (0, 255, 0)]
17
+ markers = [1, 5]
18
+
19
+ # - - - - - examples - - - - - #
20
+ image_examples = [
21
+ ["examples/brushnet/src/test_image.jpg", "A beautiful cake on the table", "examples/brushnet/src/test_mask.jpg", 0, []],
22
+ ]
23
+
24
+
25
+ # choose the base model here
26
+ base_model_path = "data/ckpt/realisticVisionV60B1_v51VAE"
27
+ # base_model_path = "runwayml/stable-diffusion-v1-5"
28
+
29
+ # input brushnet ckpt path
30
+ brushnet_path = "data/ckpt/segmentation_mask_brushnet_ckpt"
31
+
32
+ # input source image / mask image path and the text prompt
33
+ image_path="examples/brushnet/src/test_image.jpg"
34
+ mask_path="examples/brushnet/src/test_mask.jpg"
35
+ caption="A cake on the table."
36
+
37
+ # conditioning scale
38
+ paintingnet_conditioning_scale=1.0
39
+
40
+ brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch.float16)
41
+ pipe = StableDiffusionBrushNetPipeline.from_pretrained(
42
+ base_model_path, brushnet=brushnet, torch_dtype=torch.float16, low_cpu_mem_usage=False
43
+ )
44
+
45
+ # speed up diffusion process with faster scheduler and memory optimization
46
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
47
+ # remove following line if xformers is not installed or when using Torch 2.0.
48
+ # pipe.enable_xformers_memory_efficient_attention()
49
+ # memory optimization.
50
+ pipe.enable_model_cpu_offload()
51
+
52
+ def resize_image(input_image, resolution):
53
+ H, W, C = input_image.shape
54
+ H = float(H)
55
+ W = float(W)
56
+ k = float(resolution) / min(H, W)
57
+ H *= k
58
+ W *= k
59
+ H = int(np.round(H / 64.0)) * 64
60
+ W = int(np.round(W / 64.0)) * 64
61
+ img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
62
+ return img
63
+
64
+
65
+ def process(input_image,
66
+ original_image,
67
+ original_mask,
68
+ input_mask,
69
+ selected_points,
70
+ prompt,
71
+ negative_prompt,
72
+ blended,
73
+ invert_mask,
74
+ control_strength,
75
+ seed,
76
+ randomize_seed,
77
+ guidance_scale,
78
+ num_inference_steps):
79
+ if original_image is None:
80
+ raise gr.Error('Please upload the input image')
81
+ if (original_mask is None or len(selected_points)==0) and input_mask is None:
82
+ raise gr.Error("Please click the region where you hope unchanged/changed, or upload a white-black Mask image")
83
+
84
+ # load example image
85
+ if isinstance(original_image, int):
86
+ image_name = image_examples[original_image][0]
87
+ original_image = cv2.imread(image_name)
88
+ original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
89
+
90
+ if input_mask is not None:
91
+ H,W=original_image.shape[:2]
92
+ original_mask = cv2.resize(input_mask, (W, H))
93
+ else:
94
+ original_mask = np.clip(255 - original_mask, 0, 255).astype(np.uint8)
95
+
96
+ if invert_mask:
97
+ original_mask=255-original_mask
98
+
99
+ mask = 1.*(original_mask.sum(-1)>255)[:,:,np.newaxis]
100
+ masked_image = original_image * (1-mask)
101
+
102
+ init_image = Image.fromarray(masked_image.astype(np.uint8)).convert("RGB")
103
+ mask_image = Image.fromarray(original_mask.astype(np.uint8)).convert("RGB")
104
+
105
+ generator = torch.Generator("cuda").manual_seed(random.randint(0,2147483647) if randomize_seed else seed)
106
+
107
+ image = pipe(
108
+ [prompt]*2,
109
+ init_image,
110
+ mask_image,
111
+ num_inference_steps=num_inference_steps,
112
+ guidance_scale=guidance_scale,
113
+ generator=generator,
114
+ brushnet_conditioning_scale=float(control_strength),
115
+ negative_prompt=[negative_prompt]*2,
116
+ ).images
117
+
118
+ if blended:
119
+ if control_strength<1.0:
120
+ raise gr.Error('Using blurred blending with control strength less than 1.0 is not allowed')
121
+ blended_image=[]
122
+ # blur, you can adjust the parameters for better performance
123
+ mask = cv2.GaussianBlur(mask*255, (21, 21), 0)/255
124
+ mask = mask[:,:,np.newaxis]
125
+ for image_i in image:
126
+ image_np=np.array(image_i)
127
+ image_pasted=original_image * (1-mask) + image_np*mask
128
+
129
+ image_pasted=image_pasted.astype(image_np.dtype)
130
+ blended_image.append(Image.fromarray(image_pasted))
131
+
132
+ image=blended_image
133
+
134
+ return image
135
+
136
+ block = gr.Blocks(
137
+ theme=gr.themes.Soft(
138
+ radius_size=gr.themes.sizes.radius_none,
139
+ text_size=gr.themes.sizes.text_md
140
+ )
141
+ ).queue()
142
+ with block:
143
+ with gr.Row():
144
+ with gr.Column():
145
+
146
+ gr.HTML(f"""
147
+ <div style="text-align: center;">
148
+ <h1>BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion</h1>
149
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
150
+ <a href=""></a>
151
+ <a href='https://tencentarc.github.io/BrushNet/'><img src='https://img.shields.io/badge/Project_Page-BrushNet-green' alt='Project Page'></a>
152
+ <a href='https://arxiv.org/abs/2403.06976'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
153
+ </div>
154
+ </br>
155
+ </div>
156
+ """)
157
+
158
+
159
+ with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
160
+ with gr.Row(equal_height=True):
161
+ gr.Markdown("""
162
+ - ⭐️ <b>step1: </b>Upload or select one image from Example
163
+ - ⭐️ <b>step2: </b>Click on Input-image to select the object to be retained (or upload a white-black Mask image, in which white color indicates the region you want to keep unchanged). You can tick the 'Invert Mask' box to switch region unchanged and change.
164
+ - ⭐️ <b>step3: </b>Input prompt for generating new contents
165
+ - ⭐️ <b>step4: </b>Click Run button
166
+ """)
167
+ with gr.Row():
168
+ with gr.Column():
169
+ with gr.Column(elem_id="Input"):
170
+ with gr.Row():
171
+ with gr.Tabs(elem_classes=["feedback"]):
172
+ with gr.TabItem("Input Image"):
173
+ input_image = gr.Image(type="numpy", label="input",scale=2, height=640)
174
+ original_image = gr.State(value=None,label="index")
175
+ original_mask = gr.State(value=None)
176
+ selected_points = gr.State([],label="select points")
177
+ with gr.Row(elem_id="Seg"):
178
+ radio = gr.Radio(['foreground', 'background'], label='Click to seg: ', value='foreground',scale=2)
179
+ undo_button = gr.Button('Undo seg', elem_id="btnSEG",scale=1)
180
+ prompt = gr.Textbox(label="Prompt", placeholder="Please input your prompt",value='',lines=1)
181
+ negative_prompt = gr.Text(
182
+ label="Negative Prompt",
183
+ max_lines=5,
184
+ placeholder="Please input your negative prompt",
185
+ value='ugly, low quality',lines=1
186
+ )
187
+ with gr.Group():
188
+ with gr.Row():
189
+ blending = gr.Checkbox(label="Blurred Blending", value=False)
190
+ invert_mask = gr.Checkbox(label="Invert Mask", value=True)
191
+ run_button = gr.Button("Run",elem_id="btn")
192
+
193
+ with gr.Accordion("More input params (highly-recommended)", open=False, elem_id="accordion1"):
194
+ control_strength = gr.Slider(
195
+ label="Control Strength: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01
196
+ )
197
+ with gr.Group():
198
+ seed = gr.Slider(
199
+ label="Seed: ", minimum=0, maximum=2147483647, step=1, value=551793204
200
+ )
201
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
202
+
203
+ with gr.Group():
204
+ with gr.Row():
205
+ guidance_scale = gr.Slider(
206
+ label="Guidance scale",
207
+ minimum=1,
208
+ maximum=12,
209
+ step=0.1,
210
+ value=12,
211
+ )
212
+ num_inference_steps = gr.Slider(
213
+ label="Number of inference steps",
214
+ minimum=1,
215
+ maximum=50,
216
+ step=1,
217
+ value=50,
218
+ )
219
+ with gr.Row(elem_id="Image"):
220
+ with gr.Tabs(elem_classes=["feedback1"]):
221
+ with gr.TabItem("User-specified Mask Image (Optional)"):
222
+ input_mask = gr.Image(type="numpy", label="Mask Image", height=640)
223
+
224
+ with gr.Column():
225
+ with gr.Tabs(elem_classes=["feedback"]):
226
+ with gr.TabItem("Outputs"):
227
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", preview=True)
228
+ with gr.Row():
229
+ def process_example(input_image, prompt, input_mask, original_image, selected_points): #
230
+ return input_image, prompt, input_mask, original_image, []
231
+ example = gr.Examples(
232
+ label="Input Example",
233
+ examples=image_examples,
234
+ inputs=[input_image, prompt, input_mask, original_image, selected_points],
235
+ outputs=[input_image, prompt, input_mask, original_image, selected_points],
236
+ fn=process_example,
237
+ run_on_click=True,
238
+ examples_per_page=10
239
+ )
240
+
241
+ # once user upload an image, the original image is stored in `original_image`
242
+ def store_img(img):
243
+ # image upload is too slow
244
+ if min(img.shape[0], img.shape[1]) > 512:
245
+ img = resize_image(img, 512)
246
+ if max(img.shape[0], img.shape[1])*1.0/min(img.shape[0], img.shape[1])>2.0:
247
+ raise gr.Error('image aspect ratio cannot be larger than 2.0')
248
+ return img, img, [], None # when new image is uploaded, `selected_points` should be empty
249
+
250
+ input_image.upload(
251
+ store_img,
252
+ [input_image],
253
+ [input_image, original_image, selected_points]
254
+ )
255
+
256
+ # user click the image to get points, and show the points on the image
257
+ def segmentation(img, sel_pix):
258
+ # online show seg mask
259
+ points = []
260
+ labels = []
261
+ for p, l in sel_pix:
262
+ points.append(p)
263
+ labels.append(l)
264
+ mobile_predictor.set_image(img if isinstance(img, np.ndarray) else np.array(img))
265
+ with torch.no_grad():
266
+ masks, _, _ = mobile_predictor.predict(point_coords=np.array(points), point_labels=np.array(labels), multimask_output=False)
267
+
268
+ output_mask = np.ones((masks.shape[1], masks.shape[2], 3))*255
269
+ for i in range(3):
270
+ output_mask[masks[0] == True, i] = 0.0
271
+
272
+ mask_all = np.ones((masks.shape[1], masks.shape[2], 3))
273
+ color_mask = np.random.random((1, 3)).tolist()[0]
274
+ for i in range(3):
275
+ mask_all[masks[0] == True, i] = color_mask[i]
276
+ masked_img = img / 255 * 0.3 + mask_all * 0.7
277
+ masked_img = masked_img*255
278
+ ## draw points
279
+ for point, label in sel_pix:
280
+ cv2.drawMarker(masked_img, point, colors[label], markerType=markers[label], markerSize=20, thickness=5)
281
+ return masked_img, output_mask
282
+
283
+ def get_point(img, sel_pix, point_type, evt: gr.SelectData):
284
+ if point_type == 'foreground':
285
+ sel_pix.append((evt.index, 1)) # append the foreground_point
286
+ elif point_type == 'background':
287
+ sel_pix.append((evt.index, 0)) # append the background_point
288
+ else:
289
+ sel_pix.append((evt.index, 1)) # default foreground_point
290
+
291
+ if isinstance(img, int):
292
+ image_name = image_examples[img][0]
293
+ img = cv2.imread(image_name)
294
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
295
+
296
+ # online show seg mask
297
+ masked_img, output_mask = segmentation(img, sel_pix)
298
+ return masked_img.astype(np.uint8), output_mask
299
+
300
+ input_image.select(
301
+ get_point,
302
+ [original_image, selected_points, radio],
303
+ [input_image, original_mask],
304
+ )
305
+
306
+ # undo the selected point
307
+ def undo_points(orig_img, sel_pix):
308
+ # draw points
309
+ output_mask = None
310
+ if len(sel_pix) != 0:
311
+ if isinstance(orig_img, int): # if orig_img is int, the image if select from examples
312
+ temp = cv2.imread(image_examples[orig_img][0])
313
+ temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
314
+ else:
315
+ temp = orig_img.copy()
316
+ sel_pix.pop()
317
+ # online show seg mask
318
+ if len(sel_pix) !=0:
319
+ temp, output_mask = segmentation(temp, sel_pix)
320
+ return temp.astype(np.uint8), output_mask
321
+ else:
322
+ gr.Error("Nothing to Undo")
323
+
324
+ undo_button.click(
325
+ undo_points,
326
+ [original_image, selected_points],
327
+ [input_image, original_mask]
328
+ )
329
+
330
+ ips=[input_image, original_image, original_mask, input_mask, selected_points, prompt, negative_prompt, blending, invert_mask, control_strength, seed, randomize_seed, guidance_scale, num_inference_steps]
331
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
332
+
333
+
334
+ block.launch()
examples/brushnet/src/test_image.jpg ADDED
examples/brushnet/src/test_mask.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.12.1+cu116
2
+ torchvision==0.13.1+cu116
3
+ torchaudio==0.12.1
4
+ transformers>=4.25.1
5
+ ftfy
6
+ tensorboard
7
+ datasets
8
+ Pillow==9.5.0
9
+ opencv-python
10
+ imgaug
11
+ accelerate==0.20.3
12
+ image-reward
13
+ hpsv2
14
+ torchmetrics
15
+ open-clip-torch
16
+ clip
17
+ gradio==3.50.0
18
+ segment_anything
19
+ git+https://github.com/TencentARC/BrushNet.git