File size: 4,783 Bytes
7cf0db3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import torch.nn.functional as F
import comfy
from .model_patch import add_model_patch_option, patch_model_function_wrapper
class RAUNet:
@classmethod
def INPUT_TYPES(s):
return {"required":
{
"model": ("MODEL",),
"du_start": ("INT", {"default": 0, "min": 0, "max": 10000}),
"du_end": ("INT", {"default": 4, "min": 0, "max": 10000}),
"xa_start": ("INT", {"default": 4, "min": 0, "max": 10000}),
"xa_end": ("INT", {"default": 10, "min": 0, "max": 10000}),
},
}
CATEGORY = "inpaint"
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "model_update"
def model_update(self, model, du_start, du_end, xa_start, xa_end):
model = model.clone()
add_raunet_patch(model,
du_start,
du_end,
xa_start,
xa_end)
return (model,)
# This is main patch function
def add_raunet_patch(model, du_start, du_end, xa_start, xa_end):
def raunet_forward(model, x, timesteps, transformer_options, control):
if 'model_patch' not in transformer_options:
print("RAUNet: 'model_patch' not in transformer_options, skip")
return
mp = transformer_options['model_patch']
is_SDXL = mp['SDXL']
if is_SDXL and type(model.input_blocks[6][0]) != comfy.ldm.modules.diffusionmodules.openaimodel.Downsample:
print('RAUNet: model is SDXL, but input[6] != Downsample, skip')
return
if not is_SDXL and type(model.input_blocks[3][0]) != comfy.ldm.modules.diffusionmodules.openaimodel.Downsample:
print('RAUNet: model is not SDXL, but input[3] != Downsample, skip')
return
if 'raunet' not in mp:
print('RAUNet: "raunet" not in model_patch options, skip')
return
if is_SDXL:
block = model.input_blocks[6][0]
else:
block = model.input_blocks[3][0]
total_steps = mp['total_steps']
step = mp['step']
ro = mp['raunet']
du_start = ro['du_start']
du_end = ro['du_end']
if step >= du_start and step < du_end:
block.op.stride = (4, 4)
block.op.padding = (2, 2)
block.op.dilation = (2, 2)
else:
block.op.stride = (2, 2)
block.op.padding = (1, 1)
block.op.dilation = (1, 1)
patch_model_function_wrapper(model, raunet_forward)
model.set_model_input_block_patch(in_xattn_patch)
model.set_model_output_block_patch(out_xattn_patch)
to = add_model_patch_option(model)
mp = to['model_patch']
if 'raunet' not in mp:
mp['raunet'] = {}
ro = mp['raunet']
ro['du_start'] = du_start
ro['du_end'] = du_end
ro['xa_start'] = xa_start
ro['xa_end'] = xa_end
def in_xattn_patch(h, transformer_options):
# both SDXL and SD15 = (input,4)
if transformer_options["block"] != ("input", 4):
# wrong block
return h
if 'model_patch' not in transformer_options:
print("RAUNet (i-x-p): 'model_patch' not in transformer_options")
return h
mp = transformer_options['model_patch']
if 'raunet' not in mp:
print("RAUNet (i-x-p): 'raunet' not in model_patch options")
return h
step = mp['step']
ro = mp['raunet']
xa_start = ro['xa_start']
xa_end = ro['xa_end']
if step < xa_start or step >= xa_end:
return h
h = F.avg_pool2d(h, kernel_size=(2,2))
return h
def out_xattn_patch(h, hsp, transformer_options):
if 'model_patch' not in transformer_options:
print("RAUNet (o-x-p): 'model_patch' not in transformer_options")
return h, hsp
mp = transformer_options['model_patch']
if 'raunet' not in mp:
print("RAUNet (o-x-p): 'raunet' not in model_patch options")
return h
step = mp['step']
is_SDXL = mp['SDXL']
ro = mp['raunet']
xa_start = ro['xa_start']
xa_end = ro['xa_end']
if is_SDXL:
if transformer_options["block"] != ("output", 5):
# wrong block
return h, hsp
else:
if transformer_options["block"] != ("output", 8):
# wrong block
return h, hsp
if step < xa_start or step >= xa_end:
return h, hsp
#error in hidiffusion codebase, size * 2 for particular sizes only
#re_size = (int(h.shape[-2] * 2), int(h.shape[-1] * 2))
re_size = (hsp.shape[-2], hsp.shape[-1])
h = F.interpolate(h, size=re_size, mode='bicubic')
return h, hsp
|