|
import torch
|
|
import comfy
|
|
|
|
|
|
|
|
def add_model_patch_option(model):
|
|
if 'transformer_options' not in model.model_options:
|
|
model.model_options['transformer_options'] = {}
|
|
to = model.model_options['transformer_options']
|
|
if "model_patch" not in to:
|
|
to["model_patch"] = {}
|
|
return to
|
|
|
|
|
|
|
|
def patch_model_function_wrapper(model, forward_patch, remove=False):
|
|
|
|
def brushnet_model_function_wrapper(apply_model_method, options_dict):
|
|
to = options_dict['c']['transformer_options']
|
|
|
|
control = None
|
|
if 'control' in options_dict['c']:
|
|
control = options_dict['c']['control']
|
|
|
|
x = options_dict['input']
|
|
timestep = options_dict['timestep']
|
|
|
|
|
|
if 'model_patch' not in to or 'forward' not in to['model_patch']:
|
|
return apply_model_method(x, timestep, **options_dict['c'])
|
|
|
|
mp = to['model_patch']
|
|
unet = mp['unet']
|
|
|
|
all_sigmas = mp['all_sigmas']
|
|
sigma = to['sigmas'][0].item()
|
|
total_steps = all_sigmas.shape[0] - 1
|
|
step = torch.argmin((all_sigmas - sigma).abs()).item()
|
|
|
|
mp['step'] = step
|
|
mp['total_steps'] = total_steps
|
|
|
|
|
|
xc = model.model.model_sampling.calculate_input(timestep, x)
|
|
if 'c_concat' in options_dict['c'] and options_dict['c']['c_concat'] is not None:
|
|
xc = torch.cat([xc] + [options_dict['c']['c_concat']], dim=1)
|
|
t = model.model.model_sampling.timestep(timestep).float()
|
|
|
|
for method in mp['forward']:
|
|
method(unet, xc, t, to, control)
|
|
|
|
return apply_model_method(x, timestep, **options_dict['c'])
|
|
|
|
if "model_function_wrapper" in model.model_options and model.model_options["model_function_wrapper"]:
|
|
print('BrushNet is going to replace existing model_function_wrapper:', model.model_options["model_function_wrapper"])
|
|
model.set_model_unet_function_wrapper(brushnet_model_function_wrapper)
|
|
|
|
to = add_model_patch_option(model)
|
|
mp = to['model_patch']
|
|
|
|
if isinstance(model.model.model_config, comfy.supported_models.SD15):
|
|
mp['SDXL'] = False
|
|
elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
|
|
mp['SDXL'] = True
|
|
else:
|
|
print('Base model type: ', type(model.model.model_config))
|
|
raise Exception("Unsupported model type: ", type(model.model.model_config))
|
|
|
|
if 'forward' not in mp:
|
|
mp['forward'] = []
|
|
|
|
if remove:
|
|
if forward_patch in mp['forward']:
|
|
mp['forward'].remove(forward_patch)
|
|
else:
|
|
mp['forward'].append(forward_patch)
|
|
|
|
mp['unet'] = model.model.diffusion_model
|
|
mp['step'] = 0
|
|
mp['total_steps'] = 1
|
|
|
|
|
|
if comfy.samplers.sample.__doc__ is None or 'BrushNet' not in comfy.samplers.sample.__doc__:
|
|
comfy.samplers.original_sample = comfy.samplers.sample
|
|
comfy.samplers.sample = modified_sample
|
|
|
|
if comfy.ldm.modules.diffusionmodules.openaimodel.apply_control.__doc__ is None or \
|
|
'BrushNet' not in comfy.ldm.modules.diffusionmodules.openaimodel.apply_control.__doc__:
|
|
comfy.ldm.modules.diffusionmodules.openaimodel.original_apply_control = comfy.ldm.modules.diffusionmodules.openaimodel.apply_control
|
|
comfy.ldm.modules.diffusionmodules.openaimodel.apply_control = modified_apply_control
|
|
|
|
|
|
|
|
|
|
def modified_sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={},
|
|
latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
|
'''
|
|
Modified by BrushNet nodes
|
|
'''
|
|
cfg_guider = comfy.samplers.CFGGuider(model)
|
|
cfg_guider.set_conds(positive, negative)
|
|
cfg_guider.set_cfg(cfg)
|
|
|
|
|
|
|
|
to = add_model_patch_option(model)
|
|
to['model_patch']['all_sigmas'] = sigmas
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
|
|
|
|
|
|
|
def modified_apply_control(h, control, name):
|
|
'''
|
|
Modified by BrushNet nodes
|
|
'''
|
|
if control is not None and name in control and len(control[name]) > 0:
|
|
ctrl = control[name].pop()
|
|
if ctrl is not None:
|
|
if h.shape[2] != ctrl.shape[2] or h.shape[3] != ctrl.shape[3]:
|
|
ctrl = torch.nn.functional.interpolate(ctrl, size=(h.shape[2], h.shape[3]), mode='bicubic').to(h.dtype).to(h.device)
|
|
try:
|
|
h += ctrl
|
|
except:
|
|
print.warning("warning control could not be applied {} {}".format(h.shape, ctrl.shape))
|
|
return h
|
|
|
|
|