Freak-ppa commited on
Commit
b15e30e
·
verified ·
1 Parent(s): 51622a7

Update ComfyUI/custom_nodes/ComfyUI-BrushNet/brushnet_nodes.py

Browse files
ComfyUI/custom_nodes/ComfyUI-BrushNet/brushnet_nodes.py CHANGED
@@ -1,1085 +1,1098 @@
1
- import os
2
- import types
3
- from typing import Tuple
4
-
5
- import torch
6
- import torchvision.transforms as T
7
- import torch.nn.functional as F
8
- from accelerate import init_empty_weights, load_checkpoint_and_dispatch
9
-
10
- import comfy
11
- import folder_paths
12
-
13
- from .model_patch import add_model_patch_option, patch_model_function_wrapper
14
-
15
- from .brushnet.brushnet import BrushNetModel
16
- from .brushnet.brushnet_ca import BrushNetModel as PowerPaintModel
17
-
18
- from .brushnet.powerpaint_utils import TokenizerWrapper, add_tokens
19
-
20
- current_directory = os.path.dirname(os.path.abspath(__file__))
21
- brushnet_config_file = os.path.join(current_directory, 'brushnet', 'brushnet.json')
22
- brushnet_xl_config_file = os.path.join(current_directory, 'brushnet', 'brushnet_xl.json')
23
- powerpaint_config_file = os.path.join(current_directory,'brushnet', 'powerpaint.json')
24
-
25
- sd15_scaling_factor = 0.18215
26
- sdxl_scaling_factor = 0.13025
27
-
28
- ModelsToUnload = [comfy.sd1_clip.SD1ClipModel,
29
- comfy.ldm.models.autoencoder.AutoencoderKL
30
- ]
31
-
32
-
33
- class BrushNetLoader:
34
-
35
- @classmethod
36
- def INPUT_TYPES(self):
37
- self.inpaint_files = get_files_with_extension('inpaint')
38
- return {"required":
39
- {
40
- "brushnet": ([file for file in self.inpaint_files], ),
41
- "dtype": (['float16', 'bfloat16', 'float32', 'float64'], ),
42
- },
43
- }
44
-
45
- CATEGORY = "inpaint"
46
- RETURN_TYPES = ("BRMODEL",)
47
- RETURN_NAMES = ("brushnet",)
48
-
49
- FUNCTION = "brushnet_loading"
50
-
51
- def brushnet_loading(self, brushnet, dtype):
52
- brushnet_file = os.path.join(self.inpaint_files[brushnet], brushnet)
53
- is_SDXL = False
54
- is_PP = False
55
- sd = comfy.utils.load_torch_file(brushnet_file)
56
- brushnet_down_block, brushnet_mid_block, brushnet_up_block, keys = brushnet_blocks(sd)
57
- del sd
58
- if brushnet_down_block == 24 and brushnet_mid_block == 2 and brushnet_up_block == 30:
59
- is_SDXL = False
60
- if keys == 322:
61
- is_PP = False
62
- print('BrushNet model type: SD1.5')
63
- else:
64
- is_PP = True
65
- print('PowerPaint model type: SD1.5')
66
- elif brushnet_down_block == 18 and brushnet_mid_block == 2 and brushnet_up_block == 22:
67
- print('BrushNet model type: Loading SDXL')
68
- is_SDXL = True
69
- is_PP = False
70
- else:
71
- raise Exception("Unknown BrushNet model")
72
-
73
- with init_empty_weights():
74
- if is_SDXL:
75
- brushnet_config = BrushNetModel.load_config(brushnet_xl_config_file)
76
- brushnet_model = BrushNetModel.from_config(brushnet_config)
77
- elif is_PP:
78
- brushnet_config = PowerPaintModel.load_config(powerpaint_config_file)
79
- brushnet_model = PowerPaintModel.from_config(brushnet_config)
80
- else:
81
- brushnet_config = BrushNetModel.load_config(brushnet_config_file)
82
- brushnet_model = BrushNetModel.from_config(brushnet_config)
83
-
84
- if is_PP:
85
- print("PowerPaint model file:", brushnet_file)
86
- else:
87
- print("BrushNet model file:", brushnet_file)
88
-
89
- if dtype == 'float16':
90
- torch_dtype = torch.float16
91
- elif dtype == 'bfloat16':
92
- torch_dtype = torch.bfloat16
93
- elif dtype == 'float32':
94
- torch_dtype = torch.float32
95
- else:
96
- torch_dtype = torch.float64
97
-
98
- brushnet_model = load_checkpoint_and_dispatch(
99
- brushnet_model,
100
- brushnet_file,
101
- device_map="sequential",
102
- max_memory=None,
103
- offload_folder=None,
104
- offload_state_dict=False,
105
- dtype=torch_dtype,
106
- force_hooks=False,
107
- )
108
-
109
- if is_PP:
110
- print("PowerPaint model is loaded")
111
- elif is_SDXL:
112
- print("BrushNet SDXL model is loaded")
113
- else:
114
- print("BrushNet SD1.5 model is loaded")
115
-
116
- return ({"brushnet": brushnet_model, "SDXL": is_SDXL, "PP": is_PP, "dtype": torch_dtype}, )
117
-
118
-
119
- class PowerPaintCLIPLoader:
120
-
121
- @classmethod
122
- def INPUT_TYPES(self):
123
- self.inpaint_files = get_files_with_extension('inpaint', ['.bin'])
124
- self.clip_files = get_files_with_extension('clip')
125
- return {"required":
126
- {
127
- "base": ([file for file in self.clip_files], ),
128
- "powerpaint": ([file for file in self.inpaint_files], ),
129
- },
130
- }
131
-
132
- CATEGORY = "inpaint"
133
- RETURN_TYPES = ("CLIP",)
134
- RETURN_NAMES = ("clip",)
135
-
136
- FUNCTION = "ppclip_loading"
137
-
138
- def ppclip_loading(self, base, powerpaint):
139
- base_CLIP_file = os.path.join(self.clip_files[base], base)
140
- pp_CLIP_file = os.path.join(self.inpaint_files[powerpaint], powerpaint)
141
-
142
- pp_clip = comfy.sd.load_clip(ckpt_paths=[base_CLIP_file])
143
-
144
- print('PowerPaint base CLIP file: ', base_CLIP_file)
145
-
146
- pp_tokenizer = TokenizerWrapper(pp_clip.tokenizer.clip_l.tokenizer)
147
- pp_text_encoder = pp_clip.patcher.model.clip_l.transformer
148
-
149
- add_tokens(
150
- tokenizer = pp_tokenizer,
151
- text_encoder = pp_text_encoder,
152
- placeholder_tokens = ["P_ctxt", "P_shape", "P_obj"],
153
- initialize_tokens = ["a", "a", "a"],
154
- num_vectors_per_token = 10,
155
- )
156
-
157
- pp_text_encoder.load_state_dict(comfy.utils.load_torch_file(pp_CLIP_file), strict=False)
158
-
159
- print('PowerPaint CLIP file: ', pp_CLIP_file)
160
-
161
- pp_clip.tokenizer.clip_l.tokenizer = pp_tokenizer
162
- pp_clip.patcher.model.clip_l.transformer = pp_text_encoder
163
-
164
- return (pp_clip,)
165
-
166
-
167
- class PowerPaint:
168
-
169
- @classmethod
170
- def INPUT_TYPES(s):
171
- return {"required":
172
- {
173
- "model": ("MODEL",),
174
- "vae": ("VAE", ),
175
- "image": ("IMAGE",),
176
- "mask": ("MASK",),
177
- "powerpaint": ("BRMODEL", ),
178
- "clip": ("CLIP", ),
179
- "positive": ("CONDITIONING", ),
180
- "negative": ("CONDITIONING", ),
181
- "fitting" : ("FLOAT", {"default": 1.0, "min": 0.3, "max": 1.0}),
182
- "function": (['text guided', 'shape guided', 'object removal', 'context aware', 'image outpainting'], ),
183
- "scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
184
- "start_at": ("INT", {"default": 0, "min": 0, "max": 10000}),
185
- "end_at": ("INT", {"default": 10000, "min": 0, "max": 10000}),
186
- "save_memory": (['none', 'auto', 'max'], ),
187
- },
188
- }
189
-
190
- CATEGORY = "inpaint"
191
- RETURN_TYPES = ("MODEL","CONDITIONING","CONDITIONING","LATENT",)
192
- RETURN_NAMES = ("model","positive","negative","latent",)
193
-
194
- FUNCTION = "model_update"
195
-
196
- def model_update(self, model, vae, image, mask, powerpaint, clip, positive, negative, fitting, function, scale, start_at, end_at, save_memory):
197
-
198
- is_SDXL, is_PP = check_compatibilty(model, powerpaint)
199
- if not is_PP:
200
- raise Exception("BrushNet model was loaded, please use BrushNet node")
201
-
202
- # Make a copy of the model so that we're not patching it everywhere in the workflow.
203
- model = model.clone()
204
-
205
- # prepare image and mask
206
- # no batches for original image and mask
207
- masked_image, mask = prepare_image(image, mask)
208
-
209
- batch = masked_image.shape[0]
210
- #width = masked_image.shape[2]
211
- #height = masked_image.shape[1]
212
-
213
- if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format, 'scale_factor'):
214
- scaling_factor = model.model.model_config.latent_format.scale_factor
215
- else:
216
- scaling_factor = sd15_scaling_factor
217
-
218
- torch_dtype = powerpaint['dtype']
219
-
220
- # prepare conditioning latents
221
- conditioning_latents = get_image_latents(masked_image, mask, vae, scaling_factor)
222
- conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
223
- conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
224
-
225
- # prepare embeddings
226
-
227
- if function == "object removal":
228
- promptA = "P_ctxt"
229
- promptB = "P_ctxt"
230
- negative_promptA = "P_obj"
231
- negative_promptB = "P_obj"
232
- print('You should add to positive prompt: "empty scene blur"')
233
- #positive = positive + " empty scene blur"
234
- elif function == "context aware":
235
- promptA = "P_ctxt"
236
- promptB = "P_ctxt"
237
- negative_promptA = ""
238
- negative_promptB = ""
239
- #positive = positive + " empty scene"
240
- print('You should add to positive prompt: "empty scene"')
241
- elif function == "shape guided":
242
- promptA = "P_shape"
243
- promptB = "P_ctxt"
244
- negative_promptA = "P_shape"
245
- negative_promptB = "P_ctxt"
246
- elif function == "image outpainting":
247
- promptA = "P_ctxt"
248
- promptB = "P_ctxt"
249
- negative_promptA = "P_obj"
250
- negative_promptB = "P_obj"
251
- #positive = positive + " empty scene"
252
- print('You should add to positive prompt: "empty scene"')
253
- else:
254
- promptA = "P_obj"
255
- promptB = "P_obj"
256
- negative_promptA = "P_obj"
257
- negative_promptB = "P_obj"
258
-
259
- tokens = clip.tokenize(promptA)
260
- prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False)
261
-
262
- tokens = clip.tokenize(negative_promptA)
263
- negative_prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False)
264
-
265
- tokens = clip.tokenize(promptB)
266
- prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False)
267
-
268
- tokens = clip.tokenize(negative_promptB)
269
- negative_prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False)
270
-
271
- prompt_embeds_pp = (prompt_embedsA * fitting + (1.0 - fitting) * prompt_embedsB).to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
272
- negative_prompt_embeds_pp = (negative_prompt_embedsA * fitting + (1.0 - fitting) * negative_prompt_embedsB).to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
273
-
274
- # unload vae and CLIPs
275
- del vae
276
- del clip
277
- for loaded_model in comfy.model_management.current_loaded_models:
278
- if type(loaded_model.model.model) in ModelsToUnload:
279
- comfy.model_management.current_loaded_models.remove(loaded_model)
280
- loaded_model.model_unload()
281
- del loaded_model
282
-
283
- # apply patch to model
284
-
285
- brushnet_conditioning_scale = scale
286
- control_guidance_start = start_at
287
- control_guidance_end = end_at
288
-
289
- if save_memory != 'none':
290
- powerpaint['brushnet'].set_attention_slice(save_memory)
291
-
292
- add_brushnet_patch(model,
293
- powerpaint['brushnet'],
294
- torch_dtype,
295
- conditioning_latents,
296
- (brushnet_conditioning_scale, control_guidance_start, control_guidance_end),
297
- negative_prompt_embeds_pp, prompt_embeds_pp,
298
- None, None, None,
299
- False)
300
-
301
- latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]], device=powerpaint['brushnet'].device)
302
-
303
- return (model, positive, negative, {"samples":latent},)
304
-
305
-
306
- class BrushNet:
307
-
308
- @classmethod
309
- def INPUT_TYPES(s):
310
- return {"required":
311
- {
312
- "model": ("MODEL",),
313
- "vae": ("VAE", ),
314
- "image": ("IMAGE",),
315
- "mask": ("MASK",),
316
- "brushnet": ("BRMODEL", ),
317
- "positive": ("CONDITIONING", ),
318
- "negative": ("CONDITIONING", ),
319
- "scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
320
- "start_at": ("INT", {"default": 0, "min": 0, "max": 10000}),
321
- "end_at": ("INT", {"default": 10000, "min": 0, "max": 10000}),
322
- },
323
- }
324
-
325
- CATEGORY = "inpaint"
326
- RETURN_TYPES = ("MODEL","CONDITIONING","CONDITIONING","LATENT",)
327
- RETURN_NAMES = ("model","positive","negative","latent",)
328
-
329
- FUNCTION = "model_update"
330
-
331
- def model_update(self, model, vae, image, mask, brushnet, positive, negative, scale, start_at, end_at):
332
-
333
- is_SDXL, is_PP = check_compatibilty(model, brushnet)
334
-
335
- if is_PP:
336
- raise Exception("PowerPaint model was loaded, please use PowerPaint node")
337
-
338
- # Make a copy of the model so that we're not patching it everywhere in the workflow.
339
- model = model.clone()
340
-
341
- # prepare image and mask
342
- # no batches for original image and mask
343
- masked_image, mask = prepare_image(image, mask)
344
-
345
- batch = masked_image.shape[0]
346
- width = masked_image.shape[2]
347
- height = masked_image.shape[1]
348
-
349
- if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format, 'scale_factor'):
350
- scaling_factor = model.model.model_config.latent_format.scale_factor
351
- elif is_SDXL:
352
- scaling_factor = sdxl_scaling_factor
353
- else:
354
- scaling_factor = sd15_scaling_factor
355
-
356
- torch_dtype = brushnet['dtype']
357
-
358
- # prepare conditioning latents
359
- conditioning_latents = get_image_latents(masked_image, mask, vae, scaling_factor)
360
- conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
361
- conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
362
-
363
- # unload vae
364
- del vae
365
- for loaded_model in comfy.model_management.current_loaded_models:
366
- if type(loaded_model.model.model) in ModelsToUnload:
367
- comfy.model_management.current_loaded_models.remove(loaded_model)
368
- loaded_model.model_unload()
369
- del loaded_model
370
-
371
- # prepare embeddings
372
-
373
- prompt_embeds = positive[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
374
- negative_prompt_embeds = negative[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
375
-
376
- max_tokens = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
377
- if prompt_embeds.shape[1] < max_tokens:
378
- multiplier = max_tokens // 77 - prompt_embeds.shape[1] // 77
379
- prompt_embeds = torch.concat([prompt_embeds] + [prompt_embeds[:,-77:,:]] * multiplier, dim=1)
380
- print('BrushNet: negative prompt more than 75 tokens:', negative_prompt_embeds.shape, 'multiplying prompt_embeds')
381
- if negative_prompt_embeds.shape[1] < max_tokens:
382
- multiplier = max_tokens // 77 - negative_prompt_embeds.shape[1] // 77
383
- negative_prompt_embeds = torch.concat([negative_prompt_embeds] + [negative_prompt_embeds[:,-77:,:]] * multiplier, dim=1)
384
- print('BrushNet: positive prompt more than 75 tokens:', prompt_embeds.shape, 'multiplying negative_prompt_embeds')
385
-
386
- if len(positive[0]) > 1 and 'pooled_output' in positive[0][1] and positive[0][1]['pooled_output'] is not None:
387
- pooled_prompt_embeds = positive[0][1]['pooled_output'].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
388
- else:
389
- print('BrushNet: positive conditioning has not pooled_output')
390
- if is_SDXL:
391
- print('BrushNet will not produce correct results')
392
- pooled_prompt_embeds = torch.empty([2, 1280], device=brushnet['brushnet'].device).to(dtype=torch_dtype)
393
-
394
- if len(negative[0]) > 1 and 'pooled_output' in negative[0][1] and negative[0][1]['pooled_output'] is not None:
395
- negative_pooled_prompt_embeds = negative[0][1]['pooled_output'].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
396
- else:
397
- print('BrushNet: negative conditioning has not pooled_output')
398
- if is_SDXL:
399
- print('BrushNet will not produce correct results')
400
- negative_pooled_prompt_embeds = torch.empty([1, pooled_prompt_embeds.shape[1]], device=brushnet['brushnet'].device).to(dtype=torch_dtype)
401
-
402
- time_ids = torch.FloatTensor([[height, width, 0., 0., height, width]]).to(dtype=torch_dtype).to(brushnet['brushnet'].device)
403
-
404
- if not is_SDXL:
405
- pooled_prompt_embeds = None
406
- negative_pooled_prompt_embeds = None
407
- time_ids = None
408
-
409
- # apply patch to model
410
-
411
- brushnet_conditioning_scale = scale
412
- control_guidance_start = start_at
413
- control_guidance_end = end_at
414
-
415
- add_brushnet_patch(model,
416
- brushnet['brushnet'],
417
- torch_dtype,
418
- conditioning_latents,
419
- (brushnet_conditioning_scale, control_guidance_start, control_guidance_end),
420
- prompt_embeds, negative_prompt_embeds,
421
- pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids,
422
- False)
423
-
424
- latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]], device=brushnet['brushnet'].device)
425
-
426
- return (model, positive, negative, {"samples":latent},)
427
-
428
-
429
- class BlendInpaint:
430
-
431
- @classmethod
432
- def INPUT_TYPES(s):
433
- return {"required":
434
- {
435
- "inpaint": ("IMAGE",),
436
- "original": ("IMAGE",),
437
- "mask": ("MASK",),
438
- "kernel": ("INT", {"default": 10, "min": 1, "max": 1000}),
439
- "sigma": ("FLOAT", {"default": 10.0, "min": 0.01, "max": 1000}),
440
- },
441
- "optional":
442
- {
443
- "origin": ("VECTOR",),
444
- },
445
- }
446
-
447
- CATEGORY = "inpaint"
448
- RETURN_TYPES = ("IMAGE","MASK",)
449
- RETURN_NAMES = ("image","MASK",)
450
-
451
- FUNCTION = "blend_inpaint"
452
-
453
- def blend_inpaint(self, inpaint: torch.Tensor, original: torch.Tensor, mask, kernel: int, sigma:int, origin=None) -> Tuple[torch.Tensor]:
454
-
455
- original, mask = check_image_mask(original, mask, 'Blend Inpaint')
456
-
457
- if len(inpaint.shape) < 4:
458
- # image tensor shape should be [B, H, W, C], but batch somehow is missing
459
- inpaint = inpaint[None,:,:,:]
460
-
461
- if inpaint.shape[0] < original.shape[0]:
462
- print("Blend Inpaint gets batch of original images (%d) but only (%d) inpaint images" % (original.shape[0], inpaint.shape[0]))
463
- original= original[:inpaint.shape[0],:,:]
464
- mask = mask[:inpaint.shape[0],:,:]
465
-
466
- if inpaint.shape[0] > original.shape[0]:
467
- # batch over inpaint
468
- count = 0
469
- original_list = []
470
- mask_list = []
471
- origin_list = []
472
- while (count < inpaint.shape[0]):
473
- for i in range(original.shape[0]):
474
- original_list.append(original[i][None,:,:,:])
475
- mask_list.append(mask[i][None,:,:])
476
- if origin is not None:
477
- origin_list.append(origin[i][None,:])
478
- count += 1
479
- if count >= inpaint.shape[0]:
480
- break
481
- original = torch.concat(original_list, dim=0)
482
- mask = torch.concat(mask_list, dim=0)
483
- if origin is not None:
484
- origin = torch.concat(origin_list, dim=0)
485
-
486
- if kernel % 2 == 0:
487
- kernel += 1
488
- transform = T.GaussianBlur(kernel_size=(kernel, kernel), sigma=(sigma, sigma))
489
-
490
- ret = []
491
- blurred = []
492
- for i in range(inpaint.shape[0]):
493
- if origin is None:
494
- blurred_mask = transform(mask[i][None,None,:,:]).to(original.device).to(original.dtype)
495
- blurred.append(blurred_mask[0])
496
-
497
- result = torch.nn.functional.interpolate(
498
- inpaint[i][None,:,:,:].permute(0, 3, 1, 2),
499
- size=(
500
- original[i].shape[0],
501
- original[i].shape[1],
502
- )
503
- ).permute(0, 2, 3, 1).to(original.device).to(original.dtype)
504
- else:
505
- # got mask from CutForInpaint
506
- height, width, _ = original[i].shape
507
- x0 = origin[i][0].item()
508
- y0 = origin[i][1].item()
509
-
510
- if mask[i].shape[0] < height or mask[i].shape[1] < width:
511
- padded_mask = F.pad(input=mask[i], pad=(x0, width-x0-mask[i].shape[1],
512
- y0, height-y0-mask[i].shape[0]), mode='constant', value=0)
513
- else:
514
- padded_mask = mask[i]
515
- blurred_mask = transform(padded_mask[None,None,:,:]).to(original.device).to(original.dtype)
516
- blurred.append(blurred_mask[0][0])
517
-
518
- result = F.pad(input=inpaint[i], pad=(0, 0, x0, width-x0-inpaint[i].shape[1],
519
- y0, height-y0-inpaint[i].shape[0]), mode='constant', value=0)
520
- result = result[None,:,:,:].to(original.device).to(original.dtype)
521
-
522
- ret.append(original[i] * (1.0 - blurred_mask[0][0][:,:,None]) + result[0] * blurred_mask[0][0][:,:,None])
523
-
524
- return (torch.stack(ret), torch.stack(blurred), )
525
-
526
-
527
- class CutForInpaint:
528
-
529
- @classmethod
530
- def INPUT_TYPES(s):
531
- return {"required":
532
- {
533
- "image": ("IMAGE",),
534
- "mask": ("MASK",),
535
- "width": ("INT", {"default": 512, "min": 64, "max": 2048}),
536
- "height": ("INT", {"default": 512, "min": 64, "max": 2048}),
537
- },
538
- }
539
-
540
- CATEGORY = "inpaint"
541
- RETURN_TYPES = ("IMAGE","MASK","VECTOR",)
542
- RETURN_NAMES = ("image","mask","origin",)
543
-
544
- FUNCTION = "cut_for_inpaint"
545
-
546
- def cut_for_inpaint(self, image: torch.Tensor, mask: torch.Tensor, width: int, height: int):
547
-
548
- image, mask = check_image_mask(image, mask, 'BrushNet')
549
-
550
- ret = []
551
- msk = []
552
- org = []
553
- for i in range(image.shape[0]):
554
- x0, y0, w, h = cut_with_mask(mask[i], width, height)
555
- ret.append((image[i][y0:y0+h,x0:x0+w,:]))
556
- msk.append((mask[i][y0:y0+h,x0:x0+w]))
557
- org.append(torch.IntTensor([x0,y0]))
558
-
559
- return (torch.stack(ret), torch.stack(msk), torch.stack(org), )
560
-
561
-
562
- #### Utility function
563
-
564
- def get_files_with_extension(folder_name, extension=['.safetensors']):
565
-
566
- try:
567
- folders = folder_paths.get_folder_paths(folder_name)
568
- except:
569
- folders = []
570
-
571
- if not folders:
572
- folders = [os.path.join(folder_paths.models_dir, folder_name)]
573
- if not os.path.isdir(folders[0]):
574
- folders = [os.path.join(folder_paths.base_path, folder_name)]
575
- if not os.path.isdir(folders[0]):
576
- return {}
577
-
578
- filtered_folders = []
579
- for x in folders:
580
- if not os.path.isdir(x):
581
- continue
582
- the_same = False
583
- for y in filtered_folders:
584
- if os.path.samefile(x, y):
585
- the_same = True
586
- break
587
- if not the_same:
588
- filtered_folders.append(x)
589
-
590
- if not filtered_folders:
591
- return {}
592
-
593
- output = {}
594
- for x in filtered_folders:
595
- files, folders_all = folder_paths.recursive_search(x, excluded_dir_names=[".git"])
596
- filtered_files = folder_paths.filter_files_extensions(files, extension)
597
-
598
- for f in filtered_files:
599
- output[f] = x
600
-
601
- return output
602
-
603
-
604
- # get blocks from state_dict so we could know which model it is
605
- def brushnet_blocks(sd):
606
- brushnet_down_block = 0
607
- brushnet_mid_block = 0
608
- brushnet_up_block = 0
609
- for key in sd:
610
- if 'brushnet_down_block' in key:
611
- brushnet_down_block += 1
612
- if 'brushnet_mid_block' in key:
613
- brushnet_mid_block += 1
614
- if 'brushnet_up_block' in key:
615
- brushnet_up_block += 1
616
- return (brushnet_down_block, brushnet_mid_block, brushnet_up_block, len(sd))
617
-
618
-
619
- # Check models compatibility
620
- def check_compatibilty(model, brushnet):
621
- is_SDXL = False
622
- is_PP = False
623
- if isinstance(model.model.model_config, comfy.supported_models.SD15):
624
- print('Base model type: SD1.5')
625
- is_SDXL = False
626
- if brushnet["SDXL"]:
627
- raise Exception("Base model is SD15, but BrushNet is SDXL type")
628
- if brushnet["PP"]:
629
- is_PP = True
630
- elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
631
- print('Base model type: SDXL')
632
- is_SDXL = True
633
- if not brushnet["SDXL"]:
634
- raise Exception("Base model is SDXL, but BrushNet is SD15 type")
635
- else:
636
- print('Base model type: ', type(model.model.model_config))
637
- raise Exception("Unsupported model type: " + str(type(model.model.model_config)))
638
-
639
- return (is_SDXL, is_PP)
640
-
641
-
642
- def check_image_mask(image, mask, name):
643
- if len(image.shape) < 4:
644
- # image tensor shape should be [B, H, W, C], but batch somehow is missing
645
- image = image[None,:,:,:]
646
-
647
- if len(mask.shape) > 3:
648
- # mask tensor shape should be [B, H, W] but we get [B, H, W, C], image may be?
649
- # take first mask, red channel
650
- mask = (mask[:,:,:,0])[:,:,:]
651
- elif len(mask.shape) < 3:
652
- # mask tensor shape should be [B, H, W] but batch somehow is missing
653
- mask = mask[None,:,:]
654
-
655
- if image.shape[0] > mask.shape[0]:
656
- print(name, "gets batch of images (%d) but only %d masks" % (image.shape[0], mask.shape[0]))
657
- if mask.shape[0] == 1:
658
- print(name, "will copy the mask to fill batch")
659
- mask = torch.cat([mask] * image.shape[0], dim=0)
660
- else:
661
- print(name, "will add empty masks to fill batch")
662
- empty_mask = torch.zeros([image.shape[0] - mask.shape[0], mask.shape[1], mask.shape[2]])
663
- mask = torch.cat([mask, empty_mask], dim=0)
664
- elif image.shape[0] < mask.shape[0]:
665
- print(name, "gets batch of images (%d) but too many (%d) masks" % (image.shape[0], mask.shape[0]))
666
- mask = mask[:image.shape[0],:,:]
667
-
668
- return (image, mask)
669
-
670
-
671
- # Prepare image and mask
672
- def prepare_image(image, mask):
673
-
674
- image, mask = check_image_mask(image, mask, 'BrushNet')
675
-
676
- print("BrushNet image.shape =", image.shape, "mask.shape =", mask.shape)
677
-
678
- if mask.shape[2] != image.shape[2] or mask.shape[1] != image.shape[1]:
679
- raise Exception("Image and mask should be the same size")
680
-
681
- # As a suggestion of inferno46n2 (https://github.com/nullquant/ComfyUI-BrushNet/issues/64)
682
- mask = mask.round()
683
-
684
- masked_image = image * (1.0 - mask[:,:,:,None])
685
-
686
- return (masked_image, mask)
687
-
688
-
689
- # Get origin of the mask
690
- def cut_with_mask(mask, width, height):
691
- iy, ix = (mask == 1).nonzero(as_tuple=True)
692
-
693
- h0, w0 = mask.shape
694
-
695
- if iy.numel() == 0:
696
- x_c = w0 / 2.0
697
- y_c = h0 / 2.0
698
- else:
699
- x_min = ix.min().item()
700
- x_max = ix.max().item()
701
- y_min = iy.min().item()
702
- y_max = iy.max().item()
703
-
704
- if x_max - x_min > width or y_max - y_min > height:
705
- raise Exception("Masked area is bigger than provided dimensions")
706
-
707
- x_c = (x_min + x_max) / 2.0
708
- y_c = (y_min + y_max) / 2.0
709
-
710
- width2 = width / 2.0
711
- height2 = height / 2.0
712
-
713
- if w0 <= width:
714
- x0 = 0
715
- w = w0
716
- else:
717
- x0 = max(0, x_c - width2)
718
- w = width
719
- if x0 + width > w0:
720
- x0 = w0 - width
721
-
722
- if h0 <= height:
723
- y0 = 0
724
- h = h0
725
- else:
726
- y0 = max(0, y_c - height2)
727
- h = height
728
- if y0 + height > h0:
729
- y0 = h0 - height
730
-
731
- return (int(x0), int(y0), int(w), int(h))
732
-
733
-
734
- # Prepare conditioning_latents
735
- @torch.inference_mode()
736
- def get_image_latents(masked_image, mask, vae, scaling_factor):
737
- processed_image = masked_image.to(vae.device)
738
- image_latents = vae.encode(processed_image[:,:,:,:3]) * scaling_factor
739
- processed_mask = 1. - mask[:,None,:,:]
740
- interpolated_mask = torch.nn.functional.interpolate(
741
- processed_mask,
742
- size=(
743
- image_latents.shape[-2],
744
- image_latents.shape[-1]
745
- )
746
- )
747
- interpolated_mask = interpolated_mask.to(image_latents.device)
748
-
749
- conditioning_latents = [image_latents, interpolated_mask]
750
-
751
- print('BrushNet CL: image_latents shape =', image_latents.shape, 'interpolated_mask shape =', interpolated_mask.shape)
752
-
753
- return conditioning_latents
754
-
755
-
756
- # Main function where magic happens
757
- @torch.inference_mode()
758
- def brushnet_inference(x, timesteps, transformer_options, debug):
759
- if 'model_patch' not in transformer_options:
760
- print('BrushNet inference: there is no model_patch key in transformer_options')
761
- return ([], 0, [])
762
- mp = transformer_options['model_patch']
763
- if 'brushnet' not in mp:
764
- print('BrushNet inference: there is no brushnet key in mdel_patch')
765
- return ([], 0, [])
766
- bo = mp['brushnet']
767
- if 'model' not in bo:
768
- print('BrushNet inference: there is no model key in brushnet')
769
- return ([], 0, [])
770
- brushnet = bo['model']
771
- if not (isinstance(brushnet, BrushNetModel) or isinstance(brushnet, PowerPaintModel)):
772
- print('BrushNet model is not a BrushNetModel class')
773
- return ([], 0, [])
774
-
775
- torch_dtype = bo['dtype']
776
- cl_list = bo['latents']
777
- brushnet_conditioning_scale, control_guidance_start, control_guidance_end = bo['controls']
778
- pe = bo['prompt_embeds']
779
- npe = bo['negative_prompt_embeds']
780
- ppe, nppe, time_ids = bo['add_embeds']
781
-
782
- #do_classifier_free_guidance = mp['free_guidance']
783
- do_classifier_free_guidance = len(transformer_options['cond_or_uncond']) > 1
784
-
785
- x = x.detach().clone()
786
- x = x.to(torch_dtype).to(brushnet.device)
787
-
788
- timesteps = timesteps.detach().clone()
789
- timesteps = timesteps.to(torch_dtype).to(brushnet.device)
790
-
791
- total_steps = mp['total_steps']
792
- step = mp['step']
793
-
794
- added_cond_kwargs = {}
795
-
796
- if do_classifier_free_guidance and step == 0:
797
- print('BrushNet inference: do_classifier_free_guidance is True')
798
-
799
- sub_idx = None
800
- if 'ad_params' in transformer_options and 'sub_idxs' in transformer_options['ad_params']:
801
- sub_idx = transformer_options['ad_params']['sub_idxs']
802
-
803
- # we have batch input images
804
- batch = cl_list[0].shape[0]
805
- # we have incoming latents
806
- latents_incoming = x.shape[0]
807
- # and we already got some
808
- latents_got = bo['latent_id']
809
- if step == 0 or batch > 1:
810
- print('BrushNet inference, step = %d: image batch = %d, got %d latents, starting from %d' \
811
- % (step, batch, latents_incoming, latents_got))
812
-
813
- image_latents = []
814
- masks = []
815
- prompt_embeds = []
816
- negative_prompt_embeds = []
817
- pooled_prompt_embeds = []
818
- negative_pooled_prompt_embeds = []
819
- if sub_idx:
820
- # AnimateDiff indexes detected
821
- if step == 0:
822
- print('BrushNet inference: AnimateDiff indexes detected and applied')
823
-
824
- batch = len(sub_idx)
825
-
826
- if do_classifier_free_guidance:
827
- for i in sub_idx:
828
- image_latents.append(cl_list[0][i][None,:,:,:])
829
- masks.append(cl_list[1][i][None,:,:,:])
830
- prompt_embeds.append(pe)
831
- negative_prompt_embeds.append(npe)
832
- pooled_prompt_embeds.append(ppe)
833
- negative_pooled_prompt_embeds.append(nppe)
834
- for i in sub_idx:
835
- image_latents.append(cl_list[0][i][None,:,:,:])
836
- masks.append(cl_list[1][i][None,:,:,:])
837
- else:
838
- for i in sub_idx:
839
- image_latents.append(cl_list[0][i][None,:,:,:])
840
- masks.append(cl_list[1][i][None,:,:,:])
841
- prompt_embeds.append(pe)
842
- pooled_prompt_embeds.append(ppe)
843
- else:
844
- # do_classifier_free_guidance = 2 passes, 1st pass is cond, 2nd is uncond
845
- continue_batch = True
846
- for i in range(latents_incoming):
847
- number = latents_got + i
848
- if number < batch:
849
- # 1st pass, cond
850
- image_latents.append(cl_list[0][number][None,:,:,:])
851
- masks.append(cl_list[1][number][None,:,:,:])
852
- prompt_embeds.append(pe)
853
- pooled_prompt_embeds.append(ppe)
854
- elif do_classifier_free_guidance and number < batch * 2:
855
- # 2nd pass, uncond
856
- image_latents.append(cl_list[0][number-batch][None,:,:,:])
857
- masks.append(cl_list[1][number-batch][None,:,:,:])
858
- negative_prompt_embeds.append(npe)
859
- negative_pooled_prompt_embeds.append(nppe)
860
- else:
861
- # latent batch
862
- image_latents.append(cl_list[0][0][None,:,:,:])
863
- masks.append(cl_list[1][0][None,:,:,:])
864
- prompt_embeds.append(pe)
865
- pooled_prompt_embeds.append(ppe)
866
- latents_got = -i
867
- continue_batch = False
868
-
869
- if continue_batch:
870
- # we don't have full batch yet
871
- if do_classifier_free_guidance:
872
- if number < batch * 2 - 1:
873
- bo['latent_id'] = number + 1
874
- else:
875
- bo['latent_id'] = 0
876
- else:
877
- if number < batch - 1:
878
- bo['latent_id'] = number + 1
879
- else:
880
- bo['latent_id'] = 0
881
- else:
882
- bo['latent_id'] = 0
883
-
884
- cl = []
885
- for il, m in zip(image_latents, masks):
886
- cl.append(torch.concat([il, m], dim=1))
887
- cl2apply = torch.concat(cl, dim=0)
888
-
889
- conditioning_latents = cl2apply.to(torch_dtype).to(brushnet.device)
890
-
891
- prompt_embeds.extend(negative_prompt_embeds)
892
- prompt_embeds = torch.concat(prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device)
893
-
894
- if ppe is not None:
895
- added_cond_kwargs = {}
896
- added_cond_kwargs['time_ids'] = torch.concat([time_ids] * latents_incoming, dim = 0).to(torch_dtype).to(brushnet.device)
897
-
898
- pooled_prompt_embeds.extend(negative_pooled_prompt_embeds)
899
- pooled_prompt_embeds = torch.concat(pooled_prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device)
900
- added_cond_kwargs['text_embeds'] = pooled_prompt_embeds
901
- else:
902
- added_cond_kwargs = None
903
-
904
- if x.shape[2] != conditioning_latents.shape[2] or x.shape[3] != conditioning_latents.shape[3]:
905
- if step == 0:
906
- print('BrushNet inference: image', conditioning_latents.shape, 'and latent', x.shape, 'have different size, resizing image')
907
- conditioning_latents = torch.nn.functional.interpolate(
908
- conditioning_latents, size=(
909
- x.shape[2],
910
- x.shape[3],
911
- ), mode='bicubic',
912
- ).to(torch_dtype).to(brushnet.device)
913
-
914
- if step == 0:
915
- print('BrushNet inference: sample', x.shape, ', CL', conditioning_latents.shape, 'dtype', torch_dtype)
916
-
917
- if debug: print('BrushNet: step =', step)
918
-
919
- if step < control_guidance_start or step > control_guidance_end:
920
- cond_scale = 0.0
921
- else:
922
- cond_scale = brushnet_conditioning_scale
923
-
924
- return brushnet(x,
925
- encoder_hidden_states=prompt_embeds,
926
- brushnet_cond=conditioning_latents,
927
- timestep = timesteps,
928
- conditioning_scale=cond_scale,
929
- guess_mode=False,
930
- added_cond_kwargs=added_cond_kwargs,
931
- return_dict=False,
932
- debug=debug,
933
- )
934
-
935
-
936
- # This is main patch function
937
- def add_brushnet_patch(model, brushnet, torch_dtype, conditioning_latents,
938
- controls,
939
- prompt_embeds, negative_prompt_embeds,
940
- pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids,
941
- debug):
942
-
943
- is_SDXL = isinstance(model.model.model_config, comfy.supported_models.SDXL)
944
-
945
- if is_SDXL:
946
- input_blocks = [[0, comfy.ops.disable_weight_init.Conv2d],
947
- [1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
948
- [2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
949
- [3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
950
- [4, comfy.ldm.modules.attention.SpatialTransformer],
951
- [5, comfy.ldm.modules.attention.SpatialTransformer],
952
- [6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
953
- [7, comfy.ldm.modules.attention.SpatialTransformer],
954
- [8, comfy.ldm.modules.attention.SpatialTransformer]]
955
- middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]
956
- output_blocks = [[0, comfy.ldm.modules.attention.SpatialTransformer],
957
- [1, comfy.ldm.modules.attention.SpatialTransformer],
958
- [2, comfy.ldm.modules.attention.SpatialTransformer],
959
- [2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
960
- [3, comfy.ldm.modules.attention.SpatialTransformer],
961
- [4, comfy.ldm.modules.attention.SpatialTransformer],
962
- [5, comfy.ldm.modules.attention.SpatialTransformer],
963
- [5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
964
- [6, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
965
- [7, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
966
- [8, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
967
- else:
968
- input_blocks = [[0, comfy.ops.disable_weight_init.Conv2d],
969
- [1, comfy.ldm.modules.attention.SpatialTransformer],
970
- [2, comfy.ldm.modules.attention.SpatialTransformer],
971
- [3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
972
- [4, comfy.ldm.modules.attention.SpatialTransformer],
973
- [5, comfy.ldm.modules.attention.SpatialTransformer],
974
- [6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
975
- [7, comfy.ldm.modules.attention.SpatialTransformer],
976
- [8, comfy.ldm.modules.attention.SpatialTransformer],
977
- [9, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
978
- [10, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
979
- [11, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
980
- middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]
981
- output_blocks = [[0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
982
- [1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
983
- [2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
984
- [2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
985
- [3, comfy.ldm.modules.attention.SpatialTransformer],
986
- [4, comfy.ldm.modules.attention.SpatialTransformer],
987
- [5, comfy.ldm.modules.attention.SpatialTransformer],
988
- [5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
989
- [6, comfy.ldm.modules.attention.SpatialTransformer],
990
- [7, comfy.ldm.modules.attention.SpatialTransformer],
991
- [8, comfy.ldm.modules.attention.SpatialTransformer],
992
- [8, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
993
- [9, comfy.ldm.modules.attention.SpatialTransformer],
994
- [10, comfy.ldm.modules.attention.SpatialTransformer],
995
- [11, comfy.ldm.modules.attention.SpatialTransformer]]
996
-
997
- def last_layer_index(block, tp):
998
- layer_list = []
999
- for layer in block:
1000
- layer_list.append(type(layer))
1001
- layer_list.reverse()
1002
- if tp not in layer_list:
1003
- return -1, layer_list.reverse()
1004
- return len(layer_list) - 1 - layer_list.index(tp), layer_list
1005
-
1006
- def brushnet_forward(model, x, timesteps, transformer_options, control):
1007
- if 'brushnet' not in transformer_options['model_patch']:
1008
- input_samples = []
1009
- mid_sample = 0
1010
- output_samples = []
1011
- else:
1012
- # brushnet inference
1013
- input_samples, mid_sample, output_samples = brushnet_inference(x, timesteps, transformer_options, debug)
1014
-
1015
- # give additional samples to blocks
1016
- for i, tp in input_blocks:
1017
- idx, layer_list = last_layer_index(model.input_blocks[i], tp)
1018
- if idx < 0:
1019
- print("BrushNet can't find", tp, "layer in", i,"input block:", layer_list)
1020
- continue
1021
- model.input_blocks[i][idx].add_sample_after = input_samples.pop(0) if input_samples else 0
1022
-
1023
- idx, layer_list = last_layer_index(model.middle_block, middle_block[1])
1024
- if idx < 0:
1025
- print("BrushNet can't find", middle_block[1], "layer in middle block", layer_list)
1026
- model.middle_block[idx].add_sample_after = mid_sample
1027
-
1028
- for i, tp in output_blocks:
1029
- idx, layer_list = last_layer_index(model.output_blocks[i], tp)
1030
- if idx < 0:
1031
- print("BrushNet can't find", tp, "layer in", i,"outnput block:", layer_list)
1032
- continue
1033
- model.output_blocks[i][idx].add_sample_after = output_samples.pop(0) if output_samples else 0
1034
-
1035
- patch_model_function_wrapper(model, brushnet_forward)
1036
-
1037
- to = add_model_patch_option(model)
1038
- mp = to['model_patch']
1039
- if 'brushnet' not in mp:
1040
- mp['brushnet'] = {}
1041
- bo = mp['brushnet']
1042
-
1043
- bo['model'] = brushnet
1044
- bo['dtype'] = torch_dtype
1045
- bo['latents'] = conditioning_latents
1046
- bo['controls'] = controls
1047
- bo['prompt_embeds'] = prompt_embeds
1048
- bo['negative_prompt_embeds'] = negative_prompt_embeds
1049
- bo['add_embeds'] = (pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids)
1050
- bo['latent_id'] = 0
1051
-
1052
- # patch layers `forward` so we can apply brushnet
1053
- def forward_patched_by_brushnet(self, x, *args, **kwargs):
1054
- h = self.original_forward(x, *args, **kwargs)
1055
- if hasattr(self, 'add_sample_after') and type(self):
1056
- to_add = self.add_sample_after
1057
- if torch.is_tensor(to_add):
1058
- # interpolate due to RAUNet
1059
- if h.shape[2] != to_add.shape[2] or h.shape[3] != to_add.shape[3]:
1060
- to_add = torch.nn.functional.interpolate(to_add, size=(h.shape[2], h.shape[3]), mode='bicubic')
1061
- h += to_add.to(h.dtype).to(h.device)
1062
- else:
1063
- h += self.add_sample_after
1064
- self.add_sample_after = 0
1065
- return h
1066
-
1067
- for i, block in enumerate(model.model.diffusion_model.input_blocks):
1068
- for j, layer in enumerate(block):
1069
- if not hasattr(layer, 'original_forward'):
1070
- layer.original_forward = layer.forward
1071
- layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
1072
- layer.add_sample_after = 0
1073
-
1074
- for j, layer in enumerate(model.model.diffusion_model.middle_block):
1075
- if not hasattr(layer, 'original_forward'):
1076
- layer.original_forward = layer.forward
1077
- layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
1078
- layer.add_sample_after = 0
1079
-
1080
- for i, block in enumerate(model.model.diffusion_model.output_blocks):
1081
- for j, layer in enumerate(block):
1082
- if not hasattr(layer, 'original_forward'):
1083
- layer.original_forward = layer.forward
1084
- layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
1085
- layer.add_sample_after = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import types
3
+ from typing import Tuple
4
+
5
+ import torch
6
+ import torchvision.transforms as T
7
+ import torch.nn.functional as F
8
+ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
9
+
10
+ import comfy
11
+ import folder_paths
12
+
13
+ from .model_patch import add_model_patch_option, patch_model_function_wrapper
14
+
15
+ from .brushnet.brushnet import BrushNetModel
16
+ from .brushnet.brushnet_ca import BrushNetModel as PowerPaintModel
17
+
18
+ from .brushnet.powerpaint_utils import TokenizerWrapper, add_tokens
19
+
20
+ current_directory = os.path.dirname(os.path.abspath(__file__))
21
+ brushnet_config_file = os.path.join(current_directory, 'brushnet', 'brushnet.json')
22
+ brushnet_xl_config_file = os.path.join(current_directory, 'brushnet', 'brushnet_xl.json')
23
+ powerpaint_config_file = os.path.join(current_directory,'brushnet', 'powerpaint.json')
24
+
25
+ sd15_scaling_factor = 0.18215
26
+ sdxl_scaling_factor = 0.13025
27
+
28
+ ModelsToUnload = [comfy.sd1_clip.SD1ClipModel,
29
+ comfy.ldm.models.autoencoder.AutoencoderKL
30
+ ]
31
+
32
+
33
+ class BrushNetLoader:
34
+
35
+ @classmethod
36
+ def INPUT_TYPES(self):
37
+ self.inpaint_files = get_files_with_extension('inpaint')
38
+ return {"required":
39
+ {
40
+ "brushnet": ([file for file in self.inpaint_files], ),
41
+ "dtype": (['float16', 'bfloat16', 'float32', 'float64'], ),
42
+ },
43
+ }
44
+
45
+ CATEGORY = "inpaint"
46
+ RETURN_TYPES = ("BRMODEL",)
47
+ RETURN_NAMES = ("brushnet",)
48
+
49
+ FUNCTION = "brushnet_loading"
50
+
51
+ def brushnet_loading(self, brushnet, dtype):
52
+ brushnet_file = os.path.join(self.inpaint_files[brushnet], brushnet)
53
+ is_SDXL = False
54
+ is_PP = False
55
+ sd = comfy.utils.load_torch_file(brushnet_file)
56
+ brushnet_down_block, brushnet_mid_block, brushnet_up_block, keys = brushnet_blocks(sd)
57
+ del sd
58
+ if brushnet_down_block == 24 and brushnet_mid_block == 2 and brushnet_up_block == 30:
59
+ is_SDXL = False
60
+ if keys == 322:
61
+ is_PP = False
62
+ print('BrushNet model type: SD1.5')
63
+ else:
64
+ is_PP = True
65
+ print('PowerPaint model type: SD1.5')
66
+ elif brushnet_down_block == 18 and brushnet_mid_block == 2 and brushnet_up_block == 22:
67
+ print('BrushNet model type: Loading SDXL')
68
+ is_SDXL = True
69
+ is_PP = False
70
+ else:
71
+ raise Exception("Unknown BrushNet model")
72
+
73
+ with init_empty_weights():
74
+ if is_SDXL:
75
+ brushnet_config = BrushNetModel.load_config(brushnet_xl_config_file)
76
+ brushnet_model = BrushNetModel.from_config(brushnet_config)
77
+ elif is_PP:
78
+ brushnet_config = PowerPaintModel.load_config(powerpaint_config_file)
79
+ brushnet_model = PowerPaintModel.from_config(brushnet_config)
80
+ else:
81
+ brushnet_config = BrushNetModel.load_config(brushnet_config_file)
82
+ brushnet_model = BrushNetModel.from_config(brushnet_config)
83
+
84
+ if is_PP:
85
+ print("PowerPaint model file:", brushnet_file)
86
+ else:
87
+ print("BrushNet model file:", brushnet_file)
88
+
89
+ if dtype == 'float16':
90
+ torch_dtype = torch.float16
91
+ elif dtype == 'bfloat16':
92
+ torch_dtype = torch.bfloat16
93
+ elif dtype == 'float32':
94
+ torch_dtype = torch.float32
95
+ else:
96
+ torch_dtype = torch.float64
97
+
98
+ brushnet_model = load_checkpoint_and_dispatch(
99
+ brushnet_model,
100
+ brushnet_file,
101
+ device_map="sequential",
102
+ max_memory=None,
103
+ offload_folder=None,
104
+ offload_state_dict=False,
105
+ dtype=torch_dtype,
106
+ force_hooks=False,
107
+ )
108
+
109
+ if is_PP:
110
+ print("PowerPaint model is loaded")
111
+ elif is_SDXL:
112
+ print("BrushNet SDXL model is loaded")
113
+ else:
114
+ print("BrushNet SD1.5 model is loaded")
115
+
116
+ return ({"brushnet": brushnet_model, "SDXL": is_SDXL, "PP": is_PP, "dtype": torch_dtype}, )
117
+
118
+
119
+ class PowerPaintCLIPLoader:
120
+
121
+ @classmethod
122
+ def INPUT_TYPES(self):
123
+ self.inpaint_files = get_files_with_extension('inpaint', ['.bin'])
124
+ self.clip_files = get_files_with_extension('clip')
125
+ return {"required":
126
+ {
127
+ "base": ([file for file in self.clip_files], ),
128
+ "powerpaint": ([file for file in self.inpaint_files], ),
129
+ },
130
+ }
131
+
132
+ CATEGORY = "inpaint"
133
+ RETURN_TYPES = ("CLIP",)
134
+ RETURN_NAMES = ("clip",)
135
+
136
+ FUNCTION = "ppclip_loading"
137
+
138
+ def ppclip_loading(self, base, powerpaint):
139
+ base_CLIP_file = os.path.join(self.clip_files[base], base)
140
+ pp_CLIP_file = os.path.join(self.inpaint_files[powerpaint], powerpaint)
141
+
142
+ pp_clip = comfy.sd.load_clip(ckpt_paths=[base_CLIP_file])
143
+
144
+ print('PowerPaint base CLIP file: ', base_CLIP_file)
145
+
146
+ pp_tokenizer = TokenizerWrapper(pp_clip.tokenizer.clip_l.tokenizer)
147
+ pp_text_encoder = pp_clip.patcher.model.clip_l.transformer
148
+
149
+ add_tokens(
150
+ tokenizer = pp_tokenizer,
151
+ text_encoder = pp_text_encoder,
152
+ placeholder_tokens = ["P_ctxt", "P_shape", "P_obj"],
153
+ initialize_tokens = ["a", "a", "a"],
154
+ num_vectors_per_token = 10,
155
+ )
156
+
157
+ pp_text_encoder.load_state_dict(comfy.utils.load_torch_file(pp_CLIP_file), strict=False)
158
+
159
+ print('PowerPaint CLIP file: ', pp_CLIP_file)
160
+
161
+ pp_clip.tokenizer.clip_l.tokenizer = pp_tokenizer
162
+ pp_clip.patcher.model.clip_l.transformer = pp_text_encoder
163
+
164
+ return (pp_clip,)
165
+
166
+
167
+ class PowerPaint:
168
+
169
+ @classmethod
170
+ def INPUT_TYPES(s):
171
+ return {"required":
172
+ {
173
+ "model": ("MODEL",),
174
+ "vae": ("VAE", ),
175
+ "image": ("IMAGE",),
176
+ "mask": ("MASK",),
177
+ "powerpaint": ("BRMODEL", ),
178
+ "clip": ("CLIP", ),
179
+ "positive": ("CONDITIONING", ),
180
+ "negative": ("CONDITIONING", ),
181
+ "fitting" : ("FLOAT", {"default": 1.0, "min": 0.3, "max": 1.0}),
182
+ "function": (['text guided', 'shape guided', 'object removal', 'context aware', 'image outpainting'], ),
183
+ "scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
184
+ "start_at": ("INT", {"default": 0, "min": 0, "max": 10000}),
185
+ "end_at": ("INT", {"default": 10000, "min": 0, "max": 10000}),
186
+ "save_memory": (['none', 'auto', 'max'], ),
187
+ },
188
+ }
189
+
190
+ CATEGORY = "inpaint"
191
+ RETURN_TYPES = ("MODEL","CONDITIONING","CONDITIONING","LATENT",)
192
+ RETURN_NAMES = ("model","positive","negative","latent",)
193
+
194
+ FUNCTION = "model_update"
195
+
196
+ def model_update(self, model, vae, image, mask, powerpaint, clip, positive, negative, fitting, function, scale, start_at, end_at, save_memory):
197
+
198
+ is_SDXL, is_PP = check_compatibilty(model, powerpaint)
199
+ if not is_PP:
200
+ raise Exception("BrushNet model was loaded, please use BrushNet node")
201
+
202
+ # Make a copy of the model so that we're not patching it everywhere in the workflow.
203
+ model = model.clone()
204
+
205
+ # prepare image and mask
206
+ # no batches for original image and mask
207
+ masked_image, mask = prepare_image(image, mask)
208
+
209
+ batch = masked_image.shape[0]
210
+ #width = masked_image.shape[2]
211
+ #height = masked_image.shape[1]
212
+
213
+ if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format, 'scale_factor'):
214
+ scaling_factor = model.model.model_config.latent_format.scale_factor
215
+ else:
216
+ scaling_factor = sd15_scaling_factor
217
+
218
+ torch_dtype = powerpaint['dtype']
219
+
220
+ # prepare conditioning latents
221
+ conditioning_latents = get_image_latents(masked_image, mask, vae, scaling_factor)
222
+ conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
223
+ conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
224
+
225
+ # prepare embeddings
226
+
227
+ if function == "object removal":
228
+ promptA = "P_ctxt"
229
+ promptB = "P_ctxt"
230
+ negative_promptA = "P_obj"
231
+ negative_promptB = "P_obj"
232
+ print('You should add to positive prompt: "empty scene blur"')
233
+ #positive = positive + " empty scene blur"
234
+ elif function == "context aware":
235
+ promptA = "P_ctxt"
236
+ promptB = "P_ctxt"
237
+ negative_promptA = ""
238
+ negative_promptB = ""
239
+ #positive = positive + " empty scene"
240
+ print('You should add to positive prompt: "empty scene"')
241
+ elif function == "shape guided":
242
+ promptA = "P_shape"
243
+ promptB = "P_ctxt"
244
+ negative_promptA = "P_shape"
245
+ negative_promptB = "P_ctxt"
246
+ elif function == "image outpainting":
247
+ promptA = "P_ctxt"
248
+ promptB = "P_ctxt"
249
+ negative_promptA = "P_obj"
250
+ negative_promptB = "P_obj"
251
+ #positive = positive + " empty scene"
252
+ print('You should add to positive prompt: "empty scene"')
253
+ else:
254
+ promptA = "P_obj"
255
+ promptB = "P_obj"
256
+ negative_promptA = "P_obj"
257
+ negative_promptB = "P_obj"
258
+
259
+ tokens = clip.tokenize(promptA)
260
+ prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False)
261
+
262
+ tokens = clip.tokenize(negative_promptA)
263
+ negative_prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False)
264
+
265
+ tokens = clip.tokenize(promptB)
266
+ prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False)
267
+
268
+ tokens = clip.tokenize(negative_promptB)
269
+ negative_prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False)
270
+
271
+ prompt_embeds_pp = (prompt_embedsA * fitting + (1.0 - fitting) * prompt_embedsB).to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
272
+ negative_prompt_embeds_pp = (negative_prompt_embedsA * fitting + (1.0 - fitting) * negative_prompt_embedsB).to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
273
+
274
+ # unload vae and CLIPs
275
+ del vae
276
+ del clip
277
+ for loaded_model in comfy.model_management.current_loaded_models:
278
+ if type(loaded_model.model.model) in ModelsToUnload:
279
+ comfy.model_management.current_loaded_models.remove(loaded_model)
280
+ loaded_model.model_unload()
281
+ del loaded_model
282
+
283
+ # apply patch to model
284
+
285
+ brushnet_conditioning_scale = scale
286
+ control_guidance_start = start_at
287
+ control_guidance_end = end_at
288
+
289
+ if save_memory != 'none':
290
+ powerpaint['brushnet'].set_attention_slice(save_memory)
291
+
292
+ add_brushnet_patch(model,
293
+ powerpaint['brushnet'],
294
+ torch_dtype,
295
+ conditioning_latents,
296
+ (brushnet_conditioning_scale, control_guidance_start, control_guidance_end),
297
+ negative_prompt_embeds_pp, prompt_embeds_pp,
298
+ None, None, None,
299
+ False)
300
+
301
+ latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]], device=powerpaint['brushnet'].device)
302
+
303
+ return (model, positive, negative, {"samples":latent},)
304
+
305
+
306
+ class BrushNet:
307
+
308
+ @classmethod
309
+ def INPUT_TYPES(s):
310
+ return {"required":
311
+ {
312
+ "model": ("MODEL",),
313
+ "vae": ("VAE", ),
314
+ "image": ("IMAGE",),
315
+ "mask": ("MASK",),
316
+ "brushnet": ("BRMODEL", ),
317
+ "positive": ("CONDITIONING", ),
318
+ "negative": ("CONDITIONING", ),
319
+ "scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
320
+ "start_at": ("INT", {"default": 0, "min": 0, "max": 10000}),
321
+ "end_at": ("INT", {"default": 10000, "min": 0, "max": 10000}),
322
+ },
323
+ }
324
+
325
+ CATEGORY = "inpaint"
326
+ RETURN_TYPES = ("MODEL","CONDITIONING","CONDITIONING","LATENT",)
327
+ RETURN_NAMES = ("model","positive","negative","latent",)
328
+
329
+ FUNCTION = "model_update"
330
+
331
+ def model_update(self, model, vae, image, mask, brushnet, positive, negative, scale, start_at, end_at):
332
+
333
+ is_SDXL, is_PP = check_compatibilty(model, brushnet)
334
+
335
+ if is_PP:
336
+ raise Exception("PowerPaint model was loaded, please use PowerPaint node")
337
+
338
+ # Make a copy of the model so that we're not patching it everywhere in the workflow.
339
+ model = model.clone()
340
+
341
+ # prepare image and mask
342
+ # no batches for original image and mask
343
+ masked_image, mask = prepare_image(image, mask)
344
+
345
+ batch = masked_image.shape[0]
346
+ width = masked_image.shape[2]
347
+ height = masked_image.shape[1]
348
+
349
+ if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format, 'scale_factor'):
350
+ scaling_factor = model.model.model_config.latent_format.scale_factor
351
+ elif is_SDXL:
352
+ scaling_factor = sdxl_scaling_factor
353
+ else:
354
+ scaling_factor = sd15_scaling_factor
355
+
356
+ torch_dtype = brushnet['dtype']
357
+
358
+ # prepare conditioning latents
359
+ conditioning_latents = get_image_latents(masked_image, mask, vae, scaling_factor)
360
+ conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
361
+ conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
362
+
363
+ # unload vae
364
+ del vae
365
+ for loaded_model in comfy.model_management.current_loaded_models:
366
+ if type(loaded_model.model.model) in ModelsToUnload:
367
+ comfy.model_management.current_loaded_models.remove(loaded_model)
368
+ loaded_model.model_unload()
369
+ del loaded_model
370
+
371
+ # prepare embeddings
372
+
373
+ prompt_embeds = positive[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
374
+ negative_prompt_embeds = negative[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
375
+
376
+ max_tokens = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
377
+ if prompt_embeds.shape[1] < max_tokens:
378
+ multiplier = max_tokens // 77 - prompt_embeds.shape[1] // 77
379
+ prompt_embeds = torch.concat([prompt_embeds] + [prompt_embeds[:,-77:,:]] * multiplier, dim=1)
380
+ print('BrushNet: negative prompt more than 75 tokens:', negative_prompt_embeds.shape, 'multiplying prompt_embeds')
381
+ if negative_prompt_embeds.shape[1] < max_tokens:
382
+ multiplier = max_tokens // 77 - negative_prompt_embeds.shape[1] // 77
383
+ negative_prompt_embeds = torch.concat([negative_prompt_embeds] + [negative_prompt_embeds[:,-77:,:]] * multiplier, dim=1)
384
+ print('BrushNet: positive prompt more than 75 tokens:', prompt_embeds.shape, 'multiplying negative_prompt_embeds')
385
+
386
+ if len(positive[0]) > 1 and 'pooled_output' in positive[0][1] and positive[0][1]['pooled_output'] is not None:
387
+ pooled_prompt_embeds = positive[0][1]['pooled_output'].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
388
+ else:
389
+ print('BrushNet: positive conditioning has not pooled_output')
390
+ if is_SDXL:
391
+ print('BrushNet will not produce correct results')
392
+ pooled_prompt_embeds = torch.empty([2, 1280], device=brushnet['brushnet'].device).to(dtype=torch_dtype)
393
+
394
+ if len(negative[0]) > 1 and 'pooled_output' in negative[0][1] and negative[0][1]['pooled_output'] is not None:
395
+ negative_pooled_prompt_embeds = negative[0][1]['pooled_output'].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
396
+ else:
397
+ print('BrushNet: negative conditioning has not pooled_output')
398
+ if is_SDXL:
399
+ print('BrushNet will not produce correct results')
400
+ negative_pooled_prompt_embeds = torch.empty([1, pooled_prompt_embeds.shape[1]], device=brushnet['brushnet'].device).to(dtype=torch_dtype)
401
+
402
+ time_ids = torch.FloatTensor([[height, width, 0., 0., height, width]]).to(dtype=torch_dtype).to(brushnet['brushnet'].device)
403
+
404
+ if not is_SDXL:
405
+ pooled_prompt_embeds = None
406
+ negative_pooled_prompt_embeds = None
407
+ time_ids = None
408
+
409
+ # apply patch to model
410
+
411
+ brushnet_conditioning_scale = scale
412
+ control_guidance_start = start_at
413
+ control_guidance_end = end_at
414
+
415
+ add_brushnet_patch(model,
416
+ brushnet['brushnet'],
417
+ torch_dtype,
418
+ conditioning_latents,
419
+ (brushnet_conditioning_scale, control_guidance_start, control_guidance_end),
420
+ prompt_embeds, negative_prompt_embeds,
421
+ pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids,
422
+ False)
423
+
424
+ latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]], device=brushnet['brushnet'].device)
425
+
426
+ return (model, positive, negative, {"samples":latent},)
427
+
428
+
429
+ class BlendInpaint:
430
+
431
+ @classmethod
432
+ def INPUT_TYPES(s):
433
+ return {"required":
434
+ {
435
+ "inpaint": ("IMAGE",),
436
+ "original": ("IMAGE",),
437
+ "mask": ("MASK",),
438
+ "kernel": ("INT", {"default": 10, "min": 1, "max": 1000}),
439
+ "sigma": ("FLOAT", {"default": 10.0, "min": 0.01, "max": 1000}),
440
+ },
441
+ "optional":
442
+ {
443
+ "origin": ("VECTOR",),
444
+ },
445
+ }
446
+
447
+ CATEGORY = "inpaint"
448
+ RETURN_TYPES = ("IMAGE","MASK",)
449
+ RETURN_NAMES = ("image","MASK",)
450
+
451
+ FUNCTION = "blend_inpaint"
452
+
453
+ def blend_inpaint(self, inpaint: torch.Tensor, original: torch.Tensor, mask, kernel: int, sigma:float, origin=None):
454
+ original, mask = check_image_mask(original, mask, 'Blend Inpaint')
455
+
456
+ if len(inpaint.shape) < 4:
457
+ inpaint = inpaint[None,:,:,:]
458
+
459
+ if inpaint.shape[0] < original.shape[0]:
460
+ original = original[:inpaint.shape[0],:,:]
461
+ mask = mask[:inpaint.shape[0],:,:]
462
+
463
+ if inpaint.shape[0] > original.shape[0]:
464
+ count = 0
465
+ original_list = []
466
+ mask_list = []
467
+ origin_list = []
468
+ while (count < inpaint.shape[0]):
469
+ for i in range(original.shape[0]):
470
+ original_list.append(original[i][None,:,:,:])
471
+ mask_list.append(mask[i][None,:,:])
472
+ if origin is not None:
473
+ origin_list.append(origin[i][None,:])
474
+ count += 1
475
+ if count >= inpaint.shape[0]:
476
+ break
477
+ original = torch.concat(original_list, dim=0)
478
+ mask = torch.concat(mask_list, dim=0)
479
+ if origin is not None:
480
+ origin = torch.concat(origin_list, dim=0)
481
+
482
+ if kernel % 2 == 0:
483
+ kernel += 1
484
+ transform = T.GaussianBlur(kernel_size=(kernel, kernel), sigma=(sigma, sigma))
485
+
486
+ ret = []
487
+ blurred = []
488
+ for i in range(inpaint.shape[0]):
489
+ height, width, _ = original[i].shape
490
+ x0, y0, cut_width, cut_height = origin[i]
491
+
492
+ # Ensure cut dimensions don't exceed original image dimensions
493
+ cut_width = min(cut_width, width - x0)
494
+ cut_height = min(cut_height, height - y0)
495
+
496
+ scaled_inpaint = F.interpolate(inpaint[i].permute(2, 0, 1).unsqueeze(0), size=(cut_height, cut_width), mode='bilinear', align_corners=False).squeeze(0).permute(1, 2, 0)
497
+
498
+ result = original[i].clone()
499
+ result[y0:y0+cut_height, x0:x0+cut_width] = scaled_inpaint
500
+
501
+ # Create a new mask for blending
502
+ blend_mask = torch.zeros((height, width), device=mask.device, dtype=mask.dtype)
503
+ blend_mask[y0:y0+cut_height, x0:x0+cut_width] = 1.0
504
+
505
+ # Apply Gaussian blur to the blend mask
506
+ blurred_mask = transform(blend_mask.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)
507
+ blurred.append(blurred_mask)
508
+
509
+ # Apply the blended mask
510
+ ret.append(original[i] * (1.0 - blurred_mask[:,:,None]) + result * blurred_mask[:,:,None])
511
+
512
+ return (torch.stack(ret), torch.stack(blurred))
513
+
514
+
515
+
516
+ def scale_mask_and_image(image, mask, width, height):
517
+ h0, w0 = mask.shape
518
+ iy, ix = (mask == 1).nonzero(as_tuple=True)
519
+
520
+ if iy.numel() == 0:
521
+ x_c, y_c = w0 / 2.0, h0 / 2.0
522
+ mask_width, mask_height = 1, 1
523
+ else:
524
+ x_min, x_max = ix.min().item(), ix.max().item()
525
+ y_min, y_max = iy.min().item(), iy.max().item()
526
+ x_c, y_c = (x_min + x_max) / 2.0, (y_min + y_max) / 2.0
527
+ mask_width, mask_height = x_max - x_min + 1, y_max - y_min + 1
528
+
529
+ aspect_ratio = width / height
530
+ mask_aspect_ratio = mask_width / mask_height
531
+
532
+ if mask_aspect_ratio > aspect_ratio:
533
+ new_mask_width = mask_width
534
+ new_mask_height = mask_width / aspect_ratio
535
+ else:
536
+ new_mask_height = mask_height
537
+ new_mask_width = mask_height * aspect_ratio
538
+
539
+ margin = 0.3
540
+ cut_width = int(new_mask_width * (1 + 2 * margin))
541
+ cut_height = int(new_mask_height * (1 + 2 * margin))
542
+
543
+ x0 = max(0, min(w0 - cut_width, int(x_c - cut_width / 2)))
544
+ y0 = max(0, min(h0 - cut_height, int(y_c - cut_height / 2)))
545
+
546
+ # Adjust cut dimensions if they exceed image dimensions
547
+ cut_width = min(cut_width, w0 - x0)
548
+ cut_height = min(cut_height, h0 - y0)
549
+
550
+ cut_image = image[y0:y0+cut_height, x0:x0+cut_width]
551
+ cut_mask = mask[y0:y0+cut_height, x0:x0+cut_width]
552
+
553
+ if cut_width >= width and cut_height >= height:
554
+ # For large masks, return without scaling
555
+ return cut_image, cut_mask, (x0, y0, cut_width, cut_height)
556
+ else:
557
+ # For small masks, scale up to the specified size
558
+ scaled_image = F.interpolate(cut_image.permute(2, 0, 1).unsqueeze(0), size=(height, width), mode='bilinear', align_corners=False).squeeze(0).permute(1, 2, 0)
559
+ scaled_mask = F.interpolate(cut_mask.unsqueeze(0).unsqueeze(0).float(), size=(height, width), mode='nearest').squeeze(0).squeeze(0)
560
+ return scaled_image, scaled_mask, (x0, y0, cut_width, cut_height)
561
+
562
+ class CutForInpaint:
563
+
564
+ @classmethod
565
+ def INPUT_TYPES(s):
566
+ return {"required":
567
+ {
568
+ "image": ("IMAGE",),
569
+ "mask": ("MASK",),
570
+ "width": ("INT", {"default": 512, "min": 64, "max": 2048}),
571
+ "height": ("INT", {"default": 512, "min": 64, "max": 2048}),
572
+ },
573
+ }
574
+
575
+ CATEGORY = "inpaint"
576
+ RETURN_TYPES = ("IMAGE","MASK","VECTOR",)
577
+ RETURN_NAMES = ("image","mask","origin",)
578
+
579
+ FUNCTION = "cut_for_inpaint"
580
+
581
+ def cut_for_inpaint(self, image: torch.Tensor, mask: torch.Tensor, width: int, height: int):
582
+ ret = []
583
+ msk = []
584
+ org = []
585
+ for i in range(image.shape[0]):
586
+ cut_image, cut_mask, (x0, y0, cut_width, cut_height) = scale_mask_and_image(image[i], mask[i], width, height)
587
+ ret.append(cut_image)
588
+ msk.append(cut_mask)
589
+ org.append(torch.IntTensor([x0, y0, cut_width, cut_height]))
590
+
591
+ return (torch.stack(ret), torch.stack(msk), torch.stack(org))
592
+
593
+
594
+ #### Utility function
595
+
596
+ def get_files_with_extension(folder_name, extension=['.safetensors']):
597
+
598
+ try:
599
+ folders = folder_paths.get_folder_paths(folder_name)
600
+ except:
601
+ folders = []
602
+
603
+ if not folders:
604
+ folders = [os.path.join(folder_paths.models_dir, folder_name)]
605
+ if not os.path.isdir(folders[0]):
606
+ folders = [os.path.join(folder_paths.base_path, folder_name)]
607
+ if not os.path.isdir(folders[0]):
608
+ return {}
609
+
610
+ filtered_folders = []
611
+ for x in folders:
612
+ if not os.path.isdir(x):
613
+ continue
614
+ the_same = False
615
+ for y in filtered_folders:
616
+ if os.path.samefile(x, y):
617
+ the_same = True
618
+ break
619
+ if not the_same:
620
+ filtered_folders.append(x)
621
+
622
+ if not filtered_folders:
623
+ return {}
624
+
625
+ output = {}
626
+ for x in filtered_folders:
627
+ files, folders_all = folder_paths.recursive_search(x, excluded_dir_names=[".git"])
628
+ filtered_files = folder_paths.filter_files_extensions(files, extension)
629
+
630
+ for f in filtered_files:
631
+ output[f] = x
632
+
633
+ return output
634
+
635
+
636
+ # get blocks from state_dict so we could know which model it is
637
+ def brushnet_blocks(sd):
638
+ brushnet_down_block = 0
639
+ brushnet_mid_block = 0
640
+ brushnet_up_block = 0
641
+ for key in sd:
642
+ if 'brushnet_down_block' in key:
643
+ brushnet_down_block += 1
644
+ if 'brushnet_mid_block' in key:
645
+ brushnet_mid_block += 1
646
+ if 'brushnet_up_block' in key:
647
+ brushnet_up_block += 1
648
+ return (brushnet_down_block, brushnet_mid_block, brushnet_up_block, len(sd))
649
+
650
+
651
+ # Check models compatibility
652
+ def check_compatibilty(model, brushnet):
653
+ is_SDXL = False
654
+ is_PP = False
655
+ if isinstance(model.model.model_config, comfy.supported_models.SD15):
656
+ print('Base model type: SD1.5')
657
+ is_SDXL = False
658
+ if brushnet["SDXL"]:
659
+ raise Exception("Base model is SD15, but BrushNet is SDXL type")
660
+ if brushnet["PP"]:
661
+ is_PP = True
662
+ elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
663
+ print('Base model type: SDXL')
664
+ is_SDXL = True
665
+ if not brushnet["SDXL"]:
666
+ raise Exception("Base model is SDXL, but BrushNet is SD15 type")
667
+ else:
668
+ print('Base model type: ', type(model.model.model_config))
669
+ raise Exception("Unsupported model type: " + str(type(model.model.model_config)))
670
+
671
+ return (is_SDXL, is_PP)
672
+
673
+
674
+ def check_image_mask(image, mask, name):
675
+ if len(image.shape) < 4:
676
+ # image tensor shape should be [B, H, W, C], but batch somehow is missing
677
+ image = image[None,:,:,:]
678
+
679
+ if len(mask.shape) > 3:
680
+ # mask tensor shape should be [B, H, W] but we get [B, H, W, C], image may be?
681
+ # take first mask, red channel
682
+ mask = (mask[:,:,:,0])[:,:,:]
683
+ elif len(mask.shape) < 3:
684
+ # mask tensor shape should be [B, H, W] but batch somehow is missing
685
+ mask = mask[None,:,:]
686
+
687
+ if image.shape[0] > mask.shape[0]:
688
+ print(name, "gets batch of images (%d) but only %d masks" % (image.shape[0], mask.shape[0]))
689
+ if mask.shape[0] == 1:
690
+ print(name, "will copy the mask to fill batch")
691
+ mask = torch.cat([mask] * image.shape[0], dim=0)
692
+ else:
693
+ print(name, "will add empty masks to fill batch")
694
+ empty_mask = torch.zeros([image.shape[0] - mask.shape[0], mask.shape[1], mask.shape[2]])
695
+ mask = torch.cat([mask, empty_mask], dim=0)
696
+ elif image.shape[0] < mask.shape[0]:
697
+ print(name, "gets batch of images (%d) but too many (%d) masks" % (image.shape[0], mask.shape[0]))
698
+ mask = mask[:image.shape[0],:,:]
699
+
700
+ return (image, mask)
701
+
702
+
703
+ # Prepare image and mask
704
+ def prepare_image(image, mask):
705
+
706
+ image, mask = check_image_mask(image, mask, 'BrushNet')
707
+
708
+ print("BrushNet image.shape =", image.shape, "mask.shape =", mask.shape)
709
+
710
+ if mask.shape[2] != image.shape[2] or mask.shape[1] != image.shape[1]:
711
+ raise Exception("Image and mask should be the same size")
712
+
713
+ # As a suggestion of inferno46n2 (https://github.com/nullquant/ComfyUI-BrushNet/issues/64)
714
+ mask = mask.round()
715
+
716
+ masked_image = image * (1.0 - mask[:,:,:,None])
717
+
718
+ return (masked_image, mask)
719
+
720
+
721
+ # Get origin of the mask
722
+ def cut_with_mask(mask, width, height):
723
+ iy, ix = (mask == 1).nonzero(as_tuple=True)
724
+ h0, w0 = mask.shape
725
+
726
+ if iy.numel() == 0:
727
+ x_c, y_c = w0 / 2.0, h0 / 2.0
728
+ mask_width, mask_height = 0, 0
729
+ else:
730
+ x_min, x_max = ix.min().item(), ix.max().item()
731
+ y_min, y_max = iy.min().item(), iy.max().item()
732
+ x_c, y_c = (x_min + x_max) / 2.0, (y_min + y_max) / 2.0
733
+ mask_width, mask_height = x_max - x_min + 1, y_max - y_min + 1
734
+
735
+ cut_width = max(width, mask_width * 1.4) # 140% of mask width
736
+ cut_height = max(height, mask_height * 1.4) # 140% of mask height
737
+
738
+ cut_width = min(cut_width, w0)
739
+ cut_height = min(cut_height, h0)
740
+
741
+ x0 = max(0, min(w0 - cut_width, x_c - cut_width / 2))
742
+ y0 = max(0, min(h0 - cut_height, y_c - cut_height / 2))
743
+
744
+ return (int(x0), int(y0), int(cut_width), int(cut_height))
745
+
746
+
747
+ # Prepare conditioning_latents
748
+ @torch.inference_mode()
749
+ def get_image_latents(masked_image, mask, vae, scaling_factor):
750
+ processed_image = masked_image.to(vae.device)
751
+ image_latents = vae.encode(processed_image[:,:,:,:3]) * scaling_factor
752
+ processed_mask = 1. - mask[:,None,:,:]
753
+ interpolated_mask = torch.nn.functional.interpolate(
754
+ processed_mask,
755
+ size=(
756
+ image_latents.shape[-2],
757
+ image_latents.shape[-1]
758
+ )
759
+ )
760
+ interpolated_mask = interpolated_mask.to(image_latents.device)
761
+
762
+ conditioning_latents = [image_latents, interpolated_mask]
763
+
764
+ print('BrushNet CL: image_latents shape =', image_latents.shape, 'interpolated_mask shape =', interpolated_mask.shape)
765
+
766
+ return conditioning_latents
767
+
768
+
769
+ # Main function where magic happens
770
+ @torch.inference_mode()
771
+ def brushnet_inference(x, timesteps, transformer_options, debug):
772
+ if 'model_patch' not in transformer_options:
773
+ print('BrushNet inference: there is no model_patch key in transformer_options')
774
+ return ([], 0, [])
775
+ mp = transformer_options['model_patch']
776
+ if 'brushnet' not in mp:
777
+ print('BrushNet inference: there is no brushnet key in mdel_patch')
778
+ return ([], 0, [])
779
+ bo = mp['brushnet']
780
+ if 'model' not in bo:
781
+ print('BrushNet inference: there is no model key in brushnet')
782
+ return ([], 0, [])
783
+ brushnet = bo['model']
784
+ if not (isinstance(brushnet, BrushNetModel) or isinstance(brushnet, PowerPaintModel)):
785
+ print('BrushNet model is not a BrushNetModel class')
786
+ return ([], 0, [])
787
+
788
+ torch_dtype = bo['dtype']
789
+ cl_list = bo['latents']
790
+ brushnet_conditioning_scale, control_guidance_start, control_guidance_end = bo['controls']
791
+ pe = bo['prompt_embeds']
792
+ npe = bo['negative_prompt_embeds']
793
+ ppe, nppe, time_ids = bo['add_embeds']
794
+
795
+ #do_classifier_free_guidance = mp['free_guidance']
796
+ do_classifier_free_guidance = len(transformer_options['cond_or_uncond']) > 1
797
+
798
+ x = x.detach().clone()
799
+ x = x.to(torch_dtype).to(brushnet.device)
800
+
801
+ timesteps = timesteps.detach().clone()
802
+ timesteps = timesteps.to(torch_dtype).to(brushnet.device)
803
+
804
+ total_steps = mp['total_steps']
805
+ step = mp['step']
806
+
807
+ added_cond_kwargs = {}
808
+
809
+ if do_classifier_free_guidance and step == 0:
810
+ print('BrushNet inference: do_classifier_free_guidance is True')
811
+
812
+ sub_idx = None
813
+ if 'ad_params' in transformer_options and 'sub_idxs' in transformer_options['ad_params']:
814
+ sub_idx = transformer_options['ad_params']['sub_idxs']
815
+
816
+ # we have batch input images
817
+ batch = cl_list[0].shape[0]
818
+ # we have incoming latents
819
+ latents_incoming = x.shape[0]
820
+ # and we already got some
821
+ latents_got = bo['latent_id']
822
+ if step == 0 or batch > 1:
823
+ print('BrushNet inference, step = %d: image batch = %d, got %d latents, starting from %d' \
824
+ % (step, batch, latents_incoming, latents_got))
825
+
826
+ image_latents = []
827
+ masks = []
828
+ prompt_embeds = []
829
+ negative_prompt_embeds = []
830
+ pooled_prompt_embeds = []
831
+ negative_pooled_prompt_embeds = []
832
+ if sub_idx:
833
+ # AnimateDiff indexes detected
834
+ if step == 0:
835
+ print('BrushNet inference: AnimateDiff indexes detected and applied')
836
+
837
+ batch = len(sub_idx)
838
+
839
+ if do_classifier_free_guidance:
840
+ for i in sub_idx:
841
+ image_latents.append(cl_list[0][i][None,:,:,:])
842
+ masks.append(cl_list[1][i][None,:,:,:])
843
+ prompt_embeds.append(pe)
844
+ negative_prompt_embeds.append(npe)
845
+ pooled_prompt_embeds.append(ppe)
846
+ negative_pooled_prompt_embeds.append(nppe)
847
+ for i in sub_idx:
848
+ image_latents.append(cl_list[0][i][None,:,:,:])
849
+ masks.append(cl_list[1][i][None,:,:,:])
850
+ else:
851
+ for i in sub_idx:
852
+ image_latents.append(cl_list[0][i][None,:,:,:])
853
+ masks.append(cl_list[1][i][None,:,:,:])
854
+ prompt_embeds.append(pe)
855
+ pooled_prompt_embeds.append(ppe)
856
+ else:
857
+ # do_classifier_free_guidance = 2 passes, 1st pass is cond, 2nd is uncond
858
+ continue_batch = True
859
+ for i in range(latents_incoming):
860
+ number = latents_got + i
861
+ if number < batch:
862
+ # 1st pass, cond
863
+ image_latents.append(cl_list[0][number][None,:,:,:])
864
+ masks.append(cl_list[1][number][None,:,:,:])
865
+ prompt_embeds.append(pe)
866
+ pooled_prompt_embeds.append(ppe)
867
+ elif do_classifier_free_guidance and number < batch * 2:
868
+ # 2nd pass, uncond
869
+ image_latents.append(cl_list[0][number-batch][None,:,:,:])
870
+ masks.append(cl_list[1][number-batch][None,:,:,:])
871
+ negative_prompt_embeds.append(npe)
872
+ negative_pooled_prompt_embeds.append(nppe)
873
+ else:
874
+ # latent batch
875
+ image_latents.append(cl_list[0][0][None,:,:,:])
876
+ masks.append(cl_list[1][0][None,:,:,:])
877
+ prompt_embeds.append(pe)
878
+ pooled_prompt_embeds.append(ppe)
879
+ latents_got = -i
880
+ continue_batch = False
881
+
882
+ if continue_batch:
883
+ # we don't have full batch yet
884
+ if do_classifier_free_guidance:
885
+ if number < batch * 2 - 1:
886
+ bo['latent_id'] = number + 1
887
+ else:
888
+ bo['latent_id'] = 0
889
+ else:
890
+ if number < batch - 1:
891
+ bo['latent_id'] = number + 1
892
+ else:
893
+ bo['latent_id'] = 0
894
+ else:
895
+ bo['latent_id'] = 0
896
+
897
+ cl = []
898
+ for il, m in zip(image_latents, masks):
899
+ cl.append(torch.concat([il, m], dim=1))
900
+ cl2apply = torch.concat(cl, dim=0)
901
+
902
+ conditioning_latents = cl2apply.to(torch_dtype).to(brushnet.device)
903
+
904
+ prompt_embeds.extend(negative_prompt_embeds)
905
+ prompt_embeds = torch.concat(prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device)
906
+
907
+ if ppe is not None:
908
+ added_cond_kwargs = {}
909
+ added_cond_kwargs['time_ids'] = torch.concat([time_ids] * latents_incoming, dim = 0).to(torch_dtype).to(brushnet.device)
910
+
911
+ pooled_prompt_embeds.extend(negative_pooled_prompt_embeds)
912
+ pooled_prompt_embeds = torch.concat(pooled_prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device)
913
+ added_cond_kwargs['text_embeds'] = pooled_prompt_embeds
914
+ else:
915
+ added_cond_kwargs = None
916
+
917
+ if x.shape[2] != conditioning_latents.shape[2] or x.shape[3] != conditioning_latents.shape[3]:
918
+ if step == 0:
919
+ print('BrushNet inference: image', conditioning_latents.shape, 'and latent', x.shape, 'have different size, resizing image')
920
+ conditioning_latents = torch.nn.functional.interpolate(
921
+ conditioning_latents, size=(
922
+ x.shape[2],
923
+ x.shape[3],
924
+ ), mode='bicubic',
925
+ ).to(torch_dtype).to(brushnet.device)
926
+
927
+ if step == 0:
928
+ print('BrushNet inference: sample', x.shape, ', CL', conditioning_latents.shape, 'dtype', torch_dtype)
929
+
930
+ if debug: print('BrushNet: step =', step)
931
+
932
+ if step < control_guidance_start or step > control_guidance_end:
933
+ cond_scale = 0.0
934
+ else:
935
+ cond_scale = brushnet_conditioning_scale
936
+
937
+ return brushnet(x,
938
+ encoder_hidden_states=prompt_embeds,
939
+ brushnet_cond=conditioning_latents,
940
+ timestep = timesteps,
941
+ conditioning_scale=cond_scale,
942
+ guess_mode=False,
943
+ added_cond_kwargs=added_cond_kwargs,
944
+ return_dict=False,
945
+ debug=debug,
946
+ )
947
+
948
+
949
+ # This is main patch function
950
+ def add_brushnet_patch(model, brushnet, torch_dtype, conditioning_latents,
951
+ controls,
952
+ prompt_embeds, negative_prompt_embeds,
953
+ pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids,
954
+ debug):
955
+
956
+ is_SDXL = isinstance(model.model.model_config, comfy.supported_models.SDXL)
957
+
958
+ if is_SDXL:
959
+ input_blocks = [[0, comfy.ops.disable_weight_init.Conv2d],
960
+ [1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
961
+ [2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
962
+ [3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
963
+ [4, comfy.ldm.modules.attention.SpatialTransformer],
964
+ [5, comfy.ldm.modules.attention.SpatialTransformer],
965
+ [6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
966
+ [7, comfy.ldm.modules.attention.SpatialTransformer],
967
+ [8, comfy.ldm.modules.attention.SpatialTransformer]]
968
+ middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]
969
+ output_blocks = [[0, comfy.ldm.modules.attention.SpatialTransformer],
970
+ [1, comfy.ldm.modules.attention.SpatialTransformer],
971
+ [2, comfy.ldm.modules.attention.SpatialTransformer],
972
+ [2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
973
+ [3, comfy.ldm.modules.attention.SpatialTransformer],
974
+ [4, comfy.ldm.modules.attention.SpatialTransformer],
975
+ [5, comfy.ldm.modules.attention.SpatialTransformer],
976
+ [5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
977
+ [6, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
978
+ [7, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
979
+ [8, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
980
+ else:
981
+ input_blocks = [[0, comfy.ops.disable_weight_init.Conv2d],
982
+ [1, comfy.ldm.modules.attention.SpatialTransformer],
983
+ [2, comfy.ldm.modules.attention.SpatialTransformer],
984
+ [3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
985
+ [4, comfy.ldm.modules.attention.SpatialTransformer],
986
+ [5, comfy.ldm.modules.attention.SpatialTransformer],
987
+ [6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
988
+ [7, comfy.ldm.modules.attention.SpatialTransformer],
989
+ [8, comfy.ldm.modules.attention.SpatialTransformer],
990
+ [9, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
991
+ [10, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
992
+ [11, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
993
+ middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]
994
+ output_blocks = [[0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
995
+ [1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
996
+ [2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
997
+ [2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
998
+ [3, comfy.ldm.modules.attention.SpatialTransformer],
999
+ [4, comfy.ldm.modules.attention.SpatialTransformer],
1000
+ [5, comfy.ldm.modules.attention.SpatialTransformer],
1001
+ [5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
1002
+ [6, comfy.ldm.modules.attention.SpatialTransformer],
1003
+ [7, comfy.ldm.modules.attention.SpatialTransformer],
1004
+ [8, comfy.ldm.modules.attention.SpatialTransformer],
1005
+ [8, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
1006
+ [9, comfy.ldm.modules.attention.SpatialTransformer],
1007
+ [10, comfy.ldm.modules.attention.SpatialTransformer],
1008
+ [11, comfy.ldm.modules.attention.SpatialTransformer]]
1009
+
1010
+ def last_layer_index(block, tp):
1011
+ layer_list = []
1012
+ for layer in block:
1013
+ layer_list.append(type(layer))
1014
+ layer_list.reverse()
1015
+ if tp not in layer_list:
1016
+ return -1, layer_list.reverse()
1017
+ return len(layer_list) - 1 - layer_list.index(tp), layer_list
1018
+
1019
+ def brushnet_forward(model, x, timesteps, transformer_options, control):
1020
+ if 'brushnet' not in transformer_options['model_patch']:
1021
+ input_samples = []
1022
+ mid_sample = 0
1023
+ output_samples = []
1024
+ else:
1025
+ # brushnet inference
1026
+ input_samples, mid_sample, output_samples = brushnet_inference(x, timesteps, transformer_options, debug)
1027
+
1028
+ # give additional samples to blocks
1029
+ for i, tp in input_blocks:
1030
+ idx, layer_list = last_layer_index(model.input_blocks[i], tp)
1031
+ if idx < 0:
1032
+ print("BrushNet can't find", tp, "layer in", i,"input block:", layer_list)
1033
+ continue
1034
+ model.input_blocks[i][idx].add_sample_after = input_samples.pop(0) if input_samples else 0
1035
+
1036
+ idx, layer_list = last_layer_index(model.middle_block, middle_block[1])
1037
+ if idx < 0:
1038
+ print("BrushNet can't find", middle_block[1], "layer in middle block", layer_list)
1039
+ model.middle_block[idx].add_sample_after = mid_sample
1040
+
1041
+ for i, tp in output_blocks:
1042
+ idx, layer_list = last_layer_index(model.output_blocks[i], tp)
1043
+ if idx < 0:
1044
+ print("BrushNet can't find", tp, "layer in", i,"outnput block:", layer_list)
1045
+ continue
1046
+ model.output_blocks[i][idx].add_sample_after = output_samples.pop(0) if output_samples else 0
1047
+
1048
+ patch_model_function_wrapper(model, brushnet_forward)
1049
+
1050
+ to = add_model_patch_option(model)
1051
+ mp = to['model_patch']
1052
+ if 'brushnet' not in mp:
1053
+ mp['brushnet'] = {}
1054
+ bo = mp['brushnet']
1055
+
1056
+ bo['model'] = brushnet
1057
+ bo['dtype'] = torch_dtype
1058
+ bo['latents'] = conditioning_latents
1059
+ bo['controls'] = controls
1060
+ bo['prompt_embeds'] = prompt_embeds
1061
+ bo['negative_prompt_embeds'] = negative_prompt_embeds
1062
+ bo['add_embeds'] = (pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids)
1063
+ bo['latent_id'] = 0
1064
+
1065
+ # patch layers `forward` so we can apply brushnet
1066
+ def forward_patched_by_brushnet(self, x, *args, **kwargs):
1067
+ h = self.original_forward(x, *args, **kwargs)
1068
+ if hasattr(self, 'add_sample_after') and type(self):
1069
+ to_add = self.add_sample_after
1070
+ if torch.is_tensor(to_add):
1071
+ # interpolate due to RAUNet
1072
+ if h.shape[2] != to_add.shape[2] or h.shape[3] != to_add.shape[3]:
1073
+ to_add = torch.nn.functional.interpolate(to_add, size=(h.shape[2], h.shape[3]), mode='bicubic')
1074
+ h += to_add.to(h.dtype).to(h.device)
1075
+ else:
1076
+ h += self.add_sample_after
1077
+ self.add_sample_after = 0
1078
+ return h
1079
+
1080
+ for i, block in enumerate(model.model.diffusion_model.input_blocks):
1081
+ for j, layer in enumerate(block):
1082
+ if not hasattr(layer, 'original_forward'):
1083
+ layer.original_forward = layer.forward
1084
+ layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
1085
+ layer.add_sample_after = 0
1086
+
1087
+ for j, layer in enumerate(model.model.diffusion_model.middle_block):
1088
+ if not hasattr(layer, 'original_forward'):
1089
+ layer.original_forward = layer.forward
1090
+ layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
1091
+ layer.add_sample_after = 0
1092
+
1093
+ for i, block in enumerate(model.model.diffusion_model.output_blocks):
1094
+ for j, layer in enumerate(block):
1095
+ if not hasattr(layer, 'original_forward'):
1096
+ layer.original_forward = layer.forward
1097
+ layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
1098
+ layer.add_sample_after = 0