File size: 4,783 Bytes
7cf0db3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import torch.nn.functional as F
import comfy

from .model_patch import add_model_patch_option, patch_model_function_wrapper



class RAUNet:

    @classmethod
    def INPUT_TYPES(s):
        return {"required":
                    {    
                        "model": ("MODEL",),
                        "du_start": ("INT", {"default": 0, "min": 0, "max": 10000}),
                        "du_end": ("INT", {"default": 4, "min": 0, "max": 10000}),
                        "xa_start": ("INT", {"default": 4, "min": 0, "max": 10000}),
                        "xa_end": ("INT", {"default": 10, "min": 0, "max": 10000}),
                     },
        }

    CATEGORY = "inpaint"
    RETURN_TYPES = ("MODEL",)
    RETURN_NAMES = ("model",)

    FUNCTION = "model_update"

    def model_update(self, model, du_start, du_end, xa_start, xa_end):
        
        model = model.clone()

        add_raunet_patch(model, 
                         du_start,
                         du_end,
                         xa_start,
                         xa_end)
        
        return (model,)


# This is main patch function
def add_raunet_patch(model, du_start, du_end, xa_start, xa_end):
    
    def raunet_forward(model, x, timesteps, transformer_options, control):
        if 'model_patch' not in transformer_options:
            print("RAUNet: 'model_patch' not in transformer_options, skip")
            return

        mp = transformer_options['model_patch']
        is_SDXL = mp['SDXL']

        if is_SDXL and type(model.input_blocks[6][0]) != comfy.ldm.modules.diffusionmodules.openaimodel.Downsample:
            print('RAUNet: model is SDXL, but input[6] != Downsample, skip')
            return

        if not is_SDXL and type(model.input_blocks[3][0]) != comfy.ldm.modules.diffusionmodules.openaimodel.Downsample:
            print('RAUNet: model is not SDXL, but input[3] != Downsample, skip')
            return
        
        if 'raunet' not in mp:
            print('RAUNet: "raunet" not in model_patch options, skip')
            return

        if is_SDXL:
            block = model.input_blocks[6][0]
        else:
            block = model.input_blocks[3][0]

        total_steps = mp['total_steps']
        step = mp['step']

        ro = mp['raunet']
        du_start = ro['du_start']
        du_end = ro['du_end']

        if step >= du_start and step < du_end:
            block.op.stride = (4, 4)
            block.op.padding = (2, 2)
            block.op.dilation = (2, 2)
        else:
            block.op.stride = (2, 2)
            block.op.padding = (1, 1)
            block.op.dilation = (1, 1)

    patch_model_function_wrapper(model, raunet_forward)
    model.set_model_input_block_patch(in_xattn_patch)
    model.set_model_output_block_patch(out_xattn_patch)

    to = add_model_patch_option(model)
    mp = to['model_patch']
    if 'raunet' not in mp:
        mp['raunet'] = {}
    ro = mp['raunet']

    ro['du_start'] = du_start
    ro['du_end'] = du_end
    ro['xa_start'] = xa_start
    ro['xa_end'] = xa_end


def in_xattn_patch(h, transformer_options):
    # both SDXL and SD15 = (input,4)
    if transformer_options["block"] != ("input", 4):
        # wrong block
        return h
    if 'model_patch' not in transformer_options:
        print("RAUNet (i-x-p): 'model_patch' not in transformer_options")
        return h
    mp = transformer_options['model_patch']
    if 'raunet' not in mp:
        print("RAUNet (i-x-p): 'raunet' not in model_patch options")
        return h

    step = mp['step']
    ro = mp['raunet']
    xa_start = ro['xa_start']
    xa_end = ro['xa_end']

    if step < xa_start or step >= xa_end:
        return h
    h = F.avg_pool2d(h, kernel_size=(2,2))
    return h


def out_xattn_patch(h, hsp, transformer_options):
    if 'model_patch' not in transformer_options:
        print("RAUNet (o-x-p): 'model_patch' not in transformer_options")
        return h, hsp
    mp = transformer_options['model_patch']
    if 'raunet' not in mp:
        print("RAUNet (o-x-p): 'raunet' not in model_patch options")
        return h
    
    step = mp['step']
    is_SDXL = mp['SDXL']
    ro = mp['raunet']
    xa_start = ro['xa_start']
    xa_end = ro['xa_end']

    if is_SDXL:
        if transformer_options["block"] != ("output", 5):
            # wrong block
            return h, hsp
    else:
        if transformer_options["block"] != ("output", 8):
            # wrong block
            return h, hsp

    if step < xa_start or step >= xa_end:
        return h, hsp
    #error in hidiffusion codebase, size * 2 for particular sizes only
    #re_size = (int(h.shape[-2] * 2), int(h.shape[-1] * 2))
    re_size = (hsp.shape[-2], hsp.shape[-1])
    h = F.interpolate(h, size=re_size, mode='bicubic')

    return h, hsp