|
import torch.nn.functional as F |
|
import comfy |
|
|
|
from .model_patch import add_model_patch_option, patch_model_function_wrapper |
|
|
|
|
|
|
|
class RAUNet: |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": |
|
{ |
|
"model": ("MODEL",), |
|
"du_start": ("INT", {"default": 0, "min": 0, "max": 10000}), |
|
"du_end": ("INT", {"default": 4, "min": 0, "max": 10000}), |
|
"xa_start": ("INT", {"default": 4, "min": 0, "max": 10000}), |
|
"xa_end": ("INT", {"default": 10, "min": 0, "max": 10000}), |
|
}, |
|
} |
|
|
|
CATEGORY = "inpaint" |
|
RETURN_TYPES = ("MODEL",) |
|
RETURN_NAMES = ("model",) |
|
|
|
FUNCTION = "model_update" |
|
|
|
def model_update(self, model, du_start, du_end, xa_start, xa_end): |
|
|
|
model = model.clone() |
|
|
|
add_raunet_patch(model, |
|
du_start, |
|
du_end, |
|
xa_start, |
|
xa_end) |
|
|
|
return (model,) |
|
|
|
|
|
|
|
def add_raunet_patch(model, du_start, du_end, xa_start, xa_end): |
|
|
|
def raunet_forward(model, x, timesteps, transformer_options, control): |
|
if 'model_patch' not in transformer_options: |
|
print("RAUNet: 'model_patch' not in transformer_options, skip") |
|
return |
|
|
|
mp = transformer_options['model_patch'] |
|
is_SDXL = mp['SDXL'] |
|
|
|
if is_SDXL and type(model.input_blocks[6][0]) != comfy.ldm.modules.diffusionmodules.openaimodel.Downsample: |
|
print('RAUNet: model is SDXL, but input[6] != Downsample, skip') |
|
return |
|
|
|
if not is_SDXL and type(model.input_blocks[3][0]) != comfy.ldm.modules.diffusionmodules.openaimodel.Downsample: |
|
print('RAUNet: model is not SDXL, but input[3] != Downsample, skip') |
|
return |
|
|
|
if 'raunet' not in mp: |
|
print('RAUNet: "raunet" not in model_patch options, skip') |
|
return |
|
|
|
if is_SDXL: |
|
block = model.input_blocks[6][0] |
|
else: |
|
block = model.input_blocks[3][0] |
|
|
|
total_steps = mp['total_steps'] |
|
step = mp['step'] |
|
|
|
ro = mp['raunet'] |
|
du_start = ro['du_start'] |
|
du_end = ro['du_end'] |
|
|
|
if step >= du_start and step < du_end: |
|
block.op.stride = (4, 4) |
|
block.op.padding = (2, 2) |
|
block.op.dilation = (2, 2) |
|
else: |
|
block.op.stride = (2, 2) |
|
block.op.padding = (1, 1) |
|
block.op.dilation = (1, 1) |
|
|
|
patch_model_function_wrapper(model, raunet_forward) |
|
model.set_model_input_block_patch(in_xattn_patch) |
|
model.set_model_output_block_patch(out_xattn_patch) |
|
|
|
to = add_model_patch_option(model) |
|
mp = to['model_patch'] |
|
if 'raunet' not in mp: |
|
mp['raunet'] = {} |
|
ro = mp['raunet'] |
|
|
|
ro['du_start'] = du_start |
|
ro['du_end'] = du_end |
|
ro['xa_start'] = xa_start |
|
ro['xa_end'] = xa_end |
|
|
|
|
|
def in_xattn_patch(h, transformer_options): |
|
|
|
if transformer_options["block"] != ("input", 4): |
|
|
|
return h |
|
if 'model_patch' not in transformer_options: |
|
print("RAUNet (i-x-p): 'model_patch' not in transformer_options") |
|
return h |
|
mp = transformer_options['model_patch'] |
|
if 'raunet' not in mp: |
|
print("RAUNet (i-x-p): 'raunet' not in model_patch options") |
|
return h |
|
|
|
step = mp['step'] |
|
ro = mp['raunet'] |
|
xa_start = ro['xa_start'] |
|
xa_end = ro['xa_end'] |
|
|
|
if step < xa_start or step >= xa_end: |
|
return h |
|
h = F.avg_pool2d(h, kernel_size=(2,2)) |
|
return h |
|
|
|
|
|
def out_xattn_patch(h, hsp, transformer_options): |
|
if 'model_patch' not in transformer_options: |
|
print("RAUNet (o-x-p): 'model_patch' not in transformer_options") |
|
return h, hsp |
|
mp = transformer_options['model_patch'] |
|
if 'raunet' not in mp: |
|
print("RAUNet (o-x-p): 'raunet' not in model_patch options") |
|
return h |
|
|
|
step = mp['step'] |
|
is_SDXL = mp['SDXL'] |
|
ro = mp['raunet'] |
|
xa_start = ro['xa_start'] |
|
xa_end = ro['xa_end'] |
|
|
|
if is_SDXL: |
|
if transformer_options["block"] != ("output", 5): |
|
|
|
return h, hsp |
|
else: |
|
if transformer_options["block"] != ("output", 8): |
|
|
|
return h, hsp |
|
|
|
if step < xa_start or step >= xa_end: |
|
return h, hsp |
|
|
|
|
|
re_size = (hsp.shape[-2], hsp.shape[-1]) |
|
h = F.interpolate(h, size=re_size, mode='bicubic') |
|
|
|
return h, hsp |
|
|
|
|
|
|