Spaces:
Runtime error
Runtime error
Delete
#15
by
Baykon
- opened
This view is limited to 50 files because it contains too many changes.
See the raw diff here.
- MagicQuill/.DS_Store +0 -0
- MagicQuill/brushnet/brushnet.json +0 -58
- MagicQuill/brushnet/brushnet.py +0 -949
- MagicQuill/brushnet/brushnet_ca.py +0 -983
- MagicQuill/brushnet/brushnet_xl.json +0 -63
- MagicQuill/brushnet/powerpaint.json +0 -57
- MagicQuill/brushnet/powerpaint_utils.py +0 -496
- MagicQuill/brushnet/unet_2d_blocks.py +0 -0
- MagicQuill/brushnet/unet_2d_condition.py +0 -1355
- MagicQuill/brushnet_nodes.py +0 -1094
- MagicQuill/comfy/.DS_Store +0 -0
- MagicQuill/comfy/checkpoint_pickle.py +0 -13
- MagicQuill/comfy/cldm/__pycache__/cldm.cpython-310.pyc +0 -0
- MagicQuill/comfy/cldm/cldm.py +0 -313
- MagicQuill/comfy/cli_args.py +0 -143
- MagicQuill/comfy/clip_config_bigg.json +0 -23
- MagicQuill/comfy/clip_model.py +0 -194
- MagicQuill/comfy/clip_vision.py +0 -117
- MagicQuill/comfy/clip_vision_config_g.json +0 -18
- MagicQuill/comfy/clip_vision_config_h.json +0 -18
- MagicQuill/comfy/clip_vision_config_vitl.json +0 -18
- MagicQuill/comfy/conds.py +0 -83
- MagicQuill/comfy/controlnet.py +0 -554
- MagicQuill/comfy/diffusers_convert.py +0 -281
- MagicQuill/comfy/diffusers_load.py +0 -36
- MagicQuill/comfy/extra_samplers/__pycache__/uni_pc.cpython-310.pyc +0 -0
- MagicQuill/comfy/extra_samplers/uni_pc.py +0 -875
- MagicQuill/comfy/gligen.py +0 -343
- MagicQuill/comfy/k_diffusion/__pycache__/sampling.cpython-310.pyc +0 -0
- MagicQuill/comfy/k_diffusion/__pycache__/utils.cpython-310.pyc +0 -0
- MagicQuill/comfy/k_diffusion/sampling.py +0 -843
- MagicQuill/comfy/k_diffusion/utils.py +0 -313
- MagicQuill/comfy/latent_formats.py +0 -141
- MagicQuill/comfy/ldm/.DS_Store +0 -0
- MagicQuill/comfy/ldm/__pycache__/util.cpython-310.pyc +0 -0
- MagicQuill/comfy/ldm/audio/__pycache__/autoencoder.cpython-310.pyc +0 -0
- MagicQuill/comfy/ldm/audio/__pycache__/dit.cpython-310.pyc +0 -0
- MagicQuill/comfy/ldm/audio/__pycache__/embedders.cpython-310.pyc +0 -0
- MagicQuill/comfy/ldm/audio/autoencoder.py +0 -282
- MagicQuill/comfy/ldm/audio/dit.py +0 -888
- MagicQuill/comfy/ldm/audio/embedders.py +0 -108
- MagicQuill/comfy/ldm/cascade/__pycache__/common.cpython-310.pyc +0 -0
- MagicQuill/comfy/ldm/cascade/__pycache__/controlnet.cpython-310.pyc +0 -0
- MagicQuill/comfy/ldm/cascade/__pycache__/stage_a.cpython-310.pyc +0 -0
- MagicQuill/comfy/ldm/cascade/__pycache__/stage_b.cpython-310.pyc +0 -0
- MagicQuill/comfy/ldm/cascade/__pycache__/stage_c.cpython-310.pyc +0 -0
- MagicQuill/comfy/ldm/cascade/__pycache__/stage_c_coder.cpython-310.pyc +0 -0
- MagicQuill/comfy/ldm/cascade/common.py +0 -161
- MagicQuill/comfy/ldm/cascade/controlnet.py +0 -93
- MagicQuill/comfy/ldm/cascade/stage_a.py +0 -255
MagicQuill/.DS_Store
DELETED
Binary file (6.15 kB)
|
|
MagicQuill/brushnet/brushnet.json
DELETED
@@ -1,58 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"_class_name": "BrushNetModel",
|
3 |
-
"_diffusers_version": "0.27.0.dev0",
|
4 |
-
"_name_or_path": "runs/logs/brushnet_randommask/checkpoint-100000",
|
5 |
-
"act_fn": "silu",
|
6 |
-
"addition_embed_type": null,
|
7 |
-
"addition_embed_type_num_heads": 64,
|
8 |
-
"addition_time_embed_dim": null,
|
9 |
-
"attention_head_dim": 8,
|
10 |
-
"block_out_channels": [
|
11 |
-
320,
|
12 |
-
640,
|
13 |
-
1280,
|
14 |
-
1280
|
15 |
-
],
|
16 |
-
"brushnet_conditioning_channel_order": "rgb",
|
17 |
-
"class_embed_type": null,
|
18 |
-
"conditioning_channels": 5,
|
19 |
-
"conditioning_embedding_out_channels": [
|
20 |
-
16,
|
21 |
-
32,
|
22 |
-
96,
|
23 |
-
256
|
24 |
-
],
|
25 |
-
"cross_attention_dim": 768,
|
26 |
-
"down_block_types": [
|
27 |
-
"DownBlock2D",
|
28 |
-
"DownBlock2D",
|
29 |
-
"DownBlock2D",
|
30 |
-
"DownBlock2D"
|
31 |
-
],
|
32 |
-
"downsample_padding": 1,
|
33 |
-
"encoder_hid_dim": null,
|
34 |
-
"encoder_hid_dim_type": null,
|
35 |
-
"flip_sin_to_cos": true,
|
36 |
-
"freq_shift": 0,
|
37 |
-
"global_pool_conditions": false,
|
38 |
-
"in_channels": 4,
|
39 |
-
"layers_per_block": 2,
|
40 |
-
"mid_block_scale_factor": 1,
|
41 |
-
"mid_block_type": "MidBlock2D",
|
42 |
-
"norm_eps": 1e-05,
|
43 |
-
"norm_num_groups": 32,
|
44 |
-
"num_attention_heads": null,
|
45 |
-
"num_class_embeds": null,
|
46 |
-
"only_cross_attention": false,
|
47 |
-
"projection_class_embeddings_input_dim": null,
|
48 |
-
"resnet_time_scale_shift": "default",
|
49 |
-
"transformer_layers_per_block": 1,
|
50 |
-
"up_block_types": [
|
51 |
-
"UpBlock2D",
|
52 |
-
"UpBlock2D",
|
53 |
-
"UpBlock2D",
|
54 |
-
"UpBlock2D"
|
55 |
-
],
|
56 |
-
"upcast_attention": false,
|
57 |
-
"use_linear_projection": false
|
58 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/brushnet/brushnet.py
DELETED
@@ -1,949 +0,0 @@
|
|
1 |
-
from dataclasses import dataclass
|
2 |
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
3 |
-
|
4 |
-
import torch
|
5 |
-
from torch import nn
|
6 |
-
from torch.nn import functional as F
|
7 |
-
|
8 |
-
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
9 |
-
from diffusers.utils import BaseOutput, logging
|
10 |
-
from diffusers.models.attention_processor import (
|
11 |
-
ADDED_KV_ATTENTION_PROCESSORS,
|
12 |
-
CROSS_ATTENTION_PROCESSORS,
|
13 |
-
AttentionProcessor,
|
14 |
-
AttnAddedKVProcessor,
|
15 |
-
AttnProcessor,
|
16 |
-
)
|
17 |
-
from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
18 |
-
from diffusers.models.modeling_utils import ModelMixin
|
19 |
-
|
20 |
-
from .unet_2d_blocks import (
|
21 |
-
CrossAttnDownBlock2D,
|
22 |
-
DownBlock2D,
|
23 |
-
UNetMidBlock2D,
|
24 |
-
UNetMidBlock2DCrossAttn,
|
25 |
-
get_down_block,
|
26 |
-
get_mid_block,
|
27 |
-
get_up_block,
|
28 |
-
MidBlock2D
|
29 |
-
)
|
30 |
-
|
31 |
-
from .unet_2d_condition import UNet2DConditionModel
|
32 |
-
|
33 |
-
|
34 |
-
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
35 |
-
|
36 |
-
|
37 |
-
@dataclass
|
38 |
-
class BrushNetOutput(BaseOutput):
|
39 |
-
"""
|
40 |
-
The output of [`BrushNetModel`].
|
41 |
-
|
42 |
-
Args:
|
43 |
-
up_block_res_samples (`tuple[torch.Tensor]`):
|
44 |
-
A tuple of upsample activations at different resolutions for each upsampling block. Each tensor should
|
45 |
-
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
46 |
-
used to condition the original UNet's upsampling activations.
|
47 |
-
down_block_res_samples (`tuple[torch.Tensor]`):
|
48 |
-
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
|
49 |
-
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
50 |
-
used to condition the original UNet's downsampling activations.
|
51 |
-
mid_down_block_re_sample (`torch.Tensor`):
|
52 |
-
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
|
53 |
-
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
|
54 |
-
Output can be used to condition the original UNet's middle block activation.
|
55 |
-
"""
|
56 |
-
|
57 |
-
up_block_res_samples: Tuple[torch.Tensor]
|
58 |
-
down_block_res_samples: Tuple[torch.Tensor]
|
59 |
-
mid_block_res_sample: torch.Tensor
|
60 |
-
|
61 |
-
|
62 |
-
class BrushNetModel(ModelMixin, ConfigMixin):
|
63 |
-
"""
|
64 |
-
A BrushNet model.
|
65 |
-
|
66 |
-
Args:
|
67 |
-
in_channels (`int`, defaults to 4):
|
68 |
-
The number of channels in the input sample.
|
69 |
-
flip_sin_to_cos (`bool`, defaults to `True`):
|
70 |
-
Whether to flip the sin to cos in the time embedding.
|
71 |
-
freq_shift (`int`, defaults to 0):
|
72 |
-
The frequency shift to apply to the time embedding.
|
73 |
-
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
74 |
-
The tuple of downsample blocks to use.
|
75 |
-
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
76 |
-
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
|
77 |
-
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
|
78 |
-
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
79 |
-
The tuple of upsample blocks to use.
|
80 |
-
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
81 |
-
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
82 |
-
The tuple of output channels for each block.
|
83 |
-
layers_per_block (`int`, defaults to 2):
|
84 |
-
The number of layers per block.
|
85 |
-
downsample_padding (`int`, defaults to 1):
|
86 |
-
The padding to use for the downsampling convolution.
|
87 |
-
mid_block_scale_factor (`float`, defaults to 1):
|
88 |
-
The scale factor to use for the mid block.
|
89 |
-
act_fn (`str`, defaults to "silu"):
|
90 |
-
The activation function to use.
|
91 |
-
norm_num_groups (`int`, *optional*, defaults to 32):
|
92 |
-
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
93 |
-
in post-processing.
|
94 |
-
norm_eps (`float`, defaults to 1e-5):
|
95 |
-
The epsilon to use for the normalization.
|
96 |
-
cross_attention_dim (`int`, defaults to 1280):
|
97 |
-
The dimension of the cross attention features.
|
98 |
-
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
99 |
-
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
100 |
-
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
101 |
-
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
102 |
-
encoder_hid_dim (`int`, *optional*, defaults to None):
|
103 |
-
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
104 |
-
dimension to `cross_attention_dim`.
|
105 |
-
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
106 |
-
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
107 |
-
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
108 |
-
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
109 |
-
The dimension of the attention heads.
|
110 |
-
use_linear_projection (`bool`, defaults to `False`):
|
111 |
-
class_embed_type (`str`, *optional*, defaults to `None`):
|
112 |
-
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
113 |
-
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
114 |
-
addition_embed_type (`str`, *optional*, defaults to `None`):
|
115 |
-
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
116 |
-
"text". "text" will use the `TextTimeEmbedding` layer.
|
117 |
-
num_class_embeds (`int`, *optional*, defaults to 0):
|
118 |
-
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
119 |
-
class conditioning with `class_embed_type` equal to `None`.
|
120 |
-
upcast_attention (`bool`, defaults to `False`):
|
121 |
-
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
122 |
-
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
123 |
-
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
124 |
-
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
125 |
-
`class_embed_type="projection"`.
|
126 |
-
brushnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
127 |
-
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
128 |
-
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
129 |
-
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
130 |
-
global_pool_conditions (`bool`, defaults to `False`):
|
131 |
-
TODO(Patrick) - unused parameter.
|
132 |
-
addition_embed_type_num_heads (`int`, defaults to 64):
|
133 |
-
The number of heads to use for the `TextTimeEmbedding` layer.
|
134 |
-
"""
|
135 |
-
|
136 |
-
_supports_gradient_checkpointing = True
|
137 |
-
|
138 |
-
@register_to_config
|
139 |
-
def __init__(
|
140 |
-
self,
|
141 |
-
in_channels: int = 4,
|
142 |
-
conditioning_channels: int = 5,
|
143 |
-
flip_sin_to_cos: bool = True,
|
144 |
-
freq_shift: int = 0,
|
145 |
-
down_block_types: Tuple[str, ...] = (
|
146 |
-
"DownBlock2D",
|
147 |
-
"DownBlock2D",
|
148 |
-
"DownBlock2D",
|
149 |
-
"DownBlock2D",
|
150 |
-
),
|
151 |
-
mid_block_type: Optional[str] = "UNetMidBlock2D",
|
152 |
-
up_block_types: Tuple[str, ...] = (
|
153 |
-
"UpBlock2D",
|
154 |
-
"UpBlock2D",
|
155 |
-
"UpBlock2D",
|
156 |
-
"UpBlock2D",
|
157 |
-
),
|
158 |
-
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
159 |
-
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
160 |
-
layers_per_block: int = 2,
|
161 |
-
downsample_padding: int = 1,
|
162 |
-
mid_block_scale_factor: float = 1,
|
163 |
-
act_fn: str = "silu",
|
164 |
-
norm_num_groups: Optional[int] = 32,
|
165 |
-
norm_eps: float = 1e-5,
|
166 |
-
cross_attention_dim: int = 1280,
|
167 |
-
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
168 |
-
encoder_hid_dim: Optional[int] = None,
|
169 |
-
encoder_hid_dim_type: Optional[str] = None,
|
170 |
-
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
171 |
-
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
172 |
-
use_linear_projection: bool = False,
|
173 |
-
class_embed_type: Optional[str] = None,
|
174 |
-
addition_embed_type: Optional[str] = None,
|
175 |
-
addition_time_embed_dim: Optional[int] = None,
|
176 |
-
num_class_embeds: Optional[int] = None,
|
177 |
-
upcast_attention: bool = False,
|
178 |
-
resnet_time_scale_shift: str = "default",
|
179 |
-
projection_class_embeddings_input_dim: Optional[int] = None,
|
180 |
-
brushnet_conditioning_channel_order: str = "rgb",
|
181 |
-
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
182 |
-
global_pool_conditions: bool = False,
|
183 |
-
addition_embed_type_num_heads: int = 64,
|
184 |
-
):
|
185 |
-
super().__init__()
|
186 |
-
|
187 |
-
# If `num_attention_heads` is not defined (which is the case for most models)
|
188 |
-
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
189 |
-
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
190 |
-
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
191 |
-
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
192 |
-
# which is why we correct for the naming here.
|
193 |
-
num_attention_heads = num_attention_heads or attention_head_dim
|
194 |
-
|
195 |
-
# Check inputs
|
196 |
-
if len(down_block_types) != len(up_block_types):
|
197 |
-
raise ValueError(
|
198 |
-
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
199 |
-
)
|
200 |
-
|
201 |
-
if len(block_out_channels) != len(down_block_types):
|
202 |
-
raise ValueError(
|
203 |
-
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
204 |
-
)
|
205 |
-
|
206 |
-
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
207 |
-
raise ValueError(
|
208 |
-
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
209 |
-
)
|
210 |
-
|
211 |
-
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
212 |
-
raise ValueError(
|
213 |
-
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
214 |
-
)
|
215 |
-
|
216 |
-
if isinstance(transformer_layers_per_block, int):
|
217 |
-
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
218 |
-
|
219 |
-
# input
|
220 |
-
conv_in_kernel = 3
|
221 |
-
conv_in_padding = (conv_in_kernel - 1) // 2
|
222 |
-
self.conv_in_condition = nn.Conv2d(
|
223 |
-
in_channels+conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
224 |
-
)
|
225 |
-
|
226 |
-
# time
|
227 |
-
time_embed_dim = block_out_channels[0] * 4
|
228 |
-
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
229 |
-
timestep_input_dim = block_out_channels[0]
|
230 |
-
self.time_embedding = TimestepEmbedding(
|
231 |
-
timestep_input_dim,
|
232 |
-
time_embed_dim,
|
233 |
-
act_fn=act_fn,
|
234 |
-
)
|
235 |
-
|
236 |
-
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
237 |
-
encoder_hid_dim_type = "text_proj"
|
238 |
-
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
239 |
-
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
240 |
-
|
241 |
-
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
242 |
-
raise ValueError(
|
243 |
-
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
244 |
-
)
|
245 |
-
|
246 |
-
if encoder_hid_dim_type == "text_proj":
|
247 |
-
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
248 |
-
elif encoder_hid_dim_type == "text_image_proj":
|
249 |
-
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
250 |
-
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
251 |
-
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
252 |
-
self.encoder_hid_proj = TextImageProjection(
|
253 |
-
text_embed_dim=encoder_hid_dim,
|
254 |
-
image_embed_dim=cross_attention_dim,
|
255 |
-
cross_attention_dim=cross_attention_dim,
|
256 |
-
)
|
257 |
-
|
258 |
-
elif encoder_hid_dim_type is not None:
|
259 |
-
raise ValueError(
|
260 |
-
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
261 |
-
)
|
262 |
-
else:
|
263 |
-
self.encoder_hid_proj = None
|
264 |
-
|
265 |
-
# class embedding
|
266 |
-
if class_embed_type is None and num_class_embeds is not None:
|
267 |
-
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
268 |
-
elif class_embed_type == "timestep":
|
269 |
-
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
270 |
-
elif class_embed_type == "identity":
|
271 |
-
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
272 |
-
elif class_embed_type == "projection":
|
273 |
-
if projection_class_embeddings_input_dim is None:
|
274 |
-
raise ValueError(
|
275 |
-
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
276 |
-
)
|
277 |
-
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
278 |
-
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
279 |
-
# 2. it projects from an arbitrary input dimension.
|
280 |
-
#
|
281 |
-
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
282 |
-
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
283 |
-
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
284 |
-
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
285 |
-
else:
|
286 |
-
self.class_embedding = None
|
287 |
-
|
288 |
-
if addition_embed_type == "text":
|
289 |
-
if encoder_hid_dim is not None:
|
290 |
-
text_time_embedding_from_dim = encoder_hid_dim
|
291 |
-
else:
|
292 |
-
text_time_embedding_from_dim = cross_attention_dim
|
293 |
-
|
294 |
-
self.add_embedding = TextTimeEmbedding(
|
295 |
-
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
296 |
-
)
|
297 |
-
elif addition_embed_type == "text_image":
|
298 |
-
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
299 |
-
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
300 |
-
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
301 |
-
self.add_embedding = TextImageTimeEmbedding(
|
302 |
-
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
303 |
-
)
|
304 |
-
elif addition_embed_type == "text_time":
|
305 |
-
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
306 |
-
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
307 |
-
|
308 |
-
elif addition_embed_type is not None:
|
309 |
-
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
310 |
-
|
311 |
-
self.down_blocks = nn.ModuleList([])
|
312 |
-
self.brushnet_down_blocks = nn.ModuleList([])
|
313 |
-
|
314 |
-
if isinstance(only_cross_attention, bool):
|
315 |
-
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
316 |
-
|
317 |
-
if isinstance(attention_head_dim, int):
|
318 |
-
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
319 |
-
|
320 |
-
if isinstance(num_attention_heads, int):
|
321 |
-
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
322 |
-
|
323 |
-
# down
|
324 |
-
output_channel = block_out_channels[0]
|
325 |
-
|
326 |
-
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
327 |
-
brushnet_block = zero_module(brushnet_block)
|
328 |
-
self.brushnet_down_blocks.append(brushnet_block)
|
329 |
-
|
330 |
-
for i, down_block_type in enumerate(down_block_types):
|
331 |
-
input_channel = output_channel
|
332 |
-
output_channel = block_out_channels[i]
|
333 |
-
is_final_block = i == len(block_out_channels) - 1
|
334 |
-
|
335 |
-
down_block = get_down_block(
|
336 |
-
down_block_type,
|
337 |
-
num_layers=layers_per_block,
|
338 |
-
transformer_layers_per_block=transformer_layers_per_block[i],
|
339 |
-
in_channels=input_channel,
|
340 |
-
out_channels=output_channel,
|
341 |
-
temb_channels=time_embed_dim,
|
342 |
-
add_downsample=not is_final_block,
|
343 |
-
resnet_eps=norm_eps,
|
344 |
-
resnet_act_fn=act_fn,
|
345 |
-
resnet_groups=norm_num_groups,
|
346 |
-
cross_attention_dim=cross_attention_dim,
|
347 |
-
num_attention_heads=num_attention_heads[i],
|
348 |
-
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
349 |
-
downsample_padding=downsample_padding,
|
350 |
-
use_linear_projection=use_linear_projection,
|
351 |
-
only_cross_attention=only_cross_attention[i],
|
352 |
-
upcast_attention=upcast_attention,
|
353 |
-
resnet_time_scale_shift=resnet_time_scale_shift,
|
354 |
-
)
|
355 |
-
self.down_blocks.append(down_block)
|
356 |
-
|
357 |
-
for _ in range(layers_per_block):
|
358 |
-
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
359 |
-
brushnet_block = zero_module(brushnet_block)
|
360 |
-
self.brushnet_down_blocks.append(brushnet_block)
|
361 |
-
|
362 |
-
if not is_final_block:
|
363 |
-
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
364 |
-
brushnet_block = zero_module(brushnet_block)
|
365 |
-
self.brushnet_down_blocks.append(brushnet_block)
|
366 |
-
|
367 |
-
# mid
|
368 |
-
mid_block_channel = block_out_channels[-1]
|
369 |
-
|
370 |
-
brushnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
|
371 |
-
brushnet_block = zero_module(brushnet_block)
|
372 |
-
self.brushnet_mid_block = brushnet_block
|
373 |
-
|
374 |
-
self.mid_block = get_mid_block(
|
375 |
-
mid_block_type,
|
376 |
-
transformer_layers_per_block=transformer_layers_per_block[-1],
|
377 |
-
in_channels=mid_block_channel,
|
378 |
-
temb_channels=time_embed_dim,
|
379 |
-
resnet_eps=norm_eps,
|
380 |
-
resnet_act_fn=act_fn,
|
381 |
-
output_scale_factor=mid_block_scale_factor,
|
382 |
-
resnet_time_scale_shift=resnet_time_scale_shift,
|
383 |
-
cross_attention_dim=cross_attention_dim,
|
384 |
-
num_attention_heads=num_attention_heads[-1],
|
385 |
-
resnet_groups=norm_num_groups,
|
386 |
-
use_linear_projection=use_linear_projection,
|
387 |
-
upcast_attention=upcast_attention,
|
388 |
-
)
|
389 |
-
|
390 |
-
# count how many layers upsample the images
|
391 |
-
self.num_upsamplers = 0
|
392 |
-
|
393 |
-
# up
|
394 |
-
reversed_block_out_channels = list(reversed(block_out_channels))
|
395 |
-
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
396 |
-
reversed_transformer_layers_per_block = (list(reversed(transformer_layers_per_block)))
|
397 |
-
only_cross_attention = list(reversed(only_cross_attention))
|
398 |
-
|
399 |
-
output_channel = reversed_block_out_channels[0]
|
400 |
-
|
401 |
-
self.up_blocks = nn.ModuleList([])
|
402 |
-
self.brushnet_up_blocks = nn.ModuleList([])
|
403 |
-
|
404 |
-
for i, up_block_type in enumerate(up_block_types):
|
405 |
-
is_final_block = i == len(block_out_channels) - 1
|
406 |
-
|
407 |
-
prev_output_channel = output_channel
|
408 |
-
output_channel = reversed_block_out_channels[i]
|
409 |
-
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
410 |
-
|
411 |
-
# add upsample block for all BUT final layer
|
412 |
-
if not is_final_block:
|
413 |
-
add_upsample = True
|
414 |
-
self.num_upsamplers += 1
|
415 |
-
else:
|
416 |
-
add_upsample = False
|
417 |
-
|
418 |
-
up_block = get_up_block(
|
419 |
-
up_block_type,
|
420 |
-
num_layers=layers_per_block+1,
|
421 |
-
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
422 |
-
in_channels=input_channel,
|
423 |
-
out_channels=output_channel,
|
424 |
-
prev_output_channel=prev_output_channel,
|
425 |
-
temb_channels=time_embed_dim,
|
426 |
-
add_upsample=add_upsample,
|
427 |
-
resnet_eps=norm_eps,
|
428 |
-
resnet_act_fn=act_fn,
|
429 |
-
resolution_idx=i,
|
430 |
-
resnet_groups=norm_num_groups,
|
431 |
-
cross_attention_dim=cross_attention_dim,
|
432 |
-
num_attention_heads=reversed_num_attention_heads[i],
|
433 |
-
use_linear_projection=use_linear_projection,
|
434 |
-
only_cross_attention=only_cross_attention[i],
|
435 |
-
upcast_attention=upcast_attention,
|
436 |
-
resnet_time_scale_shift=resnet_time_scale_shift,
|
437 |
-
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
438 |
-
)
|
439 |
-
self.up_blocks.append(up_block)
|
440 |
-
prev_output_channel = output_channel
|
441 |
-
|
442 |
-
for _ in range(layers_per_block+1):
|
443 |
-
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
444 |
-
brushnet_block = zero_module(brushnet_block)
|
445 |
-
self.brushnet_up_blocks.append(brushnet_block)
|
446 |
-
|
447 |
-
if not is_final_block:
|
448 |
-
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
449 |
-
brushnet_block = zero_module(brushnet_block)
|
450 |
-
self.brushnet_up_blocks.append(brushnet_block)
|
451 |
-
|
452 |
-
|
453 |
-
@classmethod
|
454 |
-
def from_unet(
|
455 |
-
cls,
|
456 |
-
unet: UNet2DConditionModel,
|
457 |
-
brushnet_conditioning_channel_order: str = "rgb",
|
458 |
-
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
459 |
-
load_weights_from_unet: bool = True,
|
460 |
-
conditioning_channels: int = 5,
|
461 |
-
):
|
462 |
-
r"""
|
463 |
-
Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
|
464 |
-
|
465 |
-
Parameters:
|
466 |
-
unet (`UNet2DConditionModel`):
|
467 |
-
The UNet model weights to copy to the [`BrushNetModel`]. All configuration options are also copied
|
468 |
-
where applicable.
|
469 |
-
"""
|
470 |
-
transformer_layers_per_block = (
|
471 |
-
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
472 |
-
)
|
473 |
-
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
474 |
-
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
|
475 |
-
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
|
476 |
-
addition_time_embed_dim = (
|
477 |
-
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
|
478 |
-
)
|
479 |
-
|
480 |
-
brushnet = cls(
|
481 |
-
in_channels=unet.config.in_channels,
|
482 |
-
conditioning_channels=conditioning_channels,
|
483 |
-
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
484 |
-
freq_shift=unet.config.freq_shift,
|
485 |
-
down_block_types=["DownBlock2D" for block_name in unet.config.down_block_types],
|
486 |
-
mid_block_type='MidBlock2D',
|
487 |
-
up_block_types=["UpBlock2D" for block_name in unet.config.down_block_types],
|
488 |
-
only_cross_attention=unet.config.only_cross_attention,
|
489 |
-
block_out_channels=unet.config.block_out_channels,
|
490 |
-
layers_per_block=unet.config.layers_per_block,
|
491 |
-
downsample_padding=unet.config.downsample_padding,
|
492 |
-
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
493 |
-
act_fn=unet.config.act_fn,
|
494 |
-
norm_num_groups=unet.config.norm_num_groups,
|
495 |
-
norm_eps=unet.config.norm_eps,
|
496 |
-
cross_attention_dim=unet.config.cross_attention_dim,
|
497 |
-
transformer_layers_per_block=transformer_layers_per_block,
|
498 |
-
encoder_hid_dim=encoder_hid_dim,
|
499 |
-
encoder_hid_dim_type=encoder_hid_dim_type,
|
500 |
-
attention_head_dim=unet.config.attention_head_dim,
|
501 |
-
num_attention_heads=unet.config.num_attention_heads,
|
502 |
-
use_linear_projection=unet.config.use_linear_projection,
|
503 |
-
class_embed_type=unet.config.class_embed_type,
|
504 |
-
addition_embed_type=addition_embed_type,
|
505 |
-
addition_time_embed_dim=addition_time_embed_dim,
|
506 |
-
num_class_embeds=unet.config.num_class_embeds,
|
507 |
-
upcast_attention=unet.config.upcast_attention,
|
508 |
-
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
509 |
-
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
510 |
-
brushnet_conditioning_channel_order=brushnet_conditioning_channel_order,
|
511 |
-
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
512 |
-
)
|
513 |
-
|
514 |
-
if load_weights_from_unet:
|
515 |
-
conv_in_condition_weight=torch.zeros_like(brushnet.conv_in_condition.weight)
|
516 |
-
conv_in_condition_weight[:,:4,...]=unet.conv_in.weight
|
517 |
-
conv_in_condition_weight[:,4:8,...]=unet.conv_in.weight
|
518 |
-
brushnet.conv_in_condition.weight=torch.nn.Parameter(conv_in_condition_weight)
|
519 |
-
brushnet.conv_in_condition.bias=unet.conv_in.bias
|
520 |
-
|
521 |
-
brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
522 |
-
brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
523 |
-
|
524 |
-
if brushnet.class_embedding:
|
525 |
-
brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
526 |
-
|
527 |
-
brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(),strict=False)
|
528 |
-
brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(),strict=False)
|
529 |
-
brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(),strict=False)
|
530 |
-
|
531 |
-
return brushnet
|
532 |
-
|
533 |
-
@property
|
534 |
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
535 |
-
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
536 |
-
r"""
|
537 |
-
Returns:
|
538 |
-
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
539 |
-
indexed by its weight name.
|
540 |
-
"""
|
541 |
-
# set recursively
|
542 |
-
processors = {}
|
543 |
-
|
544 |
-
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
545 |
-
if hasattr(module, "get_processor"):
|
546 |
-
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
547 |
-
|
548 |
-
for sub_name, child in module.named_children():
|
549 |
-
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
550 |
-
|
551 |
-
return processors
|
552 |
-
|
553 |
-
for name, module in self.named_children():
|
554 |
-
fn_recursive_add_processors(name, module, processors)
|
555 |
-
|
556 |
-
return processors
|
557 |
-
|
558 |
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
559 |
-
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
560 |
-
r"""
|
561 |
-
Sets the attention processor to use to compute attention.
|
562 |
-
|
563 |
-
Parameters:
|
564 |
-
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
565 |
-
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
566 |
-
for **all** `Attention` layers.
|
567 |
-
|
568 |
-
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
569 |
-
processor. This is strongly recommended when setting trainable attention processors.
|
570 |
-
|
571 |
-
"""
|
572 |
-
count = len(self.attn_processors.keys())
|
573 |
-
|
574 |
-
if isinstance(processor, dict) and len(processor) != count:
|
575 |
-
raise ValueError(
|
576 |
-
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
577 |
-
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
578 |
-
)
|
579 |
-
|
580 |
-
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
581 |
-
if hasattr(module, "set_processor"):
|
582 |
-
if not isinstance(processor, dict):
|
583 |
-
module.set_processor(processor)
|
584 |
-
else:
|
585 |
-
module.set_processor(processor.pop(f"{name}.processor"))
|
586 |
-
|
587 |
-
for sub_name, child in module.named_children():
|
588 |
-
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
589 |
-
|
590 |
-
for name, module in self.named_children():
|
591 |
-
fn_recursive_attn_processor(name, module, processor)
|
592 |
-
|
593 |
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
594 |
-
def set_default_attn_processor(self):
|
595 |
-
"""
|
596 |
-
Disables custom attention processors and sets the default attention implementation.
|
597 |
-
"""
|
598 |
-
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
599 |
-
processor = AttnAddedKVProcessor()
|
600 |
-
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
601 |
-
processor = AttnProcessor()
|
602 |
-
else:
|
603 |
-
raise ValueError(
|
604 |
-
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
605 |
-
)
|
606 |
-
|
607 |
-
self.set_attn_processor(processor)
|
608 |
-
|
609 |
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
610 |
-
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
611 |
-
r"""
|
612 |
-
Enable sliced attention computation.
|
613 |
-
|
614 |
-
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
615 |
-
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
616 |
-
|
617 |
-
Args:
|
618 |
-
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
619 |
-
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
620 |
-
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
621 |
-
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
622 |
-
must be a multiple of `slice_size`.
|
623 |
-
"""
|
624 |
-
sliceable_head_dims = []
|
625 |
-
|
626 |
-
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
627 |
-
if hasattr(module, "set_attention_slice"):
|
628 |
-
sliceable_head_dims.append(module.sliceable_head_dim)
|
629 |
-
|
630 |
-
for child in module.children():
|
631 |
-
fn_recursive_retrieve_sliceable_dims(child)
|
632 |
-
|
633 |
-
# retrieve number of attention layers
|
634 |
-
for module in self.children():
|
635 |
-
fn_recursive_retrieve_sliceable_dims(module)
|
636 |
-
|
637 |
-
num_sliceable_layers = len(sliceable_head_dims)
|
638 |
-
|
639 |
-
if slice_size == "auto":
|
640 |
-
# half the attention head size is usually a good trade-off between
|
641 |
-
# speed and memory
|
642 |
-
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
643 |
-
elif slice_size == "max":
|
644 |
-
# make smallest slice possible
|
645 |
-
slice_size = num_sliceable_layers * [1]
|
646 |
-
|
647 |
-
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
648 |
-
|
649 |
-
if len(slice_size) != len(sliceable_head_dims):
|
650 |
-
raise ValueError(
|
651 |
-
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
652 |
-
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
653 |
-
)
|
654 |
-
|
655 |
-
for i in range(len(slice_size)):
|
656 |
-
size = slice_size[i]
|
657 |
-
dim = sliceable_head_dims[i]
|
658 |
-
if size is not None and size > dim:
|
659 |
-
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
660 |
-
|
661 |
-
# Recursively walk through all the children.
|
662 |
-
# Any children which exposes the set_attention_slice method
|
663 |
-
# gets the message
|
664 |
-
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
665 |
-
if hasattr(module, "set_attention_slice"):
|
666 |
-
module.set_attention_slice(slice_size.pop())
|
667 |
-
|
668 |
-
for child in module.children():
|
669 |
-
fn_recursive_set_attention_slice(child, slice_size)
|
670 |
-
|
671 |
-
reversed_slice_size = list(reversed(slice_size))
|
672 |
-
for module in self.children():
|
673 |
-
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
674 |
-
|
675 |
-
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
676 |
-
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
677 |
-
module.gradient_checkpointing = value
|
678 |
-
|
679 |
-
def forward(
|
680 |
-
self,
|
681 |
-
sample: torch.FloatTensor,
|
682 |
-
encoder_hidden_states: torch.Tensor,
|
683 |
-
brushnet_cond: torch.FloatTensor,
|
684 |
-
timestep = None,
|
685 |
-
time_emb = None,
|
686 |
-
conditioning_scale: float = 1.0,
|
687 |
-
class_labels: Optional[torch.Tensor] = None,
|
688 |
-
timestep_cond: Optional[torch.Tensor] = None,
|
689 |
-
attention_mask: Optional[torch.Tensor] = None,
|
690 |
-
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
691 |
-
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
692 |
-
guess_mode: bool = False,
|
693 |
-
return_dict: bool = True,
|
694 |
-
debug = False,
|
695 |
-
) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
|
696 |
-
"""
|
697 |
-
The [`BrushNetModel`] forward method.
|
698 |
-
|
699 |
-
Args:
|
700 |
-
sample (`torch.FloatTensor`):
|
701 |
-
The noisy input tensor.
|
702 |
-
timestep (`Union[torch.Tensor, float, int]`):
|
703 |
-
The number of timesteps to denoise an input.
|
704 |
-
encoder_hidden_states (`torch.Tensor`):
|
705 |
-
The encoder hidden states.
|
706 |
-
brushnet_cond (`torch.FloatTensor`):
|
707 |
-
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
708 |
-
conditioning_scale (`float`, defaults to `1.0`):
|
709 |
-
The scale factor for BrushNet outputs.
|
710 |
-
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
711 |
-
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
712 |
-
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
713 |
-
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
|
714 |
-
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
|
715 |
-
embeddings.
|
716 |
-
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
717 |
-
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
718 |
-
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
719 |
-
negative values to the attention scores corresponding to "discard" tokens.
|
720 |
-
added_cond_kwargs (`dict`):
|
721 |
-
Additional conditions for the Stable Diffusion XL UNet.
|
722 |
-
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
723 |
-
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
724 |
-
guess_mode (`bool`, defaults to `False`):
|
725 |
-
In this mode, the BrushNet encoder tries its best to recognize the input content of the input even if
|
726 |
-
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
727 |
-
return_dict (`bool`, defaults to `True`):
|
728 |
-
Whether or not to return a [`~models.brushnet.BrushNetOutput`] instead of a plain tuple.
|
729 |
-
|
730 |
-
Returns:
|
731 |
-
[`~models.brushnet.BrushNetOutput`] **or** `tuple`:
|
732 |
-
If `return_dict` is `True`, a [`~models.brushnet.BrushNetOutput`] is returned, otherwise a tuple is
|
733 |
-
returned where the first element is the sample tensor.
|
734 |
-
"""
|
735 |
-
|
736 |
-
# check channel order
|
737 |
-
channel_order = self.config.brushnet_conditioning_channel_order
|
738 |
-
|
739 |
-
if channel_order == "rgb":
|
740 |
-
# in rgb order by default
|
741 |
-
...
|
742 |
-
elif channel_order == "bgr":
|
743 |
-
brushnet_cond = torch.flip(brushnet_cond, dims=[1])
|
744 |
-
else:
|
745 |
-
raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
|
746 |
-
|
747 |
-
# prepare attention_mask
|
748 |
-
if attention_mask is not None:
|
749 |
-
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
750 |
-
attention_mask = attention_mask.unsqueeze(1)
|
751 |
-
|
752 |
-
if timestep is None and time_emb is None:
|
753 |
-
raise ValueError(f"`timestep` and `emb` are both None")
|
754 |
-
|
755 |
-
#print("BN: sample.device", sample.device)
|
756 |
-
#print("BN: TE.device", self.time_embedding.linear_1.weight.device)
|
757 |
-
|
758 |
-
if timestep is not None:
|
759 |
-
# 1. time
|
760 |
-
timesteps = timestep
|
761 |
-
if not torch.is_tensor(timesteps):
|
762 |
-
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
763 |
-
# This would be a good case for the `match` statement (Python 3.10+)
|
764 |
-
is_mps = sample.device.type == "mps"
|
765 |
-
if isinstance(timestep, float):
|
766 |
-
dtype = torch.float32 if is_mps else torch.float64
|
767 |
-
else:
|
768 |
-
dtype = torch.int32 if is_mps else torch.int64
|
769 |
-
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
770 |
-
elif len(timesteps.shape) == 0:
|
771 |
-
timesteps = timesteps[None].to(sample.device)
|
772 |
-
|
773 |
-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
774 |
-
timesteps = timesteps.expand(sample.shape[0])
|
775 |
-
|
776 |
-
t_emb = self.time_proj(timesteps)
|
777 |
-
|
778 |
-
# timesteps does not contain any weights and will always return f32 tensors
|
779 |
-
# but time_embedding might actually be running in fp16. so we need to cast here.
|
780 |
-
# there might be better ways to encapsulate this.
|
781 |
-
t_emb = t_emb.to(dtype=sample.dtype)
|
782 |
-
|
783 |
-
#print("t_emb.device =",t_emb.device)
|
784 |
-
|
785 |
-
emb = self.time_embedding(t_emb, timestep_cond)
|
786 |
-
aug_emb = None
|
787 |
-
|
788 |
-
#print('emb.shape', emb.shape)
|
789 |
-
|
790 |
-
if self.class_embedding is not None:
|
791 |
-
if class_labels is None:
|
792 |
-
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
793 |
-
|
794 |
-
if self.config.class_embed_type == "timestep":
|
795 |
-
class_labels = self.time_proj(class_labels)
|
796 |
-
|
797 |
-
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
798 |
-
emb = emb + class_emb
|
799 |
-
|
800 |
-
if self.config.addition_embed_type is not None:
|
801 |
-
if self.config.addition_embed_type == "text":
|
802 |
-
aug_emb = self.add_embedding(encoder_hidden_states)
|
803 |
-
|
804 |
-
elif self.config.addition_embed_type == "text_time":
|
805 |
-
if "text_embeds" not in added_cond_kwargs:
|
806 |
-
raise ValueError(
|
807 |
-
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
808 |
-
)
|
809 |
-
text_embeds = added_cond_kwargs.get("text_embeds")
|
810 |
-
if "time_ids" not in added_cond_kwargs:
|
811 |
-
raise ValueError(
|
812 |
-
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
813 |
-
)
|
814 |
-
time_ids = added_cond_kwargs.get("time_ids")
|
815 |
-
time_embeds = self.add_time_proj(time_ids.flatten())
|
816 |
-
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
817 |
-
|
818 |
-
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
819 |
-
add_embeds = add_embeds.to(emb.dtype)
|
820 |
-
aug_emb = self.add_embedding(add_embeds)
|
821 |
-
|
822 |
-
#print('text_embeds', text_embeds.shape, 'time_ids', time_ids.shape, 'time_embeds', time_embeds.shape, 'add__embeds', add_embeds.shape, 'aug_emb', aug_emb.shape)
|
823 |
-
|
824 |
-
emb = emb + aug_emb if aug_emb is not None else emb
|
825 |
-
else:
|
826 |
-
emb = time_emb
|
827 |
-
|
828 |
-
# 2. pre-process
|
829 |
-
|
830 |
-
brushnet_cond=torch.concat([sample,brushnet_cond],1)
|
831 |
-
sample = self.conv_in_condition(brushnet_cond)
|
832 |
-
|
833 |
-
# 3. down
|
834 |
-
down_block_res_samples = (sample,)
|
835 |
-
for downsample_block in self.down_blocks:
|
836 |
-
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
837 |
-
sample, res_samples = downsample_block(
|
838 |
-
hidden_states=sample,
|
839 |
-
temb=emb,
|
840 |
-
encoder_hidden_states=encoder_hidden_states,
|
841 |
-
attention_mask=attention_mask,
|
842 |
-
cross_attention_kwargs=cross_attention_kwargs,
|
843 |
-
)
|
844 |
-
else:
|
845 |
-
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
846 |
-
|
847 |
-
down_block_res_samples += res_samples
|
848 |
-
|
849 |
-
# 4. PaintingNet down blocks
|
850 |
-
brushnet_down_block_res_samples = ()
|
851 |
-
for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
|
852 |
-
down_block_res_sample = brushnet_down_block(down_block_res_sample)
|
853 |
-
brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
|
854 |
-
|
855 |
-
|
856 |
-
# 5. mid
|
857 |
-
if self.mid_block is not None:
|
858 |
-
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
859 |
-
sample = self.mid_block(
|
860 |
-
sample,
|
861 |
-
emb,
|
862 |
-
encoder_hidden_states=encoder_hidden_states,
|
863 |
-
attention_mask=attention_mask,
|
864 |
-
cross_attention_kwargs=cross_attention_kwargs,
|
865 |
-
)
|
866 |
-
else:
|
867 |
-
sample = self.mid_block(sample, emb)
|
868 |
-
|
869 |
-
# 6. BrushNet mid blocks
|
870 |
-
brushnet_mid_block_res_sample = self.brushnet_mid_block(sample)
|
871 |
-
|
872 |
-
# 7. up
|
873 |
-
up_block_res_samples = ()
|
874 |
-
for i, upsample_block in enumerate(self.up_blocks):
|
875 |
-
is_final_block = i == len(self.up_blocks) - 1
|
876 |
-
|
877 |
-
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
878 |
-
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
879 |
-
|
880 |
-
# if we have not reached the final block and need to forward the
|
881 |
-
# upsample size, we do it here
|
882 |
-
if not is_final_block:
|
883 |
-
upsample_size = down_block_res_samples[-1].shape[2:]
|
884 |
-
|
885 |
-
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
886 |
-
sample, up_res_samples = upsample_block(
|
887 |
-
hidden_states=sample,
|
888 |
-
temb=emb,
|
889 |
-
res_hidden_states_tuple=res_samples,
|
890 |
-
encoder_hidden_states=encoder_hidden_states,
|
891 |
-
cross_attention_kwargs=cross_attention_kwargs,
|
892 |
-
upsample_size=upsample_size,
|
893 |
-
attention_mask=attention_mask,
|
894 |
-
return_res_samples=True
|
895 |
-
)
|
896 |
-
else:
|
897 |
-
sample, up_res_samples = upsample_block(
|
898 |
-
hidden_states=sample,
|
899 |
-
temb=emb,
|
900 |
-
res_hidden_states_tuple=res_samples,
|
901 |
-
upsample_size=upsample_size,
|
902 |
-
return_res_samples=True
|
903 |
-
)
|
904 |
-
|
905 |
-
up_block_res_samples += up_res_samples
|
906 |
-
|
907 |
-
# 8. BrushNet up blocks
|
908 |
-
brushnet_up_block_res_samples = ()
|
909 |
-
for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
|
910 |
-
up_block_res_sample = brushnet_up_block(up_block_res_sample)
|
911 |
-
brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
|
912 |
-
|
913 |
-
# 6. scaling
|
914 |
-
if guess_mode and not self.config.global_pool_conditions:
|
915 |
-
scales = torch.logspace(-1, 0, len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples), device=sample.device) # 0.1 to 1.0
|
916 |
-
scales = scales * conditioning_scale
|
917 |
-
|
918 |
-
brushnet_down_block_res_samples = [sample * scale for sample, scale in zip(brushnet_down_block_res_samples, scales[:len(brushnet_down_block_res_samples)])]
|
919 |
-
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
|
920 |
-
brushnet_up_block_res_samples = [sample * scale for sample, scale in zip(brushnet_up_block_res_samples, scales[len(brushnet_down_block_res_samples)+1:])]
|
921 |
-
else:
|
922 |
-
brushnet_down_block_res_samples = [sample * conditioning_scale for sample in brushnet_down_block_res_samples]
|
923 |
-
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
|
924 |
-
brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
|
925 |
-
|
926 |
-
|
927 |
-
if self.config.global_pool_conditions:
|
928 |
-
brushnet_down_block_res_samples = [
|
929 |
-
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
|
930 |
-
]
|
931 |
-
brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
|
932 |
-
brushnet_up_block_res_samples = [
|
933 |
-
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
|
934 |
-
]
|
935 |
-
|
936 |
-
if not return_dict:
|
937 |
-
return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
|
938 |
-
|
939 |
-
return BrushNetOutput(
|
940 |
-
down_block_res_samples=brushnet_down_block_res_samples,
|
941 |
-
mid_block_res_sample=brushnet_mid_block_res_sample,
|
942 |
-
up_block_res_samples=brushnet_up_block_res_samples
|
943 |
-
)
|
944 |
-
|
945 |
-
|
946 |
-
def zero_module(module):
|
947 |
-
for p in module.parameters():
|
948 |
-
nn.init.zeros_(p)
|
949 |
-
return module
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/brushnet/brushnet_ca.py
DELETED
@@ -1,983 +0,0 @@
|
|
1 |
-
from dataclasses import dataclass
|
2 |
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
3 |
-
|
4 |
-
import torch
|
5 |
-
from torch import nn
|
6 |
-
|
7 |
-
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
8 |
-
from diffusers.utils import BaseOutput, logging
|
9 |
-
from diffusers.models.attention_processor import (
|
10 |
-
ADDED_KV_ATTENTION_PROCESSORS,
|
11 |
-
CROSS_ATTENTION_PROCESSORS,
|
12 |
-
AttentionProcessor,
|
13 |
-
AttnAddedKVProcessor,
|
14 |
-
AttnProcessor,
|
15 |
-
)
|
16 |
-
from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
17 |
-
from diffusers.models.modeling_utils import ModelMixin
|
18 |
-
|
19 |
-
from .unet_2d_blocks import (
|
20 |
-
CrossAttnDownBlock2D,
|
21 |
-
DownBlock2D,
|
22 |
-
UNetMidBlock2D,
|
23 |
-
UNetMidBlock2DCrossAttn,
|
24 |
-
get_down_block,
|
25 |
-
get_mid_block,
|
26 |
-
get_up_block,
|
27 |
-
MidBlock2D
|
28 |
-
)
|
29 |
-
|
30 |
-
from .unet_2d_condition import UNet2DConditionModel
|
31 |
-
|
32 |
-
|
33 |
-
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
34 |
-
|
35 |
-
|
36 |
-
@dataclass
|
37 |
-
class BrushNetOutput(BaseOutput):
|
38 |
-
"""
|
39 |
-
The output of [`BrushNetModel`].
|
40 |
-
|
41 |
-
Args:
|
42 |
-
up_block_res_samples (`tuple[torch.Tensor]`):
|
43 |
-
A tuple of upsample activations at different resolutions for each upsampling block. Each tensor should
|
44 |
-
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
45 |
-
used to condition the original UNet's upsampling activations.
|
46 |
-
down_block_res_samples (`tuple[torch.Tensor]`):
|
47 |
-
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
|
48 |
-
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
49 |
-
used to condition the original UNet's downsampling activations.
|
50 |
-
mid_down_block_re_sample (`torch.Tensor`):
|
51 |
-
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
|
52 |
-
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
|
53 |
-
Output can be used to condition the original UNet's middle block activation.
|
54 |
-
"""
|
55 |
-
|
56 |
-
up_block_res_samples: Tuple[torch.Tensor]
|
57 |
-
down_block_res_samples: Tuple[torch.Tensor]
|
58 |
-
mid_block_res_sample: torch.Tensor
|
59 |
-
|
60 |
-
|
61 |
-
class BrushNetModel(ModelMixin, ConfigMixin):
|
62 |
-
"""
|
63 |
-
A BrushNet model.
|
64 |
-
|
65 |
-
Args:
|
66 |
-
in_channels (`int`, defaults to 4):
|
67 |
-
The number of channels in the input sample.
|
68 |
-
flip_sin_to_cos (`bool`, defaults to `True`):
|
69 |
-
Whether to flip the sin to cos in the time embedding.
|
70 |
-
freq_shift (`int`, defaults to 0):
|
71 |
-
The frequency shift to apply to the time embedding.
|
72 |
-
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
73 |
-
The tuple of downsample blocks to use.
|
74 |
-
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
75 |
-
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
|
76 |
-
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
|
77 |
-
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
78 |
-
The tuple of upsample blocks to use.
|
79 |
-
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
80 |
-
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
81 |
-
The tuple of output channels for each block.
|
82 |
-
layers_per_block (`int`, defaults to 2):
|
83 |
-
The number of layers per block.
|
84 |
-
downsample_padding (`int`, defaults to 1):
|
85 |
-
The padding to use for the downsampling convolution.
|
86 |
-
mid_block_scale_factor (`float`, defaults to 1):
|
87 |
-
The scale factor to use for the mid block.
|
88 |
-
act_fn (`str`, defaults to "silu"):
|
89 |
-
The activation function to use.
|
90 |
-
norm_num_groups (`int`, *optional*, defaults to 32):
|
91 |
-
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
92 |
-
in post-processing.
|
93 |
-
norm_eps (`float`, defaults to 1e-5):
|
94 |
-
The epsilon to use for the normalization.
|
95 |
-
cross_attention_dim (`int`, defaults to 1280):
|
96 |
-
The dimension of the cross attention features.
|
97 |
-
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
98 |
-
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
99 |
-
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
100 |
-
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
101 |
-
encoder_hid_dim (`int`, *optional*, defaults to None):
|
102 |
-
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
103 |
-
dimension to `cross_attention_dim`.
|
104 |
-
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
105 |
-
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
106 |
-
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
107 |
-
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
108 |
-
The dimension of the attention heads.
|
109 |
-
use_linear_projection (`bool`, defaults to `False`):
|
110 |
-
class_embed_type (`str`, *optional*, defaults to `None`):
|
111 |
-
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
112 |
-
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
113 |
-
addition_embed_type (`str`, *optional*, defaults to `None`):
|
114 |
-
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
115 |
-
"text". "text" will use the `TextTimeEmbedding` layer.
|
116 |
-
num_class_embeds (`int`, *optional*, defaults to 0):
|
117 |
-
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
118 |
-
class conditioning with `class_embed_type` equal to `None`.
|
119 |
-
upcast_attention (`bool`, defaults to `False`):
|
120 |
-
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
121 |
-
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
122 |
-
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
123 |
-
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
124 |
-
`class_embed_type="projection"`.
|
125 |
-
brushnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
126 |
-
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
127 |
-
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
128 |
-
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
129 |
-
global_pool_conditions (`bool`, defaults to `False`):
|
130 |
-
TODO(Patrick) - unused parameter.
|
131 |
-
addition_embed_type_num_heads (`int`, defaults to 64):
|
132 |
-
The number of heads to use for the `TextTimeEmbedding` layer.
|
133 |
-
"""
|
134 |
-
|
135 |
-
_supports_gradient_checkpointing = True
|
136 |
-
|
137 |
-
@register_to_config
|
138 |
-
def __init__(
|
139 |
-
self,
|
140 |
-
in_channels: int = 4,
|
141 |
-
conditioning_channels: int = 5,
|
142 |
-
flip_sin_to_cos: bool = True,
|
143 |
-
freq_shift: int = 0,
|
144 |
-
down_block_types: Tuple[str, ...] = (
|
145 |
-
"CrossAttnDownBlock2D",
|
146 |
-
"CrossAttnDownBlock2D",
|
147 |
-
"CrossAttnDownBlock2D",
|
148 |
-
"DownBlock2D",
|
149 |
-
),
|
150 |
-
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
151 |
-
up_block_types: Tuple[str, ...] = (
|
152 |
-
"UpBlock2D",
|
153 |
-
"CrossAttnUpBlock2D",
|
154 |
-
"CrossAttnUpBlock2D",
|
155 |
-
"CrossAttnUpBlock2D",
|
156 |
-
),
|
157 |
-
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
158 |
-
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
159 |
-
layers_per_block: int = 2,
|
160 |
-
downsample_padding: int = 1,
|
161 |
-
mid_block_scale_factor: float = 1,
|
162 |
-
act_fn: str = "silu",
|
163 |
-
norm_num_groups: Optional[int] = 32,
|
164 |
-
norm_eps: float = 1e-5,
|
165 |
-
cross_attention_dim: int = 1280,
|
166 |
-
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
167 |
-
encoder_hid_dim: Optional[int] = None,
|
168 |
-
encoder_hid_dim_type: Optional[str] = None,
|
169 |
-
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
170 |
-
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
171 |
-
use_linear_projection: bool = False,
|
172 |
-
class_embed_type: Optional[str] = None,
|
173 |
-
addition_embed_type: Optional[str] = None,
|
174 |
-
addition_time_embed_dim: Optional[int] = None,
|
175 |
-
num_class_embeds: Optional[int] = None,
|
176 |
-
upcast_attention: bool = False,
|
177 |
-
resnet_time_scale_shift: str = "default",
|
178 |
-
projection_class_embeddings_input_dim: Optional[int] = None,
|
179 |
-
brushnet_conditioning_channel_order: str = "rgb",
|
180 |
-
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
181 |
-
global_pool_conditions: bool = False,
|
182 |
-
addition_embed_type_num_heads: int = 64,
|
183 |
-
):
|
184 |
-
super().__init__()
|
185 |
-
|
186 |
-
# If `num_attention_heads` is not defined (which is the case for most models)
|
187 |
-
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
188 |
-
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
189 |
-
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
190 |
-
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
191 |
-
# which is why we correct for the naming here.
|
192 |
-
num_attention_heads = num_attention_heads or attention_head_dim
|
193 |
-
|
194 |
-
# Check inputs
|
195 |
-
if len(down_block_types) != len(up_block_types):
|
196 |
-
raise ValueError(
|
197 |
-
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
198 |
-
)
|
199 |
-
|
200 |
-
if len(block_out_channels) != len(down_block_types):
|
201 |
-
raise ValueError(
|
202 |
-
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
203 |
-
)
|
204 |
-
|
205 |
-
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
206 |
-
raise ValueError(
|
207 |
-
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
208 |
-
)
|
209 |
-
|
210 |
-
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
211 |
-
raise ValueError(
|
212 |
-
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
213 |
-
)
|
214 |
-
|
215 |
-
if isinstance(transformer_layers_per_block, int):
|
216 |
-
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
217 |
-
|
218 |
-
# input
|
219 |
-
conv_in_kernel = 3
|
220 |
-
conv_in_padding = (conv_in_kernel - 1) // 2
|
221 |
-
self.conv_in_condition = nn.Conv2d(
|
222 |
-
in_channels + conditioning_channels,
|
223 |
-
block_out_channels[0],
|
224 |
-
kernel_size=conv_in_kernel,
|
225 |
-
padding=conv_in_padding,
|
226 |
-
)
|
227 |
-
|
228 |
-
# time
|
229 |
-
time_embed_dim = block_out_channels[0] * 4
|
230 |
-
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
231 |
-
timestep_input_dim = block_out_channels[0]
|
232 |
-
self.time_embedding = TimestepEmbedding(
|
233 |
-
timestep_input_dim,
|
234 |
-
time_embed_dim,
|
235 |
-
act_fn=act_fn,
|
236 |
-
)
|
237 |
-
|
238 |
-
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
239 |
-
encoder_hid_dim_type = "text_proj"
|
240 |
-
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
241 |
-
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
242 |
-
|
243 |
-
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
244 |
-
raise ValueError(
|
245 |
-
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
246 |
-
)
|
247 |
-
|
248 |
-
if encoder_hid_dim_type == "text_proj":
|
249 |
-
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
250 |
-
elif encoder_hid_dim_type == "text_image_proj":
|
251 |
-
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
252 |
-
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
253 |
-
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
254 |
-
self.encoder_hid_proj = TextImageProjection(
|
255 |
-
text_embed_dim=encoder_hid_dim,
|
256 |
-
image_embed_dim=cross_attention_dim,
|
257 |
-
cross_attention_dim=cross_attention_dim,
|
258 |
-
)
|
259 |
-
|
260 |
-
elif encoder_hid_dim_type is not None:
|
261 |
-
raise ValueError(
|
262 |
-
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
263 |
-
)
|
264 |
-
else:
|
265 |
-
self.encoder_hid_proj = None
|
266 |
-
|
267 |
-
# class embedding
|
268 |
-
if class_embed_type is None and num_class_embeds is not None:
|
269 |
-
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
270 |
-
elif class_embed_type == "timestep":
|
271 |
-
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
272 |
-
elif class_embed_type == "identity":
|
273 |
-
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
274 |
-
elif class_embed_type == "projection":
|
275 |
-
if projection_class_embeddings_input_dim is None:
|
276 |
-
raise ValueError(
|
277 |
-
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
278 |
-
)
|
279 |
-
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
280 |
-
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
281 |
-
# 2. it projects from an arbitrary input dimension.
|
282 |
-
#
|
283 |
-
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
284 |
-
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
285 |
-
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
286 |
-
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
287 |
-
else:
|
288 |
-
self.class_embedding = None
|
289 |
-
|
290 |
-
if addition_embed_type == "text":
|
291 |
-
if encoder_hid_dim is not None:
|
292 |
-
text_time_embedding_from_dim = encoder_hid_dim
|
293 |
-
else:
|
294 |
-
text_time_embedding_from_dim = cross_attention_dim
|
295 |
-
|
296 |
-
self.add_embedding = TextTimeEmbedding(
|
297 |
-
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
298 |
-
)
|
299 |
-
elif addition_embed_type == "text_image":
|
300 |
-
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
301 |
-
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
302 |
-
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
303 |
-
self.add_embedding = TextImageTimeEmbedding(
|
304 |
-
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
305 |
-
)
|
306 |
-
elif addition_embed_type == "text_time":
|
307 |
-
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
308 |
-
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
309 |
-
|
310 |
-
elif addition_embed_type is not None:
|
311 |
-
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
312 |
-
|
313 |
-
self.down_blocks = nn.ModuleList([])
|
314 |
-
self.brushnet_down_blocks = nn.ModuleList([])
|
315 |
-
|
316 |
-
if isinstance(only_cross_attention, bool):
|
317 |
-
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
318 |
-
|
319 |
-
if isinstance(attention_head_dim, int):
|
320 |
-
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
321 |
-
|
322 |
-
if isinstance(num_attention_heads, int):
|
323 |
-
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
324 |
-
|
325 |
-
# down
|
326 |
-
output_channel = block_out_channels[0]
|
327 |
-
|
328 |
-
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
329 |
-
brushnet_block = zero_module(brushnet_block)
|
330 |
-
self.brushnet_down_blocks.append(brushnet_block)
|
331 |
-
|
332 |
-
for i, down_block_type in enumerate(down_block_types):
|
333 |
-
input_channel = output_channel
|
334 |
-
output_channel = block_out_channels[i]
|
335 |
-
is_final_block = i == len(block_out_channels) - 1
|
336 |
-
|
337 |
-
down_block = get_down_block(
|
338 |
-
down_block_type,
|
339 |
-
num_layers=layers_per_block,
|
340 |
-
transformer_layers_per_block=transformer_layers_per_block[i],
|
341 |
-
in_channels=input_channel,
|
342 |
-
out_channels=output_channel,
|
343 |
-
temb_channels=time_embed_dim,
|
344 |
-
add_downsample=not is_final_block,
|
345 |
-
resnet_eps=norm_eps,
|
346 |
-
resnet_act_fn=act_fn,
|
347 |
-
resnet_groups=norm_num_groups,
|
348 |
-
cross_attention_dim=cross_attention_dim,
|
349 |
-
num_attention_heads=num_attention_heads[i],
|
350 |
-
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
351 |
-
downsample_padding=downsample_padding,
|
352 |
-
use_linear_projection=use_linear_projection,
|
353 |
-
only_cross_attention=only_cross_attention[i],
|
354 |
-
upcast_attention=upcast_attention,
|
355 |
-
resnet_time_scale_shift=resnet_time_scale_shift,
|
356 |
-
)
|
357 |
-
self.down_blocks.append(down_block)
|
358 |
-
|
359 |
-
for _ in range(layers_per_block):
|
360 |
-
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
361 |
-
brushnet_block = zero_module(brushnet_block)
|
362 |
-
self.brushnet_down_blocks.append(brushnet_block)
|
363 |
-
|
364 |
-
if not is_final_block:
|
365 |
-
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
366 |
-
brushnet_block = zero_module(brushnet_block)
|
367 |
-
self.brushnet_down_blocks.append(brushnet_block)
|
368 |
-
|
369 |
-
# mid
|
370 |
-
mid_block_channel = block_out_channels[-1]
|
371 |
-
|
372 |
-
brushnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
|
373 |
-
brushnet_block = zero_module(brushnet_block)
|
374 |
-
self.brushnet_mid_block = brushnet_block
|
375 |
-
|
376 |
-
self.mid_block = get_mid_block(
|
377 |
-
mid_block_type,
|
378 |
-
transformer_layers_per_block=transformer_layers_per_block[-1],
|
379 |
-
in_channels=mid_block_channel,
|
380 |
-
temb_channels=time_embed_dim,
|
381 |
-
resnet_eps=norm_eps,
|
382 |
-
resnet_act_fn=act_fn,
|
383 |
-
output_scale_factor=mid_block_scale_factor,
|
384 |
-
resnet_time_scale_shift=resnet_time_scale_shift,
|
385 |
-
cross_attention_dim=cross_attention_dim,
|
386 |
-
num_attention_heads=num_attention_heads[-1],
|
387 |
-
resnet_groups=norm_num_groups,
|
388 |
-
use_linear_projection=use_linear_projection,
|
389 |
-
upcast_attention=upcast_attention,
|
390 |
-
)
|
391 |
-
|
392 |
-
# count how many layers upsample the images
|
393 |
-
self.num_upsamplers = 0
|
394 |
-
|
395 |
-
# up
|
396 |
-
reversed_block_out_channels = list(reversed(block_out_channels))
|
397 |
-
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
398 |
-
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
399 |
-
only_cross_attention = list(reversed(only_cross_attention))
|
400 |
-
|
401 |
-
output_channel = reversed_block_out_channels[0]
|
402 |
-
|
403 |
-
self.up_blocks = nn.ModuleList([])
|
404 |
-
self.brushnet_up_blocks = nn.ModuleList([])
|
405 |
-
|
406 |
-
for i, up_block_type in enumerate(up_block_types):
|
407 |
-
is_final_block = i == len(block_out_channels) - 1
|
408 |
-
|
409 |
-
prev_output_channel = output_channel
|
410 |
-
output_channel = reversed_block_out_channels[i]
|
411 |
-
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
412 |
-
|
413 |
-
# add upsample block for all BUT final layer
|
414 |
-
if not is_final_block:
|
415 |
-
add_upsample = True
|
416 |
-
self.num_upsamplers += 1
|
417 |
-
else:
|
418 |
-
add_upsample = False
|
419 |
-
|
420 |
-
up_block = get_up_block(
|
421 |
-
up_block_type,
|
422 |
-
num_layers=layers_per_block + 1,
|
423 |
-
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
424 |
-
in_channels=input_channel,
|
425 |
-
out_channels=output_channel,
|
426 |
-
prev_output_channel=prev_output_channel,
|
427 |
-
temb_channels=time_embed_dim,
|
428 |
-
add_upsample=add_upsample,
|
429 |
-
resnet_eps=norm_eps,
|
430 |
-
resnet_act_fn=act_fn,
|
431 |
-
resolution_idx=i,
|
432 |
-
resnet_groups=norm_num_groups,
|
433 |
-
cross_attention_dim=cross_attention_dim,
|
434 |
-
num_attention_heads=reversed_num_attention_heads[i],
|
435 |
-
use_linear_projection=use_linear_projection,
|
436 |
-
only_cross_attention=only_cross_attention[i],
|
437 |
-
upcast_attention=upcast_attention,
|
438 |
-
resnet_time_scale_shift=resnet_time_scale_shift,
|
439 |
-
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
440 |
-
)
|
441 |
-
self.up_blocks.append(up_block)
|
442 |
-
prev_output_channel = output_channel
|
443 |
-
|
444 |
-
for _ in range(layers_per_block + 1):
|
445 |
-
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
446 |
-
brushnet_block = zero_module(brushnet_block)
|
447 |
-
self.brushnet_up_blocks.append(brushnet_block)
|
448 |
-
|
449 |
-
if not is_final_block:
|
450 |
-
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
451 |
-
brushnet_block = zero_module(brushnet_block)
|
452 |
-
self.brushnet_up_blocks.append(brushnet_block)
|
453 |
-
|
454 |
-
@classmethod
|
455 |
-
def from_unet(
|
456 |
-
cls,
|
457 |
-
unet: UNet2DConditionModel,
|
458 |
-
brushnet_conditioning_channel_order: str = "rgb",
|
459 |
-
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
460 |
-
load_weights_from_unet: bool = True,
|
461 |
-
conditioning_channels: int = 5,
|
462 |
-
):
|
463 |
-
r"""
|
464 |
-
Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
|
465 |
-
|
466 |
-
Parameters:
|
467 |
-
unet (`UNet2DConditionModel`):
|
468 |
-
The UNet model weights to copy to the [`BrushNetModel`]. All configuration options are also copied
|
469 |
-
where applicable.
|
470 |
-
"""
|
471 |
-
transformer_layers_per_block = (
|
472 |
-
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
473 |
-
)
|
474 |
-
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
475 |
-
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
|
476 |
-
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
|
477 |
-
addition_time_embed_dim = (
|
478 |
-
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
|
479 |
-
)
|
480 |
-
|
481 |
-
brushnet = cls(
|
482 |
-
in_channels=unet.config.in_channels,
|
483 |
-
conditioning_channels=conditioning_channels,
|
484 |
-
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
485 |
-
freq_shift=unet.config.freq_shift,
|
486 |
-
# down_block_types=['DownBlock2D','DownBlock2D','DownBlock2D','DownBlock2D'],
|
487 |
-
down_block_types=[
|
488 |
-
"CrossAttnDownBlock2D",
|
489 |
-
"CrossAttnDownBlock2D",
|
490 |
-
"CrossAttnDownBlock2D",
|
491 |
-
"DownBlock2D",
|
492 |
-
],
|
493 |
-
# mid_block_type='MidBlock2D',
|
494 |
-
mid_block_type="UNetMidBlock2DCrossAttn",
|
495 |
-
# up_block_types=['UpBlock2D','UpBlock2D','UpBlock2D','UpBlock2D'],
|
496 |
-
up_block_types=["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
|
497 |
-
only_cross_attention=unet.config.only_cross_attention,
|
498 |
-
block_out_channels=unet.config.block_out_channels,
|
499 |
-
layers_per_block=unet.config.layers_per_block,
|
500 |
-
downsample_padding=unet.config.downsample_padding,
|
501 |
-
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
502 |
-
act_fn=unet.config.act_fn,
|
503 |
-
norm_num_groups=unet.config.norm_num_groups,
|
504 |
-
norm_eps=unet.config.norm_eps,
|
505 |
-
cross_attention_dim=unet.config.cross_attention_dim,
|
506 |
-
transformer_layers_per_block=transformer_layers_per_block,
|
507 |
-
encoder_hid_dim=encoder_hid_dim,
|
508 |
-
encoder_hid_dim_type=encoder_hid_dim_type,
|
509 |
-
attention_head_dim=unet.config.attention_head_dim,
|
510 |
-
num_attention_heads=unet.config.num_attention_heads,
|
511 |
-
use_linear_projection=unet.config.use_linear_projection,
|
512 |
-
class_embed_type=unet.config.class_embed_type,
|
513 |
-
addition_embed_type=addition_embed_type,
|
514 |
-
addition_time_embed_dim=addition_time_embed_dim,
|
515 |
-
num_class_embeds=unet.config.num_class_embeds,
|
516 |
-
upcast_attention=unet.config.upcast_attention,
|
517 |
-
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
518 |
-
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
519 |
-
brushnet_conditioning_channel_order=brushnet_conditioning_channel_order,
|
520 |
-
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
521 |
-
)
|
522 |
-
|
523 |
-
if load_weights_from_unet:
|
524 |
-
conv_in_condition_weight = torch.zeros_like(brushnet.conv_in_condition.weight)
|
525 |
-
conv_in_condition_weight[:, :4, ...] = unet.conv_in.weight
|
526 |
-
conv_in_condition_weight[:, 4:8, ...] = unet.conv_in.weight
|
527 |
-
brushnet.conv_in_condition.weight = torch.nn.Parameter(conv_in_condition_weight)
|
528 |
-
brushnet.conv_in_condition.bias = unet.conv_in.bias
|
529 |
-
|
530 |
-
brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
531 |
-
brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
532 |
-
|
533 |
-
if brushnet.class_embedding:
|
534 |
-
brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
535 |
-
|
536 |
-
brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
|
537 |
-
brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
|
538 |
-
brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(), strict=False)
|
539 |
-
|
540 |
-
return brushnet.to(unet.dtype)
|
541 |
-
|
542 |
-
@property
|
543 |
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
544 |
-
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
545 |
-
r"""
|
546 |
-
Returns:
|
547 |
-
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
548 |
-
indexed by its weight name.
|
549 |
-
"""
|
550 |
-
# set recursively
|
551 |
-
processors = {}
|
552 |
-
|
553 |
-
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
554 |
-
if hasattr(module, "get_processor"):
|
555 |
-
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
556 |
-
|
557 |
-
for sub_name, child in module.named_children():
|
558 |
-
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
559 |
-
|
560 |
-
return processors
|
561 |
-
|
562 |
-
for name, module in self.named_children():
|
563 |
-
fn_recursive_add_processors(name, module, processors)
|
564 |
-
|
565 |
-
return processors
|
566 |
-
|
567 |
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
568 |
-
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
569 |
-
r"""
|
570 |
-
Sets the attention processor to use to compute attention.
|
571 |
-
|
572 |
-
Parameters:
|
573 |
-
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
574 |
-
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
575 |
-
for **all** `Attention` layers.
|
576 |
-
|
577 |
-
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
578 |
-
processor. This is strongly recommended when setting trainable attention processors.
|
579 |
-
|
580 |
-
"""
|
581 |
-
count = len(self.attn_processors.keys())
|
582 |
-
|
583 |
-
if isinstance(processor, dict) and len(processor) != count:
|
584 |
-
raise ValueError(
|
585 |
-
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
586 |
-
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
587 |
-
)
|
588 |
-
|
589 |
-
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
590 |
-
if hasattr(module, "set_processor"):
|
591 |
-
if not isinstance(processor, dict):
|
592 |
-
module.set_processor(processor)
|
593 |
-
else:
|
594 |
-
module.set_processor(processor.pop(f"{name}.processor"))
|
595 |
-
|
596 |
-
for sub_name, child in module.named_children():
|
597 |
-
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
598 |
-
|
599 |
-
for name, module in self.named_children():
|
600 |
-
fn_recursive_attn_processor(name, module, processor)
|
601 |
-
|
602 |
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
603 |
-
def set_default_attn_processor(self):
|
604 |
-
"""
|
605 |
-
Disables custom attention processors and sets the default attention implementation.
|
606 |
-
"""
|
607 |
-
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
608 |
-
processor = AttnAddedKVProcessor()
|
609 |
-
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
610 |
-
processor = AttnProcessor()
|
611 |
-
else:
|
612 |
-
raise ValueError(
|
613 |
-
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
614 |
-
)
|
615 |
-
|
616 |
-
self.set_attn_processor(processor)
|
617 |
-
|
618 |
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
619 |
-
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
620 |
-
r"""
|
621 |
-
Enable sliced attention computation.
|
622 |
-
|
623 |
-
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
624 |
-
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
625 |
-
|
626 |
-
Args:
|
627 |
-
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
628 |
-
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
629 |
-
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
630 |
-
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
631 |
-
must be a multiple of `slice_size`.
|
632 |
-
"""
|
633 |
-
sliceable_head_dims = []
|
634 |
-
|
635 |
-
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
636 |
-
if hasattr(module, "set_attention_slice"):
|
637 |
-
sliceable_head_dims.append(module.sliceable_head_dim)
|
638 |
-
|
639 |
-
for child in module.children():
|
640 |
-
fn_recursive_retrieve_sliceable_dims(child)
|
641 |
-
|
642 |
-
# retrieve number of attention layers
|
643 |
-
for module in self.children():
|
644 |
-
fn_recursive_retrieve_sliceable_dims(module)
|
645 |
-
|
646 |
-
num_sliceable_layers = len(sliceable_head_dims)
|
647 |
-
|
648 |
-
if slice_size == "auto":
|
649 |
-
# half the attention head size is usually a good trade-off between
|
650 |
-
# speed and memory
|
651 |
-
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
652 |
-
elif slice_size == "max":
|
653 |
-
# make smallest slice possible
|
654 |
-
slice_size = num_sliceable_layers * [1]
|
655 |
-
|
656 |
-
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
657 |
-
|
658 |
-
if len(slice_size) != len(sliceable_head_dims):
|
659 |
-
raise ValueError(
|
660 |
-
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
661 |
-
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
662 |
-
)
|
663 |
-
|
664 |
-
for i in range(len(slice_size)):
|
665 |
-
size = slice_size[i]
|
666 |
-
dim = sliceable_head_dims[i]
|
667 |
-
if size is not None and size > dim:
|
668 |
-
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
669 |
-
|
670 |
-
# Recursively walk through all the children.
|
671 |
-
# Any children which exposes the set_attention_slice method
|
672 |
-
# gets the message
|
673 |
-
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
674 |
-
if hasattr(module, "set_attention_slice"):
|
675 |
-
module.set_attention_slice(slice_size.pop())
|
676 |
-
|
677 |
-
for child in module.children():
|
678 |
-
fn_recursive_set_attention_slice(child, slice_size)
|
679 |
-
|
680 |
-
reversed_slice_size = list(reversed(slice_size))
|
681 |
-
for module in self.children():
|
682 |
-
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
683 |
-
|
684 |
-
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
685 |
-
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
686 |
-
module.gradient_checkpointing = value
|
687 |
-
|
688 |
-
def forward(
|
689 |
-
self,
|
690 |
-
sample: torch.FloatTensor,
|
691 |
-
timestep: Union[torch.Tensor, float, int],
|
692 |
-
encoder_hidden_states: torch.Tensor,
|
693 |
-
brushnet_cond: torch.FloatTensor,
|
694 |
-
conditioning_scale: float = 1.0,
|
695 |
-
class_labels: Optional[torch.Tensor] = None,
|
696 |
-
timestep_cond: Optional[torch.Tensor] = None,
|
697 |
-
attention_mask: Optional[torch.Tensor] = None,
|
698 |
-
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
699 |
-
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
700 |
-
guess_mode: bool = False,
|
701 |
-
return_dict: bool = True,
|
702 |
-
debug=False,
|
703 |
-
) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
|
704 |
-
"""
|
705 |
-
The [`BrushNetModel`] forward method.
|
706 |
-
|
707 |
-
Args:
|
708 |
-
sample (`torch.FloatTensor`):
|
709 |
-
The noisy input tensor.
|
710 |
-
timestep (`Union[torch.Tensor, float, int]`):
|
711 |
-
The number of timesteps to denoise an input.
|
712 |
-
encoder_hidden_states (`torch.Tensor`):
|
713 |
-
The encoder hidden states.
|
714 |
-
brushnet_cond (`torch.FloatTensor`):
|
715 |
-
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
716 |
-
conditioning_scale (`float`, defaults to `1.0`):
|
717 |
-
The scale factor for BrushNet outputs.
|
718 |
-
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
719 |
-
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
720 |
-
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
721 |
-
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
|
722 |
-
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
|
723 |
-
embeddings.
|
724 |
-
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
725 |
-
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
726 |
-
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
727 |
-
negative values to the attention scores corresponding to "discard" tokens.
|
728 |
-
added_cond_kwargs (`dict`):
|
729 |
-
Additional conditions for the Stable Diffusion XL UNet.
|
730 |
-
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
731 |
-
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
732 |
-
guess_mode (`bool`, defaults to `False`):
|
733 |
-
In this mode, the BrushNet encoder tries its best to recognize the input content of the input even if
|
734 |
-
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
735 |
-
return_dict (`bool`, defaults to `True`):
|
736 |
-
Whether or not to return a [`~models.brushnet.BrushNetOutput`] instead of a plain tuple.
|
737 |
-
|
738 |
-
Returns:
|
739 |
-
[`~models.brushnet.BrushNetOutput`] **or** `tuple`:
|
740 |
-
If `return_dict` is `True`, a [`~models.brushnet.BrushNetOutput`] is returned, otherwise a tuple is
|
741 |
-
returned where the first element is the sample tensor.
|
742 |
-
"""
|
743 |
-
# check channel order
|
744 |
-
channel_order = self.config.brushnet_conditioning_channel_order
|
745 |
-
|
746 |
-
if channel_order == "rgb":
|
747 |
-
# in rgb order by default
|
748 |
-
...
|
749 |
-
elif channel_order == "bgr":
|
750 |
-
brushnet_cond = torch.flip(brushnet_cond, dims=[1])
|
751 |
-
else:
|
752 |
-
raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
|
753 |
-
|
754 |
-
if debug: print('BrushNet CA: attn mask')
|
755 |
-
|
756 |
-
# prepare attention_mask
|
757 |
-
if attention_mask is not None:
|
758 |
-
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
759 |
-
attention_mask = attention_mask.unsqueeze(1)
|
760 |
-
|
761 |
-
if debug: print('BrushNet CA: time')
|
762 |
-
|
763 |
-
# 1. time
|
764 |
-
timesteps = timestep
|
765 |
-
if not torch.is_tensor(timesteps):
|
766 |
-
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
767 |
-
# This would be a good case for the `match` statement (Python 3.10+)
|
768 |
-
is_mps = sample.device.type == "mps"
|
769 |
-
if isinstance(timestep, float):
|
770 |
-
dtype = torch.float32 if is_mps else torch.float64
|
771 |
-
else:
|
772 |
-
dtype = torch.int32 if is_mps else torch.int64
|
773 |
-
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
774 |
-
elif len(timesteps.shape) == 0:
|
775 |
-
timesteps = timesteps[None].to(sample.device)
|
776 |
-
|
777 |
-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
778 |
-
timesteps = timesteps.expand(sample.shape[0])
|
779 |
-
|
780 |
-
t_emb = self.time_proj(timesteps)
|
781 |
-
|
782 |
-
# timesteps does not contain any weights and will always return f32 tensors
|
783 |
-
# but time_embedding might actually be running in fp16. so we need to cast here.
|
784 |
-
# there might be better ways to encapsulate this.
|
785 |
-
t_emb = t_emb.to(dtype=sample.dtype)
|
786 |
-
|
787 |
-
emb = self.time_embedding(t_emb, timestep_cond)
|
788 |
-
aug_emb = None
|
789 |
-
|
790 |
-
if self.class_embedding is not None:
|
791 |
-
if class_labels is None:
|
792 |
-
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
793 |
-
|
794 |
-
if self.config.class_embed_type == "timestep":
|
795 |
-
class_labels = self.time_proj(class_labels)
|
796 |
-
|
797 |
-
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
798 |
-
emb = emb + class_emb
|
799 |
-
|
800 |
-
if self.config.addition_embed_type is not None:
|
801 |
-
if self.config.addition_embed_type == "text":
|
802 |
-
aug_emb = self.add_embedding(encoder_hidden_states)
|
803 |
-
|
804 |
-
elif self.config.addition_embed_type == "text_time":
|
805 |
-
if "text_embeds" not in added_cond_kwargs:
|
806 |
-
raise ValueError(
|
807 |
-
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
808 |
-
)
|
809 |
-
text_embeds = added_cond_kwargs.get("text_embeds")
|
810 |
-
if "time_ids" not in added_cond_kwargs:
|
811 |
-
raise ValueError(
|
812 |
-
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
813 |
-
)
|
814 |
-
time_ids = added_cond_kwargs.get("time_ids")
|
815 |
-
time_embeds = self.add_time_proj(time_ids.flatten())
|
816 |
-
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
817 |
-
|
818 |
-
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
819 |
-
add_embeds = add_embeds.to(emb.dtype)
|
820 |
-
aug_emb = self.add_embedding(add_embeds)
|
821 |
-
|
822 |
-
emb = emb + aug_emb if aug_emb is not None else emb
|
823 |
-
|
824 |
-
if debug: print('BrushNet CA: pre-process')
|
825 |
-
|
826 |
-
|
827 |
-
# 2. pre-process
|
828 |
-
brushnet_cond = torch.concat([sample, brushnet_cond], 1)
|
829 |
-
sample = self.conv_in_condition(brushnet_cond)
|
830 |
-
|
831 |
-
if debug: print('BrushNet CA: down')
|
832 |
-
|
833 |
-
# 3. down
|
834 |
-
down_block_res_samples = (sample,)
|
835 |
-
for downsample_block in self.down_blocks:
|
836 |
-
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
837 |
-
if debug: print('BrushNet CA (down block with XA): ', type(downsample_block))
|
838 |
-
sample, res_samples = downsample_block(
|
839 |
-
hidden_states=sample,
|
840 |
-
temb=emb,
|
841 |
-
encoder_hidden_states=encoder_hidden_states,
|
842 |
-
attention_mask=attention_mask,
|
843 |
-
cross_attention_kwargs=cross_attention_kwargs,
|
844 |
-
debug=debug,
|
845 |
-
)
|
846 |
-
else:
|
847 |
-
if debug: print('BrushNet CA (down block): ', type(downsample_block))
|
848 |
-
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, debug=debug)
|
849 |
-
|
850 |
-
down_block_res_samples += res_samples
|
851 |
-
|
852 |
-
if debug: print('BrushNet CA: PP down')
|
853 |
-
|
854 |
-
# 4. PaintingNet down blocks
|
855 |
-
brushnet_down_block_res_samples = ()
|
856 |
-
for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
|
857 |
-
down_block_res_sample = brushnet_down_block(down_block_res_sample)
|
858 |
-
brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
|
859 |
-
|
860 |
-
if debug: print('BrushNet CA: PP mid')
|
861 |
-
|
862 |
-
# 5. mid
|
863 |
-
if self.mid_block is not None:
|
864 |
-
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
865 |
-
sample = self.mid_block(
|
866 |
-
sample,
|
867 |
-
emb,
|
868 |
-
encoder_hidden_states=encoder_hidden_states,
|
869 |
-
attention_mask=attention_mask,
|
870 |
-
cross_attention_kwargs=cross_attention_kwargs,
|
871 |
-
)
|
872 |
-
else:
|
873 |
-
sample = self.mid_block(sample, emb)
|
874 |
-
|
875 |
-
if debug: print('BrushNet CA: mid')
|
876 |
-
|
877 |
-
# 6. BrushNet mid blocks
|
878 |
-
brushnet_mid_block_res_sample = self.brushnet_mid_block(sample)
|
879 |
-
|
880 |
-
if debug: print('BrushNet CA: PP up')
|
881 |
-
|
882 |
-
# 7. up
|
883 |
-
up_block_res_samples = ()
|
884 |
-
for i, upsample_block in enumerate(self.up_blocks):
|
885 |
-
is_final_block = i == len(self.up_blocks) - 1
|
886 |
-
|
887 |
-
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
888 |
-
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
889 |
-
|
890 |
-
# if we have not reached the final block and need to forward the
|
891 |
-
# upsample size, we do it here
|
892 |
-
if not is_final_block:
|
893 |
-
upsample_size = down_block_res_samples[-1].shape[2:]
|
894 |
-
|
895 |
-
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
896 |
-
sample, up_res_samples = upsample_block(
|
897 |
-
hidden_states=sample,
|
898 |
-
temb=emb,
|
899 |
-
res_hidden_states_tuple=res_samples,
|
900 |
-
encoder_hidden_states=encoder_hidden_states,
|
901 |
-
cross_attention_kwargs=cross_attention_kwargs,
|
902 |
-
upsample_size=upsample_size,
|
903 |
-
attention_mask=attention_mask,
|
904 |
-
return_res_samples=True,
|
905 |
-
)
|
906 |
-
else:
|
907 |
-
sample, up_res_samples = upsample_block(
|
908 |
-
hidden_states=sample,
|
909 |
-
temb=emb,
|
910 |
-
res_hidden_states_tuple=res_samples,
|
911 |
-
upsample_size=upsample_size,
|
912 |
-
return_res_samples=True,
|
913 |
-
)
|
914 |
-
|
915 |
-
up_block_res_samples += up_res_samples
|
916 |
-
|
917 |
-
if debug: print('BrushNet CA: up')
|
918 |
-
|
919 |
-
# 8. BrushNet up blocks
|
920 |
-
brushnet_up_block_res_samples = ()
|
921 |
-
for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
|
922 |
-
up_block_res_sample = brushnet_up_block(up_block_res_sample)
|
923 |
-
brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
|
924 |
-
|
925 |
-
if debug: print('BrushNet CA: scaling')
|
926 |
-
|
927 |
-
# 6. scaling
|
928 |
-
if guess_mode and not self.config.global_pool_conditions:
|
929 |
-
scales = torch.logspace(
|
930 |
-
-1,
|
931 |
-
0,
|
932 |
-
len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples),
|
933 |
-
device=sample.device,
|
934 |
-
) # 0.1 to 1.0
|
935 |
-
scales = scales * conditioning_scale
|
936 |
-
|
937 |
-
brushnet_down_block_res_samples = [
|
938 |
-
sample * scale
|
939 |
-
for sample, scale in zip(
|
940 |
-
brushnet_down_block_res_samples, scales[: len(brushnet_down_block_res_samples)]
|
941 |
-
)
|
942 |
-
]
|
943 |
-
brushnet_mid_block_res_sample = (
|
944 |
-
brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
|
945 |
-
)
|
946 |
-
brushnet_up_block_res_samples = [
|
947 |
-
sample * scale
|
948 |
-
for sample, scale in zip(
|
949 |
-
brushnet_up_block_res_samples, scales[len(brushnet_down_block_res_samples) + 1 :]
|
950 |
-
)
|
951 |
-
]
|
952 |
-
else:
|
953 |
-
brushnet_down_block_res_samples = [
|
954 |
-
sample * conditioning_scale for sample in brushnet_down_block_res_samples
|
955 |
-
]
|
956 |
-
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
|
957 |
-
brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
|
958 |
-
|
959 |
-
if self.config.global_pool_conditions:
|
960 |
-
brushnet_down_block_res_samples = [
|
961 |
-
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
|
962 |
-
]
|
963 |
-
brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
|
964 |
-
brushnet_up_block_res_samples = [
|
965 |
-
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
|
966 |
-
]
|
967 |
-
|
968 |
-
if debug: print('BrushNet CA: finish')
|
969 |
-
|
970 |
-
if not return_dict:
|
971 |
-
return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
|
972 |
-
|
973 |
-
return BrushNetOutput(
|
974 |
-
down_block_res_samples=brushnet_down_block_res_samples,
|
975 |
-
mid_block_res_sample=brushnet_mid_block_res_sample,
|
976 |
-
up_block_res_samples=brushnet_up_block_res_samples,
|
977 |
-
)
|
978 |
-
|
979 |
-
|
980 |
-
def zero_module(module):
|
981 |
-
for p in module.parameters():
|
982 |
-
nn.init.zeros_(p)
|
983 |
-
return module
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/brushnet/brushnet_xl.json
DELETED
@@ -1,63 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"_class_name": "BrushNetModel",
|
3 |
-
"_diffusers_version": "0.27.0.dev0",
|
4 |
-
"_name_or_path": "runs/logs/brushnetsdxl_randommask/checkpoint-80000",
|
5 |
-
"act_fn": "silu",
|
6 |
-
"addition_embed_type": "text_time",
|
7 |
-
"addition_embed_type_num_heads": 64,
|
8 |
-
"addition_time_embed_dim": 256,
|
9 |
-
"attention_head_dim": [
|
10 |
-
5,
|
11 |
-
10,
|
12 |
-
20
|
13 |
-
],
|
14 |
-
"block_out_channels": [
|
15 |
-
320,
|
16 |
-
640,
|
17 |
-
1280
|
18 |
-
],
|
19 |
-
"brushnet_conditioning_channel_order": "rgb",
|
20 |
-
"class_embed_type": null,
|
21 |
-
"conditioning_channels": 5,
|
22 |
-
"conditioning_embedding_out_channels": [
|
23 |
-
16,
|
24 |
-
32,
|
25 |
-
96,
|
26 |
-
256
|
27 |
-
],
|
28 |
-
"cross_attention_dim": 2048,
|
29 |
-
"down_block_types": [
|
30 |
-
"DownBlock2D",
|
31 |
-
"DownBlock2D",
|
32 |
-
"DownBlock2D"
|
33 |
-
],
|
34 |
-
"downsample_padding": 1,
|
35 |
-
"encoder_hid_dim": null,
|
36 |
-
"encoder_hid_dim_type": null,
|
37 |
-
"flip_sin_to_cos": true,
|
38 |
-
"freq_shift": 0,
|
39 |
-
"global_pool_conditions": false,
|
40 |
-
"in_channels": 4,
|
41 |
-
"layers_per_block": 2,
|
42 |
-
"mid_block_scale_factor": 1,
|
43 |
-
"mid_block_type": "MidBlock2D",
|
44 |
-
"norm_eps": 1e-05,
|
45 |
-
"norm_num_groups": 32,
|
46 |
-
"num_attention_heads": null,
|
47 |
-
"num_class_embeds": null,
|
48 |
-
"only_cross_attention": false,
|
49 |
-
"projection_class_embeddings_input_dim": 2816,
|
50 |
-
"resnet_time_scale_shift": "default",
|
51 |
-
"transformer_layers_per_block": [
|
52 |
-
1,
|
53 |
-
2,
|
54 |
-
10
|
55 |
-
],
|
56 |
-
"up_block_types": [
|
57 |
-
"UpBlock2D",
|
58 |
-
"UpBlock2D",
|
59 |
-
"UpBlock2D"
|
60 |
-
],
|
61 |
-
"upcast_attention": null,
|
62 |
-
"use_linear_projection": true
|
63 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/brushnet/powerpaint.json
DELETED
@@ -1,57 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"_class_name": "BrushNetModel",
|
3 |
-
"_diffusers_version": "0.27.2",
|
4 |
-
"act_fn": "silu",
|
5 |
-
"addition_embed_type": null,
|
6 |
-
"addition_embed_type_num_heads": 64,
|
7 |
-
"addition_time_embed_dim": null,
|
8 |
-
"attention_head_dim": 8,
|
9 |
-
"block_out_channels": [
|
10 |
-
320,
|
11 |
-
640,
|
12 |
-
1280,
|
13 |
-
1280
|
14 |
-
],
|
15 |
-
"brushnet_conditioning_channel_order": "rgb",
|
16 |
-
"class_embed_type": null,
|
17 |
-
"conditioning_channels": 5,
|
18 |
-
"conditioning_embedding_out_channels": [
|
19 |
-
16,
|
20 |
-
32,
|
21 |
-
96,
|
22 |
-
256
|
23 |
-
],
|
24 |
-
"cross_attention_dim": 768,
|
25 |
-
"down_block_types": [
|
26 |
-
"CrossAttnDownBlock2D",
|
27 |
-
"CrossAttnDownBlock2D",
|
28 |
-
"CrossAttnDownBlock2D",
|
29 |
-
"DownBlock2D"
|
30 |
-
],
|
31 |
-
"downsample_padding": 1,
|
32 |
-
"encoder_hid_dim": null,
|
33 |
-
"encoder_hid_dim_type": null,
|
34 |
-
"flip_sin_to_cos": true,
|
35 |
-
"freq_shift": 0,
|
36 |
-
"global_pool_conditions": false,
|
37 |
-
"in_channels": 4,
|
38 |
-
"layers_per_block": 2,
|
39 |
-
"mid_block_scale_factor": 1,
|
40 |
-
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
41 |
-
"norm_eps": 1e-05,
|
42 |
-
"norm_num_groups": 32,
|
43 |
-
"num_attention_heads": null,
|
44 |
-
"num_class_embeds": null,
|
45 |
-
"only_cross_attention": false,
|
46 |
-
"projection_class_embeddings_input_dim": null,
|
47 |
-
"resnet_time_scale_shift": "default",
|
48 |
-
"transformer_layers_per_block": 1,
|
49 |
-
"up_block_types": [
|
50 |
-
"UpBlock2D",
|
51 |
-
"CrossAttnUpBlock2D",
|
52 |
-
"CrossAttnUpBlock2D",
|
53 |
-
"CrossAttnUpBlock2D"
|
54 |
-
],
|
55 |
-
"upcast_attention": false,
|
56 |
-
"use_linear_projection": false
|
57 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/brushnet/powerpaint_utils.py
DELETED
@@ -1,496 +0,0 @@
|
|
1 |
-
import copy
|
2 |
-
import random
|
3 |
-
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
from transformers import CLIPTokenizer
|
7 |
-
from typing import Any, List, Optional, Union
|
8 |
-
|
9 |
-
class TokenizerWrapper:
|
10 |
-
"""Tokenizer wrapper for CLIPTokenizer. Only support CLIPTokenizer
|
11 |
-
currently. This wrapper is modified from https://github.com/huggingface/dif
|
12 |
-
fusers/blob/e51f19aee82c8dd874b715a09dbc521d88835d68/src/diffusers/loaders.
|
13 |
-
py#L358 # noqa.
|
14 |
-
|
15 |
-
Args:
|
16 |
-
from_pretrained (Union[str, os.PathLike], optional): The *model id*
|
17 |
-
of a pretrained model or a path to a *directory* containing
|
18 |
-
model weights and config. Defaults to None.
|
19 |
-
from_config (Union[str, os.PathLike], optional): The *model id*
|
20 |
-
of a pretrained model or a path to a *directory* containing
|
21 |
-
model weights and config. Defaults to None.
|
22 |
-
|
23 |
-
*args, **kwargs: If `from_pretrained` is passed, *args and **kwargs
|
24 |
-
will be passed to `from_pretrained` function. Otherwise, *args
|
25 |
-
and **kwargs will be used to initialize the model by
|
26 |
-
`self._module_cls(*args, **kwargs)`.
|
27 |
-
"""
|
28 |
-
|
29 |
-
def __init__(self, tokenizer: CLIPTokenizer):
|
30 |
-
self.wrapped = tokenizer
|
31 |
-
self.token_map = {}
|
32 |
-
|
33 |
-
def __getattr__(self, name: str) -> Any:
|
34 |
-
if name in self.__dict__:
|
35 |
-
return getattr(self, name)
|
36 |
-
#if name == "wrapped":
|
37 |
-
# return getattr(self, 'wrapped')#super().__getattr__("wrapped")
|
38 |
-
|
39 |
-
try:
|
40 |
-
return getattr(self.wrapped, name)
|
41 |
-
except AttributeError:
|
42 |
-
raise AttributeError(
|
43 |
-
"'name' cannot be found in both "
|
44 |
-
f"'{self.__class__.__name__}' and "
|
45 |
-
f"'{self.__class__.__name__}.tokenizer'."
|
46 |
-
)
|
47 |
-
|
48 |
-
def try_adding_tokens(self, tokens: Union[str, List[str]], *args, **kwargs):
|
49 |
-
"""Attempt to add tokens to the tokenizer.
|
50 |
-
|
51 |
-
Args:
|
52 |
-
tokens (Union[str, List[str]]): The tokens to be added.
|
53 |
-
"""
|
54 |
-
num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs)
|
55 |
-
assert num_added_tokens != 0, (
|
56 |
-
f"The tokenizer already contains the token {tokens}. Please pass "
|
57 |
-
"a different `placeholder_token` that is not already in the "
|
58 |
-
"tokenizer."
|
59 |
-
)
|
60 |
-
|
61 |
-
def get_token_info(self, token: str) -> dict:
|
62 |
-
"""Get the information of a token, including its start and end index in
|
63 |
-
the current tokenizer.
|
64 |
-
|
65 |
-
Args:
|
66 |
-
token (str): The token to be queried.
|
67 |
-
|
68 |
-
Returns:
|
69 |
-
dict: The information of the token, including its start and end
|
70 |
-
index in current tokenizer.
|
71 |
-
"""
|
72 |
-
token_ids = self.__call__(token).input_ids
|
73 |
-
start, end = token_ids[1], token_ids[-2] + 1
|
74 |
-
return {"name": token, "start": start, "end": end}
|
75 |
-
|
76 |
-
def add_placeholder_token(self, placeholder_token: str, *args, num_vec_per_token: int = 1, **kwargs):
|
77 |
-
"""Add placeholder tokens to the tokenizer.
|
78 |
-
|
79 |
-
Args:
|
80 |
-
placeholder_token (str): The placeholder token to be added.
|
81 |
-
num_vec_per_token (int, optional): The number of vectors of
|
82 |
-
the added placeholder token.
|
83 |
-
*args, **kwargs: The arguments for `self.wrapped.add_tokens`.
|
84 |
-
"""
|
85 |
-
output = []
|
86 |
-
if num_vec_per_token == 1:
|
87 |
-
self.try_adding_tokens(placeholder_token, *args, **kwargs)
|
88 |
-
output.append(placeholder_token)
|
89 |
-
else:
|
90 |
-
output = []
|
91 |
-
for i in range(num_vec_per_token):
|
92 |
-
ith_token = placeholder_token + f"_{i}"
|
93 |
-
self.try_adding_tokens(ith_token, *args, **kwargs)
|
94 |
-
output.append(ith_token)
|
95 |
-
|
96 |
-
for token in self.token_map:
|
97 |
-
if token in placeholder_token:
|
98 |
-
raise ValueError(
|
99 |
-
f"The tokenizer already has placeholder token {token} "
|
100 |
-
f"that can get confused with {placeholder_token} "
|
101 |
-
"keep placeholder tokens independent"
|
102 |
-
)
|
103 |
-
self.token_map[placeholder_token] = output
|
104 |
-
|
105 |
-
def replace_placeholder_tokens_in_text(
|
106 |
-
self, text: Union[str, List[str]], vector_shuffle: bool = False, prop_tokens_to_load: float = 1.0
|
107 |
-
) -> Union[str, List[str]]:
|
108 |
-
"""Replace the keywords in text with placeholder tokens. This function
|
109 |
-
will be called in `self.__call__` and `self.encode`.
|
110 |
-
|
111 |
-
Args:
|
112 |
-
text (Union[str, List[str]]): The text to be processed.
|
113 |
-
vector_shuffle (bool, optional): Whether to shuffle the vectors.
|
114 |
-
Defaults to False.
|
115 |
-
prop_tokens_to_load (float, optional): The proportion of tokens to
|
116 |
-
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0.
|
117 |
-
|
118 |
-
Returns:
|
119 |
-
Union[str, List[str]]: The processed text.
|
120 |
-
"""
|
121 |
-
if isinstance(text, list):
|
122 |
-
output = []
|
123 |
-
for i in range(len(text)):
|
124 |
-
output.append(self.replace_placeholder_tokens_in_text(text[i], vector_shuffle=vector_shuffle))
|
125 |
-
return output
|
126 |
-
|
127 |
-
for placeholder_token in self.token_map:
|
128 |
-
if placeholder_token in text:
|
129 |
-
tokens = self.token_map[placeholder_token]
|
130 |
-
tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]
|
131 |
-
if vector_shuffle:
|
132 |
-
tokens = copy.copy(tokens)
|
133 |
-
random.shuffle(tokens)
|
134 |
-
text = text.replace(placeholder_token, " ".join(tokens))
|
135 |
-
return text
|
136 |
-
|
137 |
-
def replace_text_with_placeholder_tokens(self, text: Union[str, List[str]]) -> Union[str, List[str]]:
|
138 |
-
"""Replace the placeholder tokens in text with the original keywords.
|
139 |
-
This function will be called in `self.decode`.
|
140 |
-
|
141 |
-
Args:
|
142 |
-
text (Union[str, List[str]]): The text to be processed.
|
143 |
-
|
144 |
-
Returns:
|
145 |
-
Union[str, List[str]]: The processed text.
|
146 |
-
"""
|
147 |
-
if isinstance(text, list):
|
148 |
-
output = []
|
149 |
-
for i in range(len(text)):
|
150 |
-
output.append(self.replace_text_with_placeholder_tokens(text[i]))
|
151 |
-
return output
|
152 |
-
|
153 |
-
for placeholder_token, tokens in self.token_map.items():
|
154 |
-
merged_tokens = " ".join(tokens)
|
155 |
-
if merged_tokens in text:
|
156 |
-
text = text.replace(merged_tokens, placeholder_token)
|
157 |
-
return text
|
158 |
-
|
159 |
-
def __call__(
|
160 |
-
self,
|
161 |
-
text: Union[str, List[str]],
|
162 |
-
*args,
|
163 |
-
vector_shuffle: bool = False,
|
164 |
-
prop_tokens_to_load: float = 1.0,
|
165 |
-
**kwargs,
|
166 |
-
):
|
167 |
-
"""The call function of the wrapper.
|
168 |
-
|
169 |
-
Args:
|
170 |
-
text (Union[str, List[str]]): The text to be tokenized.
|
171 |
-
vector_shuffle (bool, optional): Whether to shuffle the vectors.
|
172 |
-
Defaults to False.
|
173 |
-
prop_tokens_to_load (float, optional): The proportion of tokens to
|
174 |
-
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0
|
175 |
-
*args, **kwargs: The arguments for `self.wrapped.__call__`.
|
176 |
-
"""
|
177 |
-
replaced_text = self.replace_placeholder_tokens_in_text(
|
178 |
-
text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
|
179 |
-
)
|
180 |
-
|
181 |
-
return self.wrapped.__call__(replaced_text, *args, **kwargs)
|
182 |
-
|
183 |
-
def encode(self, text: Union[str, List[str]], *args, **kwargs):
|
184 |
-
"""Encode the passed text to token index.
|
185 |
-
|
186 |
-
Args:
|
187 |
-
text (Union[str, List[str]]): The text to be encode.
|
188 |
-
*args, **kwargs: The arguments for `self.wrapped.__call__`.
|
189 |
-
"""
|
190 |
-
replaced_text = self.replace_placeholder_tokens_in_text(text)
|
191 |
-
return self.wrapped(replaced_text, *args, **kwargs)
|
192 |
-
|
193 |
-
def decode(self, token_ids, return_raw: bool = False, *args, **kwargs) -> Union[str, List[str]]:
|
194 |
-
"""Decode the token index to text.
|
195 |
-
|
196 |
-
Args:
|
197 |
-
token_ids: The token index to be decoded.
|
198 |
-
return_raw: Whether keep the placeholder token in the text.
|
199 |
-
Defaults to False.
|
200 |
-
*args, **kwargs: The arguments for `self.wrapped.decode`.
|
201 |
-
|
202 |
-
Returns:
|
203 |
-
Union[str, List[str]]: The decoded text.
|
204 |
-
"""
|
205 |
-
text = self.wrapped.decode(token_ids, *args, **kwargs)
|
206 |
-
if return_raw:
|
207 |
-
return text
|
208 |
-
replaced_text = self.replace_text_with_placeholder_tokens(text)
|
209 |
-
return replaced_text
|
210 |
-
|
211 |
-
def __repr__(self):
|
212 |
-
"""The representation of the wrapper."""
|
213 |
-
s = super().__repr__()
|
214 |
-
prefix = f"Wrapped Module Class: {self._module_cls}\n"
|
215 |
-
prefix += f"Wrapped Module Name: {self._module_name}\n"
|
216 |
-
if self._from_pretrained:
|
217 |
-
prefix += f"From Pretrained: {self._from_pretrained}\n"
|
218 |
-
s = prefix + s
|
219 |
-
return s
|
220 |
-
|
221 |
-
|
222 |
-
class EmbeddingLayerWithFixes(nn.Module):
|
223 |
-
"""The revised embedding layer to support external embeddings. This design
|
224 |
-
of this class is inspired by https://github.com/AUTOMATIC1111/stable-
|
225 |
-
diffusion-webui/blob/22bcc7be428c94e9408f589966c2040187245d81/modules/sd_hi
|
226 |
-
jack.py#L224 # noqa.
|
227 |
-
|
228 |
-
Args:
|
229 |
-
wrapped (nn.Emebdding): The embedding layer to be wrapped.
|
230 |
-
external_embeddings (Union[dict, List[dict]], optional): The external
|
231 |
-
embeddings added to this layer. Defaults to None.
|
232 |
-
"""
|
233 |
-
|
234 |
-
def __init__(self, wrapped: nn.Embedding, external_embeddings: Optional[Union[dict, List[dict]]] = None):
|
235 |
-
super().__init__()
|
236 |
-
self.wrapped = wrapped
|
237 |
-
self.num_embeddings = wrapped.weight.shape[0]
|
238 |
-
|
239 |
-
self.external_embeddings = []
|
240 |
-
if external_embeddings:
|
241 |
-
self.add_embeddings(external_embeddings)
|
242 |
-
|
243 |
-
self.trainable_embeddings = nn.ParameterDict()
|
244 |
-
|
245 |
-
@property
|
246 |
-
def weight(self):
|
247 |
-
"""Get the weight of wrapped embedding layer."""
|
248 |
-
return self.wrapped.weight
|
249 |
-
|
250 |
-
def check_duplicate_names(self, embeddings: List[dict]):
|
251 |
-
"""Check whether duplicate names exist in list of 'external
|
252 |
-
embeddings'.
|
253 |
-
|
254 |
-
Args:
|
255 |
-
embeddings (List[dict]): A list of embedding to be check.
|
256 |
-
"""
|
257 |
-
names = [emb["name"] for emb in embeddings]
|
258 |
-
assert len(names) == len(set(names)), (
|
259 |
-
"Found duplicated names in 'external_embeddings'. Name list: " f"'{names}'"
|
260 |
-
)
|
261 |
-
|
262 |
-
def check_ids_overlap(self, embeddings):
|
263 |
-
"""Check whether overlap exist in token ids of 'external_embeddings'.
|
264 |
-
|
265 |
-
Args:
|
266 |
-
embeddings (List[dict]): A list of embedding to be check.
|
267 |
-
"""
|
268 |
-
ids_range = [[emb["start"], emb["end"], emb["name"]] for emb in embeddings]
|
269 |
-
ids_range.sort() # sort by 'start'
|
270 |
-
# check if 'end' has overlapping
|
271 |
-
for idx in range(len(ids_range) - 1):
|
272 |
-
name1, name2 = ids_range[idx][-1], ids_range[idx + 1][-1]
|
273 |
-
assert ids_range[idx][1] <= ids_range[idx + 1][0], (
|
274 |
-
f"Found ids overlapping between embeddings '{name1}' " f"and '{name2}'."
|
275 |
-
)
|
276 |
-
|
277 |
-
def add_embeddings(self, embeddings: Optional[Union[dict, List[dict]]]):
|
278 |
-
"""Add external embeddings to this layer.
|
279 |
-
|
280 |
-
Use case:
|
281 |
-
|
282 |
-
>>> 1. Add token to tokenizer and get the token id.
|
283 |
-
>>> tokenizer = TokenizerWrapper('openai/clip-vit-base-patch32')
|
284 |
-
>>> # 'how much' in kiswahili
|
285 |
-
>>> tokenizer.add_placeholder_tokens('ngapi', num_vec_per_token=4)
|
286 |
-
>>>
|
287 |
-
>>> 2. Add external embeddings to the model.
|
288 |
-
>>> new_embedding = {
|
289 |
-
>>> 'name': 'ngapi', # 'how much' in kiswahili
|
290 |
-
>>> 'embedding': torch.ones(1, 15) * 4,
|
291 |
-
>>> 'start': tokenizer.get_token_info('kwaheri')['start'],
|
292 |
-
>>> 'end': tokenizer.get_token_info('kwaheri')['end'],
|
293 |
-
>>> 'trainable': False # if True, will registry as a parameter
|
294 |
-
>>> }
|
295 |
-
>>> embedding_layer = nn.Embedding(10, 15)
|
296 |
-
>>> embedding_layer_wrapper = EmbeddingLayerWithFixes(embedding_layer)
|
297 |
-
>>> embedding_layer_wrapper.add_embeddings(new_embedding)
|
298 |
-
>>>
|
299 |
-
>>> 3. Forward tokenizer and embedding layer!
|
300 |
-
>>> input_text = ['hello, ngapi!', 'hello my friend, ngapi?']
|
301 |
-
>>> input_ids = tokenizer(
|
302 |
-
>>> input_text, padding='max_length', truncation=True,
|
303 |
-
>>> return_tensors='pt')['input_ids']
|
304 |
-
>>> out_feat = embedding_layer_wrapper(input_ids)
|
305 |
-
>>>
|
306 |
-
>>> 4. Let's validate the result!
|
307 |
-
>>> assert (out_feat[0, 3: 7] == 2.3).all()
|
308 |
-
>>> assert (out_feat[2, 5: 9] == 2.3).all()
|
309 |
-
|
310 |
-
Args:
|
311 |
-
embeddings (Union[dict, list[dict]]): The external embeddings to
|
312 |
-
be added. Each dict must contain the following 4 fields: 'name'
|
313 |
-
(the name of this embedding), 'embedding' (the embedding
|
314 |
-
tensor), 'start' (the start token id of this embedding), 'end'
|
315 |
-
(the end token id of this embedding). For example:
|
316 |
-
`{name: NAME, start: START, end: END, embedding: torch.Tensor}`
|
317 |
-
"""
|
318 |
-
if isinstance(embeddings, dict):
|
319 |
-
embeddings = [embeddings]
|
320 |
-
|
321 |
-
self.external_embeddings += embeddings
|
322 |
-
self.check_duplicate_names(self.external_embeddings)
|
323 |
-
self.check_ids_overlap(self.external_embeddings)
|
324 |
-
|
325 |
-
# set for trainable
|
326 |
-
added_trainable_emb_info = []
|
327 |
-
for embedding in embeddings:
|
328 |
-
trainable = embedding.get("trainable", False)
|
329 |
-
if trainable:
|
330 |
-
name = embedding["name"]
|
331 |
-
embedding["embedding"] = torch.nn.Parameter(embedding["embedding"])
|
332 |
-
self.trainable_embeddings[name] = embedding["embedding"]
|
333 |
-
added_trainable_emb_info.append(name)
|
334 |
-
|
335 |
-
added_emb_info = [emb["name"] for emb in embeddings]
|
336 |
-
added_emb_info = ", ".join(added_emb_info)
|
337 |
-
print(f"Successfully add external embeddings: {added_emb_info}.", "current")
|
338 |
-
|
339 |
-
if added_trainable_emb_info:
|
340 |
-
added_trainable_emb_info = ", ".join(added_trainable_emb_info)
|
341 |
-
print("Successfully add trainable external embeddings: " f"{added_trainable_emb_info}", "current")
|
342 |
-
|
343 |
-
def replace_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
344 |
-
"""Replace external input ids to 0.
|
345 |
-
|
346 |
-
Args:
|
347 |
-
input_ids (torch.Tensor): The input ids to be replaced.
|
348 |
-
|
349 |
-
Returns:
|
350 |
-
torch.Tensor: The replaced input ids.
|
351 |
-
"""
|
352 |
-
input_ids_fwd = input_ids.clone()
|
353 |
-
input_ids_fwd[input_ids_fwd >= self.num_embeddings] = 0
|
354 |
-
return input_ids_fwd
|
355 |
-
|
356 |
-
def replace_embeddings(
|
357 |
-
self, input_ids: torch.Tensor, embedding: torch.Tensor, external_embedding: dict
|
358 |
-
) -> torch.Tensor:
|
359 |
-
"""Replace external embedding to the embedding layer. Noted that, in
|
360 |
-
this function we use `torch.cat` to avoid inplace modification.
|
361 |
-
|
362 |
-
Args:
|
363 |
-
input_ids (torch.Tensor): The original token ids. Shape like
|
364 |
-
[LENGTH, ].
|
365 |
-
embedding (torch.Tensor): The embedding of token ids after
|
366 |
-
`replace_input_ids` function.
|
367 |
-
external_embedding (dict): The external embedding to be replaced.
|
368 |
-
|
369 |
-
Returns:
|
370 |
-
torch.Tensor: The replaced embedding.
|
371 |
-
"""
|
372 |
-
new_embedding = []
|
373 |
-
|
374 |
-
name = external_embedding["name"]
|
375 |
-
start = external_embedding["start"]
|
376 |
-
end = external_embedding["end"]
|
377 |
-
target_ids_to_replace = [i for i in range(start, end)]
|
378 |
-
ext_emb = external_embedding["embedding"]
|
379 |
-
|
380 |
-
# do not need to replace
|
381 |
-
if not (input_ids == start).any():
|
382 |
-
return embedding
|
383 |
-
|
384 |
-
# start replace
|
385 |
-
s_idx, e_idx = 0, 0
|
386 |
-
while e_idx < len(input_ids):
|
387 |
-
if input_ids[e_idx] == start:
|
388 |
-
if e_idx != 0:
|
389 |
-
# add embedding do not need to replace
|
390 |
-
new_embedding.append(embedding[s_idx:e_idx])
|
391 |
-
|
392 |
-
# check if the next embedding need to replace is valid
|
393 |
-
actually_ids_to_replace = [int(i) for i in input_ids[e_idx : e_idx + end - start]]
|
394 |
-
assert actually_ids_to_replace == target_ids_to_replace, (
|
395 |
-
f"Invalid 'input_ids' in position: {s_idx} to {e_idx}. "
|
396 |
-
f"Expect '{target_ids_to_replace}' for embedding "
|
397 |
-
f"'{name}' but found '{actually_ids_to_replace}'."
|
398 |
-
)
|
399 |
-
|
400 |
-
new_embedding.append(ext_emb)
|
401 |
-
|
402 |
-
s_idx = e_idx + end - start
|
403 |
-
e_idx = s_idx + 1
|
404 |
-
else:
|
405 |
-
e_idx += 1
|
406 |
-
|
407 |
-
if e_idx == len(input_ids):
|
408 |
-
new_embedding.append(embedding[s_idx:e_idx])
|
409 |
-
|
410 |
-
return torch.cat(new_embedding, dim=0)
|
411 |
-
|
412 |
-
def forward(self, input_ids: torch.Tensor, external_embeddings: Optional[List[dict]] = None):
|
413 |
-
"""The forward function.
|
414 |
-
|
415 |
-
Args:
|
416 |
-
input_ids (torch.Tensor): The token ids shape like [bz, LENGTH] or
|
417 |
-
[LENGTH, ].
|
418 |
-
external_embeddings (Optional[List[dict]]): The external
|
419 |
-
embeddings. If not passed, only `self.external_embeddings`
|
420 |
-
will be used. Defaults to None.
|
421 |
-
|
422 |
-
input_ids: shape like [bz, LENGTH] or [LENGTH].
|
423 |
-
"""
|
424 |
-
assert input_ids.ndim in [1, 2]
|
425 |
-
if input_ids.ndim == 1:
|
426 |
-
input_ids = input_ids.unsqueeze(0)
|
427 |
-
|
428 |
-
if external_embeddings is None and not self.external_embeddings:
|
429 |
-
return self.wrapped(input_ids)
|
430 |
-
|
431 |
-
input_ids_fwd = self.replace_input_ids(input_ids)
|
432 |
-
inputs_embeds = self.wrapped(input_ids_fwd)
|
433 |
-
|
434 |
-
vecs = []
|
435 |
-
|
436 |
-
if external_embeddings is None:
|
437 |
-
external_embeddings = []
|
438 |
-
elif isinstance(external_embeddings, dict):
|
439 |
-
external_embeddings = [external_embeddings]
|
440 |
-
embeddings = self.external_embeddings + external_embeddings
|
441 |
-
|
442 |
-
for input_id, embedding in zip(input_ids, inputs_embeds):
|
443 |
-
new_embedding = embedding
|
444 |
-
for external_embedding in embeddings:
|
445 |
-
new_embedding = self.replace_embeddings(input_id, new_embedding, external_embedding)
|
446 |
-
vecs.append(new_embedding)
|
447 |
-
|
448 |
-
return torch.stack(vecs)
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
def add_tokens(
|
453 |
-
tokenizer, text_encoder, placeholder_tokens: list, initialize_tokens: list = None, num_vectors_per_token: int = 1
|
454 |
-
):
|
455 |
-
"""Add token for training.
|
456 |
-
|
457 |
-
# TODO: support add tokens as dict, then we can load pretrained tokens.
|
458 |
-
"""
|
459 |
-
if initialize_tokens is not None:
|
460 |
-
assert len(initialize_tokens) == len(
|
461 |
-
placeholder_tokens
|
462 |
-
), "placeholder_token should be the same length as initialize_token"
|
463 |
-
for ii in range(len(placeholder_tokens)):
|
464 |
-
tokenizer.add_placeholder_token(placeholder_tokens[ii], num_vec_per_token=num_vectors_per_token)
|
465 |
-
|
466 |
-
# text_encoder.set_embedding_layer()
|
467 |
-
embedding_layer = text_encoder.text_model.embeddings.token_embedding
|
468 |
-
text_encoder.text_model.embeddings.token_embedding = EmbeddingLayerWithFixes(embedding_layer)
|
469 |
-
embedding_layer = text_encoder.text_model.embeddings.token_embedding
|
470 |
-
|
471 |
-
assert embedding_layer is not None, (
|
472 |
-
"Do not support get embedding layer for current text encoder. " "Please check your configuration."
|
473 |
-
)
|
474 |
-
initialize_embedding = []
|
475 |
-
if initialize_tokens is not None:
|
476 |
-
for ii in range(len(placeholder_tokens)):
|
477 |
-
init_id = tokenizer(initialize_tokens[ii]).input_ids[1]
|
478 |
-
temp_embedding = embedding_layer.weight[init_id]
|
479 |
-
initialize_embedding.append(temp_embedding[None, ...].repeat(num_vectors_per_token, 1))
|
480 |
-
else:
|
481 |
-
for ii in range(len(placeholder_tokens)):
|
482 |
-
init_id = tokenizer("a").input_ids[1]
|
483 |
-
temp_embedding = embedding_layer.weight[init_id]
|
484 |
-
len_emb = temp_embedding.shape[0]
|
485 |
-
init_weight = (torch.rand(num_vectors_per_token, len_emb) - 0.5) / 2.0
|
486 |
-
initialize_embedding.append(init_weight)
|
487 |
-
|
488 |
-
# initialize_embedding = torch.cat(initialize_embedding,dim=0)
|
489 |
-
|
490 |
-
token_info_all = []
|
491 |
-
for ii in range(len(placeholder_tokens)):
|
492 |
-
token_info = tokenizer.get_token_info(placeholder_tokens[ii])
|
493 |
-
token_info["embedding"] = initialize_embedding[ii]
|
494 |
-
token_info["trainable"] = True
|
495 |
-
token_info_all.append(token_info)
|
496 |
-
embedding_layer.add_embeddings(token_info_all)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/brushnet/unet_2d_blocks.py
DELETED
The diff for this file is too large to render.
See raw diff
|
|
MagicQuill/brushnet/unet_2d_condition.py
DELETED
@@ -1,1355 +0,0 @@
|
|
1 |
-
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
from dataclasses import dataclass
|
15 |
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
16 |
-
|
17 |
-
import torch
|
18 |
-
import torch.nn as nn
|
19 |
-
import torch.utils.checkpoint
|
20 |
-
|
21 |
-
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
22 |
-
from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
|
23 |
-
from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
24 |
-
from diffusers.models.activations import get_activation
|
25 |
-
from diffusers.models.attention_processor import (
|
26 |
-
ADDED_KV_ATTENTION_PROCESSORS,
|
27 |
-
CROSS_ATTENTION_PROCESSORS,
|
28 |
-
Attention,
|
29 |
-
AttentionProcessor,
|
30 |
-
AttnAddedKVProcessor,
|
31 |
-
AttnProcessor,
|
32 |
-
)
|
33 |
-
from diffusers.models.embeddings import (
|
34 |
-
GaussianFourierProjection,
|
35 |
-
GLIGENTextBoundingboxProjection,
|
36 |
-
ImageHintTimeEmbedding,
|
37 |
-
ImageProjection,
|
38 |
-
ImageTimeEmbedding,
|
39 |
-
TextImageProjection,
|
40 |
-
TextImageTimeEmbedding,
|
41 |
-
TextTimeEmbedding,
|
42 |
-
TimestepEmbedding,
|
43 |
-
Timesteps,
|
44 |
-
)
|
45 |
-
from diffusers.models.modeling_utils import ModelMixin
|
46 |
-
from .unet_2d_blocks import (
|
47 |
-
get_down_block,
|
48 |
-
get_mid_block,
|
49 |
-
get_up_block,
|
50 |
-
)
|
51 |
-
|
52 |
-
|
53 |
-
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
54 |
-
|
55 |
-
|
56 |
-
@dataclass
|
57 |
-
class UNet2DConditionOutput(BaseOutput):
|
58 |
-
"""
|
59 |
-
The output of [`UNet2DConditionModel`].
|
60 |
-
|
61 |
-
Args:
|
62 |
-
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
63 |
-
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
64 |
-
"""
|
65 |
-
|
66 |
-
sample: torch.FloatTensor = None
|
67 |
-
|
68 |
-
|
69 |
-
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
|
70 |
-
r"""
|
71 |
-
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
72 |
-
shaped output.
|
73 |
-
|
74 |
-
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
75 |
-
for all models (such as downloading or saving).
|
76 |
-
|
77 |
-
Parameters:
|
78 |
-
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
79 |
-
Height and width of input/output sample.
|
80 |
-
in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
|
81 |
-
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
|
82 |
-
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
83 |
-
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
84 |
-
Whether to flip the sin to cos in the time embedding.
|
85 |
-
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
86 |
-
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
87 |
-
The tuple of downsample blocks to use.
|
88 |
-
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
89 |
-
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
|
90 |
-
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
|
91 |
-
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
92 |
-
The tuple of upsample blocks to use.
|
93 |
-
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
|
94 |
-
Whether to include self-attention in the basic transformer blocks, see
|
95 |
-
[`~models.attention.BasicTransformerBlock`].
|
96 |
-
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
97 |
-
The tuple of output channels for each block.
|
98 |
-
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
99 |
-
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
100 |
-
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
101 |
-
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
102 |
-
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
103 |
-
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
104 |
-
If `None`, normalization and activation layers is skipped in post-processing.
|
105 |
-
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
106 |
-
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
107 |
-
The dimension of the cross attention features.
|
108 |
-
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
|
109 |
-
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
110 |
-
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
111 |
-
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
112 |
-
reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
|
113 |
-
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
|
114 |
-
blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
|
115 |
-
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
116 |
-
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
117 |
-
encoder_hid_dim (`int`, *optional*, defaults to None):
|
118 |
-
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
119 |
-
dimension to `cross_attention_dim`.
|
120 |
-
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
121 |
-
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
122 |
-
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
123 |
-
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
124 |
-
num_attention_heads (`int`, *optional*):
|
125 |
-
The number of attention heads. If not defined, defaults to `attention_head_dim`
|
126 |
-
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
127 |
-
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
|
128 |
-
class_embed_type (`str`, *optional*, defaults to `None`):
|
129 |
-
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
130 |
-
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
131 |
-
addition_embed_type (`str`, *optional*, defaults to `None`):
|
132 |
-
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
133 |
-
"text". "text" will use the `TextTimeEmbedding` layer.
|
134 |
-
addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
|
135 |
-
Dimension for the timestep embeddings.
|
136 |
-
num_class_embeds (`int`, *optional*, defaults to `None`):
|
137 |
-
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
138 |
-
class conditioning with `class_embed_type` equal to `None`.
|
139 |
-
time_embedding_type (`str`, *optional*, defaults to `positional`):
|
140 |
-
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
141 |
-
time_embedding_dim (`int`, *optional*, defaults to `None`):
|
142 |
-
An optional override for the dimension of the projected time embedding.
|
143 |
-
time_embedding_act_fn (`str`, *optional*, defaults to `None`):
|
144 |
-
Optional activation function to use only once on the time embeddings before they are passed to the rest of
|
145 |
-
the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
146 |
-
timestep_post_act (`str`, *optional*, defaults to `None`):
|
147 |
-
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
|
148 |
-
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
|
149 |
-
The dimension of `cond_proj` layer in the timestep embedding.
|
150 |
-
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
|
151 |
-
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
|
152 |
-
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
|
153 |
-
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
|
154 |
-
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
|
155 |
-
embeddings with the class embeddings.
|
156 |
-
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
|
157 |
-
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
|
158 |
-
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
|
159 |
-
`only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
|
160 |
-
otherwise.
|
161 |
-
"""
|
162 |
-
|
163 |
-
_supports_gradient_checkpointing = True
|
164 |
-
|
165 |
-
@register_to_config
|
166 |
-
def __init__(
|
167 |
-
self,
|
168 |
-
sample_size: Optional[int] = None,
|
169 |
-
in_channels: int = 4,
|
170 |
-
out_channels: int = 4,
|
171 |
-
center_input_sample: bool = False,
|
172 |
-
flip_sin_to_cos: bool = True,
|
173 |
-
freq_shift: int = 0,
|
174 |
-
down_block_types: Tuple[str] = (
|
175 |
-
"CrossAttnDownBlock2D",
|
176 |
-
"CrossAttnDownBlock2D",
|
177 |
-
"CrossAttnDownBlock2D",
|
178 |
-
"DownBlock2D",
|
179 |
-
),
|
180 |
-
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
181 |
-
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
182 |
-
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
183 |
-
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
184 |
-
layers_per_block: Union[int, Tuple[int]] = 2,
|
185 |
-
downsample_padding: int = 1,
|
186 |
-
mid_block_scale_factor: float = 1,
|
187 |
-
dropout: float = 0.0,
|
188 |
-
act_fn: str = "silu",
|
189 |
-
norm_num_groups: Optional[int] = 32,
|
190 |
-
norm_eps: float = 1e-5,
|
191 |
-
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
192 |
-
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
193 |
-
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
|
194 |
-
encoder_hid_dim: Optional[int] = None,
|
195 |
-
encoder_hid_dim_type: Optional[str] = None,
|
196 |
-
attention_head_dim: Union[int, Tuple[int]] = 8,
|
197 |
-
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
198 |
-
dual_cross_attention: bool = False,
|
199 |
-
use_linear_projection: bool = False,
|
200 |
-
class_embed_type: Optional[str] = None,
|
201 |
-
addition_embed_type: Optional[str] = None,
|
202 |
-
addition_time_embed_dim: Optional[int] = None,
|
203 |
-
num_class_embeds: Optional[int] = None,
|
204 |
-
upcast_attention: bool = False,
|
205 |
-
resnet_time_scale_shift: str = "default",
|
206 |
-
resnet_skip_time_act: bool = False,
|
207 |
-
resnet_out_scale_factor: float = 1.0,
|
208 |
-
time_embedding_type: str = "positional",
|
209 |
-
time_embedding_dim: Optional[int] = None,
|
210 |
-
time_embedding_act_fn: Optional[str] = None,
|
211 |
-
timestep_post_act: Optional[str] = None,
|
212 |
-
time_cond_proj_dim: Optional[int] = None,
|
213 |
-
conv_in_kernel: int = 3,
|
214 |
-
conv_out_kernel: int = 3,
|
215 |
-
projection_class_embeddings_input_dim: Optional[int] = None,
|
216 |
-
attention_type: str = "default",
|
217 |
-
class_embeddings_concat: bool = False,
|
218 |
-
mid_block_only_cross_attention: Optional[bool] = None,
|
219 |
-
cross_attention_norm: Optional[str] = None,
|
220 |
-
addition_embed_type_num_heads: int = 64,
|
221 |
-
):
|
222 |
-
super().__init__()
|
223 |
-
|
224 |
-
self.sample_size = sample_size
|
225 |
-
|
226 |
-
if num_attention_heads is not None:
|
227 |
-
raise ValueError(
|
228 |
-
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
229 |
-
)
|
230 |
-
|
231 |
-
# If `num_attention_heads` is not defined (which is the case for most models)
|
232 |
-
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
233 |
-
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
234 |
-
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
235 |
-
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
236 |
-
# which is why we correct for the naming here.
|
237 |
-
num_attention_heads = num_attention_heads or attention_head_dim
|
238 |
-
|
239 |
-
# Check inputs
|
240 |
-
self._check_config(
|
241 |
-
down_block_types=down_block_types,
|
242 |
-
up_block_types=up_block_types,
|
243 |
-
only_cross_attention=only_cross_attention,
|
244 |
-
block_out_channels=block_out_channels,
|
245 |
-
layers_per_block=layers_per_block,
|
246 |
-
cross_attention_dim=cross_attention_dim,
|
247 |
-
transformer_layers_per_block=transformer_layers_per_block,
|
248 |
-
reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
|
249 |
-
attention_head_dim=attention_head_dim,
|
250 |
-
num_attention_heads=num_attention_heads,
|
251 |
-
)
|
252 |
-
|
253 |
-
# input
|
254 |
-
conv_in_padding = (conv_in_kernel - 1) // 2
|
255 |
-
self.conv_in = nn.Conv2d(
|
256 |
-
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
257 |
-
)
|
258 |
-
|
259 |
-
# time
|
260 |
-
time_embed_dim, timestep_input_dim = self._set_time_proj(
|
261 |
-
time_embedding_type,
|
262 |
-
block_out_channels=block_out_channels,
|
263 |
-
flip_sin_to_cos=flip_sin_to_cos,
|
264 |
-
freq_shift=freq_shift,
|
265 |
-
time_embedding_dim=time_embedding_dim,
|
266 |
-
)
|
267 |
-
|
268 |
-
self.time_embedding = TimestepEmbedding(
|
269 |
-
timestep_input_dim,
|
270 |
-
time_embed_dim,
|
271 |
-
act_fn=act_fn,
|
272 |
-
post_act_fn=timestep_post_act,
|
273 |
-
cond_proj_dim=time_cond_proj_dim,
|
274 |
-
)
|
275 |
-
|
276 |
-
self._set_encoder_hid_proj(
|
277 |
-
encoder_hid_dim_type,
|
278 |
-
cross_attention_dim=cross_attention_dim,
|
279 |
-
encoder_hid_dim=encoder_hid_dim,
|
280 |
-
)
|
281 |
-
|
282 |
-
# class embedding
|
283 |
-
self._set_class_embedding(
|
284 |
-
class_embed_type,
|
285 |
-
act_fn=act_fn,
|
286 |
-
num_class_embeds=num_class_embeds,
|
287 |
-
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
|
288 |
-
time_embed_dim=time_embed_dim,
|
289 |
-
timestep_input_dim=timestep_input_dim,
|
290 |
-
)
|
291 |
-
|
292 |
-
self._set_add_embedding(
|
293 |
-
addition_embed_type,
|
294 |
-
addition_embed_type_num_heads=addition_embed_type_num_heads,
|
295 |
-
addition_time_embed_dim=addition_time_embed_dim,
|
296 |
-
cross_attention_dim=cross_attention_dim,
|
297 |
-
encoder_hid_dim=encoder_hid_dim,
|
298 |
-
flip_sin_to_cos=flip_sin_to_cos,
|
299 |
-
freq_shift=freq_shift,
|
300 |
-
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
|
301 |
-
time_embed_dim=time_embed_dim,
|
302 |
-
)
|
303 |
-
|
304 |
-
if time_embedding_act_fn is None:
|
305 |
-
self.time_embed_act = None
|
306 |
-
else:
|
307 |
-
self.time_embed_act = get_activation(time_embedding_act_fn)
|
308 |
-
|
309 |
-
self.down_blocks = nn.ModuleList([])
|
310 |
-
self.up_blocks = nn.ModuleList([])
|
311 |
-
|
312 |
-
if isinstance(only_cross_attention, bool):
|
313 |
-
if mid_block_only_cross_attention is None:
|
314 |
-
mid_block_only_cross_attention = only_cross_attention
|
315 |
-
|
316 |
-
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
317 |
-
|
318 |
-
if mid_block_only_cross_attention is None:
|
319 |
-
mid_block_only_cross_attention = False
|
320 |
-
|
321 |
-
if isinstance(num_attention_heads, int):
|
322 |
-
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
323 |
-
|
324 |
-
if isinstance(attention_head_dim, int):
|
325 |
-
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
326 |
-
|
327 |
-
if isinstance(cross_attention_dim, int):
|
328 |
-
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
329 |
-
|
330 |
-
if isinstance(layers_per_block, int):
|
331 |
-
layers_per_block = [layers_per_block] * len(down_block_types)
|
332 |
-
|
333 |
-
if isinstance(transformer_layers_per_block, int):
|
334 |
-
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
335 |
-
|
336 |
-
if class_embeddings_concat:
|
337 |
-
# The time embeddings are concatenated with the class embeddings. The dimension of the
|
338 |
-
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
|
339 |
-
# regular time embeddings
|
340 |
-
blocks_time_embed_dim = time_embed_dim * 2
|
341 |
-
else:
|
342 |
-
blocks_time_embed_dim = time_embed_dim
|
343 |
-
|
344 |
-
# down
|
345 |
-
output_channel = block_out_channels[0]
|
346 |
-
for i, down_block_type in enumerate(down_block_types):
|
347 |
-
input_channel = output_channel
|
348 |
-
output_channel = block_out_channels[i]
|
349 |
-
is_final_block = i == len(block_out_channels) - 1
|
350 |
-
|
351 |
-
down_block = get_down_block(
|
352 |
-
down_block_type,
|
353 |
-
num_layers=layers_per_block[i],
|
354 |
-
transformer_layers_per_block=transformer_layers_per_block[i],
|
355 |
-
in_channels=input_channel,
|
356 |
-
out_channels=output_channel,
|
357 |
-
temb_channels=blocks_time_embed_dim,
|
358 |
-
add_downsample=not is_final_block,
|
359 |
-
resnet_eps=norm_eps,
|
360 |
-
resnet_act_fn=act_fn,
|
361 |
-
resnet_groups=norm_num_groups,
|
362 |
-
cross_attention_dim=cross_attention_dim[i],
|
363 |
-
num_attention_heads=num_attention_heads[i],
|
364 |
-
downsample_padding=downsample_padding,
|
365 |
-
dual_cross_attention=dual_cross_attention,
|
366 |
-
use_linear_projection=use_linear_projection,
|
367 |
-
only_cross_attention=only_cross_attention[i],
|
368 |
-
upcast_attention=upcast_attention,
|
369 |
-
resnet_time_scale_shift=resnet_time_scale_shift,
|
370 |
-
attention_type=attention_type,
|
371 |
-
resnet_skip_time_act=resnet_skip_time_act,
|
372 |
-
resnet_out_scale_factor=resnet_out_scale_factor,
|
373 |
-
cross_attention_norm=cross_attention_norm,
|
374 |
-
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
375 |
-
dropout=dropout,
|
376 |
-
)
|
377 |
-
self.down_blocks.append(down_block)
|
378 |
-
|
379 |
-
# mid
|
380 |
-
self.mid_block = get_mid_block(
|
381 |
-
mid_block_type,
|
382 |
-
temb_channels=blocks_time_embed_dim,
|
383 |
-
in_channels=block_out_channels[-1],
|
384 |
-
resnet_eps=norm_eps,
|
385 |
-
resnet_act_fn=act_fn,
|
386 |
-
resnet_groups=norm_num_groups,
|
387 |
-
output_scale_factor=mid_block_scale_factor,
|
388 |
-
transformer_layers_per_block=transformer_layers_per_block[-1],
|
389 |
-
num_attention_heads=num_attention_heads[-1],
|
390 |
-
cross_attention_dim=cross_attention_dim[-1],
|
391 |
-
dual_cross_attention=dual_cross_attention,
|
392 |
-
use_linear_projection=use_linear_projection,
|
393 |
-
mid_block_only_cross_attention=mid_block_only_cross_attention,
|
394 |
-
upcast_attention=upcast_attention,
|
395 |
-
resnet_time_scale_shift=resnet_time_scale_shift,
|
396 |
-
attention_type=attention_type,
|
397 |
-
resnet_skip_time_act=resnet_skip_time_act,
|
398 |
-
cross_attention_norm=cross_attention_norm,
|
399 |
-
attention_head_dim=attention_head_dim[-1],
|
400 |
-
dropout=dropout,
|
401 |
-
)
|
402 |
-
|
403 |
-
# count how many layers upsample the images
|
404 |
-
self.num_upsamplers = 0
|
405 |
-
|
406 |
-
# up
|
407 |
-
reversed_block_out_channels = list(reversed(block_out_channels))
|
408 |
-
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
409 |
-
reversed_layers_per_block = list(reversed(layers_per_block))
|
410 |
-
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
411 |
-
reversed_transformer_layers_per_block = (
|
412 |
-
list(reversed(transformer_layers_per_block))
|
413 |
-
if reverse_transformer_layers_per_block is None
|
414 |
-
else reverse_transformer_layers_per_block
|
415 |
-
)
|
416 |
-
only_cross_attention = list(reversed(only_cross_attention))
|
417 |
-
|
418 |
-
output_channel = reversed_block_out_channels[0]
|
419 |
-
for i, up_block_type in enumerate(up_block_types):
|
420 |
-
is_final_block = i == len(block_out_channels) - 1
|
421 |
-
|
422 |
-
prev_output_channel = output_channel
|
423 |
-
output_channel = reversed_block_out_channels[i]
|
424 |
-
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
425 |
-
|
426 |
-
# add upsample block for all BUT final layer
|
427 |
-
if not is_final_block:
|
428 |
-
add_upsample = True
|
429 |
-
self.num_upsamplers += 1
|
430 |
-
else:
|
431 |
-
add_upsample = False
|
432 |
-
|
433 |
-
up_block = get_up_block(
|
434 |
-
up_block_type,
|
435 |
-
num_layers=reversed_layers_per_block[i] + 1,
|
436 |
-
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
437 |
-
in_channels=input_channel,
|
438 |
-
out_channels=output_channel,
|
439 |
-
prev_output_channel=prev_output_channel,
|
440 |
-
temb_channels=blocks_time_embed_dim,
|
441 |
-
add_upsample=add_upsample,
|
442 |
-
resnet_eps=norm_eps,
|
443 |
-
resnet_act_fn=act_fn,
|
444 |
-
resolution_idx=i,
|
445 |
-
resnet_groups=norm_num_groups,
|
446 |
-
cross_attention_dim=reversed_cross_attention_dim[i],
|
447 |
-
num_attention_heads=reversed_num_attention_heads[i],
|
448 |
-
dual_cross_attention=dual_cross_attention,
|
449 |
-
use_linear_projection=use_linear_projection,
|
450 |
-
only_cross_attention=only_cross_attention[i],
|
451 |
-
upcast_attention=upcast_attention,
|
452 |
-
resnet_time_scale_shift=resnet_time_scale_shift,
|
453 |
-
attention_type=attention_type,
|
454 |
-
resnet_skip_time_act=resnet_skip_time_act,
|
455 |
-
resnet_out_scale_factor=resnet_out_scale_factor,
|
456 |
-
cross_attention_norm=cross_attention_norm,
|
457 |
-
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
458 |
-
dropout=dropout,
|
459 |
-
)
|
460 |
-
self.up_blocks.append(up_block)
|
461 |
-
prev_output_channel = output_channel
|
462 |
-
|
463 |
-
# out
|
464 |
-
if norm_num_groups is not None:
|
465 |
-
self.conv_norm_out = nn.GroupNorm(
|
466 |
-
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
467 |
-
)
|
468 |
-
|
469 |
-
self.conv_act = get_activation(act_fn)
|
470 |
-
|
471 |
-
else:
|
472 |
-
self.conv_norm_out = None
|
473 |
-
self.conv_act = None
|
474 |
-
|
475 |
-
conv_out_padding = (conv_out_kernel - 1) // 2
|
476 |
-
self.conv_out = nn.Conv2d(
|
477 |
-
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
478 |
-
)
|
479 |
-
|
480 |
-
self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)
|
481 |
-
|
482 |
-
def _check_config(
|
483 |
-
self,
|
484 |
-
down_block_types: Tuple[str],
|
485 |
-
up_block_types: Tuple[str],
|
486 |
-
only_cross_attention: Union[bool, Tuple[bool]],
|
487 |
-
block_out_channels: Tuple[int],
|
488 |
-
layers_per_block: Union[int, Tuple[int]],
|
489 |
-
cross_attention_dim: Union[int, Tuple[int]],
|
490 |
-
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
|
491 |
-
reverse_transformer_layers_per_block: bool,
|
492 |
-
attention_head_dim: int,
|
493 |
-
num_attention_heads: Optional[Union[int, Tuple[int]]],
|
494 |
-
):
|
495 |
-
if len(down_block_types) != len(up_block_types):
|
496 |
-
raise ValueError(
|
497 |
-
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
498 |
-
)
|
499 |
-
|
500 |
-
if len(block_out_channels) != len(down_block_types):
|
501 |
-
raise ValueError(
|
502 |
-
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
503 |
-
)
|
504 |
-
|
505 |
-
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
506 |
-
raise ValueError(
|
507 |
-
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
508 |
-
)
|
509 |
-
|
510 |
-
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
511 |
-
raise ValueError(
|
512 |
-
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
513 |
-
)
|
514 |
-
|
515 |
-
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
516 |
-
raise ValueError(
|
517 |
-
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
518 |
-
)
|
519 |
-
|
520 |
-
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
521 |
-
raise ValueError(
|
522 |
-
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
523 |
-
)
|
524 |
-
|
525 |
-
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
526 |
-
raise ValueError(
|
527 |
-
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
528 |
-
)
|
529 |
-
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
|
530 |
-
for layer_number_per_block in transformer_layers_per_block:
|
531 |
-
if isinstance(layer_number_per_block, list):
|
532 |
-
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
|
533 |
-
|
534 |
-
def _set_time_proj(
|
535 |
-
self,
|
536 |
-
time_embedding_type: str,
|
537 |
-
block_out_channels: int,
|
538 |
-
flip_sin_to_cos: bool,
|
539 |
-
freq_shift: float,
|
540 |
-
time_embedding_dim: int,
|
541 |
-
) -> Tuple[int, int]:
|
542 |
-
if time_embedding_type == "fourier":
|
543 |
-
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
544 |
-
if time_embed_dim % 2 != 0:
|
545 |
-
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
|
546 |
-
self.time_proj = GaussianFourierProjection(
|
547 |
-
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
548 |
-
)
|
549 |
-
timestep_input_dim = time_embed_dim
|
550 |
-
elif time_embedding_type == "positional":
|
551 |
-
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
552 |
-
|
553 |
-
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
554 |
-
timestep_input_dim = block_out_channels[0]
|
555 |
-
else:
|
556 |
-
raise ValueError(
|
557 |
-
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
|
558 |
-
)
|
559 |
-
|
560 |
-
return time_embed_dim, timestep_input_dim
|
561 |
-
|
562 |
-
def _set_encoder_hid_proj(
|
563 |
-
self,
|
564 |
-
encoder_hid_dim_type: Optional[str],
|
565 |
-
cross_attention_dim: Union[int, Tuple[int]],
|
566 |
-
encoder_hid_dim: Optional[int],
|
567 |
-
):
|
568 |
-
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
569 |
-
encoder_hid_dim_type = "text_proj"
|
570 |
-
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
571 |
-
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
572 |
-
|
573 |
-
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
574 |
-
raise ValueError(
|
575 |
-
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
576 |
-
)
|
577 |
-
|
578 |
-
if encoder_hid_dim_type == "text_proj":
|
579 |
-
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
580 |
-
elif encoder_hid_dim_type == "text_image_proj":
|
581 |
-
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
582 |
-
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
583 |
-
# case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
|
584 |
-
self.encoder_hid_proj = TextImageProjection(
|
585 |
-
text_embed_dim=encoder_hid_dim,
|
586 |
-
image_embed_dim=cross_attention_dim,
|
587 |
-
cross_attention_dim=cross_attention_dim,
|
588 |
-
)
|
589 |
-
elif encoder_hid_dim_type == "image_proj":
|
590 |
-
# Kandinsky 2.2
|
591 |
-
self.encoder_hid_proj = ImageProjection(
|
592 |
-
image_embed_dim=encoder_hid_dim,
|
593 |
-
cross_attention_dim=cross_attention_dim,
|
594 |
-
)
|
595 |
-
elif encoder_hid_dim_type is not None:
|
596 |
-
raise ValueError(
|
597 |
-
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
598 |
-
)
|
599 |
-
else:
|
600 |
-
self.encoder_hid_proj = None
|
601 |
-
|
602 |
-
def _set_class_embedding(
|
603 |
-
self,
|
604 |
-
class_embed_type: Optional[str],
|
605 |
-
act_fn: str,
|
606 |
-
num_class_embeds: Optional[int],
|
607 |
-
projection_class_embeddings_input_dim: Optional[int],
|
608 |
-
time_embed_dim: int,
|
609 |
-
timestep_input_dim: int,
|
610 |
-
):
|
611 |
-
if class_embed_type is None and num_class_embeds is not None:
|
612 |
-
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
613 |
-
elif class_embed_type == "timestep":
|
614 |
-
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
|
615 |
-
elif class_embed_type == "identity":
|
616 |
-
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
617 |
-
elif class_embed_type == "projection":
|
618 |
-
if projection_class_embeddings_input_dim is None:
|
619 |
-
raise ValueError(
|
620 |
-
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
621 |
-
)
|
622 |
-
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
623 |
-
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
624 |
-
# 2. it projects from an arbitrary input dimension.
|
625 |
-
#
|
626 |
-
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
627 |
-
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
628 |
-
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
629 |
-
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
630 |
-
elif class_embed_type == "simple_projection":
|
631 |
-
if projection_class_embeddings_input_dim is None:
|
632 |
-
raise ValueError(
|
633 |
-
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
|
634 |
-
)
|
635 |
-
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
|
636 |
-
else:
|
637 |
-
self.class_embedding = None
|
638 |
-
|
639 |
-
def _set_add_embedding(
|
640 |
-
self,
|
641 |
-
addition_embed_type: str,
|
642 |
-
addition_embed_type_num_heads: int,
|
643 |
-
addition_time_embed_dim: Optional[int],
|
644 |
-
flip_sin_to_cos: bool,
|
645 |
-
freq_shift: float,
|
646 |
-
cross_attention_dim: Optional[int],
|
647 |
-
encoder_hid_dim: Optional[int],
|
648 |
-
projection_class_embeddings_input_dim: Optional[int],
|
649 |
-
time_embed_dim: int,
|
650 |
-
):
|
651 |
-
if addition_embed_type == "text":
|
652 |
-
if encoder_hid_dim is not None:
|
653 |
-
text_time_embedding_from_dim = encoder_hid_dim
|
654 |
-
else:
|
655 |
-
text_time_embedding_from_dim = cross_attention_dim
|
656 |
-
|
657 |
-
self.add_embedding = TextTimeEmbedding(
|
658 |
-
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
659 |
-
)
|
660 |
-
elif addition_embed_type == "text_image":
|
661 |
-
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
662 |
-
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
663 |
-
# case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
|
664 |
-
self.add_embedding = TextImageTimeEmbedding(
|
665 |
-
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
666 |
-
)
|
667 |
-
elif addition_embed_type == "text_time":
|
668 |
-
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
669 |
-
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
670 |
-
elif addition_embed_type == "image":
|
671 |
-
# Kandinsky 2.2
|
672 |
-
self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
673 |
-
elif addition_embed_type == "image_hint":
|
674 |
-
# Kandinsky 2.2 ControlNet
|
675 |
-
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
676 |
-
elif addition_embed_type is not None:
|
677 |
-
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
678 |
-
|
679 |
-
def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
|
680 |
-
if attention_type in ["gated", "gated-text-image"]:
|
681 |
-
positive_len = 768
|
682 |
-
if isinstance(cross_attention_dim, int):
|
683 |
-
positive_len = cross_attention_dim
|
684 |
-
elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
|
685 |
-
positive_len = cross_attention_dim[0]
|
686 |
-
|
687 |
-
feature_type = "text-only" if attention_type == "gated" else "text-image"
|
688 |
-
self.position_net = GLIGENTextBoundingboxProjection(
|
689 |
-
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
|
690 |
-
)
|
691 |
-
|
692 |
-
@property
|
693 |
-
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
694 |
-
r"""
|
695 |
-
Returns:
|
696 |
-
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
697 |
-
indexed by its weight name.
|
698 |
-
"""
|
699 |
-
# set recursively
|
700 |
-
processors = {}
|
701 |
-
|
702 |
-
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
703 |
-
if hasattr(module, "get_processor"):
|
704 |
-
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
705 |
-
|
706 |
-
for sub_name, child in module.named_children():
|
707 |
-
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
708 |
-
|
709 |
-
return processors
|
710 |
-
|
711 |
-
for name, module in self.named_children():
|
712 |
-
fn_recursive_add_processors(name, module, processors)
|
713 |
-
|
714 |
-
return processors
|
715 |
-
|
716 |
-
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
717 |
-
r"""
|
718 |
-
Sets the attention processor to use to compute attention.
|
719 |
-
|
720 |
-
Parameters:
|
721 |
-
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
722 |
-
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
723 |
-
for **all** `Attention` layers.
|
724 |
-
|
725 |
-
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
726 |
-
processor. This is strongly recommended when setting trainable attention processors.
|
727 |
-
|
728 |
-
"""
|
729 |
-
count = len(self.attn_processors.keys())
|
730 |
-
|
731 |
-
if isinstance(processor, dict) and len(processor) != count:
|
732 |
-
raise ValueError(
|
733 |
-
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
734 |
-
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
735 |
-
)
|
736 |
-
|
737 |
-
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
738 |
-
if hasattr(module, "set_processor"):
|
739 |
-
if not isinstance(processor, dict):
|
740 |
-
module.set_processor(processor)
|
741 |
-
else:
|
742 |
-
module.set_processor(processor.pop(f"{name}.processor"))
|
743 |
-
|
744 |
-
for sub_name, child in module.named_children():
|
745 |
-
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
746 |
-
|
747 |
-
for name, module in self.named_children():
|
748 |
-
fn_recursive_attn_processor(name, module, processor)
|
749 |
-
|
750 |
-
def set_default_attn_processor(self):
|
751 |
-
"""
|
752 |
-
Disables custom attention processors and sets the default attention implementation.
|
753 |
-
"""
|
754 |
-
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
755 |
-
processor = AttnAddedKVProcessor()
|
756 |
-
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
757 |
-
processor = AttnProcessor()
|
758 |
-
else:
|
759 |
-
raise ValueError(
|
760 |
-
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
761 |
-
)
|
762 |
-
|
763 |
-
self.set_attn_processor(processor)
|
764 |
-
|
765 |
-
def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
|
766 |
-
r"""
|
767 |
-
Enable sliced attention computation.
|
768 |
-
|
769 |
-
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
770 |
-
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
771 |
-
|
772 |
-
Args:
|
773 |
-
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
774 |
-
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
775 |
-
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
776 |
-
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
777 |
-
must be a multiple of `slice_size`.
|
778 |
-
"""
|
779 |
-
sliceable_head_dims = []
|
780 |
-
|
781 |
-
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
782 |
-
if hasattr(module, "set_attention_slice"):
|
783 |
-
sliceable_head_dims.append(module.sliceable_head_dim)
|
784 |
-
|
785 |
-
for child in module.children():
|
786 |
-
fn_recursive_retrieve_sliceable_dims(child)
|
787 |
-
|
788 |
-
# retrieve number of attention layers
|
789 |
-
for module in self.children():
|
790 |
-
fn_recursive_retrieve_sliceable_dims(module)
|
791 |
-
|
792 |
-
num_sliceable_layers = len(sliceable_head_dims)
|
793 |
-
|
794 |
-
if slice_size == "auto":
|
795 |
-
# half the attention head size is usually a good trade-off between
|
796 |
-
# speed and memory
|
797 |
-
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
798 |
-
elif slice_size == "max":
|
799 |
-
# make smallest slice possible
|
800 |
-
slice_size = num_sliceable_layers * [1]
|
801 |
-
|
802 |
-
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
803 |
-
|
804 |
-
if len(slice_size) != len(sliceable_head_dims):
|
805 |
-
raise ValueError(
|
806 |
-
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
807 |
-
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
808 |
-
)
|
809 |
-
|
810 |
-
for i in range(len(slice_size)):
|
811 |
-
size = slice_size[i]
|
812 |
-
dim = sliceable_head_dims[i]
|
813 |
-
if size is not None and size > dim:
|
814 |
-
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
815 |
-
|
816 |
-
# Recursively walk through all the children.
|
817 |
-
# Any children which exposes the set_attention_slice method
|
818 |
-
# gets the message
|
819 |
-
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
820 |
-
if hasattr(module, "set_attention_slice"):
|
821 |
-
module.set_attention_slice(slice_size.pop())
|
822 |
-
|
823 |
-
for child in module.children():
|
824 |
-
fn_recursive_set_attention_slice(child, slice_size)
|
825 |
-
|
826 |
-
reversed_slice_size = list(reversed(slice_size))
|
827 |
-
for module in self.children():
|
828 |
-
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
829 |
-
|
830 |
-
def _set_gradient_checkpointing(self, module, value=False):
|
831 |
-
if hasattr(module, "gradient_checkpointing"):
|
832 |
-
module.gradient_checkpointing = value
|
833 |
-
|
834 |
-
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
835 |
-
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
836 |
-
|
837 |
-
The suffixes after the scaling factors represent the stage blocks where they are being applied.
|
838 |
-
|
839 |
-
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
|
840 |
-
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
841 |
-
|
842 |
-
Args:
|
843 |
-
s1 (`float`):
|
844 |
-
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
845 |
-
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
846 |
-
s2 (`float`):
|
847 |
-
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
848 |
-
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
849 |
-
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
850 |
-
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
851 |
-
"""
|
852 |
-
for i, upsample_block in enumerate(self.up_blocks):
|
853 |
-
setattr(upsample_block, "s1", s1)
|
854 |
-
setattr(upsample_block, "s2", s2)
|
855 |
-
setattr(upsample_block, "b1", b1)
|
856 |
-
setattr(upsample_block, "b2", b2)
|
857 |
-
|
858 |
-
def disable_freeu(self):
|
859 |
-
"""Disables the FreeU mechanism."""
|
860 |
-
freeu_keys = {"s1", "s2", "b1", "b2"}
|
861 |
-
for i, upsample_block in enumerate(self.up_blocks):
|
862 |
-
for k in freeu_keys:
|
863 |
-
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
|
864 |
-
setattr(upsample_block, k, None)
|
865 |
-
|
866 |
-
def fuse_qkv_projections(self):
|
867 |
-
"""
|
868 |
-
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
869 |
-
are fused. For cross-attention modules, key and value projection matrices are fused.
|
870 |
-
|
871 |
-
<Tip warning={true}>
|
872 |
-
|
873 |
-
This API is 🧪 experimental.
|
874 |
-
|
875 |
-
</Tip>
|
876 |
-
"""
|
877 |
-
self.original_attn_processors = None
|
878 |
-
|
879 |
-
for _, attn_processor in self.attn_processors.items():
|
880 |
-
if "Added" in str(attn_processor.__class__.__name__):
|
881 |
-
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
882 |
-
|
883 |
-
self.original_attn_processors = self.attn_processors
|
884 |
-
|
885 |
-
for module in self.modules():
|
886 |
-
if isinstance(module, Attention):
|
887 |
-
module.fuse_projections(fuse=True)
|
888 |
-
|
889 |
-
def unfuse_qkv_projections(self):
|
890 |
-
"""Disables the fused QKV projection if enabled.
|
891 |
-
|
892 |
-
<Tip warning={true}>
|
893 |
-
|
894 |
-
This API is 🧪 experimental.
|
895 |
-
|
896 |
-
</Tip>
|
897 |
-
|
898 |
-
"""
|
899 |
-
if self.original_attn_processors is not None:
|
900 |
-
self.set_attn_processor(self.original_attn_processors)
|
901 |
-
|
902 |
-
def unload_lora(self):
|
903 |
-
"""Unloads LoRA weights."""
|
904 |
-
deprecate(
|
905 |
-
"unload_lora",
|
906 |
-
"0.28.0",
|
907 |
-
"Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
|
908 |
-
)
|
909 |
-
for module in self.modules():
|
910 |
-
if hasattr(module, "set_lora_layer"):
|
911 |
-
module.set_lora_layer(None)
|
912 |
-
|
913 |
-
def get_time_embed(
|
914 |
-
self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
|
915 |
-
) -> Optional[torch.Tensor]:
|
916 |
-
timesteps = timestep
|
917 |
-
if not torch.is_tensor(timesteps):
|
918 |
-
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
919 |
-
# This would be a good case for the `match` statement (Python 3.10+)
|
920 |
-
is_mps = sample.device.type == "mps"
|
921 |
-
if isinstance(timestep, float):
|
922 |
-
dtype = torch.float32 if is_mps else torch.float64
|
923 |
-
else:
|
924 |
-
dtype = torch.int32 if is_mps else torch.int64
|
925 |
-
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
926 |
-
elif len(timesteps.shape) == 0:
|
927 |
-
timesteps = timesteps[None].to(sample.device)
|
928 |
-
|
929 |
-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
930 |
-
timesteps = timesteps.expand(sample.shape[0])
|
931 |
-
|
932 |
-
t_emb = self.time_proj(timesteps)
|
933 |
-
# `Timesteps` does not contain any weights and will always return f32 tensors
|
934 |
-
# but time_embedding might actually be running in fp16. so we need to cast here.
|
935 |
-
# there might be better ways to encapsulate this.
|
936 |
-
t_emb = t_emb.to(dtype=sample.dtype)
|
937 |
-
return t_emb
|
938 |
-
|
939 |
-
def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
940 |
-
class_emb = None
|
941 |
-
if self.class_embedding is not None:
|
942 |
-
if class_labels is None:
|
943 |
-
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
944 |
-
|
945 |
-
if self.config.class_embed_type == "timestep":
|
946 |
-
class_labels = self.time_proj(class_labels)
|
947 |
-
|
948 |
-
# `Timesteps` does not contain any weights and will always return f32 tensors
|
949 |
-
# there might be better ways to encapsulate this.
|
950 |
-
class_labels = class_labels.to(dtype=sample.dtype)
|
951 |
-
|
952 |
-
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
953 |
-
return class_emb
|
954 |
-
|
955 |
-
def get_aug_embed(
|
956 |
-
self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
|
957 |
-
) -> Optional[torch.Tensor]:
|
958 |
-
aug_emb = None
|
959 |
-
if self.config.addition_embed_type == "text":
|
960 |
-
aug_emb = self.add_embedding(encoder_hidden_states)
|
961 |
-
elif self.config.addition_embed_type == "text_image":
|
962 |
-
# Kandinsky 2.1 - style
|
963 |
-
if "image_embeds" not in added_cond_kwargs:
|
964 |
-
raise ValueError(
|
965 |
-
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
966 |
-
)
|
967 |
-
|
968 |
-
image_embs = added_cond_kwargs.get("image_embeds")
|
969 |
-
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
970 |
-
aug_emb = self.add_embedding(text_embs, image_embs)
|
971 |
-
elif self.config.addition_embed_type == "text_time":
|
972 |
-
# SDXL - style
|
973 |
-
if "text_embeds" not in added_cond_kwargs:
|
974 |
-
raise ValueError(
|
975 |
-
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
976 |
-
)
|
977 |
-
text_embeds = added_cond_kwargs.get("text_embeds")
|
978 |
-
if "time_ids" not in added_cond_kwargs:
|
979 |
-
raise ValueError(
|
980 |
-
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
981 |
-
)
|
982 |
-
time_ids = added_cond_kwargs.get("time_ids")
|
983 |
-
time_embeds = self.add_time_proj(time_ids.flatten())
|
984 |
-
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
985 |
-
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
986 |
-
add_embeds = add_embeds.to(emb.dtype)
|
987 |
-
aug_emb = self.add_embedding(add_embeds)
|
988 |
-
elif self.config.addition_embed_type == "image":
|
989 |
-
# Kandinsky 2.2 - style
|
990 |
-
if "image_embeds" not in added_cond_kwargs:
|
991 |
-
raise ValueError(
|
992 |
-
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
993 |
-
)
|
994 |
-
image_embs = added_cond_kwargs.get("image_embeds")
|
995 |
-
aug_emb = self.add_embedding(image_embs)
|
996 |
-
elif self.config.addition_embed_type == "image_hint":
|
997 |
-
# Kandinsky 2.2 - style
|
998 |
-
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
999 |
-
raise ValueError(
|
1000 |
-
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
1001 |
-
)
|
1002 |
-
image_embs = added_cond_kwargs.get("image_embeds")
|
1003 |
-
hint = added_cond_kwargs.get("hint")
|
1004 |
-
aug_emb = self.add_embedding(image_embs, hint)
|
1005 |
-
return aug_emb
|
1006 |
-
|
1007 |
-
def process_encoder_hidden_states(
|
1008 |
-
self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
|
1009 |
-
) -> torch.Tensor:
|
1010 |
-
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
1011 |
-
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
1012 |
-
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
1013 |
-
# Kandinsky 2.1 - style
|
1014 |
-
if "image_embeds" not in added_cond_kwargs:
|
1015 |
-
raise ValueError(
|
1016 |
-
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1017 |
-
)
|
1018 |
-
|
1019 |
-
image_embeds = added_cond_kwargs.get("image_embeds")
|
1020 |
-
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
1021 |
-
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
|
1022 |
-
# Kandinsky 2.2 - style
|
1023 |
-
if "image_embeds" not in added_cond_kwargs:
|
1024 |
-
raise ValueError(
|
1025 |
-
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1026 |
-
)
|
1027 |
-
image_embeds = added_cond_kwargs.get("image_embeds")
|
1028 |
-
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
1029 |
-
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
|
1030 |
-
if "image_embeds" not in added_cond_kwargs:
|
1031 |
-
raise ValueError(
|
1032 |
-
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1033 |
-
)
|
1034 |
-
image_embeds = added_cond_kwargs.get("image_embeds")
|
1035 |
-
image_embeds = self.encoder_hid_proj(image_embeds)
|
1036 |
-
encoder_hidden_states = (encoder_hidden_states, image_embeds)
|
1037 |
-
return encoder_hidden_states
|
1038 |
-
|
1039 |
-
def forward(
|
1040 |
-
self,
|
1041 |
-
sample: torch.FloatTensor,
|
1042 |
-
timestep: Union[torch.Tensor, float, int],
|
1043 |
-
encoder_hidden_states: torch.Tensor,
|
1044 |
-
class_labels: Optional[torch.Tensor] = None,
|
1045 |
-
timestep_cond: Optional[torch.Tensor] = None,
|
1046 |
-
attention_mask: Optional[torch.Tensor] = None,
|
1047 |
-
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1048 |
-
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
1049 |
-
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
1050 |
-
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
1051 |
-
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
1052 |
-
encoder_attention_mask: Optional[torch.Tensor] = None,
|
1053 |
-
return_dict: bool = True,
|
1054 |
-
down_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
|
1055 |
-
mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None,
|
1056 |
-
up_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
|
1057 |
-
) -> Union[UNet2DConditionOutput, Tuple]:
|
1058 |
-
r"""
|
1059 |
-
The [`UNet2DConditionModel`] forward method.
|
1060 |
-
|
1061 |
-
Args:
|
1062 |
-
sample (`torch.FloatTensor`):
|
1063 |
-
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
1064 |
-
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
1065 |
-
encoder_hidden_states (`torch.FloatTensor`):
|
1066 |
-
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
1067 |
-
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
1068 |
-
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
1069 |
-
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
|
1070 |
-
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
|
1071 |
-
through the `self.time_embedding` layer to obtain the timestep embeddings.
|
1072 |
-
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
1073 |
-
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
1074 |
-
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
1075 |
-
negative values to the attention scores corresponding to "discard" tokens.
|
1076 |
-
cross_attention_kwargs (`dict`, *optional*):
|
1077 |
-
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
1078 |
-
`self.processor` in
|
1079 |
-
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
1080 |
-
added_cond_kwargs: (`dict`, *optional*):
|
1081 |
-
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
|
1082 |
-
are passed along to the UNet blocks.
|
1083 |
-
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
|
1084 |
-
A tuple of tensors that if specified are added to the residuals of down unet blocks.
|
1085 |
-
mid_block_additional_residual: (`torch.Tensor`, *optional*):
|
1086 |
-
A tensor that if specified is added to the residual of the middle unet block.
|
1087 |
-
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
1088 |
-
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
|
1089 |
-
encoder_attention_mask (`torch.Tensor`):
|
1090 |
-
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
1091 |
-
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
1092 |
-
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
1093 |
-
return_dict (`bool`, *optional*, defaults to `True`):
|
1094 |
-
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
1095 |
-
tuple.
|
1096 |
-
|
1097 |
-
Returns:
|
1098 |
-
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
1099 |
-
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
|
1100 |
-
otherwise a `tuple` is returned where the first element is the sample tensor.
|
1101 |
-
"""
|
1102 |
-
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
1103 |
-
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
1104 |
-
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
1105 |
-
# on the fly if necessary.
|
1106 |
-
default_overall_up_factor = 2**self.num_upsamplers
|
1107 |
-
|
1108 |
-
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
1109 |
-
forward_upsample_size = False
|
1110 |
-
upsample_size = None
|
1111 |
-
|
1112 |
-
for dim in sample.shape[-2:]:
|
1113 |
-
if dim % default_overall_up_factor != 0:
|
1114 |
-
# Forward upsample size to force interpolation output size.
|
1115 |
-
forward_upsample_size = True
|
1116 |
-
break
|
1117 |
-
|
1118 |
-
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
1119 |
-
# expects mask of shape:
|
1120 |
-
# [batch, key_tokens]
|
1121 |
-
# adds singleton query_tokens dimension:
|
1122 |
-
# [batch, 1, key_tokens]
|
1123 |
-
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
1124 |
-
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
1125 |
-
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
1126 |
-
if attention_mask is not None:
|
1127 |
-
# assume that mask is expressed as:
|
1128 |
-
# (1 = keep, 0 = discard)
|
1129 |
-
# convert mask into a bias that can be added to attention scores:
|
1130 |
-
# (keep = +0, discard = -10000.0)
|
1131 |
-
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
1132 |
-
attention_mask = attention_mask.unsqueeze(1)
|
1133 |
-
|
1134 |
-
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
1135 |
-
if encoder_attention_mask is not None:
|
1136 |
-
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
1137 |
-
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
1138 |
-
|
1139 |
-
# 0. center input if necessary
|
1140 |
-
if self.config.center_input_sample:
|
1141 |
-
sample = 2 * sample - 1.0
|
1142 |
-
|
1143 |
-
# 1. time
|
1144 |
-
t_emb = self.get_time_embed(sample=sample, timestep=timestep)
|
1145 |
-
emb = self.time_embedding(t_emb, timestep_cond)
|
1146 |
-
aug_emb = None
|
1147 |
-
|
1148 |
-
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
|
1149 |
-
if class_emb is not None:
|
1150 |
-
if self.config.class_embeddings_concat:
|
1151 |
-
emb = torch.cat([emb, class_emb], dim=-1)
|
1152 |
-
else:
|
1153 |
-
emb = emb + class_emb
|
1154 |
-
|
1155 |
-
aug_emb = self.get_aug_embed(
|
1156 |
-
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
1157 |
-
)
|
1158 |
-
if self.config.addition_embed_type == "image_hint":
|
1159 |
-
aug_emb, hint = aug_emb
|
1160 |
-
sample = torch.cat([sample, hint], dim=1)
|
1161 |
-
|
1162 |
-
emb = emb + aug_emb if aug_emb is not None else emb
|
1163 |
-
|
1164 |
-
if self.time_embed_act is not None:
|
1165 |
-
emb = self.time_embed_act(emb)
|
1166 |
-
|
1167 |
-
encoder_hidden_states = self.process_encoder_hidden_states(
|
1168 |
-
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
1169 |
-
)
|
1170 |
-
|
1171 |
-
# 2. pre-process
|
1172 |
-
sample = self.conv_in(sample)
|
1173 |
-
|
1174 |
-
# 2.5 GLIGEN position net
|
1175 |
-
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
|
1176 |
-
cross_attention_kwargs = cross_attention_kwargs.copy()
|
1177 |
-
gligen_args = cross_attention_kwargs.pop("gligen")
|
1178 |
-
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
1179 |
-
|
1180 |
-
# 3. down
|
1181 |
-
# we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
|
1182 |
-
# to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
|
1183 |
-
if cross_attention_kwargs is not None:
|
1184 |
-
cross_attention_kwargs = cross_attention_kwargs.copy()
|
1185 |
-
lora_scale = cross_attention_kwargs.pop("scale", 1.0)
|
1186 |
-
else:
|
1187 |
-
lora_scale = 1.0
|
1188 |
-
|
1189 |
-
if USE_PEFT_BACKEND:
|
1190 |
-
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
1191 |
-
scale_lora_layers(self, lora_scale)
|
1192 |
-
|
1193 |
-
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
1194 |
-
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
|
1195 |
-
is_adapter = down_intrablock_additional_residuals is not None
|
1196 |
-
# maintain backward compatibility for legacy usage, where
|
1197 |
-
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
|
1198 |
-
# but can only use one or the other
|
1199 |
-
is_brushnet = down_block_add_samples is not None and mid_block_add_sample is not None and up_block_add_samples is not None
|
1200 |
-
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
|
1201 |
-
deprecate(
|
1202 |
-
"T2I should not use down_block_additional_residuals",
|
1203 |
-
"1.3.0",
|
1204 |
-
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
|
1205 |
-
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
|
1206 |
-
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
|
1207 |
-
standard_warn=False,
|
1208 |
-
)
|
1209 |
-
down_intrablock_additional_residuals = down_block_additional_residuals
|
1210 |
-
is_adapter = True
|
1211 |
-
|
1212 |
-
down_block_res_samples = (sample,)
|
1213 |
-
|
1214 |
-
if is_brushnet:
|
1215 |
-
sample = sample + down_block_add_samples.pop(0)
|
1216 |
-
|
1217 |
-
for downsample_block in self.down_blocks:
|
1218 |
-
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
1219 |
-
# For t2i-adapter CrossAttnDownBlock2D
|
1220 |
-
additional_residuals = {}
|
1221 |
-
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
1222 |
-
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
|
1223 |
-
|
1224 |
-
i = len(down_block_add_samples)
|
1225 |
-
|
1226 |
-
if is_brushnet and len(down_block_add_samples)>0:
|
1227 |
-
additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
|
1228 |
-
for _ in range(len(downsample_block.resnets)+(downsample_block.downsamplers !=None))]
|
1229 |
-
|
1230 |
-
sample, res_samples = downsample_block(
|
1231 |
-
hidden_states=sample,
|
1232 |
-
temb=emb,
|
1233 |
-
encoder_hidden_states=encoder_hidden_states,
|
1234 |
-
attention_mask=attention_mask,
|
1235 |
-
cross_attention_kwargs=cross_attention_kwargs,
|
1236 |
-
encoder_attention_mask=encoder_attention_mask,
|
1237 |
-
**additional_residuals,
|
1238 |
-
)
|
1239 |
-
else:
|
1240 |
-
additional_residuals = {}
|
1241 |
-
|
1242 |
-
i = len(down_block_add_samples)
|
1243 |
-
|
1244 |
-
if is_brushnet and len(down_block_add_samples)>0:
|
1245 |
-
additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
|
1246 |
-
for _ in range(len(downsample_block.resnets)+(downsample_block.downsamplers !=None))]
|
1247 |
-
|
1248 |
-
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, **additional_residuals)
|
1249 |
-
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
1250 |
-
sample += down_intrablock_additional_residuals.pop(0)
|
1251 |
-
|
1252 |
-
down_block_res_samples += res_samples
|
1253 |
-
|
1254 |
-
if is_controlnet:
|
1255 |
-
new_down_block_res_samples = ()
|
1256 |
-
|
1257 |
-
for down_block_res_sample, down_block_additional_residual in zip(
|
1258 |
-
down_block_res_samples, down_block_additional_residuals
|
1259 |
-
):
|
1260 |
-
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
1261 |
-
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
1262 |
-
|
1263 |
-
down_block_res_samples = new_down_block_res_samples
|
1264 |
-
|
1265 |
-
# 4. mid
|
1266 |
-
if self.mid_block is not None:
|
1267 |
-
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
1268 |
-
sample = self.mid_block(
|
1269 |
-
sample,
|
1270 |
-
emb,
|
1271 |
-
encoder_hidden_states=encoder_hidden_states,
|
1272 |
-
attention_mask=attention_mask,
|
1273 |
-
cross_attention_kwargs=cross_attention_kwargs,
|
1274 |
-
encoder_attention_mask=encoder_attention_mask,
|
1275 |
-
)
|
1276 |
-
else:
|
1277 |
-
sample = self.mid_block(sample, emb)
|
1278 |
-
|
1279 |
-
# To support T2I-Adapter-XL
|
1280 |
-
if (
|
1281 |
-
is_adapter
|
1282 |
-
and len(down_intrablock_additional_residuals) > 0
|
1283 |
-
and sample.shape == down_intrablock_additional_residuals[0].shape
|
1284 |
-
):
|
1285 |
-
sample += down_intrablock_additional_residuals.pop(0)
|
1286 |
-
|
1287 |
-
if is_controlnet:
|
1288 |
-
sample = sample + mid_block_additional_residual
|
1289 |
-
|
1290 |
-
if is_brushnet:
|
1291 |
-
sample = sample + mid_block_add_sample
|
1292 |
-
|
1293 |
-
# 5. up
|
1294 |
-
for i, upsample_block in enumerate(self.up_blocks):
|
1295 |
-
is_final_block = i == len(self.up_blocks) - 1
|
1296 |
-
|
1297 |
-
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
1298 |
-
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
1299 |
-
|
1300 |
-
# if we have not reached the final block and need to forward the
|
1301 |
-
# upsample size, we do it here
|
1302 |
-
if not is_final_block and forward_upsample_size:
|
1303 |
-
upsample_size = down_block_res_samples[-1].shape[2:]
|
1304 |
-
|
1305 |
-
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
1306 |
-
additional_residuals = {}
|
1307 |
-
|
1308 |
-
i = len(up_block_add_samples)
|
1309 |
-
|
1310 |
-
if is_brushnet and len(up_block_add_samples)>0:
|
1311 |
-
additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
|
1312 |
-
for _ in range(len(upsample_block.resnets)+(upsample_block.upsamplers !=None))]
|
1313 |
-
|
1314 |
-
sample = upsample_block(
|
1315 |
-
hidden_states=sample,
|
1316 |
-
temb=emb,
|
1317 |
-
res_hidden_states_tuple=res_samples,
|
1318 |
-
encoder_hidden_states=encoder_hidden_states,
|
1319 |
-
cross_attention_kwargs=cross_attention_kwargs,
|
1320 |
-
upsample_size=upsample_size,
|
1321 |
-
attention_mask=attention_mask,
|
1322 |
-
encoder_attention_mask=encoder_attention_mask,
|
1323 |
-
**additional_residuals,
|
1324 |
-
)
|
1325 |
-
else:
|
1326 |
-
additional_residuals = {}
|
1327 |
-
|
1328 |
-
i = len(up_block_add_samples)
|
1329 |
-
|
1330 |
-
if is_brushnet and len(up_block_add_samples)>0:
|
1331 |
-
additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
|
1332 |
-
for _ in range(len(upsample_block.resnets)+(upsample_block.upsamplers !=None))]
|
1333 |
-
|
1334 |
-
sample = upsample_block(
|
1335 |
-
hidden_states=sample,
|
1336 |
-
temb=emb,
|
1337 |
-
res_hidden_states_tuple=res_samples,
|
1338 |
-
upsample_size=upsample_size,
|
1339 |
-
**additional_residuals,
|
1340 |
-
)
|
1341 |
-
|
1342 |
-
# 6. post-process
|
1343 |
-
if self.conv_norm_out:
|
1344 |
-
sample = self.conv_norm_out(sample)
|
1345 |
-
sample = self.conv_act(sample)
|
1346 |
-
sample = self.conv_out(sample)
|
1347 |
-
|
1348 |
-
if USE_PEFT_BACKEND:
|
1349 |
-
# remove `lora_scale` from each PEFT layer
|
1350 |
-
unscale_lora_layers(self, lora_scale)
|
1351 |
-
|
1352 |
-
if not return_dict:
|
1353 |
-
return (sample,)
|
1354 |
-
|
1355 |
-
return UNet2DConditionOutput(sample=sample)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/brushnet_nodes.py
DELETED
@@ -1,1094 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import types
|
3 |
-
from typing import Tuple
|
4 |
-
|
5 |
-
import torch
|
6 |
-
import torchvision.transforms as T
|
7 |
-
import torch.nn.functional as F
|
8 |
-
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
9 |
-
import sys
|
10 |
-
|
11 |
-
import comfy.sd
|
12 |
-
import comfy.utils
|
13 |
-
import comfy.model_management
|
14 |
-
import comfy.sd1_clip
|
15 |
-
import comfy.ldm.models.autoencoder
|
16 |
-
import comfy.supported_models
|
17 |
-
|
18 |
-
import folder_paths
|
19 |
-
|
20 |
-
from .model_patch import add_model_patch_option, patch_model_function_wrapper
|
21 |
-
from .brushnet.brushnet import BrushNetModel
|
22 |
-
from .brushnet.brushnet_ca import BrushNetModel as PowerPaintModel
|
23 |
-
from .brushnet.powerpaint_utils import TokenizerWrapper, add_tokens
|
24 |
-
|
25 |
-
current_directory = os.path.dirname(os.path.abspath(__file__))
|
26 |
-
brushnet_config_file = os.path.join(current_directory, 'brushnet', 'brushnet.json')
|
27 |
-
brushnet_xl_config_file = os.path.join(current_directory, 'brushnet', 'brushnet_xl.json')
|
28 |
-
powerpaint_config_file = os.path.join(current_directory,'brushnet', 'powerpaint.json')
|
29 |
-
|
30 |
-
sd15_scaling_factor = 0.18215
|
31 |
-
sdxl_scaling_factor = 0.13025
|
32 |
-
|
33 |
-
print(sys.path)
|
34 |
-
ModelsToUnload = [comfy.sd1_clip.SD1ClipModel,
|
35 |
-
comfy.ldm.models.autoencoder.AutoencoderKL
|
36 |
-
]
|
37 |
-
|
38 |
-
|
39 |
-
class BrushNetLoader:
|
40 |
-
@classmethod
|
41 |
-
def INPUT_TYPES(self):
|
42 |
-
self.inpaint_files = get_files_with_extension('inpaint')
|
43 |
-
return {"required":
|
44 |
-
{
|
45 |
-
"brushnet": ([file for file in self.inpaint_files], ),
|
46 |
-
"dtype": (['float16', 'bfloat16', 'float32', 'float64'], ),
|
47 |
-
},
|
48 |
-
}
|
49 |
-
|
50 |
-
CATEGORY = "inpaint"
|
51 |
-
RETURN_TYPES = ("BRMODEL",)
|
52 |
-
RETURN_NAMES = ("brushnet",)
|
53 |
-
|
54 |
-
FUNCTION = "brushnet_loading"
|
55 |
-
|
56 |
-
def brushnet_loading(self, brushnet, dtype):
|
57 |
-
brushnet_file = os.path.join(self.inpaint_files[brushnet], brushnet)
|
58 |
-
print('BrushNet model file:', brushnet_file)
|
59 |
-
is_SDXL = False
|
60 |
-
is_PP = False
|
61 |
-
sd = comfy.utils.load_torch_file(brushnet_file)
|
62 |
-
brushnet_down_block, brushnet_mid_block, brushnet_up_block, keys = brushnet_blocks(sd)
|
63 |
-
del sd
|
64 |
-
if brushnet_down_block == 24 and brushnet_mid_block == 2 and brushnet_up_block == 30:
|
65 |
-
is_SDXL = False
|
66 |
-
if keys == 322:
|
67 |
-
is_PP = False
|
68 |
-
print('BrushNet model type: SD1.5')
|
69 |
-
else:
|
70 |
-
is_PP = True
|
71 |
-
print('PowerPaint model type: SD1.5')
|
72 |
-
elif brushnet_down_block == 18 and brushnet_mid_block == 2 and brushnet_up_block == 22:
|
73 |
-
print('BrushNet model type: Loading SDXL')
|
74 |
-
is_SDXL = True
|
75 |
-
is_PP = False
|
76 |
-
else:
|
77 |
-
raise Exception("Unknown BrushNet model")
|
78 |
-
|
79 |
-
with init_empty_weights():
|
80 |
-
if is_SDXL:
|
81 |
-
brushnet_config = BrushNetModel.load_config(brushnet_xl_config_file)
|
82 |
-
brushnet_model = BrushNetModel.from_config(brushnet_config)
|
83 |
-
elif is_PP:
|
84 |
-
brushnet_config = PowerPaintModel.load_config(powerpaint_config_file)
|
85 |
-
brushnet_model = PowerPaintModel.from_config(brushnet_config)
|
86 |
-
else:
|
87 |
-
brushnet_config = BrushNetModel.load_config(brushnet_config_file)
|
88 |
-
brushnet_model = BrushNetModel.from_config(brushnet_config)
|
89 |
-
|
90 |
-
if is_PP:
|
91 |
-
print("PowerPaint model file:", brushnet_file)
|
92 |
-
else:
|
93 |
-
print("BrushNet model file:", brushnet_file)
|
94 |
-
|
95 |
-
if dtype == 'float16':
|
96 |
-
torch_dtype = torch.float16
|
97 |
-
elif dtype == 'bfloat16':
|
98 |
-
torch_dtype = torch.bfloat16
|
99 |
-
elif dtype == 'float32':
|
100 |
-
torch_dtype = torch.float32
|
101 |
-
else:
|
102 |
-
torch_dtype = torch.float64
|
103 |
-
|
104 |
-
brushnet_model = load_checkpoint_and_dispatch(
|
105 |
-
brushnet_model,
|
106 |
-
brushnet_file,
|
107 |
-
device_map="sequential",
|
108 |
-
max_memory=None,
|
109 |
-
offload_folder=None,
|
110 |
-
offload_state_dict=False,
|
111 |
-
dtype=torch_dtype,
|
112 |
-
force_hooks=False,
|
113 |
-
)
|
114 |
-
|
115 |
-
if is_PP:
|
116 |
-
print("PowerPaint model is loaded")
|
117 |
-
elif is_SDXL:
|
118 |
-
print("BrushNet SDXL model is loaded")
|
119 |
-
else:
|
120 |
-
print("BrushNet SD1.5 model is loaded")
|
121 |
-
|
122 |
-
return ({"brushnet": brushnet_model, "SDXL": is_SDXL, "PP": is_PP, "dtype": torch_dtype}, )
|
123 |
-
|
124 |
-
|
125 |
-
class PowerPaintCLIPLoader:
|
126 |
-
|
127 |
-
@classmethod
|
128 |
-
def INPUT_TYPES(self):
|
129 |
-
self.inpaint_files = get_files_with_extension('inpaint', ['.bin'])
|
130 |
-
self.clip_files = get_files_with_extension('clip')
|
131 |
-
return {"required":
|
132 |
-
{
|
133 |
-
"base": ([file for file in self.clip_files], ),
|
134 |
-
"powerpaint": ([file for file in self.inpaint_files], ),
|
135 |
-
},
|
136 |
-
}
|
137 |
-
|
138 |
-
CATEGORY = "inpaint"
|
139 |
-
RETURN_TYPES = ("CLIP",)
|
140 |
-
RETURN_NAMES = ("clip",)
|
141 |
-
|
142 |
-
FUNCTION = "ppclip_loading"
|
143 |
-
|
144 |
-
def ppclip_loading(self, base, powerpaint):
|
145 |
-
base_CLIP_file = os.path.join(self.clip_files[base], base)
|
146 |
-
pp_CLIP_file = os.path.join(self.inpaint_files[powerpaint], powerpaint)
|
147 |
-
|
148 |
-
pp_clip = comfy.sd.load_clip(ckpt_paths=[base_CLIP_file])
|
149 |
-
|
150 |
-
print('PowerPaint base CLIP file: ', base_CLIP_file)
|
151 |
-
|
152 |
-
pp_tokenizer = TokenizerWrapper(pp_clip.tokenizer.clip_l.tokenizer)
|
153 |
-
pp_text_encoder = pp_clip.patcher.model.clip_l.transformer
|
154 |
-
|
155 |
-
add_tokens(
|
156 |
-
tokenizer = pp_tokenizer,
|
157 |
-
text_encoder = pp_text_encoder,
|
158 |
-
placeholder_tokens = ["P_ctxt", "P_shape", "P_obj"],
|
159 |
-
initialize_tokens = ["a", "a", "a"],
|
160 |
-
num_vectors_per_token = 10,
|
161 |
-
)
|
162 |
-
|
163 |
-
pp_text_encoder.load_state_dict(comfy.utils.load_torch_file(pp_CLIP_file), strict=False)
|
164 |
-
|
165 |
-
print('PowerPaint CLIP file: ', pp_CLIP_file)
|
166 |
-
|
167 |
-
pp_clip.tokenizer.clip_l.tokenizer = pp_tokenizer
|
168 |
-
pp_clip.patcher.model.clip_l.transformer = pp_text_encoder
|
169 |
-
|
170 |
-
return (pp_clip,)
|
171 |
-
|
172 |
-
|
173 |
-
class PowerPaint:
|
174 |
-
|
175 |
-
@classmethod
|
176 |
-
def INPUT_TYPES(s):
|
177 |
-
return {"required":
|
178 |
-
{
|
179 |
-
"model": ("MODEL",),
|
180 |
-
"vae": ("VAE", ),
|
181 |
-
"image": ("IMAGE",),
|
182 |
-
"mask": ("MASK",),
|
183 |
-
"powerpaint": ("BRMODEL", ),
|
184 |
-
"clip": ("CLIP", ),
|
185 |
-
"positive": ("CONDITIONING", ),
|
186 |
-
"negative": ("CONDITIONING", ),
|
187 |
-
"fitting" : ("FLOAT", {"default": 1.0, "min": 0.3, "max": 1.0}),
|
188 |
-
"function": (['text guided', 'shape guided', 'object removal', 'context aware', 'image outpainting'], ),
|
189 |
-
"scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
|
190 |
-
"start_at": ("INT", {"default": 0, "min": 0, "max": 10000}),
|
191 |
-
"end_at": ("INT", {"default": 10000, "min": 0, "max": 10000}),
|
192 |
-
"save_memory": (['none', 'auto', 'max'], ),
|
193 |
-
},
|
194 |
-
}
|
195 |
-
|
196 |
-
CATEGORY = "inpaint"
|
197 |
-
RETURN_TYPES = ("MODEL","CONDITIONING","CONDITIONING","LATENT",)
|
198 |
-
RETURN_NAMES = ("model","positive","negative","latent",)
|
199 |
-
|
200 |
-
FUNCTION = "model_update"
|
201 |
-
|
202 |
-
def model_update(self, model, vae, image, mask, powerpaint, clip, positive, negative, fitting, function, scale, start_at, end_at, save_memory):
|
203 |
-
|
204 |
-
is_SDXL, is_PP = check_compatibilty(model, powerpaint)
|
205 |
-
if not is_PP:
|
206 |
-
raise Exception("BrushNet model was loaded, please use BrushNet node")
|
207 |
-
|
208 |
-
# Make a copy of the model so that we're not patching it everywhere in the workflow.
|
209 |
-
model = model.clone()
|
210 |
-
|
211 |
-
# prepare image and mask
|
212 |
-
# no batches for original image and mask
|
213 |
-
masked_image, mask = prepare_image(image, mask)
|
214 |
-
|
215 |
-
batch = masked_image.shape[0]
|
216 |
-
#width = masked_image.shape[2]
|
217 |
-
#height = masked_image.shape[1]
|
218 |
-
|
219 |
-
if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format, 'scale_factor'):
|
220 |
-
scaling_factor = model.model.model_config.latent_format.scale_factor
|
221 |
-
else:
|
222 |
-
scaling_factor = sd15_scaling_factor
|
223 |
-
|
224 |
-
torch_dtype = powerpaint['dtype']
|
225 |
-
|
226 |
-
# prepare conditioning latents
|
227 |
-
conditioning_latents = get_image_latents(masked_image, mask, vae, scaling_factor)
|
228 |
-
conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
|
229 |
-
conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
|
230 |
-
|
231 |
-
# prepare embeddings
|
232 |
-
|
233 |
-
if function == "object removal":
|
234 |
-
promptA = "P_ctxt"
|
235 |
-
promptB = "P_ctxt"
|
236 |
-
negative_promptA = "P_obj"
|
237 |
-
negative_promptB = "P_obj"
|
238 |
-
print('You should add to positive prompt: "empty scene blur"')
|
239 |
-
#positive = positive + " empty scene blur"
|
240 |
-
elif function == "context aware":
|
241 |
-
promptA = "P_ctxt"
|
242 |
-
promptB = "P_ctxt"
|
243 |
-
negative_promptA = ""
|
244 |
-
negative_promptB = ""
|
245 |
-
#positive = positive + " empty scene"
|
246 |
-
print('You should add to positive prompt: "empty scene"')
|
247 |
-
elif function == "shape guided":
|
248 |
-
promptA = "P_shape"
|
249 |
-
promptB = "P_ctxt"
|
250 |
-
negative_promptA = "P_shape"
|
251 |
-
negative_promptB = "P_ctxt"
|
252 |
-
elif function == "image outpainting":
|
253 |
-
promptA = "P_ctxt"
|
254 |
-
promptB = "P_ctxt"
|
255 |
-
negative_promptA = "P_obj"
|
256 |
-
negative_promptB = "P_obj"
|
257 |
-
#positive = positive + " empty scene"
|
258 |
-
print('You should add to positive prompt: "empty scene"')
|
259 |
-
else:
|
260 |
-
promptA = "P_obj"
|
261 |
-
promptB = "P_obj"
|
262 |
-
negative_promptA = "P_obj"
|
263 |
-
negative_promptB = "P_obj"
|
264 |
-
|
265 |
-
tokens = clip.tokenize(promptA)
|
266 |
-
prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False)
|
267 |
-
|
268 |
-
tokens = clip.tokenize(negative_promptA)
|
269 |
-
negative_prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False)
|
270 |
-
|
271 |
-
tokens = clip.tokenize(promptB)
|
272 |
-
prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False)
|
273 |
-
|
274 |
-
tokens = clip.tokenize(negative_promptB)
|
275 |
-
negative_prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False)
|
276 |
-
|
277 |
-
prompt_embeds_pp = (prompt_embedsA * fitting + (1.0 - fitting) * prompt_embedsB).to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
|
278 |
-
negative_prompt_embeds_pp = (negative_prompt_embedsA * fitting + (1.0 - fitting) * negative_prompt_embedsB).to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
|
279 |
-
|
280 |
-
# unload vae and CLIPs
|
281 |
-
del vae
|
282 |
-
del clip
|
283 |
-
for loaded_model in comfy.model_management.current_loaded_models:
|
284 |
-
if type(loaded_model.model.model) in ModelsToUnload:
|
285 |
-
comfy.model_management.current_loaded_models.remove(loaded_model)
|
286 |
-
loaded_model.model_unload()
|
287 |
-
del loaded_model
|
288 |
-
|
289 |
-
# apply patch to model
|
290 |
-
|
291 |
-
brushnet_conditioning_scale = scale
|
292 |
-
control_guidance_start = start_at
|
293 |
-
control_guidance_end = end_at
|
294 |
-
|
295 |
-
if save_memory != 'none':
|
296 |
-
powerpaint['brushnet'].set_attention_slice(save_memory)
|
297 |
-
|
298 |
-
add_brushnet_patch(model,
|
299 |
-
powerpaint['brushnet'],
|
300 |
-
torch_dtype,
|
301 |
-
conditioning_latents,
|
302 |
-
(brushnet_conditioning_scale, control_guidance_start, control_guidance_end),
|
303 |
-
negative_prompt_embeds_pp, prompt_embeds_pp,
|
304 |
-
None, None, None,
|
305 |
-
False)
|
306 |
-
|
307 |
-
latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]], device=powerpaint['brushnet'].device)
|
308 |
-
|
309 |
-
return (model, positive, negative, {"samples":latent},)
|
310 |
-
|
311 |
-
|
312 |
-
class BrushNet:
|
313 |
-
|
314 |
-
@classmethod
|
315 |
-
def INPUT_TYPES(s):
|
316 |
-
return {"required":
|
317 |
-
{
|
318 |
-
"model": ("MODEL",),
|
319 |
-
"vae": ("VAE", ),
|
320 |
-
"image": ("IMAGE",),
|
321 |
-
"mask": ("MASK",),
|
322 |
-
"brushnet": ("BRMODEL", ),
|
323 |
-
"positive": ("CONDITIONING", ),
|
324 |
-
"negative": ("CONDITIONING", ),
|
325 |
-
"scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
|
326 |
-
"start_at": ("INT", {"default": 0, "min": 0, "max": 10000}),
|
327 |
-
"end_at": ("INT", {"default": 10000, "min": 0, "max": 10000}),
|
328 |
-
},
|
329 |
-
}
|
330 |
-
|
331 |
-
CATEGORY = "inpaint"
|
332 |
-
RETURN_TYPES = ("MODEL","CONDITIONING","CONDITIONING","LATENT",)
|
333 |
-
RETURN_NAMES = ("model","positive","negative","latent",)
|
334 |
-
|
335 |
-
FUNCTION = "model_update"
|
336 |
-
|
337 |
-
def model_update(self, model, vae, image, mask, brushnet, positive, negative, scale, start_at, end_at):
|
338 |
-
|
339 |
-
is_SDXL, is_PP = check_compatibilty(model, brushnet)
|
340 |
-
|
341 |
-
if is_PP:
|
342 |
-
raise Exception("PowerPaint model was loaded, please use PowerPaint node")
|
343 |
-
|
344 |
-
# Make a copy of the model so that we're not patching it everywhere in the workflow.
|
345 |
-
model = model.clone()
|
346 |
-
|
347 |
-
# prepare image and mask
|
348 |
-
# no batches for original image and mask
|
349 |
-
masked_image, mask = prepare_image(image, mask)
|
350 |
-
|
351 |
-
batch = masked_image.shape[0]
|
352 |
-
width = masked_image.shape[2]
|
353 |
-
height = masked_image.shape[1]
|
354 |
-
|
355 |
-
if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format, 'scale_factor'):
|
356 |
-
scaling_factor = model.model.model_config.latent_format.scale_factor
|
357 |
-
elif is_SDXL:
|
358 |
-
scaling_factor = sdxl_scaling_factor
|
359 |
-
else:
|
360 |
-
scaling_factor = sd15_scaling_factor
|
361 |
-
|
362 |
-
torch_dtype = brushnet['dtype']
|
363 |
-
|
364 |
-
# prepare conditioning latents
|
365 |
-
conditioning_latents = get_image_latents(masked_image, mask, vae, scaling_factor)
|
366 |
-
conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
|
367 |
-
conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
|
368 |
-
|
369 |
-
# unload vae
|
370 |
-
del vae
|
371 |
-
for loaded_model in comfy.model_management.current_loaded_models:
|
372 |
-
if type(loaded_model.model.model) in ModelsToUnload:
|
373 |
-
comfy.model_management.current_loaded_models.remove(loaded_model)
|
374 |
-
loaded_model.model_unload()
|
375 |
-
del loaded_model
|
376 |
-
|
377 |
-
# prepare embeddings
|
378 |
-
|
379 |
-
prompt_embeds = positive[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
|
380 |
-
negative_prompt_embeds = negative[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
|
381 |
-
|
382 |
-
max_tokens = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
|
383 |
-
if prompt_embeds.shape[1] < max_tokens:
|
384 |
-
multiplier = max_tokens // 77 - prompt_embeds.shape[1] // 77
|
385 |
-
prompt_embeds = torch.concat([prompt_embeds] + [prompt_embeds[:,-77:,:]] * multiplier, dim=1)
|
386 |
-
print('BrushNet: negative prompt more than 75 tokens:', negative_prompt_embeds.shape, 'multiplying prompt_embeds')
|
387 |
-
if negative_prompt_embeds.shape[1] < max_tokens:
|
388 |
-
multiplier = max_tokens // 77 - negative_prompt_embeds.shape[1] // 77
|
389 |
-
negative_prompt_embeds = torch.concat([negative_prompt_embeds] + [negative_prompt_embeds[:,-77:,:]] * multiplier, dim=1)
|
390 |
-
print('BrushNet: positive prompt more than 75 tokens:', prompt_embeds.shape, 'multiplying negative_prompt_embeds')
|
391 |
-
|
392 |
-
if len(positive[0]) > 1 and 'pooled_output' in positive[0][1] and positive[0][1]['pooled_output'] is not None:
|
393 |
-
pooled_prompt_embeds = positive[0][1]['pooled_output'].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
|
394 |
-
else:
|
395 |
-
print('BrushNet: positive conditioning has not pooled_output')
|
396 |
-
if is_SDXL:
|
397 |
-
print('BrushNet will not produce correct results')
|
398 |
-
pooled_prompt_embeds = torch.empty([2, 1280], device=brushnet['brushnet'].device).to(dtype=torch_dtype)
|
399 |
-
|
400 |
-
if len(negative[0]) > 1 and 'pooled_output' in negative[0][1] and negative[0][1]['pooled_output'] is not None:
|
401 |
-
negative_pooled_prompt_embeds = negative[0][1]['pooled_output'].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
|
402 |
-
else:
|
403 |
-
print('BrushNet: negative conditioning has not pooled_output')
|
404 |
-
if is_SDXL:
|
405 |
-
print('BrushNet will not produce correct results')
|
406 |
-
negative_pooled_prompt_embeds = torch.empty([1, pooled_prompt_embeds.shape[1]], device=brushnet['brushnet'].device).to(dtype=torch_dtype)
|
407 |
-
|
408 |
-
time_ids = torch.FloatTensor([[height, width, 0., 0., height, width]]).to(dtype=torch_dtype).to(brushnet['brushnet'].device)
|
409 |
-
|
410 |
-
if not is_SDXL:
|
411 |
-
pooled_prompt_embeds = None
|
412 |
-
negative_pooled_prompt_embeds = None
|
413 |
-
time_ids = None
|
414 |
-
|
415 |
-
# apply patch to model
|
416 |
-
|
417 |
-
brushnet_conditioning_scale = scale
|
418 |
-
control_guidance_start = start_at
|
419 |
-
control_guidance_end = end_at
|
420 |
-
|
421 |
-
add_brushnet_patch(model,
|
422 |
-
brushnet['brushnet'],
|
423 |
-
torch_dtype,
|
424 |
-
conditioning_latents,
|
425 |
-
(brushnet_conditioning_scale, control_guidance_start, control_guidance_end),
|
426 |
-
prompt_embeds, negative_prompt_embeds,
|
427 |
-
pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids,
|
428 |
-
False)
|
429 |
-
|
430 |
-
latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]], device=brushnet['brushnet'].device)
|
431 |
-
|
432 |
-
return (model, positive, negative, {"samples":latent},)
|
433 |
-
|
434 |
-
|
435 |
-
class BlendInpaint:
|
436 |
-
|
437 |
-
@classmethod
|
438 |
-
def INPUT_TYPES(s):
|
439 |
-
return {"required":
|
440 |
-
{
|
441 |
-
"inpaint": ("IMAGE",),
|
442 |
-
"original": ("IMAGE",),
|
443 |
-
"mask": ("MASK",),
|
444 |
-
"kernel": ("INT", {"default": 10, "min": 1, "max": 1000}),
|
445 |
-
"sigma": ("FLOAT", {"default": 10.0, "min": 0.01, "max": 1000}),
|
446 |
-
},
|
447 |
-
"optional":
|
448 |
-
{
|
449 |
-
"origin": ("VECTOR",),
|
450 |
-
},
|
451 |
-
}
|
452 |
-
|
453 |
-
CATEGORY = "inpaint"
|
454 |
-
RETURN_TYPES = ("IMAGE","MASK",)
|
455 |
-
RETURN_NAMES = ("image","MASK",)
|
456 |
-
|
457 |
-
FUNCTION = "blend_inpaint"
|
458 |
-
|
459 |
-
def blend_inpaint(self, inpaint: torch.Tensor, original: torch.Tensor, mask, kernel: int, sigma:int, origin=None) -> Tuple[torch.Tensor]:
|
460 |
-
|
461 |
-
original, mask = check_image_mask(original, mask, 'Blend Inpaint')
|
462 |
-
|
463 |
-
if len(inpaint.shape) < 4:
|
464 |
-
# image tensor shape should be [B, H, W, C], but batch somehow is missing
|
465 |
-
inpaint = inpaint[None,:,:,:]
|
466 |
-
|
467 |
-
if inpaint.shape[0] < original.shape[0]:
|
468 |
-
print("Blend Inpaint gets batch of original images (%d) but only (%d) inpaint images" % (original.shape[0], inpaint.shape[0]))
|
469 |
-
original= original[:inpaint.shape[0],:,:]
|
470 |
-
mask = mask[:inpaint.shape[0],:,:]
|
471 |
-
|
472 |
-
if inpaint.shape[0] > original.shape[0]:
|
473 |
-
# batch over inpaint
|
474 |
-
count = 0
|
475 |
-
original_list = []
|
476 |
-
mask_list = []
|
477 |
-
origin_list = []
|
478 |
-
while (count < inpaint.shape[0]):
|
479 |
-
for i in range(original.shape[0]):
|
480 |
-
original_list.append(original[i][None,:,:,:])
|
481 |
-
mask_list.append(mask[i][None,:,:])
|
482 |
-
if origin is not None:
|
483 |
-
origin_list.append(origin[i][None,:])
|
484 |
-
count += 1
|
485 |
-
if count >= inpaint.shape[0]:
|
486 |
-
break
|
487 |
-
original = torch.concat(original_list, dim=0)
|
488 |
-
mask = torch.concat(mask_list, dim=0)
|
489 |
-
if origin is not None:
|
490 |
-
origin = torch.concat(origin_list, dim=0)
|
491 |
-
|
492 |
-
if kernel % 2 == 0:
|
493 |
-
kernel += 1
|
494 |
-
transform = T.GaussianBlur(kernel_size=(kernel, kernel), sigma=(sigma, sigma))
|
495 |
-
|
496 |
-
ret = []
|
497 |
-
blurred = []
|
498 |
-
for i in range(inpaint.shape[0]):
|
499 |
-
if origin is None:
|
500 |
-
blurred_mask = transform(mask[i][None,None,:,:]).to(original.device).to(original.dtype)
|
501 |
-
blurred.append(blurred_mask[0])
|
502 |
-
|
503 |
-
result = torch.nn.functional.interpolate(
|
504 |
-
inpaint[i][None,:,:,:].permute(0, 3, 1, 2),
|
505 |
-
size=(
|
506 |
-
original[i].shape[0],
|
507 |
-
original[i].shape[1],
|
508 |
-
)
|
509 |
-
).permute(0, 2, 3, 1).to(original.device).to(original.dtype)
|
510 |
-
else:
|
511 |
-
# got mask from CutForInpaint
|
512 |
-
height, width, _ = original[i].shape
|
513 |
-
x0 = origin[i][0].item()
|
514 |
-
y0 = origin[i][1].item()
|
515 |
-
|
516 |
-
if mask[i].shape[0] < height or mask[i].shape[1] < width:
|
517 |
-
padded_mask = F.pad(input=mask[i], pad=(x0, width-x0-mask[i].shape[1],
|
518 |
-
y0, height-y0-mask[i].shape[0]), mode='constant', value=0)
|
519 |
-
else:
|
520 |
-
padded_mask = mask[i]
|
521 |
-
blurred_mask = transform(padded_mask[None,None,:,:]).to(original.device).to(original.dtype)
|
522 |
-
blurred.append(blurred_mask[0][0])
|
523 |
-
|
524 |
-
result = F.pad(input=inpaint[i], pad=(0, 0, x0, width-x0-inpaint[i].shape[1],
|
525 |
-
y0, height-y0-inpaint[i].shape[0]), mode='constant', value=0)
|
526 |
-
result = result[None,:,:,:].to(original.device).to(original.dtype)
|
527 |
-
|
528 |
-
ret.append(original[i] * (1.0 - blurred_mask[0][0][:,:,None]) + result[0] * blurred_mask[0][0][:,:,None])
|
529 |
-
|
530 |
-
return (torch.stack(ret), torch.stack(blurred), )
|
531 |
-
|
532 |
-
|
533 |
-
class CutForInpaint:
|
534 |
-
|
535 |
-
@classmethod
|
536 |
-
def INPUT_TYPES(s):
|
537 |
-
return {"required":
|
538 |
-
{
|
539 |
-
"image": ("IMAGE",),
|
540 |
-
"mask": ("MASK",),
|
541 |
-
"width": ("INT", {"default": 512, "min": 64, "max": 2048}),
|
542 |
-
"height": ("INT", {"default": 512, "min": 64, "max": 2048}),
|
543 |
-
},
|
544 |
-
}
|
545 |
-
|
546 |
-
CATEGORY = "inpaint"
|
547 |
-
RETURN_TYPES = ("IMAGE","MASK","VECTOR",)
|
548 |
-
RETURN_NAMES = ("image","mask","origin",)
|
549 |
-
|
550 |
-
FUNCTION = "cut_for_inpaint"
|
551 |
-
|
552 |
-
def cut_for_inpaint(self, image: torch.Tensor, mask: torch.Tensor, width: int, height: int):
|
553 |
-
|
554 |
-
image, mask = check_image_mask(image, mask, 'BrushNet')
|
555 |
-
|
556 |
-
ret = []
|
557 |
-
msk = []
|
558 |
-
org = []
|
559 |
-
for i in range(image.shape[0]):
|
560 |
-
x0, y0, w, h = cut_with_mask(mask[i], width, height)
|
561 |
-
ret.append((image[i][y0:y0+h,x0:x0+w,:]))
|
562 |
-
msk.append((mask[i][y0:y0+h,x0:x0+w]))
|
563 |
-
org.append(torch.IntTensor([x0,y0]))
|
564 |
-
|
565 |
-
return (torch.stack(ret), torch.stack(msk), torch.stack(org), )
|
566 |
-
|
567 |
-
|
568 |
-
#### Utility function
|
569 |
-
|
570 |
-
def get_files_with_extension(folder_name, extension=['.safetensors']):
|
571 |
-
|
572 |
-
try:
|
573 |
-
folders = folder_paths.get_folder_paths(folder_name)
|
574 |
-
except:
|
575 |
-
folders = []
|
576 |
-
|
577 |
-
if not folders:
|
578 |
-
folders = [os.path.join(folder_paths.models_dir, folder_name)]
|
579 |
-
if not os.path.isdir(folders[0]):
|
580 |
-
folders = [os.path.join(folder_paths.base_path, folder_name)]
|
581 |
-
if not os.path.isdir(folders[0]):
|
582 |
-
return {}
|
583 |
-
|
584 |
-
filtered_folders = []
|
585 |
-
for x in folders:
|
586 |
-
if not os.path.isdir(x):
|
587 |
-
continue
|
588 |
-
the_same = False
|
589 |
-
for y in filtered_folders:
|
590 |
-
if os.path.samefile(x, y):
|
591 |
-
the_same = True
|
592 |
-
break
|
593 |
-
if not the_same:
|
594 |
-
filtered_folders.append(x)
|
595 |
-
|
596 |
-
if not filtered_folders:
|
597 |
-
return {}
|
598 |
-
|
599 |
-
output = {}
|
600 |
-
for x in filtered_folders:
|
601 |
-
files, folders_all = folder_paths.recursive_search(x, excluded_dir_names=[".git"])
|
602 |
-
filtered_files = folder_paths.filter_files_extensions(files, extension)
|
603 |
-
|
604 |
-
for f in filtered_files:
|
605 |
-
output[f] = x
|
606 |
-
|
607 |
-
return output
|
608 |
-
|
609 |
-
|
610 |
-
# get blocks from state_dict so we could know which model it is
|
611 |
-
def brushnet_blocks(sd):
|
612 |
-
brushnet_down_block = 0
|
613 |
-
brushnet_mid_block = 0
|
614 |
-
brushnet_up_block = 0
|
615 |
-
for key in sd:
|
616 |
-
if 'brushnet_down_block' in key:
|
617 |
-
brushnet_down_block += 1
|
618 |
-
if 'brushnet_mid_block' in key:
|
619 |
-
brushnet_mid_block += 1
|
620 |
-
if 'brushnet_up_block' in key:
|
621 |
-
brushnet_up_block += 1
|
622 |
-
return (brushnet_down_block, brushnet_mid_block, brushnet_up_block, len(sd))
|
623 |
-
|
624 |
-
|
625 |
-
# Check models compatibility
|
626 |
-
def check_compatibilty(model, brushnet):
|
627 |
-
is_SDXL = False
|
628 |
-
is_PP = False
|
629 |
-
if isinstance(model.model.model_config, comfy.supported_models.SD15):
|
630 |
-
print('Base model type: SD1.5')
|
631 |
-
is_SDXL = False
|
632 |
-
if brushnet["SDXL"]:
|
633 |
-
raise Exception("Base model is SD15, but BrushNet is SDXL type")
|
634 |
-
if brushnet["PP"]:
|
635 |
-
is_PP = True
|
636 |
-
elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
|
637 |
-
print('Base model type: SDXL')
|
638 |
-
is_SDXL = True
|
639 |
-
if not brushnet["SDXL"]:
|
640 |
-
raise Exception("Base model is SDXL, but BrushNet is SD15 type")
|
641 |
-
else:
|
642 |
-
print('Base model type: ', type(model.model.model_config))
|
643 |
-
raise Exception("Unsupported model type: " + str(type(model.model.model_config)))
|
644 |
-
|
645 |
-
return (is_SDXL, is_PP)
|
646 |
-
|
647 |
-
|
648 |
-
def check_image_mask(image, mask, name):
|
649 |
-
if len(image.shape) < 4:
|
650 |
-
# image tensor shape should be [B, H, W, C], but batch somehow is missing
|
651 |
-
image = image[None,:,:,:]
|
652 |
-
|
653 |
-
if len(mask.shape) > 3:
|
654 |
-
# mask tensor shape should be [B, H, W] but we get [B, H, W, C], image may be?
|
655 |
-
# take first mask, red channel
|
656 |
-
mask = (mask[:,:,:,0])[:,:,:]
|
657 |
-
elif len(mask.shape) < 3:
|
658 |
-
# mask tensor shape should be [B, H, W] but batch somehow is missing
|
659 |
-
mask = mask[None,:,:]
|
660 |
-
|
661 |
-
if image.shape[0] > mask.shape[0]:
|
662 |
-
print(name, "gets batch of images (%d) but only %d masks" % (image.shape[0], mask.shape[0]))
|
663 |
-
if mask.shape[0] == 1:
|
664 |
-
print(name, "will copy the mask to fill batch")
|
665 |
-
mask = torch.cat([mask] * image.shape[0], dim=0)
|
666 |
-
else:
|
667 |
-
print(name, "will add empty masks to fill batch")
|
668 |
-
empty_mask = torch.zeros([image.shape[0] - mask.shape[0], mask.shape[1], mask.shape[2]])
|
669 |
-
mask = torch.cat([mask, empty_mask], dim=0)
|
670 |
-
elif image.shape[0] < mask.shape[0]:
|
671 |
-
print(name, "gets batch of images (%d) but too many (%d) masks" % (image.shape[0], mask.shape[0]))
|
672 |
-
mask = mask[:image.shape[0],:,:]
|
673 |
-
|
674 |
-
return (image, mask)
|
675 |
-
|
676 |
-
|
677 |
-
# Prepare image and mask
|
678 |
-
def prepare_image(image, mask):
|
679 |
-
|
680 |
-
image, mask = check_image_mask(image, mask, 'BrushNet')
|
681 |
-
|
682 |
-
print("BrushNet image.shape =", image.shape, "mask.shape =", mask.shape)
|
683 |
-
|
684 |
-
if mask.shape[2] != image.shape[2] or mask.shape[1] != image.shape[1]:
|
685 |
-
raise Exception("Image and mask should be the same size")
|
686 |
-
|
687 |
-
# As a suggestion of inferno46n2 (https://github.com/nullquant/ComfyUI-BrushNet/issues/64)
|
688 |
-
mask = mask.round()
|
689 |
-
|
690 |
-
masked_image = image * (1.0 - mask[:,:,:,None])
|
691 |
-
|
692 |
-
return (masked_image, mask)
|
693 |
-
|
694 |
-
|
695 |
-
# Get origin of the mask
|
696 |
-
def cut_with_mask(mask, width, height):
|
697 |
-
iy, ix = (mask == 1).nonzero(as_tuple=True)
|
698 |
-
|
699 |
-
h0, w0 = mask.shape
|
700 |
-
|
701 |
-
if iy.numel() == 0:
|
702 |
-
x_c = w0 / 2.0
|
703 |
-
y_c = h0 / 2.0
|
704 |
-
else:
|
705 |
-
x_min = ix.min().item()
|
706 |
-
x_max = ix.max().item()
|
707 |
-
y_min = iy.min().item()
|
708 |
-
y_max = iy.max().item()
|
709 |
-
|
710 |
-
if x_max - x_min > width or y_max - y_min > height:
|
711 |
-
raise Exception("Masked area is bigger than provided dimensions")
|
712 |
-
|
713 |
-
x_c = (x_min + x_max) / 2.0
|
714 |
-
y_c = (y_min + y_max) / 2.0
|
715 |
-
|
716 |
-
width2 = width / 2.0
|
717 |
-
height2 = height / 2.0
|
718 |
-
|
719 |
-
if w0 <= width:
|
720 |
-
x0 = 0
|
721 |
-
w = w0
|
722 |
-
else:
|
723 |
-
x0 = max(0, x_c - width2)
|
724 |
-
w = width
|
725 |
-
if x0 + width > w0:
|
726 |
-
x0 = w0 - width
|
727 |
-
|
728 |
-
if h0 <= height:
|
729 |
-
y0 = 0
|
730 |
-
h = h0
|
731 |
-
else:
|
732 |
-
y0 = max(0, y_c - height2)
|
733 |
-
h = height
|
734 |
-
if y0 + height > h0:
|
735 |
-
y0 = h0 - height
|
736 |
-
|
737 |
-
return (int(x0), int(y0), int(w), int(h))
|
738 |
-
|
739 |
-
|
740 |
-
# Prepare conditioning_latents
|
741 |
-
@torch.inference_mode()
|
742 |
-
def get_image_latents(masked_image, mask, vae, scaling_factor):
|
743 |
-
processed_image = masked_image.to(vae.device)
|
744 |
-
image_latents = vae.encode(processed_image[:,:,:,:3]) * scaling_factor
|
745 |
-
processed_mask = 1. - mask[:,None,:,:]
|
746 |
-
interpolated_mask = torch.nn.functional.interpolate(
|
747 |
-
processed_mask,
|
748 |
-
size=(
|
749 |
-
image_latents.shape[-2],
|
750 |
-
image_latents.shape[-1]
|
751 |
-
)
|
752 |
-
)
|
753 |
-
interpolated_mask = interpolated_mask.to(image_latents.device)
|
754 |
-
|
755 |
-
conditioning_latents = [image_latents, interpolated_mask]
|
756 |
-
|
757 |
-
print('BrushNet CL: image_latents shape =', image_latents.shape, 'interpolated_mask shape =', interpolated_mask.shape)
|
758 |
-
|
759 |
-
return conditioning_latents
|
760 |
-
|
761 |
-
|
762 |
-
# Main function where magic happens
|
763 |
-
@torch.inference_mode()
|
764 |
-
def brushnet_inference(x, timesteps, transformer_options, debug):
|
765 |
-
if 'model_patch' not in transformer_options:
|
766 |
-
print('BrushNet inference: there is no model_patch key in transformer_options')
|
767 |
-
return ([], 0, [])
|
768 |
-
mp = transformer_options['model_patch']
|
769 |
-
if 'brushnet' not in mp:
|
770 |
-
print('BrushNet inference: there is no brushnet key in mdel_patch')
|
771 |
-
return ([], 0, [])
|
772 |
-
bo = mp['brushnet']
|
773 |
-
if 'model' not in bo:
|
774 |
-
print('BrushNet inference: there is no model key in brushnet')
|
775 |
-
return ([], 0, [])
|
776 |
-
brushnet = bo['model']
|
777 |
-
if not (isinstance(brushnet, BrushNetModel) or isinstance(brushnet, PowerPaintModel)):
|
778 |
-
print('BrushNet model is not a BrushNetModel class')
|
779 |
-
return ([], 0, [])
|
780 |
-
|
781 |
-
torch_dtype = bo['dtype']
|
782 |
-
cl_list = bo['latents']
|
783 |
-
brushnet_conditioning_scale, control_guidance_start, control_guidance_end = bo['controls']
|
784 |
-
pe = bo['prompt_embeds']
|
785 |
-
npe = bo['negative_prompt_embeds']
|
786 |
-
ppe, nppe, time_ids = bo['add_embeds']
|
787 |
-
|
788 |
-
#do_classifier_free_guidance = mp['free_guidance']
|
789 |
-
do_classifier_free_guidance = len(transformer_options['cond_or_uncond']) > 1
|
790 |
-
|
791 |
-
x = x.detach().clone()
|
792 |
-
x = x.to(torch_dtype).to(brushnet.device)
|
793 |
-
|
794 |
-
timesteps = timesteps.detach().clone()
|
795 |
-
timesteps = timesteps.to(torch_dtype).to(brushnet.device)
|
796 |
-
|
797 |
-
total_steps = mp['total_steps']
|
798 |
-
step = mp['step']
|
799 |
-
|
800 |
-
added_cond_kwargs = {}
|
801 |
-
|
802 |
-
if do_classifier_free_guidance and step == 0:
|
803 |
-
print('BrushNet inference: do_classifier_free_guidance is True')
|
804 |
-
|
805 |
-
sub_idx = None
|
806 |
-
if 'ad_params' in transformer_options and 'sub_idxs' in transformer_options['ad_params']:
|
807 |
-
sub_idx = transformer_options['ad_params']['sub_idxs']
|
808 |
-
|
809 |
-
# we have batch input images
|
810 |
-
batch = cl_list[0].shape[0]
|
811 |
-
# we have incoming latents
|
812 |
-
latents_incoming = x.shape[0]
|
813 |
-
# and we already got some
|
814 |
-
latents_got = bo['latent_id']
|
815 |
-
if step == 0 or batch > 1:
|
816 |
-
print('BrushNet inference, step = %d: image batch = %d, got %d latents, starting from %d' \
|
817 |
-
% (step, batch, latents_incoming, latents_got))
|
818 |
-
|
819 |
-
image_latents = []
|
820 |
-
masks = []
|
821 |
-
prompt_embeds = []
|
822 |
-
negative_prompt_embeds = []
|
823 |
-
pooled_prompt_embeds = []
|
824 |
-
negative_pooled_prompt_embeds = []
|
825 |
-
if sub_idx:
|
826 |
-
# AnimateDiff indexes detected
|
827 |
-
if step == 0:
|
828 |
-
print('BrushNet inference: AnimateDiff indexes detected and applied')
|
829 |
-
|
830 |
-
batch = len(sub_idx)
|
831 |
-
|
832 |
-
if do_classifier_free_guidance:
|
833 |
-
for i in sub_idx:
|
834 |
-
image_latents.append(cl_list[0][i][None,:,:,:])
|
835 |
-
masks.append(cl_list[1][i][None,:,:,:])
|
836 |
-
prompt_embeds.append(pe)
|
837 |
-
negative_prompt_embeds.append(npe)
|
838 |
-
pooled_prompt_embeds.append(ppe)
|
839 |
-
negative_pooled_prompt_embeds.append(nppe)
|
840 |
-
for i in sub_idx:
|
841 |
-
image_latents.append(cl_list[0][i][None,:,:,:])
|
842 |
-
masks.append(cl_list[1][i][None,:,:,:])
|
843 |
-
else:
|
844 |
-
for i in sub_idx:
|
845 |
-
image_latents.append(cl_list[0][i][None,:,:,:])
|
846 |
-
masks.append(cl_list[1][i][None,:,:,:])
|
847 |
-
prompt_embeds.append(pe)
|
848 |
-
pooled_prompt_embeds.append(ppe)
|
849 |
-
else:
|
850 |
-
# do_classifier_free_guidance = 2 passes, 1st pass is cond, 2nd is uncond
|
851 |
-
continue_batch = True
|
852 |
-
for i in range(latents_incoming):
|
853 |
-
number = latents_got + i
|
854 |
-
if number < batch:
|
855 |
-
# 1st pass, cond
|
856 |
-
image_latents.append(cl_list[0][number][None,:,:,:])
|
857 |
-
masks.append(cl_list[1][number][None,:,:,:])
|
858 |
-
prompt_embeds.append(pe)
|
859 |
-
pooled_prompt_embeds.append(ppe)
|
860 |
-
elif do_classifier_free_guidance and number < batch * 2:
|
861 |
-
# 2nd pass, uncond
|
862 |
-
image_latents.append(cl_list[0][number-batch][None,:,:,:])
|
863 |
-
masks.append(cl_list[1][number-batch][None,:,:,:])
|
864 |
-
negative_prompt_embeds.append(npe)
|
865 |
-
negative_pooled_prompt_embeds.append(nppe)
|
866 |
-
else:
|
867 |
-
# latent batch
|
868 |
-
image_latents.append(cl_list[0][0][None,:,:,:])
|
869 |
-
masks.append(cl_list[1][0][None,:,:,:])
|
870 |
-
prompt_embeds.append(pe)
|
871 |
-
pooled_prompt_embeds.append(ppe)
|
872 |
-
latents_got = -i
|
873 |
-
continue_batch = False
|
874 |
-
|
875 |
-
if continue_batch:
|
876 |
-
# we don't have full batch yet
|
877 |
-
if do_classifier_free_guidance:
|
878 |
-
if number < batch * 2 - 1:
|
879 |
-
bo['latent_id'] = number + 1
|
880 |
-
else:
|
881 |
-
bo['latent_id'] = 0
|
882 |
-
else:
|
883 |
-
if number < batch - 1:
|
884 |
-
bo['latent_id'] = number + 1
|
885 |
-
else:
|
886 |
-
bo['latent_id'] = 0
|
887 |
-
else:
|
888 |
-
bo['latent_id'] = 0
|
889 |
-
|
890 |
-
cl = []
|
891 |
-
for il, m in zip(image_latents, masks):
|
892 |
-
cl.append(torch.concat([il, m], dim=1))
|
893 |
-
cl2apply = torch.concat(cl, dim=0)
|
894 |
-
|
895 |
-
conditioning_latents = cl2apply.to(torch_dtype).to(brushnet.device)
|
896 |
-
|
897 |
-
# print("BrushNet CL: conditioning_latents shape =", conditioning_latents.shape)
|
898 |
-
# print("BrushNet CL: x shape =", x.shape)
|
899 |
-
|
900 |
-
prompt_embeds.extend(negative_prompt_embeds)
|
901 |
-
prompt_embeds = torch.concat(prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device)
|
902 |
-
|
903 |
-
if ppe is not None:
|
904 |
-
added_cond_kwargs = {}
|
905 |
-
added_cond_kwargs['time_ids'] = torch.concat([time_ids] * latents_incoming, dim = 0).to(torch_dtype).to(brushnet.device)
|
906 |
-
|
907 |
-
pooled_prompt_embeds.extend(negative_pooled_prompt_embeds)
|
908 |
-
pooled_prompt_embeds = torch.concat(pooled_prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device)
|
909 |
-
added_cond_kwargs['text_embeds'] = pooled_prompt_embeds
|
910 |
-
else:
|
911 |
-
added_cond_kwargs = None
|
912 |
-
|
913 |
-
if x.shape[2] != conditioning_latents.shape[2] or x.shape[3] != conditioning_latents.shape[3]:
|
914 |
-
if step == 0:
|
915 |
-
print('BrushNet inference: image', conditioning_latents.shape, 'and latent', x.shape, 'have different size, resizing image')
|
916 |
-
conditioning_latents = torch.nn.functional.interpolate(
|
917 |
-
conditioning_latents, size=(
|
918 |
-
x.shape[2],
|
919 |
-
x.shape[3],
|
920 |
-
), mode='bicubic',
|
921 |
-
).to(torch_dtype).to(brushnet.device)
|
922 |
-
|
923 |
-
if step == 0:
|
924 |
-
print('BrushNet inference: sample', x.shape, ', CL', conditioning_latents.shape, 'dtype', torch_dtype)
|
925 |
-
|
926 |
-
if debug: print('BrushNet: step =', step)
|
927 |
-
|
928 |
-
if step < control_guidance_start or step > control_guidance_end:
|
929 |
-
cond_scale = 0.0
|
930 |
-
else:
|
931 |
-
cond_scale = brushnet_conditioning_scale
|
932 |
-
|
933 |
-
return brushnet(x,
|
934 |
-
encoder_hidden_states=prompt_embeds,
|
935 |
-
brushnet_cond=conditioning_latents,
|
936 |
-
timestep = timesteps,
|
937 |
-
conditioning_scale=cond_scale,
|
938 |
-
guess_mode=False,
|
939 |
-
added_cond_kwargs=added_cond_kwargs,
|
940 |
-
return_dict=False,
|
941 |
-
debug=debug,
|
942 |
-
)
|
943 |
-
|
944 |
-
|
945 |
-
# This is main patch function
|
946 |
-
def add_brushnet_patch(model, brushnet, torch_dtype, conditioning_latents,
|
947 |
-
controls,
|
948 |
-
prompt_embeds, negative_prompt_embeds,
|
949 |
-
pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids,
|
950 |
-
debug):
|
951 |
-
|
952 |
-
is_SDXL = isinstance(model.model.model_config, comfy.supported_models.SDXL)
|
953 |
-
|
954 |
-
if is_SDXL:
|
955 |
-
input_blocks = [[0, comfy.ops.disable_weight_init.Conv2d],
|
956 |
-
[1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
957 |
-
[2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
958 |
-
[3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
|
959 |
-
[4, comfy.ldm.modules.attention.SpatialTransformer],
|
960 |
-
[5, comfy.ldm.modules.attention.SpatialTransformer],
|
961 |
-
[6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
|
962 |
-
[7, comfy.ldm.modules.attention.SpatialTransformer],
|
963 |
-
[8, comfy.ldm.modules.attention.SpatialTransformer]]
|
964 |
-
middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]
|
965 |
-
output_blocks = [[0, comfy.ldm.modules.attention.SpatialTransformer],
|
966 |
-
[1, comfy.ldm.modules.attention.SpatialTransformer],
|
967 |
-
[2, comfy.ldm.modules.attention.SpatialTransformer],
|
968 |
-
[2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
|
969 |
-
[3, comfy.ldm.modules.attention.SpatialTransformer],
|
970 |
-
[4, comfy.ldm.modules.attention.SpatialTransformer],
|
971 |
-
[5, comfy.ldm.modules.attention.SpatialTransformer],
|
972 |
-
[5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
|
973 |
-
[6, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
974 |
-
[7, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
975 |
-
[8, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
|
976 |
-
else:
|
977 |
-
input_blocks = [[0, comfy.ops.disable_weight_init.Conv2d],
|
978 |
-
[1, comfy.ldm.modules.attention.SpatialTransformer],
|
979 |
-
[2, comfy.ldm.modules.attention.SpatialTransformer],
|
980 |
-
[3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
|
981 |
-
[4, comfy.ldm.modules.attention.SpatialTransformer],
|
982 |
-
[5, comfy.ldm.modules.attention.SpatialTransformer],
|
983 |
-
[6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
|
984 |
-
[7, comfy.ldm.modules.attention.SpatialTransformer],
|
985 |
-
[8, comfy.ldm.modules.attention.SpatialTransformer],
|
986 |
-
[9, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
|
987 |
-
[10, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
988 |
-
[11, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
|
989 |
-
middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]
|
990 |
-
output_blocks = [[0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
991 |
-
[1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
992 |
-
[2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
993 |
-
[2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
|
994 |
-
[3, comfy.ldm.modules.attention.SpatialTransformer],
|
995 |
-
[4, comfy.ldm.modules.attention.SpatialTransformer],
|
996 |
-
[5, comfy.ldm.modules.attention.SpatialTransformer],
|
997 |
-
[5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
|
998 |
-
[6, comfy.ldm.modules.attention.SpatialTransformer],
|
999 |
-
[7, comfy.ldm.modules.attention.SpatialTransformer],
|
1000 |
-
[8, comfy.ldm.modules.attention.SpatialTransformer],
|
1001 |
-
[8, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
|
1002 |
-
[9, comfy.ldm.modules.attention.SpatialTransformer],
|
1003 |
-
[10, comfy.ldm.modules.attention.SpatialTransformer],
|
1004 |
-
[11, comfy.ldm.modules.attention.SpatialTransformer]]
|
1005 |
-
|
1006 |
-
def last_layer_index(block, tp):
|
1007 |
-
layer_list = []
|
1008 |
-
for layer in block:
|
1009 |
-
layer_list.append(type(layer))
|
1010 |
-
layer_list.reverse()
|
1011 |
-
if tp not in layer_list:
|
1012 |
-
return -1, layer_list.reverse()
|
1013 |
-
return len(layer_list) - 1 - layer_list.index(tp), layer_list
|
1014 |
-
|
1015 |
-
def brushnet_forward(model, x, timesteps, transformer_options, control):
|
1016 |
-
if 'brushnet' not in transformer_options['model_patch']:
|
1017 |
-
input_samples = []
|
1018 |
-
mid_sample = 0
|
1019 |
-
output_samples = []
|
1020 |
-
else:
|
1021 |
-
# brushnet inference
|
1022 |
-
input_samples, mid_sample, output_samples = brushnet_inference(x, timesteps, transformer_options, debug)
|
1023 |
-
|
1024 |
-
# give additional samples to blocks
|
1025 |
-
for i, tp in input_blocks:
|
1026 |
-
idx, layer_list = last_layer_index(model.input_blocks[i], tp)
|
1027 |
-
if idx < 0:
|
1028 |
-
print("BrushNet can't find", tp, "layer in", i,"input block:", layer_list)
|
1029 |
-
continue
|
1030 |
-
model.input_blocks[i][idx].add_sample_after = input_samples.pop(0) if input_samples else 0
|
1031 |
-
|
1032 |
-
idx, layer_list = last_layer_index(model.middle_block, middle_block[1])
|
1033 |
-
if idx < 0:
|
1034 |
-
print("BrushNet can't find", middle_block[1], "layer in middle block", layer_list)
|
1035 |
-
model.middle_block[idx].add_sample_after = mid_sample
|
1036 |
-
|
1037 |
-
for i, tp in output_blocks:
|
1038 |
-
idx, layer_list = last_layer_index(model.output_blocks[i], tp)
|
1039 |
-
if idx < 0:
|
1040 |
-
print("BrushNet can't find", tp, "layer in", i,"outnput block:", layer_list)
|
1041 |
-
continue
|
1042 |
-
model.output_blocks[i][idx].add_sample_after = output_samples.pop(0) if output_samples else 0
|
1043 |
-
|
1044 |
-
patch_model_function_wrapper(model, brushnet_forward)
|
1045 |
-
|
1046 |
-
to = add_model_patch_option(model)
|
1047 |
-
mp = to['model_patch']
|
1048 |
-
if 'brushnet' not in mp:
|
1049 |
-
mp['brushnet'] = {}
|
1050 |
-
bo = mp['brushnet']
|
1051 |
-
|
1052 |
-
bo['model'] = brushnet
|
1053 |
-
bo['dtype'] = torch_dtype
|
1054 |
-
bo['latents'] = conditioning_latents
|
1055 |
-
bo['controls'] = controls
|
1056 |
-
bo['prompt_embeds'] = prompt_embeds
|
1057 |
-
bo['negative_prompt_embeds'] = negative_prompt_embeds
|
1058 |
-
bo['add_embeds'] = (pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids)
|
1059 |
-
bo['latent_id'] = 0
|
1060 |
-
|
1061 |
-
# patch layers `forward` so we can apply brushnet
|
1062 |
-
def forward_patched_by_brushnet(self, x, *args, **kwargs):
|
1063 |
-
h = self.original_forward(x, *args, **kwargs)
|
1064 |
-
if hasattr(self, 'add_sample_after') and type(self):
|
1065 |
-
to_add = self.add_sample_after
|
1066 |
-
if torch.is_tensor(to_add):
|
1067 |
-
# interpolate due to RAUNet
|
1068 |
-
if h.shape[2] != to_add.shape[2] or h.shape[3] != to_add.shape[3]:
|
1069 |
-
to_add = torch.nn.functional.interpolate(to_add, size=(h.shape[2], h.shape[3]), mode='bicubic')
|
1070 |
-
h += to_add.to(h.dtype).to(h.device)
|
1071 |
-
else:
|
1072 |
-
h += self.add_sample_after
|
1073 |
-
self.add_sample_after = 0
|
1074 |
-
return h
|
1075 |
-
|
1076 |
-
for i, block in enumerate(model.model.diffusion_model.input_blocks):
|
1077 |
-
for j, layer in enumerate(block):
|
1078 |
-
if not hasattr(layer, 'original_forward'):
|
1079 |
-
layer.original_forward = layer.forward
|
1080 |
-
layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
|
1081 |
-
layer.add_sample_after = 0
|
1082 |
-
|
1083 |
-
for j, layer in enumerate(model.model.diffusion_model.middle_block):
|
1084 |
-
if not hasattr(layer, 'original_forward'):
|
1085 |
-
layer.original_forward = layer.forward
|
1086 |
-
layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
|
1087 |
-
layer.add_sample_after = 0
|
1088 |
-
|
1089 |
-
for i, block in enumerate(model.model.diffusion_model.output_blocks):
|
1090 |
-
for j, layer in enumerate(block):
|
1091 |
-
if not hasattr(layer, 'original_forward'):
|
1092 |
-
layer.original_forward = layer.forward
|
1093 |
-
layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
|
1094 |
-
layer.add_sample_after = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/comfy/.DS_Store
DELETED
Binary file (6.15 kB)
|
|
MagicQuill/comfy/checkpoint_pickle.py
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
import pickle
|
2 |
-
|
3 |
-
load = pickle.load
|
4 |
-
|
5 |
-
class Empty:
|
6 |
-
pass
|
7 |
-
|
8 |
-
class Unpickler(pickle.Unpickler):
|
9 |
-
def find_class(self, module, name):
|
10 |
-
#TODO: safe unpickle
|
11 |
-
if module.startswith("pytorch_lightning"):
|
12 |
-
return Empty
|
13 |
-
return super().find_class(module, name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/comfy/cldm/__pycache__/cldm.cpython-310.pyc
DELETED
Binary file (6.11 kB)
|
|
MagicQuill/comfy/cldm/cldm.py
DELETED
@@ -1,313 +0,0 @@
|
|
1 |
-
#taken from: https://github.com/lllyasviel/ControlNet
|
2 |
-
#and modified
|
3 |
-
|
4 |
-
import torch
|
5 |
-
import torch as th
|
6 |
-
import torch.nn as nn
|
7 |
-
|
8 |
-
from ..ldm.modules.diffusionmodules.util import (
|
9 |
-
zero_module,
|
10 |
-
timestep_embedding,
|
11 |
-
)
|
12 |
-
|
13 |
-
from ..ldm.modules.attention import SpatialTransformer
|
14 |
-
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
|
15 |
-
from ..ldm.util import exists
|
16 |
-
import comfy.ops
|
17 |
-
|
18 |
-
class ControlledUnetModel(UNetModel):
|
19 |
-
#implemented in the ldm unet
|
20 |
-
pass
|
21 |
-
|
22 |
-
class ControlNet(nn.Module):
|
23 |
-
def __init__(
|
24 |
-
self,
|
25 |
-
image_size,
|
26 |
-
in_channels,
|
27 |
-
model_channels,
|
28 |
-
hint_channels,
|
29 |
-
num_res_blocks,
|
30 |
-
dropout=0,
|
31 |
-
channel_mult=(1, 2, 4, 8),
|
32 |
-
conv_resample=True,
|
33 |
-
dims=2,
|
34 |
-
num_classes=None,
|
35 |
-
use_checkpoint=False,
|
36 |
-
dtype=torch.float32,
|
37 |
-
num_heads=-1,
|
38 |
-
num_head_channels=-1,
|
39 |
-
num_heads_upsample=-1,
|
40 |
-
use_scale_shift_norm=False,
|
41 |
-
resblock_updown=False,
|
42 |
-
use_new_attention_order=False,
|
43 |
-
use_spatial_transformer=False, # custom transformer support
|
44 |
-
transformer_depth=1, # custom transformer support
|
45 |
-
context_dim=None, # custom transformer support
|
46 |
-
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
47 |
-
legacy=True,
|
48 |
-
disable_self_attentions=None,
|
49 |
-
num_attention_blocks=None,
|
50 |
-
disable_middle_self_attn=False,
|
51 |
-
use_linear_in_transformer=False,
|
52 |
-
adm_in_channels=None,
|
53 |
-
transformer_depth_middle=None,
|
54 |
-
transformer_depth_output=None,
|
55 |
-
attn_precision=None,
|
56 |
-
device=None,
|
57 |
-
operations=comfy.ops.disable_weight_init,
|
58 |
-
**kwargs,
|
59 |
-
):
|
60 |
-
super().__init__()
|
61 |
-
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
|
62 |
-
if use_spatial_transformer:
|
63 |
-
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
64 |
-
|
65 |
-
if context_dim is not None:
|
66 |
-
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
67 |
-
# from omegaconf.listconfig import ListConfig
|
68 |
-
# if type(context_dim) == ListConfig:
|
69 |
-
# context_dim = list(context_dim)
|
70 |
-
|
71 |
-
if num_heads_upsample == -1:
|
72 |
-
num_heads_upsample = num_heads
|
73 |
-
|
74 |
-
if num_heads == -1:
|
75 |
-
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
76 |
-
|
77 |
-
if num_head_channels == -1:
|
78 |
-
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
79 |
-
|
80 |
-
self.dims = dims
|
81 |
-
self.image_size = image_size
|
82 |
-
self.in_channels = in_channels
|
83 |
-
self.model_channels = model_channels
|
84 |
-
|
85 |
-
if isinstance(num_res_blocks, int):
|
86 |
-
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
87 |
-
else:
|
88 |
-
if len(num_res_blocks) != len(channel_mult):
|
89 |
-
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
90 |
-
"as a list/tuple (per-level) with the same length as channel_mult")
|
91 |
-
self.num_res_blocks = num_res_blocks
|
92 |
-
|
93 |
-
if disable_self_attentions is not None:
|
94 |
-
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
95 |
-
assert len(disable_self_attentions) == len(channel_mult)
|
96 |
-
if num_attention_blocks is not None:
|
97 |
-
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
98 |
-
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
99 |
-
|
100 |
-
transformer_depth = transformer_depth[:]
|
101 |
-
|
102 |
-
self.dropout = dropout
|
103 |
-
self.channel_mult = channel_mult
|
104 |
-
self.conv_resample = conv_resample
|
105 |
-
self.num_classes = num_classes
|
106 |
-
self.use_checkpoint = use_checkpoint
|
107 |
-
self.dtype = dtype
|
108 |
-
self.num_heads = num_heads
|
109 |
-
self.num_head_channels = num_head_channels
|
110 |
-
self.num_heads_upsample = num_heads_upsample
|
111 |
-
self.predict_codebook_ids = n_embed is not None
|
112 |
-
|
113 |
-
time_embed_dim = model_channels * 4
|
114 |
-
self.time_embed = nn.Sequential(
|
115 |
-
operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
|
116 |
-
nn.SiLU(),
|
117 |
-
operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
|
118 |
-
)
|
119 |
-
|
120 |
-
if self.num_classes is not None:
|
121 |
-
if isinstance(self.num_classes, int):
|
122 |
-
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
123 |
-
elif self.num_classes == "continuous":
|
124 |
-
print("setting up linear c_adm embedding layer")
|
125 |
-
self.label_emb = nn.Linear(1, time_embed_dim)
|
126 |
-
elif self.num_classes == "sequential":
|
127 |
-
assert adm_in_channels is not None
|
128 |
-
self.label_emb = nn.Sequential(
|
129 |
-
nn.Sequential(
|
130 |
-
operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
|
131 |
-
nn.SiLU(),
|
132 |
-
operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
|
133 |
-
)
|
134 |
-
)
|
135 |
-
else:
|
136 |
-
raise ValueError()
|
137 |
-
|
138 |
-
self.input_blocks = nn.ModuleList(
|
139 |
-
[
|
140 |
-
TimestepEmbedSequential(
|
141 |
-
operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
|
142 |
-
)
|
143 |
-
]
|
144 |
-
)
|
145 |
-
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
|
146 |
-
|
147 |
-
self.input_hint_block = TimestepEmbedSequential(
|
148 |
-
operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
|
149 |
-
nn.SiLU(),
|
150 |
-
operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
|
151 |
-
nn.SiLU(),
|
152 |
-
operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
153 |
-
nn.SiLU(),
|
154 |
-
operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
|
155 |
-
nn.SiLU(),
|
156 |
-
operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
157 |
-
nn.SiLU(),
|
158 |
-
operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
|
159 |
-
nn.SiLU(),
|
160 |
-
operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
161 |
-
nn.SiLU(),
|
162 |
-
operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
|
163 |
-
)
|
164 |
-
|
165 |
-
self._feature_size = model_channels
|
166 |
-
input_block_chans = [model_channels]
|
167 |
-
ch = model_channels
|
168 |
-
ds = 1
|
169 |
-
for level, mult in enumerate(channel_mult):
|
170 |
-
for nr in range(self.num_res_blocks[level]):
|
171 |
-
layers = [
|
172 |
-
ResBlock(
|
173 |
-
ch,
|
174 |
-
time_embed_dim,
|
175 |
-
dropout,
|
176 |
-
out_channels=mult * model_channels,
|
177 |
-
dims=dims,
|
178 |
-
use_checkpoint=use_checkpoint,
|
179 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
180 |
-
dtype=self.dtype,
|
181 |
-
device=device,
|
182 |
-
operations=operations,
|
183 |
-
)
|
184 |
-
]
|
185 |
-
ch = mult * model_channels
|
186 |
-
num_transformers = transformer_depth.pop(0)
|
187 |
-
if num_transformers > 0:
|
188 |
-
if num_head_channels == -1:
|
189 |
-
dim_head = ch // num_heads
|
190 |
-
else:
|
191 |
-
num_heads = ch // num_head_channels
|
192 |
-
dim_head = num_head_channels
|
193 |
-
if legacy:
|
194 |
-
#num_heads = 1
|
195 |
-
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
196 |
-
if exists(disable_self_attentions):
|
197 |
-
disabled_sa = disable_self_attentions[level]
|
198 |
-
else:
|
199 |
-
disabled_sa = False
|
200 |
-
|
201 |
-
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
202 |
-
layers.append(
|
203 |
-
SpatialTransformer(
|
204 |
-
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
205 |
-
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
206 |
-
use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
|
207 |
-
)
|
208 |
-
)
|
209 |
-
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
210 |
-
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
|
211 |
-
self._feature_size += ch
|
212 |
-
input_block_chans.append(ch)
|
213 |
-
if level != len(channel_mult) - 1:
|
214 |
-
out_ch = ch
|
215 |
-
self.input_blocks.append(
|
216 |
-
TimestepEmbedSequential(
|
217 |
-
ResBlock(
|
218 |
-
ch,
|
219 |
-
time_embed_dim,
|
220 |
-
dropout,
|
221 |
-
out_channels=out_ch,
|
222 |
-
dims=dims,
|
223 |
-
use_checkpoint=use_checkpoint,
|
224 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
225 |
-
down=True,
|
226 |
-
dtype=self.dtype,
|
227 |
-
device=device,
|
228 |
-
operations=operations
|
229 |
-
)
|
230 |
-
if resblock_updown
|
231 |
-
else Downsample(
|
232 |
-
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
|
233 |
-
)
|
234 |
-
)
|
235 |
-
)
|
236 |
-
ch = out_ch
|
237 |
-
input_block_chans.append(ch)
|
238 |
-
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
|
239 |
-
ds *= 2
|
240 |
-
self._feature_size += ch
|
241 |
-
|
242 |
-
if num_head_channels == -1:
|
243 |
-
dim_head = ch // num_heads
|
244 |
-
else:
|
245 |
-
num_heads = ch // num_head_channels
|
246 |
-
dim_head = num_head_channels
|
247 |
-
if legacy:
|
248 |
-
#num_heads = 1
|
249 |
-
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
250 |
-
mid_block = [
|
251 |
-
ResBlock(
|
252 |
-
ch,
|
253 |
-
time_embed_dim,
|
254 |
-
dropout,
|
255 |
-
dims=dims,
|
256 |
-
use_checkpoint=use_checkpoint,
|
257 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
258 |
-
dtype=self.dtype,
|
259 |
-
device=device,
|
260 |
-
operations=operations
|
261 |
-
)]
|
262 |
-
if transformer_depth_middle >= 0:
|
263 |
-
mid_block += [SpatialTransformer( # always uses a self-attn
|
264 |
-
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
265 |
-
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
266 |
-
use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
|
267 |
-
),
|
268 |
-
ResBlock(
|
269 |
-
ch,
|
270 |
-
time_embed_dim,
|
271 |
-
dropout,
|
272 |
-
dims=dims,
|
273 |
-
use_checkpoint=use_checkpoint,
|
274 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
275 |
-
dtype=self.dtype,
|
276 |
-
device=device,
|
277 |
-
operations=operations
|
278 |
-
)]
|
279 |
-
self.middle_block = TimestepEmbedSequential(*mid_block)
|
280 |
-
self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
|
281 |
-
self._feature_size += ch
|
282 |
-
|
283 |
-
def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
|
284 |
-
return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
|
285 |
-
|
286 |
-
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
|
287 |
-
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
288 |
-
emb = self.time_embed(t_emb)
|
289 |
-
|
290 |
-
guided_hint = self.input_hint_block(hint, emb, context)
|
291 |
-
|
292 |
-
outs = []
|
293 |
-
|
294 |
-
hs = []
|
295 |
-
if self.num_classes is not None:
|
296 |
-
assert y.shape[0] == x.shape[0]
|
297 |
-
emb = emb + self.label_emb(y)
|
298 |
-
|
299 |
-
h = x
|
300 |
-
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
301 |
-
if guided_hint is not None:
|
302 |
-
h = module(h, emb, context)
|
303 |
-
h += guided_hint
|
304 |
-
guided_hint = None
|
305 |
-
else:
|
306 |
-
h = module(h, emb, context)
|
307 |
-
outs.append(zero_conv(h, emb, context))
|
308 |
-
|
309 |
-
h = self.middle_block(h, emb, context)
|
310 |
-
outs.append(self.middle_block_out(h, emb, context))
|
311 |
-
|
312 |
-
return outs
|
313 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/comfy/cli_args.py
DELETED
@@ -1,143 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import enum
|
3 |
-
import comfy.options
|
4 |
-
|
5 |
-
class EnumAction(argparse.Action):
|
6 |
-
"""
|
7 |
-
Argparse action for handling Enums
|
8 |
-
"""
|
9 |
-
def __init__(self, **kwargs):
|
10 |
-
# Pop off the type value
|
11 |
-
enum_type = kwargs.pop("type", None)
|
12 |
-
|
13 |
-
# Ensure an Enum subclass is provided
|
14 |
-
if enum_type is None:
|
15 |
-
raise ValueError("type must be assigned an Enum when using EnumAction")
|
16 |
-
if not issubclass(enum_type, enum.Enum):
|
17 |
-
raise TypeError("type must be an Enum when using EnumAction")
|
18 |
-
|
19 |
-
# Generate choices from the Enum
|
20 |
-
choices = tuple(e.value for e in enum_type)
|
21 |
-
kwargs.setdefault("choices", choices)
|
22 |
-
kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")
|
23 |
-
|
24 |
-
super(EnumAction, self).__init__(**kwargs)
|
25 |
-
|
26 |
-
self._enum = enum_type
|
27 |
-
|
28 |
-
def __call__(self, parser, namespace, values, option_string=None):
|
29 |
-
# Convert value back into an Enum
|
30 |
-
value = self._enum(values)
|
31 |
-
setattr(namespace, self.dest, value)
|
32 |
-
|
33 |
-
|
34 |
-
parser = argparse.ArgumentParser()
|
35 |
-
|
36 |
-
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
|
37 |
-
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
38 |
-
parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
|
39 |
-
parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
|
40 |
-
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
|
41 |
-
parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
|
42 |
-
|
43 |
-
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
|
44 |
-
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
|
45 |
-
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")
|
46 |
-
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory.")
|
47 |
-
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
48 |
-
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
49 |
-
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
50 |
-
cm_group = parser.add_mutually_exclusive_group()
|
51 |
-
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
52 |
-
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
|
53 |
-
|
54 |
-
|
55 |
-
fp_group = parser.add_mutually_exclusive_group()
|
56 |
-
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
57 |
-
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
58 |
-
|
59 |
-
fpunet_group = parser.add_mutually_exclusive_group()
|
60 |
-
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
|
61 |
-
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
|
62 |
-
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
|
63 |
-
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
|
64 |
-
|
65 |
-
fpvae_group = parser.add_mutually_exclusive_group()
|
66 |
-
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
67 |
-
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
|
68 |
-
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
|
69 |
-
|
70 |
-
parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.")
|
71 |
-
|
72 |
-
fpte_group = parser.add_mutually_exclusive_group()
|
73 |
-
fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
|
74 |
-
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
|
75 |
-
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
|
76 |
-
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
|
77 |
-
|
78 |
-
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
|
79 |
-
|
80 |
-
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
81 |
-
|
82 |
-
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
|
83 |
-
|
84 |
-
class LatentPreviewMethod(enum.Enum):
|
85 |
-
NoPreviews = "none"
|
86 |
-
Auto = "auto"
|
87 |
-
Latent2RGB = "latent2rgb"
|
88 |
-
TAESD = "taesd"
|
89 |
-
|
90 |
-
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
91 |
-
|
92 |
-
attn_group = parser.add_mutually_exclusive_group()
|
93 |
-
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
94 |
-
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
95 |
-
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
96 |
-
|
97 |
-
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
98 |
-
|
99 |
-
upcast = parser.add_mutually_exclusive_group()
|
100 |
-
upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.")
|
101 |
-
upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
|
102 |
-
|
103 |
-
|
104 |
-
vram_group = parser.add_mutually_exclusive_group()
|
105 |
-
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
|
106 |
-
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
107 |
-
vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
|
108 |
-
vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
|
109 |
-
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
110 |
-
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
111 |
-
|
112 |
-
|
113 |
-
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
114 |
-
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
115 |
-
|
116 |
-
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
117 |
-
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
118 |
-
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
|
119 |
-
|
120 |
-
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
121 |
-
|
122 |
-
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
123 |
-
|
124 |
-
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
|
125 |
-
|
126 |
-
|
127 |
-
if comfy.options.args_parsing:
|
128 |
-
args = parser.parse_args()
|
129 |
-
else:
|
130 |
-
args = parser.parse_args([])
|
131 |
-
|
132 |
-
if args.windows_standalone_build:
|
133 |
-
args.auto_launch = True
|
134 |
-
|
135 |
-
if args.disable_auto_launch:
|
136 |
-
args.auto_launch = False
|
137 |
-
|
138 |
-
import logging
|
139 |
-
logging_level = logging.INFO
|
140 |
-
if args.verbose:
|
141 |
-
logging_level = logging.DEBUG
|
142 |
-
|
143 |
-
logging.basicConfig(format="%(message)s", level=logging_level)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/comfy/clip_config_bigg.json
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"architectures": [
|
3 |
-
"CLIPTextModel"
|
4 |
-
],
|
5 |
-
"attention_dropout": 0.0,
|
6 |
-
"bos_token_id": 0,
|
7 |
-
"dropout": 0.0,
|
8 |
-
"eos_token_id": 2,
|
9 |
-
"hidden_act": "gelu",
|
10 |
-
"hidden_size": 1280,
|
11 |
-
"initializer_factor": 1.0,
|
12 |
-
"initializer_range": 0.02,
|
13 |
-
"intermediate_size": 5120,
|
14 |
-
"layer_norm_eps": 1e-05,
|
15 |
-
"max_position_embeddings": 77,
|
16 |
-
"model_type": "clip_text_model",
|
17 |
-
"num_attention_heads": 20,
|
18 |
-
"num_hidden_layers": 32,
|
19 |
-
"pad_token_id": 1,
|
20 |
-
"projection_dim": 1280,
|
21 |
-
"torch_dtype": "float32",
|
22 |
-
"vocab_size": 49408
|
23 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/comfy/clip_model.py
DELETED
@@ -1,194 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from comfy.ldm.modules.attention import optimized_attention_for_device
|
3 |
-
|
4 |
-
class CLIPAttention(torch.nn.Module):
|
5 |
-
def __init__(self, embed_dim, heads, dtype, device, operations):
|
6 |
-
super().__init__()
|
7 |
-
|
8 |
-
self.heads = heads
|
9 |
-
self.q_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
10 |
-
self.k_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
11 |
-
self.v_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
12 |
-
|
13 |
-
self.out_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
14 |
-
|
15 |
-
def forward(self, x, mask=None, optimized_attention=None):
|
16 |
-
q = self.q_proj(x)
|
17 |
-
k = self.k_proj(x)
|
18 |
-
v = self.v_proj(x)
|
19 |
-
|
20 |
-
out = optimized_attention(q, k, v, self.heads, mask)
|
21 |
-
return self.out_proj(out)
|
22 |
-
|
23 |
-
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
|
24 |
-
"gelu": torch.nn.functional.gelu,
|
25 |
-
}
|
26 |
-
|
27 |
-
class CLIPMLP(torch.nn.Module):
|
28 |
-
def __init__(self, embed_dim, intermediate_size, activation, dtype, device, operations):
|
29 |
-
super().__init__()
|
30 |
-
self.fc1 = operations.Linear(embed_dim, intermediate_size, bias=True, dtype=dtype, device=device)
|
31 |
-
self.activation = ACTIVATIONS[activation]
|
32 |
-
self.fc2 = operations.Linear(intermediate_size, embed_dim, bias=True, dtype=dtype, device=device)
|
33 |
-
|
34 |
-
def forward(self, x):
|
35 |
-
x = self.fc1(x)
|
36 |
-
x = self.activation(x)
|
37 |
-
x = self.fc2(x)
|
38 |
-
return x
|
39 |
-
|
40 |
-
class CLIPLayer(torch.nn.Module):
|
41 |
-
def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
|
42 |
-
super().__init__()
|
43 |
-
self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
44 |
-
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations)
|
45 |
-
self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
46 |
-
self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device, operations)
|
47 |
-
|
48 |
-
def forward(self, x, mask=None, optimized_attention=None):
|
49 |
-
x += self.self_attn(self.layer_norm1(x), mask, optimized_attention)
|
50 |
-
x += self.mlp(self.layer_norm2(x))
|
51 |
-
return x
|
52 |
-
|
53 |
-
|
54 |
-
class CLIPEncoder(torch.nn.Module):
|
55 |
-
def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
|
56 |
-
super().__init__()
|
57 |
-
self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
|
58 |
-
|
59 |
-
def forward(self, x, mask=None, intermediate_output=None):
|
60 |
-
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
61 |
-
|
62 |
-
if intermediate_output is not None:
|
63 |
-
if intermediate_output < 0:
|
64 |
-
intermediate_output = len(self.layers) + intermediate_output
|
65 |
-
|
66 |
-
intermediate = None
|
67 |
-
for i, l in enumerate(self.layers):
|
68 |
-
x = l(x, mask, optimized_attention)
|
69 |
-
if i == intermediate_output:
|
70 |
-
intermediate = x.clone()
|
71 |
-
return x, intermediate
|
72 |
-
|
73 |
-
class CLIPEmbeddings(torch.nn.Module):
|
74 |
-
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
|
75 |
-
super().__init__()
|
76 |
-
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
|
77 |
-
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
|
78 |
-
|
79 |
-
def forward(self, input_tokens):
|
80 |
-
return self.token_embedding(input_tokens) + self.position_embedding.weight
|
81 |
-
|
82 |
-
|
83 |
-
class CLIPTextModel_(torch.nn.Module):
|
84 |
-
def __init__(self, config_dict, dtype, device, operations):
|
85 |
-
num_layers = config_dict["num_hidden_layers"]
|
86 |
-
embed_dim = config_dict["hidden_size"]
|
87 |
-
heads = config_dict["num_attention_heads"]
|
88 |
-
intermediate_size = config_dict["intermediate_size"]
|
89 |
-
intermediate_activation = config_dict["hidden_act"]
|
90 |
-
|
91 |
-
super().__init__()
|
92 |
-
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
|
93 |
-
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
94 |
-
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
95 |
-
|
96 |
-
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
|
97 |
-
x = self.embeddings(input_tokens)
|
98 |
-
mask = None
|
99 |
-
if attention_mask is not None:
|
100 |
-
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
101 |
-
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
102 |
-
|
103 |
-
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
|
104 |
-
if mask is not None:
|
105 |
-
mask += causal_mask
|
106 |
-
else:
|
107 |
-
mask = causal_mask
|
108 |
-
|
109 |
-
x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output)
|
110 |
-
x = self.final_layer_norm(x)
|
111 |
-
if i is not None and final_layer_norm_intermediate:
|
112 |
-
i = self.final_layer_norm(i)
|
113 |
-
|
114 |
-
pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]
|
115 |
-
return x, i, pooled_output
|
116 |
-
|
117 |
-
class CLIPTextModel(torch.nn.Module):
|
118 |
-
def __init__(self, config_dict, dtype, device, operations):
|
119 |
-
super().__init__()
|
120 |
-
self.num_layers = config_dict["num_hidden_layers"]
|
121 |
-
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
|
122 |
-
embed_dim = config_dict["hidden_size"]
|
123 |
-
self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
|
124 |
-
self.text_projection.weight.copy_(torch.eye(embed_dim))
|
125 |
-
self.dtype = dtype
|
126 |
-
|
127 |
-
def get_input_embeddings(self):
|
128 |
-
return self.text_model.embeddings.token_embedding
|
129 |
-
|
130 |
-
def set_input_embeddings(self, embeddings):
|
131 |
-
self.text_model.embeddings.token_embedding = embeddings
|
132 |
-
|
133 |
-
def forward(self, *args, **kwargs):
|
134 |
-
x = self.text_model(*args, **kwargs)
|
135 |
-
out = self.text_projection(x[2])
|
136 |
-
return (x[0], x[1], out, x[2])
|
137 |
-
|
138 |
-
|
139 |
-
class CLIPVisionEmbeddings(torch.nn.Module):
|
140 |
-
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
|
141 |
-
super().__init__()
|
142 |
-
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
|
143 |
-
|
144 |
-
self.patch_embedding = operations.Conv2d(
|
145 |
-
in_channels=num_channels,
|
146 |
-
out_channels=embed_dim,
|
147 |
-
kernel_size=patch_size,
|
148 |
-
stride=patch_size,
|
149 |
-
bias=False,
|
150 |
-
dtype=dtype,
|
151 |
-
device=device
|
152 |
-
)
|
153 |
-
|
154 |
-
num_patches = (image_size // patch_size) ** 2
|
155 |
-
num_positions = num_patches + 1
|
156 |
-
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
|
157 |
-
|
158 |
-
def forward(self, pixel_values):
|
159 |
-
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
|
160 |
-
return torch.cat([self.class_embedding.to(embeds.device).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight.to(embeds.device)
|
161 |
-
|
162 |
-
|
163 |
-
class CLIPVision(torch.nn.Module):
|
164 |
-
def __init__(self, config_dict, dtype, device, operations):
|
165 |
-
super().__init__()
|
166 |
-
num_layers = config_dict["num_hidden_layers"]
|
167 |
-
embed_dim = config_dict["hidden_size"]
|
168 |
-
heads = config_dict["num_attention_heads"]
|
169 |
-
intermediate_size = config_dict["intermediate_size"]
|
170 |
-
intermediate_activation = config_dict["hidden_act"]
|
171 |
-
|
172 |
-
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=torch.float32, device=device, operations=operations)
|
173 |
-
self.pre_layrnorm = operations.LayerNorm(embed_dim)
|
174 |
-
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
175 |
-
self.post_layernorm = operations.LayerNorm(embed_dim)
|
176 |
-
|
177 |
-
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
178 |
-
x = self.embeddings(pixel_values)
|
179 |
-
x = self.pre_layrnorm(x)
|
180 |
-
#TODO: attention_mask?
|
181 |
-
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
|
182 |
-
pooled_output = self.post_layernorm(x[:, 0, :])
|
183 |
-
return x, i, pooled_output
|
184 |
-
|
185 |
-
class CLIPVisionModelProjection(torch.nn.Module):
|
186 |
-
def __init__(self, config_dict, dtype, device, operations):
|
187 |
-
super().__init__()
|
188 |
-
self.vision_model = CLIPVision(config_dict, dtype, device, operations)
|
189 |
-
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
|
190 |
-
|
191 |
-
def forward(self, *args, **kwargs):
|
192 |
-
x = self.vision_model(*args, **kwargs)
|
193 |
-
out = self.visual_projection(x[2])
|
194 |
-
return (x[0], x[1], out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/comfy/clip_vision.py
DELETED
@@ -1,117 +0,0 @@
|
|
1 |
-
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
|
2 |
-
import os
|
3 |
-
import torch
|
4 |
-
import json
|
5 |
-
import logging
|
6 |
-
|
7 |
-
import comfy.ops
|
8 |
-
import comfy.model_patcher
|
9 |
-
import comfy.model_management
|
10 |
-
import comfy.utils
|
11 |
-
import comfy.clip_model
|
12 |
-
|
13 |
-
class Output:
|
14 |
-
def __getitem__(self, key):
|
15 |
-
return getattr(self, key)
|
16 |
-
def __setitem__(self, key, item):
|
17 |
-
setattr(self, key, item)
|
18 |
-
|
19 |
-
def clip_preprocess(image, size=224):
|
20 |
-
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
|
21 |
-
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
|
22 |
-
image = image.movedim(-1, 1)
|
23 |
-
if not (image.shape[2] == size and image.shape[3] == size):
|
24 |
-
scale = (size / min(image.shape[2], image.shape[3]))
|
25 |
-
image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True)
|
26 |
-
h = (image.shape[2] - size)//2
|
27 |
-
w = (image.shape[3] - size)//2
|
28 |
-
image = image[:,:,h:h+size,w:w+size]
|
29 |
-
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
30 |
-
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
31 |
-
|
32 |
-
class ClipVisionModel():
|
33 |
-
def __init__(self, json_config):
|
34 |
-
with open(json_config) as f:
|
35 |
-
config = json.load(f)
|
36 |
-
|
37 |
-
self.load_device = comfy.model_management.text_encoder_device()
|
38 |
-
offload_device = comfy.model_management.text_encoder_offload_device()
|
39 |
-
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
40 |
-
self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
41 |
-
self.model.eval()
|
42 |
-
|
43 |
-
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
44 |
-
|
45 |
-
def load_sd(self, sd):
|
46 |
-
return self.model.load_state_dict(sd, strict=False)
|
47 |
-
|
48 |
-
def get_sd(self):
|
49 |
-
return self.model.state_dict()
|
50 |
-
|
51 |
-
def encode_image(self, image):
|
52 |
-
comfy.model_management.load_model_gpu(self.patcher)
|
53 |
-
pixel_values = clip_preprocess(image.to(self.load_device)).float()
|
54 |
-
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
55 |
-
|
56 |
-
outputs = Output()
|
57 |
-
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
|
58 |
-
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
|
59 |
-
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
|
60 |
-
return outputs
|
61 |
-
|
62 |
-
def convert_to_transformers(sd, prefix):
|
63 |
-
sd_k = sd.keys()
|
64 |
-
if "{}transformer.resblocks.0.attn.in_proj_weight".format(prefix) in sd_k:
|
65 |
-
keys_to_replace = {
|
66 |
-
"{}class_embedding".format(prefix): "vision_model.embeddings.class_embedding",
|
67 |
-
"{}conv1.weight".format(prefix): "vision_model.embeddings.patch_embedding.weight",
|
68 |
-
"{}positional_embedding".format(prefix): "vision_model.embeddings.position_embedding.weight",
|
69 |
-
"{}ln_post.bias".format(prefix): "vision_model.post_layernorm.bias",
|
70 |
-
"{}ln_post.weight".format(prefix): "vision_model.post_layernorm.weight",
|
71 |
-
"{}ln_pre.bias".format(prefix): "vision_model.pre_layrnorm.bias",
|
72 |
-
"{}ln_pre.weight".format(prefix): "vision_model.pre_layrnorm.weight",
|
73 |
-
}
|
74 |
-
|
75 |
-
for x in keys_to_replace:
|
76 |
-
if x in sd_k:
|
77 |
-
sd[keys_to_replace[x]] = sd.pop(x)
|
78 |
-
|
79 |
-
if "{}proj".format(prefix) in sd_k:
|
80 |
-
sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1)
|
81 |
-
|
82 |
-
sd = transformers_convert(sd, prefix, "vision_model.", 48)
|
83 |
-
else:
|
84 |
-
replace_prefix = {prefix: ""}
|
85 |
-
sd = state_dict_prefix_replace(sd, replace_prefix)
|
86 |
-
return sd
|
87 |
-
|
88 |
-
def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
89 |
-
if convert_keys:
|
90 |
-
sd = convert_to_transformers(sd, prefix)
|
91 |
-
if "vision_model.encoder.layers.47.layer_norm1.weight" in sd:
|
92 |
-
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
|
93 |
-
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
94 |
-
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
95 |
-
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
96 |
-
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
97 |
-
else:
|
98 |
-
return None
|
99 |
-
|
100 |
-
clip = ClipVisionModel(json_config)
|
101 |
-
m, u = clip.load_sd(sd)
|
102 |
-
if len(m) > 0:
|
103 |
-
logging.warning("missing clip vision: {}".format(m))
|
104 |
-
u = set(u)
|
105 |
-
keys = list(sd.keys())
|
106 |
-
for k in keys:
|
107 |
-
if k not in u:
|
108 |
-
t = sd.pop(k)
|
109 |
-
del t
|
110 |
-
return clip
|
111 |
-
|
112 |
-
def load(ckpt_path):
|
113 |
-
sd = load_torch_file(ckpt_path)
|
114 |
-
if "visual.transformer.resblocks.0.attn.in_proj_weight" in sd:
|
115 |
-
return load_clipvision_from_sd(sd, prefix="visual.", convert_keys=True)
|
116 |
-
else:
|
117 |
-
return load_clipvision_from_sd(sd)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/comfy/clip_vision_config_g.json
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"attention_dropout": 0.0,
|
3 |
-
"dropout": 0.0,
|
4 |
-
"hidden_act": "gelu",
|
5 |
-
"hidden_size": 1664,
|
6 |
-
"image_size": 224,
|
7 |
-
"initializer_factor": 1.0,
|
8 |
-
"initializer_range": 0.02,
|
9 |
-
"intermediate_size": 8192,
|
10 |
-
"layer_norm_eps": 1e-05,
|
11 |
-
"model_type": "clip_vision_model",
|
12 |
-
"num_attention_heads": 16,
|
13 |
-
"num_channels": 3,
|
14 |
-
"num_hidden_layers": 48,
|
15 |
-
"patch_size": 14,
|
16 |
-
"projection_dim": 1280,
|
17 |
-
"torch_dtype": "float32"
|
18 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/comfy/clip_vision_config_h.json
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"attention_dropout": 0.0,
|
3 |
-
"dropout": 0.0,
|
4 |
-
"hidden_act": "gelu",
|
5 |
-
"hidden_size": 1280,
|
6 |
-
"image_size": 224,
|
7 |
-
"initializer_factor": 1.0,
|
8 |
-
"initializer_range": 0.02,
|
9 |
-
"intermediate_size": 5120,
|
10 |
-
"layer_norm_eps": 1e-05,
|
11 |
-
"model_type": "clip_vision_model",
|
12 |
-
"num_attention_heads": 16,
|
13 |
-
"num_channels": 3,
|
14 |
-
"num_hidden_layers": 32,
|
15 |
-
"patch_size": 14,
|
16 |
-
"projection_dim": 1024,
|
17 |
-
"torch_dtype": "float32"
|
18 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/comfy/clip_vision_config_vitl.json
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"attention_dropout": 0.0,
|
3 |
-
"dropout": 0.0,
|
4 |
-
"hidden_act": "quick_gelu",
|
5 |
-
"hidden_size": 1024,
|
6 |
-
"image_size": 224,
|
7 |
-
"initializer_factor": 1.0,
|
8 |
-
"initializer_range": 0.02,
|
9 |
-
"intermediate_size": 4096,
|
10 |
-
"layer_norm_eps": 1e-05,
|
11 |
-
"model_type": "clip_vision_model",
|
12 |
-
"num_attention_heads": 16,
|
13 |
-
"num_channels": 3,
|
14 |
-
"num_hidden_layers": 24,
|
15 |
-
"patch_size": 14,
|
16 |
-
"projection_dim": 768,
|
17 |
-
"torch_dtype": "float32"
|
18 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/comfy/conds.py
DELETED
@@ -1,83 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import math
|
3 |
-
import comfy.utils
|
4 |
-
|
5 |
-
|
6 |
-
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
7 |
-
return abs(a*b) // math.gcd(a, b)
|
8 |
-
|
9 |
-
class CONDRegular:
|
10 |
-
def __init__(self, cond):
|
11 |
-
self.cond = cond
|
12 |
-
|
13 |
-
def _copy_with(self, cond):
|
14 |
-
return self.__class__(cond)
|
15 |
-
|
16 |
-
def process_cond(self, batch_size, device, **kwargs):
|
17 |
-
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
|
18 |
-
|
19 |
-
def can_concat(self, other):
|
20 |
-
if self.cond.shape != other.cond.shape:
|
21 |
-
return False
|
22 |
-
return True
|
23 |
-
|
24 |
-
def concat(self, others):
|
25 |
-
conds = [self.cond]
|
26 |
-
for x in others:
|
27 |
-
conds.append(x.cond)
|
28 |
-
return torch.cat(conds)
|
29 |
-
|
30 |
-
class CONDNoiseShape(CONDRegular):
|
31 |
-
def process_cond(self, batch_size, device, area, **kwargs):
|
32 |
-
data = self.cond
|
33 |
-
if area is not None:
|
34 |
-
dims = len(area) // 2
|
35 |
-
for i in range(dims):
|
36 |
-
data = data.narrow(i + 2, area[i + dims], area[i])
|
37 |
-
|
38 |
-
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
|
39 |
-
|
40 |
-
|
41 |
-
class CONDCrossAttn(CONDRegular):
|
42 |
-
def can_concat(self, other):
|
43 |
-
s1 = self.cond.shape
|
44 |
-
s2 = other.cond.shape
|
45 |
-
if s1 != s2:
|
46 |
-
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
|
47 |
-
return False
|
48 |
-
|
49 |
-
mult_min = lcm(s1[1], s2[1])
|
50 |
-
diff = mult_min // min(s1[1], s2[1])
|
51 |
-
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
52 |
-
return False
|
53 |
-
return True
|
54 |
-
|
55 |
-
def concat(self, others):
|
56 |
-
conds = [self.cond]
|
57 |
-
crossattn_max_len = self.cond.shape[1]
|
58 |
-
for x in others:
|
59 |
-
c = x.cond
|
60 |
-
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
|
61 |
-
conds.append(c)
|
62 |
-
|
63 |
-
out = []
|
64 |
-
for c in conds:
|
65 |
-
if c.shape[1] < crossattn_max_len:
|
66 |
-
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
|
67 |
-
out.append(c)
|
68 |
-
return torch.cat(out)
|
69 |
-
|
70 |
-
class CONDConstant(CONDRegular):
|
71 |
-
def __init__(self, cond):
|
72 |
-
self.cond = cond
|
73 |
-
|
74 |
-
def process_cond(self, batch_size, device, **kwargs):
|
75 |
-
return self._copy_with(self.cond)
|
76 |
-
|
77 |
-
def can_concat(self, other):
|
78 |
-
if self.cond != other.cond:
|
79 |
-
return False
|
80 |
-
return True
|
81 |
-
|
82 |
-
def concat(self, others):
|
83 |
-
return self.cond
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/comfy/controlnet.py
DELETED
@@ -1,554 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import math
|
3 |
-
import os
|
4 |
-
import logging
|
5 |
-
import comfy.utils
|
6 |
-
import comfy.model_management
|
7 |
-
import comfy.model_detection
|
8 |
-
import comfy.model_patcher
|
9 |
-
import comfy.ops
|
10 |
-
|
11 |
-
import comfy.cldm.cldm
|
12 |
-
import comfy.t2i_adapter.adapter
|
13 |
-
import comfy.ldm.cascade.controlnet
|
14 |
-
|
15 |
-
|
16 |
-
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
17 |
-
current_batch_size = tensor.shape[0]
|
18 |
-
#print(current_batch_size, target_batch_size)
|
19 |
-
if current_batch_size == 1:
|
20 |
-
return tensor
|
21 |
-
|
22 |
-
per_batch = target_batch_size // batched_number
|
23 |
-
tensor = tensor[:per_batch]
|
24 |
-
|
25 |
-
if per_batch > tensor.shape[0]:
|
26 |
-
tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0)
|
27 |
-
|
28 |
-
current_batch_size = tensor.shape[0]
|
29 |
-
if current_batch_size == target_batch_size:
|
30 |
-
return tensor
|
31 |
-
else:
|
32 |
-
return torch.cat([tensor] * batched_number, dim=0)
|
33 |
-
|
34 |
-
class ControlBase:
|
35 |
-
def __init__(self, device=None):
|
36 |
-
self.cond_hint_original = None
|
37 |
-
self.cond_hint = None
|
38 |
-
self.strength = 1.0
|
39 |
-
self.timestep_percent_range = (0.0, 1.0)
|
40 |
-
self.global_average_pooling = False
|
41 |
-
self.timestep_range = None
|
42 |
-
self.compression_ratio = 8
|
43 |
-
self.upscale_algorithm = 'nearest-exact'
|
44 |
-
|
45 |
-
if device is None:
|
46 |
-
device = comfy.model_management.get_torch_device()
|
47 |
-
self.device = device
|
48 |
-
self.previous_controlnet = None
|
49 |
-
|
50 |
-
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)):
|
51 |
-
self.cond_hint_original = cond_hint
|
52 |
-
self.strength = strength
|
53 |
-
self.timestep_percent_range = timestep_percent_range
|
54 |
-
return self
|
55 |
-
|
56 |
-
def pre_run(self, model, percent_to_timestep_function):
|
57 |
-
self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1]))
|
58 |
-
if self.previous_controlnet is not None:
|
59 |
-
self.previous_controlnet.pre_run(model, percent_to_timestep_function)
|
60 |
-
|
61 |
-
def set_previous_controlnet(self, controlnet):
|
62 |
-
self.previous_controlnet = controlnet
|
63 |
-
return self
|
64 |
-
|
65 |
-
def cleanup(self):
|
66 |
-
if self.previous_controlnet is not None:
|
67 |
-
self.previous_controlnet.cleanup()
|
68 |
-
if self.cond_hint is not None:
|
69 |
-
del self.cond_hint
|
70 |
-
self.cond_hint = None
|
71 |
-
self.timestep_range = None
|
72 |
-
|
73 |
-
def get_models(self):
|
74 |
-
out = []
|
75 |
-
if self.previous_controlnet is not None:
|
76 |
-
out += self.previous_controlnet.get_models()
|
77 |
-
return out
|
78 |
-
|
79 |
-
def copy_to(self, c):
|
80 |
-
c.cond_hint_original = self.cond_hint_original
|
81 |
-
c.strength = self.strength
|
82 |
-
c.timestep_percent_range = self.timestep_percent_range
|
83 |
-
c.global_average_pooling = self.global_average_pooling
|
84 |
-
c.compression_ratio = self.compression_ratio
|
85 |
-
c.upscale_algorithm = self.upscale_algorithm
|
86 |
-
|
87 |
-
def inference_memory_requirements(self, dtype):
|
88 |
-
if self.previous_controlnet is not None:
|
89 |
-
return self.previous_controlnet.inference_memory_requirements(dtype)
|
90 |
-
return 0
|
91 |
-
|
92 |
-
def control_merge(self, control_input, control_output, control_prev, output_dtype):
|
93 |
-
out = {'input':[], 'middle':[], 'output': []}
|
94 |
-
|
95 |
-
if control_input is not None:
|
96 |
-
for i in range(len(control_input)):
|
97 |
-
key = 'input'
|
98 |
-
x = control_input[i]
|
99 |
-
if x is not None:
|
100 |
-
x *= self.strength
|
101 |
-
if x.dtype != output_dtype:
|
102 |
-
x = x.to(output_dtype)
|
103 |
-
out[key].insert(0, x)
|
104 |
-
|
105 |
-
if control_output is not None:
|
106 |
-
for i in range(len(control_output)):
|
107 |
-
if i == (len(control_output) - 1):
|
108 |
-
key = 'middle'
|
109 |
-
index = 0
|
110 |
-
else:
|
111 |
-
key = 'output'
|
112 |
-
index = i
|
113 |
-
x = control_output[i]
|
114 |
-
if x is not None:
|
115 |
-
if self.global_average_pooling:
|
116 |
-
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
|
117 |
-
|
118 |
-
x *= self.strength
|
119 |
-
if x.dtype != output_dtype:
|
120 |
-
x = x.to(output_dtype)
|
121 |
-
|
122 |
-
out[key].append(x)
|
123 |
-
if control_prev is not None:
|
124 |
-
for x in ['input', 'middle', 'output']:
|
125 |
-
o = out[x]
|
126 |
-
for i in range(len(control_prev[x])):
|
127 |
-
prev_val = control_prev[x][i]
|
128 |
-
if i >= len(o):
|
129 |
-
o.append(prev_val)
|
130 |
-
elif prev_val is not None:
|
131 |
-
if o[i] is None:
|
132 |
-
o[i] = prev_val
|
133 |
-
else:
|
134 |
-
if o[i].shape[0] < prev_val.shape[0]:
|
135 |
-
o[i] = prev_val + o[i]
|
136 |
-
else:
|
137 |
-
o[i] += prev_val
|
138 |
-
return out
|
139 |
-
|
140 |
-
class ControlNet(ControlBase):
|
141 |
-
def __init__(self, control_model=None, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
|
142 |
-
super().__init__(device)
|
143 |
-
self.control_model = control_model
|
144 |
-
self.load_device = load_device
|
145 |
-
if control_model is not None:
|
146 |
-
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
147 |
-
|
148 |
-
self.global_average_pooling = global_average_pooling
|
149 |
-
self.model_sampling_current = None
|
150 |
-
self.manual_cast_dtype = manual_cast_dtype
|
151 |
-
|
152 |
-
def get_control(self, x_noisy, t, cond, batched_number):
|
153 |
-
control_prev = None
|
154 |
-
if self.previous_controlnet is not None:
|
155 |
-
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
156 |
-
|
157 |
-
if self.timestep_range is not None:
|
158 |
-
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
159 |
-
if control_prev is not None:
|
160 |
-
return control_prev
|
161 |
-
else:
|
162 |
-
return None
|
163 |
-
|
164 |
-
dtype = self.control_model.dtype
|
165 |
-
if self.manual_cast_dtype is not None:
|
166 |
-
dtype = self.manual_cast_dtype
|
167 |
-
|
168 |
-
output_dtype = x_noisy.dtype
|
169 |
-
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
|
170 |
-
if self.cond_hint is not None:
|
171 |
-
del self.cond_hint
|
172 |
-
self.cond_hint = None
|
173 |
-
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio, self.upscale_algorithm, "center").to(dtype).to(self.device)
|
174 |
-
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
175 |
-
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
176 |
-
|
177 |
-
context = cond.get('crossattn_controlnet', cond['c_crossattn'])
|
178 |
-
y = cond.get('y', None)
|
179 |
-
if y is not None:
|
180 |
-
y = y.to(dtype)
|
181 |
-
timestep = self.model_sampling_current.timestep(t)
|
182 |
-
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
183 |
-
|
184 |
-
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
|
185 |
-
return self.control_merge(None, control, control_prev, output_dtype)
|
186 |
-
|
187 |
-
def copy(self):
|
188 |
-
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
189 |
-
c.control_model = self.control_model
|
190 |
-
c.control_model_wrapped = self.control_model_wrapped
|
191 |
-
self.copy_to(c)
|
192 |
-
return c
|
193 |
-
|
194 |
-
def get_models(self):
|
195 |
-
out = super().get_models()
|
196 |
-
out.append(self.control_model_wrapped)
|
197 |
-
return out
|
198 |
-
|
199 |
-
def pre_run(self, model, percent_to_timestep_function):
|
200 |
-
super().pre_run(model, percent_to_timestep_function)
|
201 |
-
self.model_sampling_current = model.model_sampling
|
202 |
-
|
203 |
-
def cleanup(self):
|
204 |
-
self.model_sampling_current = None
|
205 |
-
super().cleanup()
|
206 |
-
|
207 |
-
class ControlLoraOps:
|
208 |
-
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
209 |
-
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
210 |
-
device=None, dtype=None) -> None:
|
211 |
-
factory_kwargs = {'device': device, 'dtype': dtype}
|
212 |
-
super().__init__()
|
213 |
-
self.in_features = in_features
|
214 |
-
self.out_features = out_features
|
215 |
-
self.weight = None
|
216 |
-
self.up = None
|
217 |
-
self.down = None
|
218 |
-
self.bias = None
|
219 |
-
|
220 |
-
def forward(self, input):
|
221 |
-
weight, bias = comfy.ops.cast_bias_weight(self, input)
|
222 |
-
if self.up is not None:
|
223 |
-
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
|
224 |
-
else:
|
225 |
-
return torch.nn.functional.linear(input, weight, bias)
|
226 |
-
|
227 |
-
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
228 |
-
def __init__(
|
229 |
-
self,
|
230 |
-
in_channels,
|
231 |
-
out_channels,
|
232 |
-
kernel_size,
|
233 |
-
stride=1,
|
234 |
-
padding=0,
|
235 |
-
dilation=1,
|
236 |
-
groups=1,
|
237 |
-
bias=True,
|
238 |
-
padding_mode='zeros',
|
239 |
-
device=None,
|
240 |
-
dtype=None
|
241 |
-
):
|
242 |
-
super().__init__()
|
243 |
-
self.in_channels = in_channels
|
244 |
-
self.out_channels = out_channels
|
245 |
-
self.kernel_size = kernel_size
|
246 |
-
self.stride = stride
|
247 |
-
self.padding = padding
|
248 |
-
self.dilation = dilation
|
249 |
-
self.transposed = False
|
250 |
-
self.output_padding = 0
|
251 |
-
self.groups = groups
|
252 |
-
self.padding_mode = padding_mode
|
253 |
-
|
254 |
-
self.weight = None
|
255 |
-
self.bias = None
|
256 |
-
self.up = None
|
257 |
-
self.down = None
|
258 |
-
|
259 |
-
|
260 |
-
def forward(self, input):
|
261 |
-
weight, bias = comfy.ops.cast_bias_weight(self, input)
|
262 |
-
if self.up is not None:
|
263 |
-
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
|
264 |
-
else:
|
265 |
-
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
266 |
-
|
267 |
-
|
268 |
-
class ControlLora(ControlNet):
|
269 |
-
def __init__(self, control_weights, global_average_pooling=False, device=None):
|
270 |
-
ControlBase.__init__(self, device)
|
271 |
-
self.control_weights = control_weights
|
272 |
-
self.global_average_pooling = global_average_pooling
|
273 |
-
|
274 |
-
def pre_run(self, model, percent_to_timestep_function):
|
275 |
-
super().pre_run(model, percent_to_timestep_function)
|
276 |
-
controlnet_config = model.model_config.unet_config.copy()
|
277 |
-
controlnet_config.pop("out_channels")
|
278 |
-
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
|
279 |
-
self.manual_cast_dtype = model.manual_cast_dtype
|
280 |
-
dtype = model.get_dtype()
|
281 |
-
if self.manual_cast_dtype is None:
|
282 |
-
class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init):
|
283 |
-
pass
|
284 |
-
else:
|
285 |
-
class control_lora_ops(ControlLoraOps, comfy.ops.manual_cast):
|
286 |
-
pass
|
287 |
-
dtype = self.manual_cast_dtype
|
288 |
-
|
289 |
-
controlnet_config["operations"] = control_lora_ops
|
290 |
-
controlnet_config["dtype"] = dtype
|
291 |
-
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
292 |
-
self.control_model.to(comfy.model_management.get_torch_device())
|
293 |
-
diffusion_model = model.diffusion_model
|
294 |
-
sd = diffusion_model.state_dict()
|
295 |
-
cm = self.control_model.state_dict()
|
296 |
-
|
297 |
-
for k in sd:
|
298 |
-
weight = sd[k]
|
299 |
-
try:
|
300 |
-
comfy.utils.set_attr_param(self.control_model, k, weight)
|
301 |
-
except:
|
302 |
-
pass
|
303 |
-
|
304 |
-
for k in self.control_weights:
|
305 |
-
if k not in {"lora_controlnet"}:
|
306 |
-
comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
|
307 |
-
|
308 |
-
def copy(self):
|
309 |
-
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
|
310 |
-
self.copy_to(c)
|
311 |
-
return c
|
312 |
-
|
313 |
-
def cleanup(self):
|
314 |
-
del self.control_model
|
315 |
-
self.control_model = None
|
316 |
-
super().cleanup()
|
317 |
-
|
318 |
-
def get_models(self):
|
319 |
-
out = ControlBase.get_models(self)
|
320 |
-
return out
|
321 |
-
|
322 |
-
def inference_memory_requirements(self, dtype):
|
323 |
-
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
324 |
-
|
325 |
-
def load_controlnet(ckpt_path, model=None):
|
326 |
-
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
327 |
-
if "lora_controlnet" in controlnet_data:
|
328 |
-
return ControlLora(controlnet_data)
|
329 |
-
|
330 |
-
controlnet_config = None
|
331 |
-
supported_inference_dtypes = None
|
332 |
-
|
333 |
-
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
|
334 |
-
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data)
|
335 |
-
diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
|
336 |
-
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
337 |
-
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
|
338 |
-
|
339 |
-
count = 0
|
340 |
-
loop = True
|
341 |
-
while loop:
|
342 |
-
suffix = [".weight", ".bias"]
|
343 |
-
for s in suffix:
|
344 |
-
k_in = "controlnet_down_blocks.{}{}".format(count, s)
|
345 |
-
k_out = "zero_convs.{}.0{}".format(count, s)
|
346 |
-
if k_in not in controlnet_data:
|
347 |
-
loop = False
|
348 |
-
break
|
349 |
-
diffusers_keys[k_in] = k_out
|
350 |
-
count += 1
|
351 |
-
|
352 |
-
count = 0
|
353 |
-
loop = True
|
354 |
-
while loop:
|
355 |
-
suffix = [".weight", ".bias"]
|
356 |
-
for s in suffix:
|
357 |
-
if count == 0:
|
358 |
-
k_in = "controlnet_cond_embedding.conv_in{}".format(s)
|
359 |
-
else:
|
360 |
-
k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
|
361 |
-
k_out = "input_hint_block.{}{}".format(count * 2, s)
|
362 |
-
if k_in not in controlnet_data:
|
363 |
-
k_in = "controlnet_cond_embedding.conv_out{}".format(s)
|
364 |
-
loop = False
|
365 |
-
diffusers_keys[k_in] = k_out
|
366 |
-
count += 1
|
367 |
-
|
368 |
-
new_sd = {}
|
369 |
-
for k in diffusers_keys:
|
370 |
-
if k in controlnet_data:
|
371 |
-
new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
|
372 |
-
|
373 |
-
leftover_keys = controlnet_data.keys()
|
374 |
-
if len(leftover_keys) > 0:
|
375 |
-
logging.warning("leftover keys: {}".format(leftover_keys))
|
376 |
-
controlnet_data = new_sd
|
377 |
-
|
378 |
-
pth_key = 'control_model.zero_convs.0.0.weight'
|
379 |
-
pth = False
|
380 |
-
key = 'zero_convs.0.0.weight'
|
381 |
-
if pth_key in controlnet_data:
|
382 |
-
pth = True
|
383 |
-
key = pth_key
|
384 |
-
prefix = "control_model."
|
385 |
-
elif key in controlnet_data:
|
386 |
-
prefix = ""
|
387 |
-
else:
|
388 |
-
net = load_t2i_adapter(controlnet_data)
|
389 |
-
if net is None:
|
390 |
-
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
|
391 |
-
return net
|
392 |
-
|
393 |
-
if controlnet_config is None:
|
394 |
-
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
|
395 |
-
supported_inference_dtypes = model_config.supported_inference_dtypes
|
396 |
-
controlnet_config = model_config.unet_config
|
397 |
-
|
398 |
-
load_device = comfy.model_management.get_torch_device()
|
399 |
-
if supported_inference_dtypes is None:
|
400 |
-
unet_dtype = comfy.model_management.unet_dtype()
|
401 |
-
else:
|
402 |
-
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
403 |
-
|
404 |
-
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
405 |
-
if manual_cast_dtype is not None:
|
406 |
-
controlnet_config["operations"] = comfy.ops.manual_cast
|
407 |
-
controlnet_config["dtype"] = unet_dtype
|
408 |
-
controlnet_config.pop("out_channels")
|
409 |
-
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
410 |
-
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
411 |
-
|
412 |
-
if pth:
|
413 |
-
if 'difference' in controlnet_data:
|
414 |
-
if model is not None:
|
415 |
-
comfy.model_management.load_models_gpu([model])
|
416 |
-
model_sd = model.model_state_dict()
|
417 |
-
for x in controlnet_data:
|
418 |
-
c_m = "control_model."
|
419 |
-
if x.startswith(c_m):
|
420 |
-
sd_key = "diffusion_model.{}".format(x[len(c_m):])
|
421 |
-
if sd_key in model_sd:
|
422 |
-
cd = controlnet_data[x]
|
423 |
-
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
|
424 |
-
else:
|
425 |
-
logging.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
|
426 |
-
|
427 |
-
class WeightsLoader(torch.nn.Module):
|
428 |
-
pass
|
429 |
-
w = WeightsLoader()
|
430 |
-
w.control_model = control_model
|
431 |
-
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
|
432 |
-
else:
|
433 |
-
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
434 |
-
|
435 |
-
if len(missing) > 0:
|
436 |
-
logging.warning("missing controlnet keys: {}".format(missing))
|
437 |
-
|
438 |
-
if len(unexpected) > 0:
|
439 |
-
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
440 |
-
|
441 |
-
global_average_pooling = False
|
442 |
-
filename = os.path.splitext(ckpt_path)[0]
|
443 |
-
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
444 |
-
global_average_pooling = True
|
445 |
-
|
446 |
-
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
447 |
-
return control
|
448 |
-
|
449 |
-
class T2IAdapter(ControlBase):
|
450 |
-
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
|
451 |
-
super().__init__(device)
|
452 |
-
self.t2i_model = t2i_model
|
453 |
-
self.channels_in = channels_in
|
454 |
-
self.control_input = None
|
455 |
-
self.compression_ratio = compression_ratio
|
456 |
-
self.upscale_algorithm = upscale_algorithm
|
457 |
-
|
458 |
-
def scale_image_to(self, width, height):
|
459 |
-
unshuffle_amount = self.t2i_model.unshuffle_amount
|
460 |
-
width = math.ceil(width / unshuffle_amount) * unshuffle_amount
|
461 |
-
height = math.ceil(height / unshuffle_amount) * unshuffle_amount
|
462 |
-
return width, height
|
463 |
-
|
464 |
-
def get_control(self, x_noisy, t, cond, batched_number):
|
465 |
-
control_prev = None
|
466 |
-
if self.previous_controlnet is not None:
|
467 |
-
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
468 |
-
|
469 |
-
if self.timestep_range is not None:
|
470 |
-
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
471 |
-
if control_prev is not None:
|
472 |
-
return control_prev
|
473 |
-
else:
|
474 |
-
return None
|
475 |
-
|
476 |
-
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
|
477 |
-
if self.cond_hint is not None:
|
478 |
-
del self.cond_hint
|
479 |
-
self.control_input = None
|
480 |
-
self.cond_hint = None
|
481 |
-
width, height = self.scale_image_to(x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio)
|
482 |
-
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, width, height, self.upscale_algorithm, "center").float().to(self.device)
|
483 |
-
if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
|
484 |
-
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
|
485 |
-
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
486 |
-
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
487 |
-
if self.control_input is None:
|
488 |
-
self.t2i_model.to(x_noisy.dtype)
|
489 |
-
self.t2i_model.to(self.device)
|
490 |
-
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
|
491 |
-
self.t2i_model.cpu()
|
492 |
-
|
493 |
-
control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input))
|
494 |
-
mid = None
|
495 |
-
if self.t2i_model.xl == True:
|
496 |
-
mid = control_input[-1:]
|
497 |
-
control_input = control_input[:-1]
|
498 |
-
return self.control_merge(control_input, mid, control_prev, x_noisy.dtype)
|
499 |
-
|
500 |
-
def copy(self):
|
501 |
-
c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm)
|
502 |
-
self.copy_to(c)
|
503 |
-
return c
|
504 |
-
|
505 |
-
def load_t2i_adapter(t2i_data):
|
506 |
-
compression_ratio = 8
|
507 |
-
upscale_algorithm = 'nearest-exact'
|
508 |
-
|
509 |
-
if 'adapter' in t2i_data:
|
510 |
-
t2i_data = t2i_data['adapter']
|
511 |
-
if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format
|
512 |
-
prefix_replace = {}
|
513 |
-
for i in range(4):
|
514 |
-
for j in range(2):
|
515 |
-
prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j)
|
516 |
-
prefix_replace["adapter.body.{}.".format(i, j)] = "body.{}.".format(i * 2)
|
517 |
-
prefix_replace["adapter."] = ""
|
518 |
-
t2i_data = comfy.utils.state_dict_prefix_replace(t2i_data, prefix_replace)
|
519 |
-
keys = t2i_data.keys()
|
520 |
-
|
521 |
-
if "body.0.in_conv.weight" in keys:
|
522 |
-
cin = t2i_data['body.0.in_conv.weight'].shape[1]
|
523 |
-
model_ad = comfy.t2i_adapter.adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
|
524 |
-
elif 'conv_in.weight' in keys:
|
525 |
-
cin = t2i_data['conv_in.weight'].shape[1]
|
526 |
-
channel = t2i_data['conv_in.weight'].shape[0]
|
527 |
-
ksize = t2i_data['body.0.block2.weight'].shape[2]
|
528 |
-
use_conv = False
|
529 |
-
down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys))
|
530 |
-
if len(down_opts) > 0:
|
531 |
-
use_conv = True
|
532 |
-
xl = False
|
533 |
-
if cin == 256 or cin == 768:
|
534 |
-
xl = True
|
535 |
-
model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
|
536 |
-
elif "backbone.0.0.weight" in keys:
|
537 |
-
model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
|
538 |
-
compression_ratio = 32
|
539 |
-
upscale_algorithm = 'bilinear'
|
540 |
-
elif "backbone.10.blocks.0.weight" in keys:
|
541 |
-
model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.weight'].shape[1], bottleneck_mode="large", proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
|
542 |
-
compression_ratio = 1
|
543 |
-
upscale_algorithm = 'nearest-exact'
|
544 |
-
else:
|
545 |
-
return None
|
546 |
-
|
547 |
-
missing, unexpected = model_ad.load_state_dict(t2i_data)
|
548 |
-
if len(missing) > 0:
|
549 |
-
logging.warning("t2i missing {}".format(missing))
|
550 |
-
|
551 |
-
if len(unexpected) > 0:
|
552 |
-
logging.debug("t2i unexpected {}".format(unexpected))
|
553 |
-
|
554 |
-
return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/comfy/diffusers_convert.py
DELETED
@@ -1,281 +0,0 @@
|
|
1 |
-
import re
|
2 |
-
import torch
|
3 |
-
import logging
|
4 |
-
|
5 |
-
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
|
6 |
-
|
7 |
-
# =================#
|
8 |
-
# UNet Conversion #
|
9 |
-
# =================#
|
10 |
-
|
11 |
-
unet_conversion_map = [
|
12 |
-
# (stable-diffusion, HF Diffusers)
|
13 |
-
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
14 |
-
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
15 |
-
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
16 |
-
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
17 |
-
("input_blocks.0.0.weight", "conv_in.weight"),
|
18 |
-
("input_blocks.0.0.bias", "conv_in.bias"),
|
19 |
-
("out.0.weight", "conv_norm_out.weight"),
|
20 |
-
("out.0.bias", "conv_norm_out.bias"),
|
21 |
-
("out.2.weight", "conv_out.weight"),
|
22 |
-
("out.2.bias", "conv_out.bias"),
|
23 |
-
]
|
24 |
-
|
25 |
-
unet_conversion_map_resnet = [
|
26 |
-
# (stable-diffusion, HF Diffusers)
|
27 |
-
("in_layers.0", "norm1"),
|
28 |
-
("in_layers.2", "conv1"),
|
29 |
-
("out_layers.0", "norm2"),
|
30 |
-
("out_layers.3", "conv2"),
|
31 |
-
("emb_layers.1", "time_emb_proj"),
|
32 |
-
("skip_connection", "conv_shortcut"),
|
33 |
-
]
|
34 |
-
|
35 |
-
unet_conversion_map_layer = []
|
36 |
-
# hardcoded number of downblocks and resnets/attentions...
|
37 |
-
# would need smarter logic for other networks.
|
38 |
-
for i in range(4):
|
39 |
-
# loop over downblocks/upblocks
|
40 |
-
|
41 |
-
for j in range(2):
|
42 |
-
# loop over resnets/attentions for downblocks
|
43 |
-
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
44 |
-
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
|
45 |
-
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
46 |
-
|
47 |
-
if i < 3:
|
48 |
-
# no attention layers in down_blocks.3
|
49 |
-
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
50 |
-
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
|
51 |
-
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
52 |
-
|
53 |
-
for j in range(3):
|
54 |
-
# loop over resnets/attentions for upblocks
|
55 |
-
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
56 |
-
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
|
57 |
-
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
58 |
-
|
59 |
-
if i > 0:
|
60 |
-
# no attention layers in up_blocks.0
|
61 |
-
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
62 |
-
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
|
63 |
-
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
64 |
-
|
65 |
-
if i < 3:
|
66 |
-
# no downsample in down_blocks.3
|
67 |
-
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
68 |
-
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
|
69 |
-
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
70 |
-
|
71 |
-
# no upsample in up_blocks.3
|
72 |
-
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
73 |
-
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
|
74 |
-
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
75 |
-
|
76 |
-
hf_mid_atn_prefix = "mid_block.attentions.0."
|
77 |
-
sd_mid_atn_prefix = "middle_block.1."
|
78 |
-
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
79 |
-
|
80 |
-
for j in range(2):
|
81 |
-
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
82 |
-
sd_mid_res_prefix = f"middle_block.{2 * j}."
|
83 |
-
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
84 |
-
|
85 |
-
|
86 |
-
def convert_unet_state_dict(unet_state_dict):
|
87 |
-
# buyer beware: this is a *brittle* function,
|
88 |
-
# and correct output requires that all of these pieces interact in
|
89 |
-
# the exact order in which I have arranged them.
|
90 |
-
mapping = {k: k for k in unet_state_dict.keys()}
|
91 |
-
for sd_name, hf_name in unet_conversion_map:
|
92 |
-
mapping[hf_name] = sd_name
|
93 |
-
for k, v in mapping.items():
|
94 |
-
if "resnets" in k:
|
95 |
-
for sd_part, hf_part in unet_conversion_map_resnet:
|
96 |
-
v = v.replace(hf_part, sd_part)
|
97 |
-
mapping[k] = v
|
98 |
-
for k, v in mapping.items():
|
99 |
-
for sd_part, hf_part in unet_conversion_map_layer:
|
100 |
-
v = v.replace(hf_part, sd_part)
|
101 |
-
mapping[k] = v
|
102 |
-
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
103 |
-
return new_state_dict
|
104 |
-
|
105 |
-
|
106 |
-
# ================#
|
107 |
-
# VAE Conversion #
|
108 |
-
# ================#
|
109 |
-
|
110 |
-
vae_conversion_map = [
|
111 |
-
# (stable-diffusion, HF Diffusers)
|
112 |
-
("nin_shortcut", "conv_shortcut"),
|
113 |
-
("norm_out", "conv_norm_out"),
|
114 |
-
("mid.attn_1.", "mid_block.attentions.0."),
|
115 |
-
]
|
116 |
-
|
117 |
-
for i in range(4):
|
118 |
-
# down_blocks have two resnets
|
119 |
-
for j in range(2):
|
120 |
-
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
|
121 |
-
sd_down_prefix = f"encoder.down.{i}.block.{j}."
|
122 |
-
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
123 |
-
|
124 |
-
if i < 3:
|
125 |
-
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
|
126 |
-
sd_downsample_prefix = f"down.{i}.downsample."
|
127 |
-
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
128 |
-
|
129 |
-
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
130 |
-
sd_upsample_prefix = f"up.{3 - i}.upsample."
|
131 |
-
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
132 |
-
|
133 |
-
# up_blocks have three resnets
|
134 |
-
# also, up blocks in hf are numbered in reverse from sd
|
135 |
-
for j in range(3):
|
136 |
-
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
137 |
-
sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
|
138 |
-
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
139 |
-
|
140 |
-
# this part accounts for mid blocks in both the encoder and the decoder
|
141 |
-
for i in range(2):
|
142 |
-
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
143 |
-
sd_mid_res_prefix = f"mid.block_{i + 1}."
|
144 |
-
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
145 |
-
|
146 |
-
vae_conversion_map_attn = [
|
147 |
-
# (stable-diffusion, HF Diffusers)
|
148 |
-
("norm.", "group_norm."),
|
149 |
-
("q.", "query."),
|
150 |
-
("k.", "key."),
|
151 |
-
("v.", "value."),
|
152 |
-
("q.", "to_q."),
|
153 |
-
("k.", "to_k."),
|
154 |
-
("v.", "to_v."),
|
155 |
-
("proj_out.", "to_out.0."),
|
156 |
-
("proj_out.", "proj_attn."),
|
157 |
-
]
|
158 |
-
|
159 |
-
|
160 |
-
def reshape_weight_for_sd(w):
|
161 |
-
# convert HF linear weights to SD conv2d weights
|
162 |
-
return w.reshape(*w.shape, 1, 1)
|
163 |
-
|
164 |
-
|
165 |
-
def convert_vae_state_dict(vae_state_dict):
|
166 |
-
mapping = {k: k for k in vae_state_dict.keys()}
|
167 |
-
for k, v in mapping.items():
|
168 |
-
for sd_part, hf_part in vae_conversion_map:
|
169 |
-
v = v.replace(hf_part, sd_part)
|
170 |
-
mapping[k] = v
|
171 |
-
for k, v in mapping.items():
|
172 |
-
if "attentions" in k:
|
173 |
-
for sd_part, hf_part in vae_conversion_map_attn:
|
174 |
-
v = v.replace(hf_part, sd_part)
|
175 |
-
mapping[k] = v
|
176 |
-
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
177 |
-
weights_to_convert = ["q", "k", "v", "proj_out"]
|
178 |
-
for k, v in new_state_dict.items():
|
179 |
-
for weight_name in weights_to_convert:
|
180 |
-
if f"mid.attn_1.{weight_name}.weight" in k:
|
181 |
-
logging.debug(f"Reshaping {k} for SD format")
|
182 |
-
new_state_dict[k] = reshape_weight_for_sd(v)
|
183 |
-
return new_state_dict
|
184 |
-
|
185 |
-
|
186 |
-
# =========================#
|
187 |
-
# Text Encoder Conversion #
|
188 |
-
# =========================#
|
189 |
-
|
190 |
-
|
191 |
-
textenc_conversion_lst = [
|
192 |
-
# (stable-diffusion, HF Diffusers)
|
193 |
-
("resblocks.", "text_model.encoder.layers."),
|
194 |
-
("ln_1", "layer_norm1"),
|
195 |
-
("ln_2", "layer_norm2"),
|
196 |
-
(".c_fc.", ".fc1."),
|
197 |
-
(".c_proj.", ".fc2."),
|
198 |
-
(".attn", ".self_attn"),
|
199 |
-
("ln_final.", "transformer.text_model.final_layer_norm."),
|
200 |
-
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
|
201 |
-
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
|
202 |
-
]
|
203 |
-
protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
|
204 |
-
textenc_pattern = re.compile("|".join(protected.keys()))
|
205 |
-
|
206 |
-
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
|
207 |
-
code2idx = {"q": 0, "k": 1, "v": 2}
|
208 |
-
|
209 |
-
# This function exists because at the time of writing torch.cat can't do fp8 with cuda
|
210 |
-
def cat_tensors(tensors):
|
211 |
-
x = 0
|
212 |
-
for t in tensors:
|
213 |
-
x += t.shape[0]
|
214 |
-
|
215 |
-
shape = [x] + list(tensors[0].shape)[1:]
|
216 |
-
out = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype)
|
217 |
-
|
218 |
-
x = 0
|
219 |
-
for t in tensors:
|
220 |
-
out[x:x + t.shape[0]] = t
|
221 |
-
x += t.shape[0]
|
222 |
-
|
223 |
-
return out
|
224 |
-
|
225 |
-
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
226 |
-
new_state_dict = {}
|
227 |
-
capture_qkv_weight = {}
|
228 |
-
capture_qkv_bias = {}
|
229 |
-
for k, v in text_enc_dict.items():
|
230 |
-
if not k.startswith(prefix):
|
231 |
-
continue
|
232 |
-
if (
|
233 |
-
k.endswith(".self_attn.q_proj.weight")
|
234 |
-
or k.endswith(".self_attn.k_proj.weight")
|
235 |
-
or k.endswith(".self_attn.v_proj.weight")
|
236 |
-
):
|
237 |
-
k_pre = k[: -len(".q_proj.weight")]
|
238 |
-
k_code = k[-len("q_proj.weight")]
|
239 |
-
if k_pre not in capture_qkv_weight:
|
240 |
-
capture_qkv_weight[k_pre] = [None, None, None]
|
241 |
-
capture_qkv_weight[k_pre][code2idx[k_code]] = v
|
242 |
-
continue
|
243 |
-
|
244 |
-
if (
|
245 |
-
k.endswith(".self_attn.q_proj.bias")
|
246 |
-
or k.endswith(".self_attn.k_proj.bias")
|
247 |
-
or k.endswith(".self_attn.v_proj.bias")
|
248 |
-
):
|
249 |
-
k_pre = k[: -len(".q_proj.bias")]
|
250 |
-
k_code = k[-len("q_proj.bias")]
|
251 |
-
if k_pre not in capture_qkv_bias:
|
252 |
-
capture_qkv_bias[k_pre] = [None, None, None]
|
253 |
-
capture_qkv_bias[k_pre][code2idx[k_code]] = v
|
254 |
-
continue
|
255 |
-
|
256 |
-
text_proj = "transformer.text_projection.weight"
|
257 |
-
if k.endswith(text_proj):
|
258 |
-
new_state_dict[k.replace(text_proj, "text_projection")] = v.transpose(0, 1).contiguous()
|
259 |
-
else:
|
260 |
-
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
|
261 |
-
new_state_dict[relabelled_key] = v
|
262 |
-
|
263 |
-
for k_pre, tensors in capture_qkv_weight.items():
|
264 |
-
if None in tensors:
|
265 |
-
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
266 |
-
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
267 |
-
new_state_dict[relabelled_key + ".in_proj_weight"] = cat_tensors(tensors)
|
268 |
-
|
269 |
-
for k_pre, tensors in capture_qkv_bias.items():
|
270 |
-
if None in tensors:
|
271 |
-
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
272 |
-
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
273 |
-
new_state_dict[relabelled_key + ".in_proj_bias"] = cat_tensors(tensors)
|
274 |
-
|
275 |
-
return new_state_dict
|
276 |
-
|
277 |
-
|
278 |
-
def convert_text_enc_state_dict(text_enc_dict):
|
279 |
-
return text_enc_dict
|
280 |
-
|
281 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/comfy/diffusers_load.py
DELETED
@@ -1,36 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
|
3 |
-
import comfy.sd
|
4 |
-
|
5 |
-
def first_file(path, filenames):
|
6 |
-
for f in filenames:
|
7 |
-
p = os.path.join(path, f)
|
8 |
-
if os.path.exists(p):
|
9 |
-
return p
|
10 |
-
return None
|
11 |
-
|
12 |
-
def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_directory=None):
|
13 |
-
diffusion_model_names = ["diffusion_pytorch_model.fp16.safetensors", "diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.fp16.bin", "diffusion_pytorch_model.bin"]
|
14 |
-
unet_path = first_file(os.path.join(model_path, "unet"), diffusion_model_names)
|
15 |
-
vae_path = first_file(os.path.join(model_path, "vae"), diffusion_model_names)
|
16 |
-
|
17 |
-
text_encoder_model_names = ["model.fp16.safetensors", "model.safetensors", "pytorch_model.fp16.bin", "pytorch_model.bin"]
|
18 |
-
text_encoder1_path = first_file(os.path.join(model_path, "text_encoder"), text_encoder_model_names)
|
19 |
-
text_encoder2_path = first_file(os.path.join(model_path, "text_encoder_2"), text_encoder_model_names)
|
20 |
-
|
21 |
-
text_encoder_paths = [text_encoder1_path]
|
22 |
-
if text_encoder2_path is not None:
|
23 |
-
text_encoder_paths.append(text_encoder2_path)
|
24 |
-
|
25 |
-
unet = comfy.sd.load_unet(unet_path)
|
26 |
-
|
27 |
-
clip = None
|
28 |
-
if output_clip:
|
29 |
-
clip = comfy.sd.load_clip(text_encoder_paths, embedding_directory=embedding_directory)
|
30 |
-
|
31 |
-
vae = None
|
32 |
-
if output_vae:
|
33 |
-
sd = comfy.utils.load_torch_file(vae_path)
|
34 |
-
vae = comfy.sd.VAE(sd=sd)
|
35 |
-
|
36 |
-
return (unet, clip, vae)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/comfy/extra_samplers/__pycache__/uni_pc.cpython-310.pyc
DELETED
Binary file (28.5 kB)
|
|
MagicQuill/comfy/extra_samplers/uni_pc.py
DELETED
@@ -1,875 +0,0 @@
|
|
1 |
-
#code taken from: https://github.com/wl-zhao/UniPC and modified
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import torch.nn.functional as F
|
5 |
-
import math
|
6 |
-
|
7 |
-
from tqdm.auto import trange, tqdm
|
8 |
-
|
9 |
-
|
10 |
-
class NoiseScheduleVP:
|
11 |
-
def __init__(
|
12 |
-
self,
|
13 |
-
schedule='discrete',
|
14 |
-
betas=None,
|
15 |
-
alphas_cumprod=None,
|
16 |
-
continuous_beta_0=0.1,
|
17 |
-
continuous_beta_1=20.,
|
18 |
-
):
|
19 |
-
"""Create a wrapper class for the forward SDE (VP type).
|
20 |
-
|
21 |
-
***
|
22 |
-
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
|
23 |
-
We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
|
24 |
-
***
|
25 |
-
|
26 |
-
The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
|
27 |
-
We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
|
28 |
-
Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
|
29 |
-
|
30 |
-
log_alpha_t = self.marginal_log_mean_coeff(t)
|
31 |
-
sigma_t = self.marginal_std(t)
|
32 |
-
lambda_t = self.marginal_lambda(t)
|
33 |
-
|
34 |
-
Moreover, as lambda(t) is an invertible function, we also support its inverse function:
|
35 |
-
|
36 |
-
t = self.inverse_lambda(lambda_t)
|
37 |
-
|
38 |
-
===============================================================
|
39 |
-
|
40 |
-
We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
|
41 |
-
|
42 |
-
1. For discrete-time DPMs:
|
43 |
-
|
44 |
-
For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
|
45 |
-
t_i = (i + 1) / N
|
46 |
-
e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
|
47 |
-
We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
|
48 |
-
|
49 |
-
Args:
|
50 |
-
betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
|
51 |
-
alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
|
52 |
-
|
53 |
-
Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
|
54 |
-
|
55 |
-
**Important**: Please pay special attention for the args for `alphas_cumprod`:
|
56 |
-
The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
|
57 |
-
q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
|
58 |
-
Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
|
59 |
-
alpha_{t_n} = \sqrt{\hat{alpha_n}},
|
60 |
-
and
|
61 |
-
log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
|
62 |
-
|
63 |
-
|
64 |
-
2. For continuous-time DPMs:
|
65 |
-
|
66 |
-
We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
|
67 |
-
schedule are the default settings in DDPM and improved-DDPM:
|
68 |
-
|
69 |
-
Args:
|
70 |
-
beta_min: A `float` number. The smallest beta for the linear schedule.
|
71 |
-
beta_max: A `float` number. The largest beta for the linear schedule.
|
72 |
-
cosine_s: A `float` number. The hyperparameter in the cosine schedule.
|
73 |
-
cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
|
74 |
-
T: A `float` number. The ending time of the forward process.
|
75 |
-
|
76 |
-
===============================================================
|
77 |
-
|
78 |
-
Args:
|
79 |
-
schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
|
80 |
-
'linear' or 'cosine' for continuous-time DPMs.
|
81 |
-
Returns:
|
82 |
-
A wrapper object of the forward SDE (VP type).
|
83 |
-
|
84 |
-
===============================================================
|
85 |
-
|
86 |
-
Example:
|
87 |
-
|
88 |
-
# For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
|
89 |
-
>>> ns = NoiseScheduleVP('discrete', betas=betas)
|
90 |
-
|
91 |
-
# For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
|
92 |
-
>>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
93 |
-
|
94 |
-
# For continuous-time DPMs (VPSDE), linear schedule:
|
95 |
-
>>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
|
96 |
-
|
97 |
-
"""
|
98 |
-
|
99 |
-
if schedule not in ['discrete', 'linear', 'cosine']:
|
100 |
-
raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule))
|
101 |
-
|
102 |
-
self.schedule = schedule
|
103 |
-
if schedule == 'discrete':
|
104 |
-
if betas is not None:
|
105 |
-
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
|
106 |
-
else:
|
107 |
-
assert alphas_cumprod is not None
|
108 |
-
log_alphas = 0.5 * torch.log(alphas_cumprod)
|
109 |
-
self.total_N = len(log_alphas)
|
110 |
-
self.T = 1.
|
111 |
-
self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
|
112 |
-
self.log_alpha_array = log_alphas.reshape((1, -1,))
|
113 |
-
else:
|
114 |
-
self.total_N = 1000
|
115 |
-
self.beta_0 = continuous_beta_0
|
116 |
-
self.beta_1 = continuous_beta_1
|
117 |
-
self.cosine_s = 0.008
|
118 |
-
self.cosine_beta_max = 999.
|
119 |
-
self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
|
120 |
-
self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
|
121 |
-
self.schedule = schedule
|
122 |
-
if schedule == 'cosine':
|
123 |
-
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
|
124 |
-
# Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
|
125 |
-
self.T = 0.9946
|
126 |
-
else:
|
127 |
-
self.T = 1.
|
128 |
-
|
129 |
-
def marginal_log_mean_coeff(self, t):
|
130 |
-
"""
|
131 |
-
Compute log(alpha_t) of a given continuous-time label t in [0, T].
|
132 |
-
"""
|
133 |
-
if self.schedule == 'discrete':
|
134 |
-
return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
|
135 |
-
elif self.schedule == 'linear':
|
136 |
-
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
137 |
-
elif self.schedule == 'cosine':
|
138 |
-
log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
|
139 |
-
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
|
140 |
-
return log_alpha_t
|
141 |
-
|
142 |
-
def marginal_alpha(self, t):
|
143 |
-
"""
|
144 |
-
Compute alpha_t of a given continuous-time label t in [0, T].
|
145 |
-
"""
|
146 |
-
return torch.exp(self.marginal_log_mean_coeff(t))
|
147 |
-
|
148 |
-
def marginal_std(self, t):
|
149 |
-
"""
|
150 |
-
Compute sigma_t of a given continuous-time label t in [0, T].
|
151 |
-
"""
|
152 |
-
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
|
153 |
-
|
154 |
-
def marginal_lambda(self, t):
|
155 |
-
"""
|
156 |
-
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
|
157 |
-
"""
|
158 |
-
log_mean_coeff = self.marginal_log_mean_coeff(t)
|
159 |
-
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
|
160 |
-
return log_mean_coeff - log_std
|
161 |
-
|
162 |
-
def inverse_lambda(self, lamb):
|
163 |
-
"""
|
164 |
-
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
|
165 |
-
"""
|
166 |
-
if self.schedule == 'linear':
|
167 |
-
tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
168 |
-
Delta = self.beta_0**2 + tmp
|
169 |
-
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
|
170 |
-
elif self.schedule == 'discrete':
|
171 |
-
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
|
172 |
-
t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
|
173 |
-
return t.reshape((-1,))
|
174 |
-
else:
|
175 |
-
log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
176 |
-
t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
|
177 |
-
t = t_fn(log_alpha)
|
178 |
-
return t
|
179 |
-
|
180 |
-
|
181 |
-
def model_wrapper(
|
182 |
-
model,
|
183 |
-
noise_schedule,
|
184 |
-
model_type="noise",
|
185 |
-
model_kwargs={},
|
186 |
-
guidance_type="uncond",
|
187 |
-
condition=None,
|
188 |
-
unconditional_condition=None,
|
189 |
-
guidance_scale=1.,
|
190 |
-
classifier_fn=None,
|
191 |
-
classifier_kwargs={},
|
192 |
-
):
|
193 |
-
"""Create a wrapper function for the noise prediction model.
|
194 |
-
|
195 |
-
DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
|
196 |
-
firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
|
197 |
-
|
198 |
-
We support four types of the diffusion model by setting `model_type`:
|
199 |
-
|
200 |
-
1. "noise": noise prediction model. (Trained by predicting noise).
|
201 |
-
|
202 |
-
2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
|
203 |
-
|
204 |
-
3. "v": velocity prediction model. (Trained by predicting the velocity).
|
205 |
-
The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
|
206 |
-
|
207 |
-
[1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
|
208 |
-
arXiv preprint arXiv:2202.00512 (2022).
|
209 |
-
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
|
210 |
-
arXiv preprint arXiv:2210.02303 (2022).
|
211 |
-
|
212 |
-
4. "score": marginal score function. (Trained by denoising score matching).
|
213 |
-
Note that the score function and the noise prediction model follows a simple relationship:
|
214 |
-
```
|
215 |
-
noise(x_t, t) = -sigma_t * score(x_t, t)
|
216 |
-
```
|
217 |
-
|
218 |
-
We support three types of guided sampling by DPMs by setting `guidance_type`:
|
219 |
-
1. "uncond": unconditional sampling by DPMs.
|
220 |
-
The input `model` has the following format:
|
221 |
-
``
|
222 |
-
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
223 |
-
``
|
224 |
-
|
225 |
-
2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
|
226 |
-
The input `model` has the following format:
|
227 |
-
``
|
228 |
-
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
229 |
-
``
|
230 |
-
|
231 |
-
The input `classifier_fn` has the following format:
|
232 |
-
``
|
233 |
-
classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
|
234 |
-
``
|
235 |
-
|
236 |
-
[3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
|
237 |
-
in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
|
238 |
-
|
239 |
-
3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
|
240 |
-
The input `model` has the following format:
|
241 |
-
``
|
242 |
-
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
|
243 |
-
``
|
244 |
-
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
|
245 |
-
|
246 |
-
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
|
247 |
-
arXiv preprint arXiv:2207.12598 (2022).
|
248 |
-
|
249 |
-
|
250 |
-
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
|
251 |
-
or continuous-time labels (i.e. epsilon to T).
|
252 |
-
|
253 |
-
We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
|
254 |
-
``
|
255 |
-
def model_fn(x, t_continuous) -> noise:
|
256 |
-
t_input = get_model_input_time(t_continuous)
|
257 |
-
return noise_pred(model, x, t_input, **model_kwargs)
|
258 |
-
``
|
259 |
-
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
|
260 |
-
|
261 |
-
===============================================================
|
262 |
-
|
263 |
-
Args:
|
264 |
-
model: A diffusion model with the corresponding format described above.
|
265 |
-
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
|
266 |
-
model_type: A `str`. The parameterization type of the diffusion model.
|
267 |
-
"noise" or "x_start" or "v" or "score".
|
268 |
-
model_kwargs: A `dict`. A dict for the other inputs of the model function.
|
269 |
-
guidance_type: A `str`. The type of the guidance for sampling.
|
270 |
-
"uncond" or "classifier" or "classifier-free".
|
271 |
-
condition: A pytorch tensor. The condition for the guided sampling.
|
272 |
-
Only used for "classifier" or "classifier-free" guidance type.
|
273 |
-
unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
|
274 |
-
Only used for "classifier-free" guidance type.
|
275 |
-
guidance_scale: A `float`. The scale for the guided sampling.
|
276 |
-
classifier_fn: A classifier function. Only used for the classifier guidance.
|
277 |
-
classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
|
278 |
-
Returns:
|
279 |
-
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
280 |
-
"""
|
281 |
-
|
282 |
-
def get_model_input_time(t_continuous):
|
283 |
-
"""
|
284 |
-
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
285 |
-
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
|
286 |
-
For continuous-time DPMs, we just use `t_continuous`.
|
287 |
-
"""
|
288 |
-
if noise_schedule.schedule == 'discrete':
|
289 |
-
return (t_continuous - 1. / noise_schedule.total_N) * 1000.
|
290 |
-
else:
|
291 |
-
return t_continuous
|
292 |
-
|
293 |
-
def noise_pred_fn(x, t_continuous, cond=None):
|
294 |
-
if t_continuous.reshape((-1,)).shape[0] == 1:
|
295 |
-
t_continuous = t_continuous.expand((x.shape[0]))
|
296 |
-
t_input = get_model_input_time(t_continuous)
|
297 |
-
output = model(x, t_input, **model_kwargs)
|
298 |
-
if model_type == "noise":
|
299 |
-
return output
|
300 |
-
elif model_type == "x_start":
|
301 |
-
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
302 |
-
dims = x.dim()
|
303 |
-
return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
|
304 |
-
elif model_type == "v":
|
305 |
-
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
306 |
-
dims = x.dim()
|
307 |
-
return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
|
308 |
-
elif model_type == "score":
|
309 |
-
sigma_t = noise_schedule.marginal_std(t_continuous)
|
310 |
-
dims = x.dim()
|
311 |
-
return -expand_dims(sigma_t, dims) * output
|
312 |
-
|
313 |
-
def cond_grad_fn(x, t_input):
|
314 |
-
"""
|
315 |
-
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
|
316 |
-
"""
|
317 |
-
with torch.enable_grad():
|
318 |
-
x_in = x.detach().requires_grad_(True)
|
319 |
-
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
|
320 |
-
return torch.autograd.grad(log_prob.sum(), x_in)[0]
|
321 |
-
|
322 |
-
def model_fn(x, t_continuous):
|
323 |
-
"""
|
324 |
-
The noise predicition model function that is used for DPM-Solver.
|
325 |
-
"""
|
326 |
-
if t_continuous.reshape((-1,)).shape[0] == 1:
|
327 |
-
t_continuous = t_continuous.expand((x.shape[0]))
|
328 |
-
if guidance_type == "uncond":
|
329 |
-
return noise_pred_fn(x, t_continuous)
|
330 |
-
elif guidance_type == "classifier":
|
331 |
-
assert classifier_fn is not None
|
332 |
-
t_input = get_model_input_time(t_continuous)
|
333 |
-
cond_grad = cond_grad_fn(x, t_input)
|
334 |
-
sigma_t = noise_schedule.marginal_std(t_continuous)
|
335 |
-
noise = noise_pred_fn(x, t_continuous)
|
336 |
-
return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
|
337 |
-
elif guidance_type == "classifier-free":
|
338 |
-
if guidance_scale == 1. or unconditional_condition is None:
|
339 |
-
return noise_pred_fn(x, t_continuous, cond=condition)
|
340 |
-
else:
|
341 |
-
x_in = torch.cat([x] * 2)
|
342 |
-
t_in = torch.cat([t_continuous] * 2)
|
343 |
-
c_in = torch.cat([unconditional_condition, condition])
|
344 |
-
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
345 |
-
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
346 |
-
|
347 |
-
assert model_type in ["noise", "x_start", "v"]
|
348 |
-
assert guidance_type in ["uncond", "classifier", "classifier-free"]
|
349 |
-
return model_fn
|
350 |
-
|
351 |
-
|
352 |
-
class UniPC:
|
353 |
-
def __init__(
|
354 |
-
self,
|
355 |
-
model_fn,
|
356 |
-
noise_schedule,
|
357 |
-
predict_x0=True,
|
358 |
-
thresholding=False,
|
359 |
-
max_val=1.,
|
360 |
-
variant='bh1',
|
361 |
-
):
|
362 |
-
"""Construct a UniPC.
|
363 |
-
|
364 |
-
We support both data_prediction and noise_prediction.
|
365 |
-
"""
|
366 |
-
self.model = model_fn
|
367 |
-
self.noise_schedule = noise_schedule
|
368 |
-
self.variant = variant
|
369 |
-
self.predict_x0 = predict_x0
|
370 |
-
self.thresholding = thresholding
|
371 |
-
self.max_val = max_val
|
372 |
-
|
373 |
-
def dynamic_thresholding_fn(self, x0, t=None):
|
374 |
-
"""
|
375 |
-
The dynamic thresholding method.
|
376 |
-
"""
|
377 |
-
dims = x0.dim()
|
378 |
-
p = self.dynamic_thresholding_ratio
|
379 |
-
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
380 |
-
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
|
381 |
-
x0 = torch.clamp(x0, -s, s) / s
|
382 |
-
return x0
|
383 |
-
|
384 |
-
def noise_prediction_fn(self, x, t):
|
385 |
-
"""
|
386 |
-
Return the noise prediction model.
|
387 |
-
"""
|
388 |
-
return self.model(x, t)
|
389 |
-
|
390 |
-
def data_prediction_fn(self, x, t):
|
391 |
-
"""
|
392 |
-
Return the data prediction model (with thresholding).
|
393 |
-
"""
|
394 |
-
noise = self.noise_prediction_fn(x, t)
|
395 |
-
dims = x.dim()
|
396 |
-
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
397 |
-
x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
|
398 |
-
if self.thresholding:
|
399 |
-
p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
|
400 |
-
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
401 |
-
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
|
402 |
-
x0 = torch.clamp(x0, -s, s) / s
|
403 |
-
return x0
|
404 |
-
|
405 |
-
def model_fn(self, x, t):
|
406 |
-
"""
|
407 |
-
Convert the model to the noise prediction model or the data prediction model.
|
408 |
-
"""
|
409 |
-
if self.predict_x0:
|
410 |
-
return self.data_prediction_fn(x, t)
|
411 |
-
else:
|
412 |
-
return self.noise_prediction_fn(x, t)
|
413 |
-
|
414 |
-
def get_time_steps(self, skip_type, t_T, t_0, N, device):
|
415 |
-
"""Compute the intermediate time steps for sampling.
|
416 |
-
"""
|
417 |
-
if skip_type == 'logSNR':
|
418 |
-
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
|
419 |
-
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
|
420 |
-
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
|
421 |
-
return self.noise_schedule.inverse_lambda(logSNR_steps)
|
422 |
-
elif skip_type == 'time_uniform':
|
423 |
-
return torch.linspace(t_T, t_0, N + 1).to(device)
|
424 |
-
elif skip_type == 'time_quadratic':
|
425 |
-
t_order = 2
|
426 |
-
t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
|
427 |
-
return t
|
428 |
-
else:
|
429 |
-
raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
|
430 |
-
|
431 |
-
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
|
432 |
-
"""
|
433 |
-
Get the order of each step for sampling by the singlestep DPM-Solver.
|
434 |
-
"""
|
435 |
-
if order == 3:
|
436 |
-
K = steps // 3 + 1
|
437 |
-
if steps % 3 == 0:
|
438 |
-
orders = [3,] * (K - 2) + [2, 1]
|
439 |
-
elif steps % 3 == 1:
|
440 |
-
orders = [3,] * (K - 1) + [1]
|
441 |
-
else:
|
442 |
-
orders = [3,] * (K - 1) + [2]
|
443 |
-
elif order == 2:
|
444 |
-
if steps % 2 == 0:
|
445 |
-
K = steps // 2
|
446 |
-
orders = [2,] * K
|
447 |
-
else:
|
448 |
-
K = steps // 2 + 1
|
449 |
-
orders = [2,] * (K - 1) + [1]
|
450 |
-
elif order == 1:
|
451 |
-
K = steps
|
452 |
-
orders = [1,] * steps
|
453 |
-
else:
|
454 |
-
raise ValueError("'order' must be '1' or '2' or '3'.")
|
455 |
-
if skip_type == 'logSNR':
|
456 |
-
# To reproduce the results in DPM-Solver paper
|
457 |
-
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
|
458 |
-
else:
|
459 |
-
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
|
460 |
-
return timesteps_outer, orders
|
461 |
-
|
462 |
-
def denoise_to_zero_fn(self, x, s):
|
463 |
-
"""
|
464 |
-
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
|
465 |
-
"""
|
466 |
-
return self.data_prediction_fn(x, s)
|
467 |
-
|
468 |
-
def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs):
|
469 |
-
if len(t.shape) == 0:
|
470 |
-
t = t.view(-1)
|
471 |
-
if 'bh' in self.variant:
|
472 |
-
return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
473 |
-
else:
|
474 |
-
assert self.variant == 'vary_coeff'
|
475 |
-
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
476 |
-
|
477 |
-
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
|
478 |
-
print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
479 |
-
ns = self.noise_schedule
|
480 |
-
assert order <= len(model_prev_list)
|
481 |
-
|
482 |
-
# first compute rks
|
483 |
-
t_prev_0 = t_prev_list[-1]
|
484 |
-
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
|
485 |
-
lambda_t = ns.marginal_lambda(t)
|
486 |
-
model_prev_0 = model_prev_list[-1]
|
487 |
-
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
488 |
-
log_alpha_t = ns.marginal_log_mean_coeff(t)
|
489 |
-
alpha_t = torch.exp(log_alpha_t)
|
490 |
-
|
491 |
-
h = lambda_t - lambda_prev_0
|
492 |
-
|
493 |
-
rks = []
|
494 |
-
D1s = []
|
495 |
-
for i in range(1, order):
|
496 |
-
t_prev_i = t_prev_list[-(i + 1)]
|
497 |
-
model_prev_i = model_prev_list[-(i + 1)]
|
498 |
-
lambda_prev_i = ns.marginal_lambda(t_prev_i)
|
499 |
-
rk = (lambda_prev_i - lambda_prev_0) / h
|
500 |
-
rks.append(rk)
|
501 |
-
D1s.append((model_prev_i - model_prev_0) / rk)
|
502 |
-
|
503 |
-
rks.append(1.)
|
504 |
-
rks = torch.tensor(rks, device=x.device)
|
505 |
-
|
506 |
-
K = len(rks)
|
507 |
-
# build C matrix
|
508 |
-
C = []
|
509 |
-
|
510 |
-
col = torch.ones_like(rks)
|
511 |
-
for k in range(1, K + 1):
|
512 |
-
C.append(col)
|
513 |
-
col = col * rks / (k + 1)
|
514 |
-
C = torch.stack(C, dim=1)
|
515 |
-
|
516 |
-
if len(D1s) > 0:
|
517 |
-
D1s = torch.stack(D1s, dim=1) # (B, K)
|
518 |
-
C_inv_p = torch.linalg.inv(C[:-1, :-1])
|
519 |
-
A_p = C_inv_p
|
520 |
-
|
521 |
-
if use_corrector:
|
522 |
-
print('using corrector')
|
523 |
-
C_inv = torch.linalg.inv(C)
|
524 |
-
A_c = C_inv
|
525 |
-
|
526 |
-
hh = -h if self.predict_x0 else h
|
527 |
-
h_phi_1 = torch.expm1(hh)
|
528 |
-
h_phi_ks = []
|
529 |
-
factorial_k = 1
|
530 |
-
h_phi_k = h_phi_1
|
531 |
-
for k in range(1, K + 2):
|
532 |
-
h_phi_ks.append(h_phi_k)
|
533 |
-
h_phi_k = h_phi_k / hh - 1 / factorial_k
|
534 |
-
factorial_k *= (k + 1)
|
535 |
-
|
536 |
-
model_t = None
|
537 |
-
if self.predict_x0:
|
538 |
-
x_t_ = (
|
539 |
-
sigma_t / sigma_prev_0 * x
|
540 |
-
- alpha_t * h_phi_1 * model_prev_0
|
541 |
-
)
|
542 |
-
# now predictor
|
543 |
-
x_t = x_t_
|
544 |
-
if len(D1s) > 0:
|
545 |
-
# compute the residuals for predictor
|
546 |
-
for k in range(K - 1):
|
547 |
-
x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
|
548 |
-
# now corrector
|
549 |
-
if use_corrector:
|
550 |
-
model_t = self.model_fn(x_t, t)
|
551 |
-
D1_t = (model_t - model_prev_0)
|
552 |
-
x_t = x_t_
|
553 |
-
k = 0
|
554 |
-
for k in range(K - 1):
|
555 |
-
x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
|
556 |
-
x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
|
557 |
-
else:
|
558 |
-
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
559 |
-
x_t_ = (
|
560 |
-
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
|
561 |
-
- (sigma_t * h_phi_1) * model_prev_0
|
562 |
-
)
|
563 |
-
# now predictor
|
564 |
-
x_t = x_t_
|
565 |
-
if len(D1s) > 0:
|
566 |
-
# compute the residuals for predictor
|
567 |
-
for k in range(K - 1):
|
568 |
-
x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
|
569 |
-
# now corrector
|
570 |
-
if use_corrector:
|
571 |
-
model_t = self.model_fn(x_t, t)
|
572 |
-
D1_t = (model_t - model_prev_0)
|
573 |
-
x_t = x_t_
|
574 |
-
k = 0
|
575 |
-
for k in range(K - 1):
|
576 |
-
x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
|
577 |
-
x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
|
578 |
-
return x_t, model_t
|
579 |
-
|
580 |
-
def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
|
581 |
-
# print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
|
582 |
-
ns = self.noise_schedule
|
583 |
-
assert order <= len(model_prev_list)
|
584 |
-
dims = x.dim()
|
585 |
-
|
586 |
-
# first compute rks
|
587 |
-
t_prev_0 = t_prev_list[-1]
|
588 |
-
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
|
589 |
-
lambda_t = ns.marginal_lambda(t)
|
590 |
-
model_prev_0 = model_prev_list[-1]
|
591 |
-
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
592 |
-
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
593 |
-
alpha_t = torch.exp(log_alpha_t)
|
594 |
-
|
595 |
-
h = lambda_t - lambda_prev_0
|
596 |
-
|
597 |
-
rks = []
|
598 |
-
D1s = []
|
599 |
-
for i in range(1, order):
|
600 |
-
t_prev_i = t_prev_list[-(i + 1)]
|
601 |
-
model_prev_i = model_prev_list[-(i + 1)]
|
602 |
-
lambda_prev_i = ns.marginal_lambda(t_prev_i)
|
603 |
-
rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
|
604 |
-
rks.append(rk)
|
605 |
-
D1s.append((model_prev_i - model_prev_0) / rk)
|
606 |
-
|
607 |
-
rks.append(1.)
|
608 |
-
rks = torch.tensor(rks, device=x.device)
|
609 |
-
|
610 |
-
R = []
|
611 |
-
b = []
|
612 |
-
|
613 |
-
hh = -h[0] if self.predict_x0 else h[0]
|
614 |
-
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
615 |
-
h_phi_k = h_phi_1 / hh - 1
|
616 |
-
|
617 |
-
factorial_i = 1
|
618 |
-
|
619 |
-
if self.variant == 'bh1':
|
620 |
-
B_h = hh
|
621 |
-
elif self.variant == 'bh2':
|
622 |
-
B_h = torch.expm1(hh)
|
623 |
-
else:
|
624 |
-
raise NotImplementedError()
|
625 |
-
|
626 |
-
for i in range(1, order + 1):
|
627 |
-
R.append(torch.pow(rks, i - 1))
|
628 |
-
b.append(h_phi_k * factorial_i / B_h)
|
629 |
-
factorial_i *= (i + 1)
|
630 |
-
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
631 |
-
|
632 |
-
R = torch.stack(R)
|
633 |
-
b = torch.tensor(b, device=x.device)
|
634 |
-
|
635 |
-
# now predictor
|
636 |
-
use_predictor = len(D1s) > 0 and x_t is None
|
637 |
-
if len(D1s) > 0:
|
638 |
-
D1s = torch.stack(D1s, dim=1) # (B, K)
|
639 |
-
if x_t is None:
|
640 |
-
# for order 2, we use a simplified version
|
641 |
-
if order == 2:
|
642 |
-
rhos_p = torch.tensor([0.5], device=b.device)
|
643 |
-
else:
|
644 |
-
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
|
645 |
-
else:
|
646 |
-
D1s = None
|
647 |
-
|
648 |
-
if use_corrector:
|
649 |
-
# print('using corrector')
|
650 |
-
# for order 1, we use a simplified version
|
651 |
-
if order == 1:
|
652 |
-
rhos_c = torch.tensor([0.5], device=b.device)
|
653 |
-
else:
|
654 |
-
rhos_c = torch.linalg.solve(R, b)
|
655 |
-
|
656 |
-
model_t = None
|
657 |
-
if self.predict_x0:
|
658 |
-
x_t_ = (
|
659 |
-
expand_dims(sigma_t / sigma_prev_0, dims) * x
|
660 |
-
- expand_dims(alpha_t * h_phi_1, dims)* model_prev_0
|
661 |
-
)
|
662 |
-
|
663 |
-
if x_t is None:
|
664 |
-
if use_predictor:
|
665 |
-
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
666 |
-
else:
|
667 |
-
pred_res = 0
|
668 |
-
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
|
669 |
-
|
670 |
-
if use_corrector:
|
671 |
-
model_t = self.model_fn(x_t, t)
|
672 |
-
if D1s is not None:
|
673 |
-
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
674 |
-
else:
|
675 |
-
corr_res = 0
|
676 |
-
D1_t = (model_t - model_prev_0)
|
677 |
-
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
|
678 |
-
else:
|
679 |
-
x_t_ = (
|
680 |
-
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
681 |
-
- expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
|
682 |
-
)
|
683 |
-
if x_t is None:
|
684 |
-
if use_predictor:
|
685 |
-
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
686 |
-
else:
|
687 |
-
pred_res = 0
|
688 |
-
x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res
|
689 |
-
|
690 |
-
if use_corrector:
|
691 |
-
model_t = self.model_fn(x_t, t)
|
692 |
-
if D1s is not None:
|
693 |
-
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
694 |
-
else:
|
695 |
-
corr_res = 0
|
696 |
-
D1_t = (model_t - model_prev_0)
|
697 |
-
x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
|
698 |
-
return x_t, model_t
|
699 |
-
|
700 |
-
|
701 |
-
def sample(self, x, timesteps, t_start=None, t_end=None, order=3, skip_type='time_uniform',
|
702 |
-
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
|
703 |
-
atol=0.0078, rtol=0.05, corrector=False, callback=None, disable_pbar=False
|
704 |
-
):
|
705 |
-
# t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
706 |
-
# t_T = self.noise_schedule.T if t_start is None else t_start
|
707 |
-
device = x.device
|
708 |
-
steps = len(timesteps) - 1
|
709 |
-
if method == 'multistep':
|
710 |
-
assert steps >= order
|
711 |
-
# timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
712 |
-
assert timesteps.shape[0] - 1 == steps
|
713 |
-
# with torch.no_grad():
|
714 |
-
for step_index in trange(steps, disable=disable_pbar):
|
715 |
-
if step_index == 0:
|
716 |
-
vec_t = timesteps[0].expand((x.shape[0]))
|
717 |
-
model_prev_list = [self.model_fn(x, vec_t)]
|
718 |
-
t_prev_list = [vec_t]
|
719 |
-
elif step_index < order:
|
720 |
-
init_order = step_index
|
721 |
-
# Init the first `order` values by lower order multistep DPM-Solver.
|
722 |
-
# for init_order in range(1, order):
|
723 |
-
vec_t = timesteps[init_order].expand(x.shape[0])
|
724 |
-
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
|
725 |
-
if model_x is None:
|
726 |
-
model_x = self.model_fn(x, vec_t)
|
727 |
-
model_prev_list.append(model_x)
|
728 |
-
t_prev_list.append(vec_t)
|
729 |
-
else:
|
730 |
-
extra_final_step = 0
|
731 |
-
if step_index == (steps - 1):
|
732 |
-
extra_final_step = 1
|
733 |
-
for step in range(step_index, step_index + 1 + extra_final_step):
|
734 |
-
vec_t = timesteps[step].expand(x.shape[0])
|
735 |
-
if lower_order_final:
|
736 |
-
step_order = min(order, steps + 1 - step)
|
737 |
-
else:
|
738 |
-
step_order = order
|
739 |
-
# print('this step order:', step_order)
|
740 |
-
if step == steps:
|
741 |
-
# print('do not run corrector at the last step')
|
742 |
-
use_corrector = False
|
743 |
-
else:
|
744 |
-
use_corrector = True
|
745 |
-
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
746 |
-
for i in range(order - 1):
|
747 |
-
t_prev_list[i] = t_prev_list[i + 1]
|
748 |
-
model_prev_list[i] = model_prev_list[i + 1]
|
749 |
-
t_prev_list[-1] = vec_t
|
750 |
-
# We do not need to evaluate the final model value.
|
751 |
-
if step < steps:
|
752 |
-
if model_x is None:
|
753 |
-
model_x = self.model_fn(x, vec_t)
|
754 |
-
model_prev_list[-1] = model_x
|
755 |
-
if callback is not None:
|
756 |
-
callback({'x': x, 'i': step_index, 'denoised': model_prev_list[-1]})
|
757 |
-
else:
|
758 |
-
raise NotImplementedError()
|
759 |
-
# if denoise_to_zero:
|
760 |
-
# x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
|
761 |
-
return x
|
762 |
-
|
763 |
-
|
764 |
-
#############################################################
|
765 |
-
# other utility functions
|
766 |
-
#############################################################
|
767 |
-
|
768 |
-
def interpolate_fn(x, xp, yp):
|
769 |
-
"""
|
770 |
-
A piecewise linear function y = f(x), using xp and yp as keypoints.
|
771 |
-
We implement f(x) in a differentiable way (i.e. applicable for autograd).
|
772 |
-
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
|
773 |
-
|
774 |
-
Args:
|
775 |
-
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
|
776 |
-
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
|
777 |
-
yp: PyTorch tensor with shape [C, K].
|
778 |
-
Returns:
|
779 |
-
The function values f(x), with shape [N, C].
|
780 |
-
"""
|
781 |
-
N, K = x.shape[0], xp.shape[1]
|
782 |
-
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
|
783 |
-
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
|
784 |
-
x_idx = torch.argmin(x_indices, dim=2)
|
785 |
-
cand_start_idx = x_idx - 1
|
786 |
-
start_idx = torch.where(
|
787 |
-
torch.eq(x_idx, 0),
|
788 |
-
torch.tensor(1, device=x.device),
|
789 |
-
torch.where(
|
790 |
-
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
791 |
-
),
|
792 |
-
)
|
793 |
-
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
|
794 |
-
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
|
795 |
-
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
|
796 |
-
start_idx2 = torch.where(
|
797 |
-
torch.eq(x_idx, 0),
|
798 |
-
torch.tensor(0, device=x.device),
|
799 |
-
torch.where(
|
800 |
-
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
801 |
-
),
|
802 |
-
)
|
803 |
-
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
|
804 |
-
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
|
805 |
-
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
|
806 |
-
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
|
807 |
-
return cand
|
808 |
-
|
809 |
-
|
810 |
-
def expand_dims(v, dims):
|
811 |
-
"""
|
812 |
-
Expand the tensor `v` to the dim `dims`.
|
813 |
-
|
814 |
-
Args:
|
815 |
-
`v`: a PyTorch tensor with shape [N].
|
816 |
-
`dim`: a `int`.
|
817 |
-
Returns:
|
818 |
-
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
|
819 |
-
"""
|
820 |
-
return v[(...,) + (None,)*(dims - 1)]
|
821 |
-
|
822 |
-
|
823 |
-
class SigmaConvert:
|
824 |
-
schedule = ""
|
825 |
-
def marginal_log_mean_coeff(self, sigma):
|
826 |
-
return 0.5 * torch.log(1 / ((sigma * sigma) + 1))
|
827 |
-
|
828 |
-
def marginal_alpha(self, t):
|
829 |
-
return torch.exp(self.marginal_log_mean_coeff(t))
|
830 |
-
|
831 |
-
def marginal_std(self, t):
|
832 |
-
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
|
833 |
-
|
834 |
-
def marginal_lambda(self, t):
|
835 |
-
"""
|
836 |
-
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
|
837 |
-
"""
|
838 |
-
log_mean_coeff = self.marginal_log_mean_coeff(t)
|
839 |
-
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
|
840 |
-
return log_mean_coeff - log_std
|
841 |
-
|
842 |
-
def predict_eps_sigma(model, input, sigma_in, **kwargs):
|
843 |
-
sigma = sigma_in.view(sigma_in.shape[:1] + (1,) * (input.ndim - 1))
|
844 |
-
input = input * ((sigma ** 2 + 1.0) ** 0.5)
|
845 |
-
return (input - model(input, sigma_in, **kwargs)) / sigma
|
846 |
-
|
847 |
-
|
848 |
-
def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'):
|
849 |
-
timesteps = sigmas.clone()
|
850 |
-
if sigmas[-1] == 0:
|
851 |
-
timesteps = sigmas[:]
|
852 |
-
timesteps[-1] = 0.001
|
853 |
-
else:
|
854 |
-
timesteps = sigmas.clone()
|
855 |
-
ns = SigmaConvert()
|
856 |
-
|
857 |
-
noise = noise / torch.sqrt(1.0 + timesteps[0] ** 2.0)
|
858 |
-
model_type = "noise"
|
859 |
-
|
860 |
-
model_fn = model_wrapper(
|
861 |
-
lambda input, sigma, **kwargs: predict_eps_sigma(model, input, sigma, **kwargs),
|
862 |
-
ns,
|
863 |
-
model_type=model_type,
|
864 |
-
guidance_type="uncond",
|
865 |
-
model_kwargs=extra_args,
|
866 |
-
)
|
867 |
-
|
868 |
-
order = min(3, len(timesteps) - 2)
|
869 |
-
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=variant)
|
870 |
-
x = uni_pc.sample(noise, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
|
871 |
-
x /= ns.marginal_alpha(timesteps[-1])
|
872 |
-
return x
|
873 |
-
|
874 |
-
def sample_unipc_bh2(model, noise, sigmas, extra_args=None, callback=None, disable=False):
|
875 |
-
return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/comfy/gligen.py
DELETED
@@ -1,343 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch import nn
|
3 |
-
from .ldm.modules.attention import CrossAttention
|
4 |
-
from inspect import isfunction
|
5 |
-
import comfy.ops
|
6 |
-
ops = comfy.ops.manual_cast
|
7 |
-
|
8 |
-
def exists(val):
|
9 |
-
return val is not None
|
10 |
-
|
11 |
-
|
12 |
-
def uniq(arr):
|
13 |
-
return{el: True for el in arr}.keys()
|
14 |
-
|
15 |
-
|
16 |
-
def default(val, d):
|
17 |
-
if exists(val):
|
18 |
-
return val
|
19 |
-
return d() if isfunction(d) else d
|
20 |
-
|
21 |
-
|
22 |
-
# feedforward
|
23 |
-
class GEGLU(nn.Module):
|
24 |
-
def __init__(self, dim_in, dim_out):
|
25 |
-
super().__init__()
|
26 |
-
self.proj = ops.Linear(dim_in, dim_out * 2)
|
27 |
-
|
28 |
-
def forward(self, x):
|
29 |
-
x, gate = self.proj(x).chunk(2, dim=-1)
|
30 |
-
return x * torch.nn.functional.gelu(gate)
|
31 |
-
|
32 |
-
|
33 |
-
class FeedForward(nn.Module):
|
34 |
-
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
35 |
-
super().__init__()
|
36 |
-
inner_dim = int(dim * mult)
|
37 |
-
dim_out = default(dim_out, dim)
|
38 |
-
project_in = nn.Sequential(
|
39 |
-
ops.Linear(dim, inner_dim),
|
40 |
-
nn.GELU()
|
41 |
-
) if not glu else GEGLU(dim, inner_dim)
|
42 |
-
|
43 |
-
self.net = nn.Sequential(
|
44 |
-
project_in,
|
45 |
-
nn.Dropout(dropout),
|
46 |
-
ops.Linear(inner_dim, dim_out)
|
47 |
-
)
|
48 |
-
|
49 |
-
def forward(self, x):
|
50 |
-
return self.net(x)
|
51 |
-
|
52 |
-
|
53 |
-
class GatedCrossAttentionDense(nn.Module):
|
54 |
-
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
55 |
-
super().__init__()
|
56 |
-
|
57 |
-
self.attn = CrossAttention(
|
58 |
-
query_dim=query_dim,
|
59 |
-
context_dim=context_dim,
|
60 |
-
heads=n_heads,
|
61 |
-
dim_head=d_head,
|
62 |
-
operations=ops)
|
63 |
-
self.ff = FeedForward(query_dim, glu=True)
|
64 |
-
|
65 |
-
self.norm1 = ops.LayerNorm(query_dim)
|
66 |
-
self.norm2 = ops.LayerNorm(query_dim)
|
67 |
-
|
68 |
-
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
69 |
-
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
70 |
-
|
71 |
-
# this can be useful: we can externally change magnitude of tanh(alpha)
|
72 |
-
# for example, when it is set to 0, then the entire model is same as
|
73 |
-
# original one
|
74 |
-
self.scale = 1
|
75 |
-
|
76 |
-
def forward(self, x, objs):
|
77 |
-
|
78 |
-
x = x + self.scale * \
|
79 |
-
torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs)
|
80 |
-
x = x + self.scale * \
|
81 |
-
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
|
82 |
-
|
83 |
-
return x
|
84 |
-
|
85 |
-
|
86 |
-
class GatedSelfAttentionDense(nn.Module):
|
87 |
-
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
88 |
-
super().__init__()
|
89 |
-
|
90 |
-
# we need a linear projection since we need cat visual feature and obj
|
91 |
-
# feature
|
92 |
-
self.linear = ops.Linear(context_dim, query_dim)
|
93 |
-
|
94 |
-
self.attn = CrossAttention(
|
95 |
-
query_dim=query_dim,
|
96 |
-
context_dim=query_dim,
|
97 |
-
heads=n_heads,
|
98 |
-
dim_head=d_head,
|
99 |
-
operations=ops)
|
100 |
-
self.ff = FeedForward(query_dim, glu=True)
|
101 |
-
|
102 |
-
self.norm1 = ops.LayerNorm(query_dim)
|
103 |
-
self.norm2 = ops.LayerNorm(query_dim)
|
104 |
-
|
105 |
-
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
106 |
-
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
107 |
-
|
108 |
-
# this can be useful: we can externally change magnitude of tanh(alpha)
|
109 |
-
# for example, when it is set to 0, then the entire model is same as
|
110 |
-
# original one
|
111 |
-
self.scale = 1
|
112 |
-
|
113 |
-
def forward(self, x, objs):
|
114 |
-
|
115 |
-
N_visual = x.shape[1]
|
116 |
-
objs = self.linear(objs)
|
117 |
-
|
118 |
-
x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn(
|
119 |
-
self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :]
|
120 |
-
x = x + self.scale * \
|
121 |
-
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
|
122 |
-
|
123 |
-
return x
|
124 |
-
|
125 |
-
|
126 |
-
class GatedSelfAttentionDense2(nn.Module):
|
127 |
-
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
128 |
-
super().__init__()
|
129 |
-
|
130 |
-
# we need a linear projection since we need cat visual feature and obj
|
131 |
-
# feature
|
132 |
-
self.linear = ops.Linear(context_dim, query_dim)
|
133 |
-
|
134 |
-
self.attn = CrossAttention(
|
135 |
-
query_dim=query_dim, context_dim=query_dim, dim_head=d_head, operations=ops)
|
136 |
-
self.ff = FeedForward(query_dim, glu=True)
|
137 |
-
|
138 |
-
self.norm1 = ops.LayerNorm(query_dim)
|
139 |
-
self.norm2 = ops.LayerNorm(query_dim)
|
140 |
-
|
141 |
-
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
142 |
-
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
143 |
-
|
144 |
-
# this can be useful: we can externally change magnitude of tanh(alpha)
|
145 |
-
# for example, when it is set to 0, then the entire model is same as
|
146 |
-
# original one
|
147 |
-
self.scale = 1
|
148 |
-
|
149 |
-
def forward(self, x, objs):
|
150 |
-
|
151 |
-
B, N_visual, _ = x.shape
|
152 |
-
B, N_ground, _ = objs.shape
|
153 |
-
|
154 |
-
objs = self.linear(objs)
|
155 |
-
|
156 |
-
# sanity check
|
157 |
-
size_v = math.sqrt(N_visual)
|
158 |
-
size_g = math.sqrt(N_ground)
|
159 |
-
assert int(size_v) == size_v, "Visual tokens must be square rootable"
|
160 |
-
assert int(size_g) == size_g, "Grounding tokens must be square rootable"
|
161 |
-
size_v = int(size_v)
|
162 |
-
size_g = int(size_g)
|
163 |
-
|
164 |
-
# select grounding token and resize it to visual token size as residual
|
165 |
-
out = self.attn(self.norm1(torch.cat([x, objs], dim=1)))[
|
166 |
-
:, N_visual:, :]
|
167 |
-
out = out.permute(0, 2, 1).reshape(B, -1, size_g, size_g)
|
168 |
-
out = torch.nn.functional.interpolate(
|
169 |
-
out, (size_v, size_v), mode='bicubic')
|
170 |
-
residual = out.reshape(B, -1, N_visual).permute(0, 2, 1)
|
171 |
-
|
172 |
-
# add residual to visual feature
|
173 |
-
x = x + self.scale * torch.tanh(self.alpha_attn) * residual
|
174 |
-
x = x + self.scale * \
|
175 |
-
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
|
176 |
-
|
177 |
-
return x
|
178 |
-
|
179 |
-
|
180 |
-
class FourierEmbedder():
|
181 |
-
def __init__(self, num_freqs=64, temperature=100):
|
182 |
-
|
183 |
-
self.num_freqs = num_freqs
|
184 |
-
self.temperature = temperature
|
185 |
-
self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
|
186 |
-
|
187 |
-
@torch.no_grad()
|
188 |
-
def __call__(self, x, cat_dim=-1):
|
189 |
-
"x: arbitrary shape of tensor. dim: cat dim"
|
190 |
-
out = []
|
191 |
-
for freq in self.freq_bands:
|
192 |
-
out.append(torch.sin(freq * x))
|
193 |
-
out.append(torch.cos(freq * x))
|
194 |
-
return torch.cat(out, cat_dim)
|
195 |
-
|
196 |
-
|
197 |
-
class PositionNet(nn.Module):
|
198 |
-
def __init__(self, in_dim, out_dim, fourier_freqs=8):
|
199 |
-
super().__init__()
|
200 |
-
self.in_dim = in_dim
|
201 |
-
self.out_dim = out_dim
|
202 |
-
|
203 |
-
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
|
204 |
-
self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy
|
205 |
-
|
206 |
-
self.linears = nn.Sequential(
|
207 |
-
ops.Linear(self.in_dim + self.position_dim, 512),
|
208 |
-
nn.SiLU(),
|
209 |
-
ops.Linear(512, 512),
|
210 |
-
nn.SiLU(),
|
211 |
-
ops.Linear(512, out_dim),
|
212 |
-
)
|
213 |
-
|
214 |
-
self.null_positive_feature = torch.nn.Parameter(
|
215 |
-
torch.zeros([self.in_dim]))
|
216 |
-
self.null_position_feature = torch.nn.Parameter(
|
217 |
-
torch.zeros([self.position_dim]))
|
218 |
-
|
219 |
-
def forward(self, boxes, masks, positive_embeddings):
|
220 |
-
B, N, _ = boxes.shape
|
221 |
-
masks = masks.unsqueeze(-1)
|
222 |
-
positive_embeddings = positive_embeddings
|
223 |
-
|
224 |
-
# embedding position (it may includes padding as placeholder)
|
225 |
-
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
|
226 |
-
|
227 |
-
# learnable null embedding
|
228 |
-
positive_null = self.null_positive_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1)
|
229 |
-
xyxy_null = self.null_position_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1)
|
230 |
-
|
231 |
-
# replace padding with learnable null embedding
|
232 |
-
positive_embeddings = positive_embeddings * \
|
233 |
-
masks + (1 - masks) * positive_null
|
234 |
-
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
|
235 |
-
|
236 |
-
objs = self.linears(
|
237 |
-
torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
|
238 |
-
assert objs.shape == torch.Size([B, N, self.out_dim])
|
239 |
-
return objs
|
240 |
-
|
241 |
-
|
242 |
-
class Gligen(nn.Module):
|
243 |
-
def __init__(self, modules, position_net, key_dim):
|
244 |
-
super().__init__()
|
245 |
-
self.module_list = nn.ModuleList(modules)
|
246 |
-
self.position_net = position_net
|
247 |
-
self.key_dim = key_dim
|
248 |
-
self.max_objs = 30
|
249 |
-
self.current_device = torch.device("cpu")
|
250 |
-
|
251 |
-
def _set_position(self, boxes, masks, positive_embeddings):
|
252 |
-
objs = self.position_net(boxes, masks, positive_embeddings)
|
253 |
-
def func(x, extra_options):
|
254 |
-
key = extra_options["transformer_index"]
|
255 |
-
module = self.module_list[key]
|
256 |
-
return module(x, objs.to(device=x.device, dtype=x.dtype))
|
257 |
-
return func
|
258 |
-
|
259 |
-
def set_position(self, latent_image_shape, position_params, device):
|
260 |
-
batch, c, h, w = latent_image_shape
|
261 |
-
masks = torch.zeros([self.max_objs], device="cpu")
|
262 |
-
boxes = []
|
263 |
-
positive_embeddings = []
|
264 |
-
for p in position_params:
|
265 |
-
x1 = (p[4]) / w
|
266 |
-
y1 = (p[3]) / h
|
267 |
-
x2 = (p[4] + p[2]) / w
|
268 |
-
y2 = (p[3] + p[1]) / h
|
269 |
-
masks[len(boxes)] = 1.0
|
270 |
-
boxes += [torch.tensor((x1, y1, x2, y2)).unsqueeze(0)]
|
271 |
-
positive_embeddings += [p[0]]
|
272 |
-
append_boxes = []
|
273 |
-
append_conds = []
|
274 |
-
if len(boxes) < self.max_objs:
|
275 |
-
append_boxes = [torch.zeros(
|
276 |
-
[self.max_objs - len(boxes), 4], device="cpu")]
|
277 |
-
append_conds = [torch.zeros(
|
278 |
-
[self.max_objs - len(boxes), self.key_dim], device="cpu")]
|
279 |
-
|
280 |
-
box_out = torch.cat(
|
281 |
-
boxes + append_boxes).unsqueeze(0).repeat(batch, 1, 1)
|
282 |
-
masks = masks.unsqueeze(0).repeat(batch, 1)
|
283 |
-
conds = torch.cat(positive_embeddings +
|
284 |
-
append_conds).unsqueeze(0).repeat(batch, 1, 1)
|
285 |
-
return self._set_position(
|
286 |
-
box_out.to(device),
|
287 |
-
masks.to(device),
|
288 |
-
conds.to(device))
|
289 |
-
|
290 |
-
def set_empty(self, latent_image_shape, device):
|
291 |
-
batch, c, h, w = latent_image_shape
|
292 |
-
masks = torch.zeros([self.max_objs], device="cpu").repeat(batch, 1)
|
293 |
-
box_out = torch.zeros([self.max_objs, 4],
|
294 |
-
device="cpu").repeat(batch, 1, 1)
|
295 |
-
conds = torch.zeros([self.max_objs, self.key_dim],
|
296 |
-
device="cpu").repeat(batch, 1, 1)
|
297 |
-
return self._set_position(
|
298 |
-
box_out.to(device),
|
299 |
-
masks.to(device),
|
300 |
-
conds.to(device))
|
301 |
-
|
302 |
-
|
303 |
-
def load_gligen(sd):
|
304 |
-
sd_k = sd.keys()
|
305 |
-
output_list = []
|
306 |
-
key_dim = 768
|
307 |
-
for a in ["input_blocks", "middle_block", "output_blocks"]:
|
308 |
-
for b in range(20):
|
309 |
-
k_temp = filter(lambda k: "{}.{}.".format(a, b)
|
310 |
-
in k and ".fuser." in k, sd_k)
|
311 |
-
k_temp = map(lambda k: (k, k.split(".fuser.")[-1]), k_temp)
|
312 |
-
|
313 |
-
n_sd = {}
|
314 |
-
for k in k_temp:
|
315 |
-
n_sd[k[1]] = sd[k[0]]
|
316 |
-
if len(n_sd) > 0:
|
317 |
-
query_dim = n_sd["linear.weight"].shape[0]
|
318 |
-
key_dim = n_sd["linear.weight"].shape[1]
|
319 |
-
|
320 |
-
if key_dim == 768: # SD1.x
|
321 |
-
n_heads = 8
|
322 |
-
d_head = query_dim // n_heads
|
323 |
-
else:
|
324 |
-
d_head = 64
|
325 |
-
n_heads = query_dim // d_head
|
326 |
-
|
327 |
-
gated = GatedSelfAttentionDense(
|
328 |
-
query_dim, key_dim, n_heads, d_head)
|
329 |
-
gated.load_state_dict(n_sd, strict=False)
|
330 |
-
output_list.append(gated)
|
331 |
-
|
332 |
-
if "position_net.null_positive_feature" in sd_k:
|
333 |
-
in_dim = sd["position_net.null_positive_feature"].shape[0]
|
334 |
-
out_dim = sd["position_net.linears.4.weight"].shape[0]
|
335 |
-
|
336 |
-
class WeightsLoader(torch.nn.Module):
|
337 |
-
pass
|
338 |
-
w = WeightsLoader()
|
339 |
-
w.position_net = PositionNet(in_dim, out_dim)
|
340 |
-
w.load_state_dict(sd, strict=False)
|
341 |
-
|
342 |
-
gligen = Gligen(output_list, w.position_net, key_dim)
|
343 |
-
return gligen
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/comfy/k_diffusion/__pycache__/sampling.cpython-310.pyc
DELETED
Binary file (28.2 kB)
|
|
MagicQuill/comfy/k_diffusion/__pycache__/utils.cpython-310.pyc
DELETED
Binary file (14 kB)
|
|
MagicQuill/comfy/k_diffusion/sampling.py
DELETED
@@ -1,843 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
|
3 |
-
from scipy import integrate
|
4 |
-
import torch
|
5 |
-
from torch import nn
|
6 |
-
import torchsde
|
7 |
-
from tqdm.auto import trange, tqdm
|
8 |
-
|
9 |
-
from . import utils
|
10 |
-
|
11 |
-
|
12 |
-
def append_zero(x):
|
13 |
-
return torch.cat([x, x.new_zeros([1])])
|
14 |
-
|
15 |
-
|
16 |
-
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
|
17 |
-
"""Constructs the noise schedule of Karras et al. (2022)."""
|
18 |
-
ramp = torch.linspace(0, 1, n, device=device)
|
19 |
-
min_inv_rho = sigma_min ** (1 / rho)
|
20 |
-
max_inv_rho = sigma_max ** (1 / rho)
|
21 |
-
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
22 |
-
return append_zero(sigmas).to(device)
|
23 |
-
|
24 |
-
|
25 |
-
def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
|
26 |
-
"""Constructs an exponential noise schedule."""
|
27 |
-
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
|
28 |
-
return append_zero(sigmas)
|
29 |
-
|
30 |
-
|
31 |
-
def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
|
32 |
-
"""Constructs an polynomial in log sigma noise schedule."""
|
33 |
-
ramp = torch.linspace(1, 0, n, device=device) ** rho
|
34 |
-
sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
|
35 |
-
return append_zero(sigmas)
|
36 |
-
|
37 |
-
|
38 |
-
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
|
39 |
-
"""Constructs a continuous VP noise schedule."""
|
40 |
-
t = torch.linspace(1, eps_s, n, device=device)
|
41 |
-
sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
|
42 |
-
return append_zero(sigmas)
|
43 |
-
|
44 |
-
|
45 |
-
def to_d(x, sigma, denoised):
|
46 |
-
"""Converts a denoiser output to a Karras ODE derivative."""
|
47 |
-
return (x - denoised) / utils.append_dims(sigma, x.ndim)
|
48 |
-
|
49 |
-
|
50 |
-
def get_ancestral_step(sigma_from, sigma_to, eta=1.):
|
51 |
-
"""Calculates the noise level (sigma_down) to step down to and the amount
|
52 |
-
of noise to add (sigma_up) when doing an ancestral sampling step."""
|
53 |
-
if not eta:
|
54 |
-
return sigma_to, 0.
|
55 |
-
sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
|
56 |
-
sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
|
57 |
-
return sigma_down, sigma_up
|
58 |
-
|
59 |
-
|
60 |
-
def default_noise_sampler(x):
|
61 |
-
return lambda sigma, sigma_next: torch.randn_like(x)
|
62 |
-
|
63 |
-
|
64 |
-
class BatchedBrownianTree:
|
65 |
-
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
66 |
-
|
67 |
-
def __init__(self, x, t0, t1, seed=None, **kwargs):
|
68 |
-
self.cpu_tree = True
|
69 |
-
if "cpu" in kwargs:
|
70 |
-
self.cpu_tree = kwargs.pop("cpu")
|
71 |
-
t0, t1, self.sign = self.sort(t0, t1)
|
72 |
-
w0 = kwargs.get('w0', torch.zeros_like(x))
|
73 |
-
if seed is None:
|
74 |
-
seed = torch.randint(0, 2 ** 63 - 1, []).item()
|
75 |
-
self.batched = True
|
76 |
-
try:
|
77 |
-
assert len(seed) == x.shape[0]
|
78 |
-
w0 = w0[0]
|
79 |
-
except TypeError:
|
80 |
-
seed = [seed]
|
81 |
-
self.batched = False
|
82 |
-
if self.cpu_tree:
|
83 |
-
self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed]
|
84 |
-
else:
|
85 |
-
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
|
86 |
-
|
87 |
-
@staticmethod
|
88 |
-
def sort(a, b):
|
89 |
-
return (a, b, 1) if a < b else (b, a, -1)
|
90 |
-
|
91 |
-
def __call__(self, t0, t1):
|
92 |
-
t0, t1, sign = self.sort(t0, t1)
|
93 |
-
if self.cpu_tree:
|
94 |
-
w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign)
|
95 |
-
else:
|
96 |
-
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
|
97 |
-
|
98 |
-
return w if self.batched else w[0]
|
99 |
-
|
100 |
-
|
101 |
-
class BrownianTreeNoiseSampler:
|
102 |
-
"""A noise sampler backed by a torchsde.BrownianTree.
|
103 |
-
|
104 |
-
Args:
|
105 |
-
x (Tensor): The tensor whose shape, device and dtype to use to generate
|
106 |
-
random samples.
|
107 |
-
sigma_min (float): The low end of the valid interval.
|
108 |
-
sigma_max (float): The high end of the valid interval.
|
109 |
-
seed (int or List[int]): The random seed. If a list of seeds is
|
110 |
-
supplied instead of a single integer, then the noise sampler will
|
111 |
-
use one BrownianTree per batch item, each with its own seed.
|
112 |
-
transform (callable): A function that maps sigma to the sampler's
|
113 |
-
internal timestep.
|
114 |
-
"""
|
115 |
-
|
116 |
-
def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False):
|
117 |
-
self.transform = transform
|
118 |
-
t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
|
119 |
-
self.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu)
|
120 |
-
|
121 |
-
def __call__(self, sigma, sigma_next):
|
122 |
-
t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
|
123 |
-
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
|
124 |
-
|
125 |
-
|
126 |
-
@torch.no_grad()
|
127 |
-
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
128 |
-
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
129 |
-
extra_args = {} if extra_args is None else extra_args
|
130 |
-
s_in = x.new_ones([x.shape[0]])
|
131 |
-
for i in trange(len(sigmas) - 1, disable=disable):
|
132 |
-
if s_churn > 0:
|
133 |
-
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
134 |
-
sigma_hat = sigmas[i] * (gamma + 1)
|
135 |
-
else:
|
136 |
-
gamma = 0
|
137 |
-
sigma_hat = sigmas[i]
|
138 |
-
|
139 |
-
if gamma > 0:
|
140 |
-
eps = torch.randn_like(x) * s_noise
|
141 |
-
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
142 |
-
denoised = model(x, sigma_hat * s_in, **extra_args)
|
143 |
-
d = to_d(x, sigma_hat, denoised)
|
144 |
-
if callback is not None:
|
145 |
-
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
146 |
-
dt = sigmas[i + 1] - sigma_hat
|
147 |
-
# Euler method
|
148 |
-
x = x + d * dt
|
149 |
-
return x
|
150 |
-
|
151 |
-
|
152 |
-
@torch.no_grad()
|
153 |
-
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
154 |
-
"""Ancestral sampling with Euler method steps."""
|
155 |
-
extra_args = {} if extra_args is None else extra_args
|
156 |
-
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
157 |
-
s_in = x.new_ones([x.shape[0]])
|
158 |
-
for i in trange(len(sigmas) - 1, disable=disable):
|
159 |
-
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
160 |
-
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
161 |
-
if callback is not None:
|
162 |
-
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
163 |
-
d = to_d(x, sigmas[i], denoised)
|
164 |
-
# Euler method
|
165 |
-
dt = sigma_down - sigmas[i]
|
166 |
-
x = x + d * dt
|
167 |
-
if sigmas[i + 1] > 0:
|
168 |
-
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
169 |
-
return x
|
170 |
-
|
171 |
-
|
172 |
-
@torch.no_grad()
|
173 |
-
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
174 |
-
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
|
175 |
-
extra_args = {} if extra_args is None else extra_args
|
176 |
-
s_in = x.new_ones([x.shape[0]])
|
177 |
-
for i in trange(len(sigmas) - 1, disable=disable):
|
178 |
-
if s_churn > 0:
|
179 |
-
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
180 |
-
sigma_hat = sigmas[i] * (gamma + 1)
|
181 |
-
else:
|
182 |
-
gamma = 0
|
183 |
-
sigma_hat = sigmas[i]
|
184 |
-
|
185 |
-
sigma_hat = sigmas[i] * (gamma + 1)
|
186 |
-
if gamma > 0:
|
187 |
-
eps = torch.randn_like(x) * s_noise
|
188 |
-
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
189 |
-
denoised = model(x, sigma_hat * s_in, **extra_args)
|
190 |
-
d = to_d(x, sigma_hat, denoised)
|
191 |
-
if callback is not None:
|
192 |
-
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
193 |
-
dt = sigmas[i + 1] - sigma_hat
|
194 |
-
if sigmas[i + 1] == 0:
|
195 |
-
# Euler method
|
196 |
-
x = x + d * dt
|
197 |
-
else:
|
198 |
-
# Heun's method
|
199 |
-
x_2 = x + d * dt
|
200 |
-
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
|
201 |
-
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
202 |
-
d_prime = (d + d_2) / 2
|
203 |
-
x = x + d_prime * dt
|
204 |
-
return x
|
205 |
-
|
206 |
-
|
207 |
-
@torch.no_grad()
|
208 |
-
def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
209 |
-
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
|
210 |
-
extra_args = {} if extra_args is None else extra_args
|
211 |
-
s_in = x.new_ones([x.shape[0]])
|
212 |
-
for i in trange(len(sigmas) - 1, disable=disable):
|
213 |
-
if s_churn > 0:
|
214 |
-
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
215 |
-
sigma_hat = sigmas[i] * (gamma + 1)
|
216 |
-
else:
|
217 |
-
gamma = 0
|
218 |
-
sigma_hat = sigmas[i]
|
219 |
-
|
220 |
-
if gamma > 0:
|
221 |
-
eps = torch.randn_like(x) * s_noise
|
222 |
-
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
223 |
-
denoised = model(x, sigma_hat * s_in, **extra_args)
|
224 |
-
d = to_d(x, sigma_hat, denoised)
|
225 |
-
if callback is not None:
|
226 |
-
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
227 |
-
if sigmas[i + 1] == 0:
|
228 |
-
# Euler method
|
229 |
-
dt = sigmas[i + 1] - sigma_hat
|
230 |
-
x = x + d * dt
|
231 |
-
else:
|
232 |
-
# DPM-Solver-2
|
233 |
-
sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
|
234 |
-
dt_1 = sigma_mid - sigma_hat
|
235 |
-
dt_2 = sigmas[i + 1] - sigma_hat
|
236 |
-
x_2 = x + d * dt_1
|
237 |
-
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
238 |
-
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
239 |
-
x = x + d_2 * dt_2
|
240 |
-
return x
|
241 |
-
|
242 |
-
|
243 |
-
@torch.no_grad()
|
244 |
-
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
245 |
-
"""Ancestral sampling with DPM-Solver second-order steps."""
|
246 |
-
extra_args = {} if extra_args is None else extra_args
|
247 |
-
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
248 |
-
s_in = x.new_ones([x.shape[0]])
|
249 |
-
for i in trange(len(sigmas) - 1, disable=disable):
|
250 |
-
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
251 |
-
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
252 |
-
if callback is not None:
|
253 |
-
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
254 |
-
d = to_d(x, sigmas[i], denoised)
|
255 |
-
if sigma_down == 0:
|
256 |
-
# Euler method
|
257 |
-
dt = sigma_down - sigmas[i]
|
258 |
-
x = x + d * dt
|
259 |
-
else:
|
260 |
-
# DPM-Solver-2
|
261 |
-
sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
|
262 |
-
dt_1 = sigma_mid - sigmas[i]
|
263 |
-
dt_2 = sigma_down - sigmas[i]
|
264 |
-
x_2 = x + d * dt_1
|
265 |
-
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
266 |
-
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
267 |
-
x = x + d_2 * dt_2
|
268 |
-
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
269 |
-
return x
|
270 |
-
|
271 |
-
|
272 |
-
def linear_multistep_coeff(order, t, i, j):
|
273 |
-
if order - 1 > i:
|
274 |
-
raise ValueError(f'Order {order} too high for step {i}')
|
275 |
-
def fn(tau):
|
276 |
-
prod = 1.
|
277 |
-
for k in range(order):
|
278 |
-
if j == k:
|
279 |
-
continue
|
280 |
-
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
|
281 |
-
return prod
|
282 |
-
return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
|
283 |
-
|
284 |
-
|
285 |
-
@torch.no_grad()
|
286 |
-
def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
|
287 |
-
extra_args = {} if extra_args is None else extra_args
|
288 |
-
s_in = x.new_ones([x.shape[0]])
|
289 |
-
sigmas_cpu = sigmas.detach().cpu().numpy()
|
290 |
-
ds = []
|
291 |
-
for i in trange(len(sigmas) - 1, disable=disable):
|
292 |
-
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
293 |
-
d = to_d(x, sigmas[i], denoised)
|
294 |
-
ds.append(d)
|
295 |
-
if len(ds) > order:
|
296 |
-
ds.pop(0)
|
297 |
-
if callback is not None:
|
298 |
-
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
299 |
-
cur_order = min(i + 1, order)
|
300 |
-
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
|
301 |
-
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
302 |
-
return x
|
303 |
-
|
304 |
-
|
305 |
-
class PIDStepSizeController:
|
306 |
-
"""A PID controller for ODE adaptive step size control."""
|
307 |
-
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
|
308 |
-
self.h = h
|
309 |
-
self.b1 = (pcoeff + icoeff + dcoeff) / order
|
310 |
-
self.b2 = -(pcoeff + 2 * dcoeff) / order
|
311 |
-
self.b3 = dcoeff / order
|
312 |
-
self.accept_safety = accept_safety
|
313 |
-
self.eps = eps
|
314 |
-
self.errs = []
|
315 |
-
|
316 |
-
def limiter(self, x):
|
317 |
-
return 1 + math.atan(x - 1)
|
318 |
-
|
319 |
-
def propose_step(self, error):
|
320 |
-
inv_error = 1 / (float(error) + self.eps)
|
321 |
-
if not self.errs:
|
322 |
-
self.errs = [inv_error, inv_error, inv_error]
|
323 |
-
self.errs[0] = inv_error
|
324 |
-
factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3
|
325 |
-
factor = self.limiter(factor)
|
326 |
-
accept = factor >= self.accept_safety
|
327 |
-
if accept:
|
328 |
-
self.errs[2] = self.errs[1]
|
329 |
-
self.errs[1] = self.errs[0]
|
330 |
-
self.h *= factor
|
331 |
-
return accept
|
332 |
-
|
333 |
-
|
334 |
-
class DPMSolver(nn.Module):
|
335 |
-
"""DPM-Solver. See https://arxiv.org/abs/2206.00927."""
|
336 |
-
|
337 |
-
def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None):
|
338 |
-
super().__init__()
|
339 |
-
self.model = model
|
340 |
-
self.extra_args = {} if extra_args is None else extra_args
|
341 |
-
self.eps_callback = eps_callback
|
342 |
-
self.info_callback = info_callback
|
343 |
-
|
344 |
-
def t(self, sigma):
|
345 |
-
return -sigma.log()
|
346 |
-
|
347 |
-
def sigma(self, t):
|
348 |
-
return t.neg().exp()
|
349 |
-
|
350 |
-
def eps(self, eps_cache, key, x, t, *args, **kwargs):
|
351 |
-
if key in eps_cache:
|
352 |
-
return eps_cache[key], eps_cache
|
353 |
-
sigma = self.sigma(t) * x.new_ones([x.shape[0]])
|
354 |
-
eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t)
|
355 |
-
if self.eps_callback is not None:
|
356 |
-
self.eps_callback()
|
357 |
-
return eps, {key: eps, **eps_cache}
|
358 |
-
|
359 |
-
def dpm_solver_1_step(self, x, t, t_next, eps_cache=None):
|
360 |
-
eps_cache = {} if eps_cache is None else eps_cache
|
361 |
-
h = t_next - t
|
362 |
-
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
|
363 |
-
x_1 = x - self.sigma(t_next) * h.expm1() * eps
|
364 |
-
return x_1, eps_cache
|
365 |
-
|
366 |
-
def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None):
|
367 |
-
eps_cache = {} if eps_cache is None else eps_cache
|
368 |
-
h = t_next - t
|
369 |
-
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
|
370 |
-
s1 = t + r1 * h
|
371 |
-
u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
|
372 |
-
eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
|
373 |
-
x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps)
|
374 |
-
return x_2, eps_cache
|
375 |
-
|
376 |
-
def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None):
|
377 |
-
eps_cache = {} if eps_cache is None else eps_cache
|
378 |
-
h = t_next - t
|
379 |
-
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
|
380 |
-
s1 = t + r1 * h
|
381 |
-
s2 = t + r2 * h
|
382 |
-
u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
|
383 |
-
eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
|
384 |
-
u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps)
|
385 |
-
eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2)
|
386 |
-
x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps)
|
387 |
-
return x_3, eps_cache
|
388 |
-
|
389 |
-
def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
|
390 |
-
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
391 |
-
if not t_end > t_start and eta:
|
392 |
-
raise ValueError('eta must be 0 for reverse sampling')
|
393 |
-
|
394 |
-
m = math.floor(nfe / 3) + 1
|
395 |
-
ts = torch.linspace(t_start, t_end, m + 1, device=x.device)
|
396 |
-
|
397 |
-
if nfe % 3 == 0:
|
398 |
-
orders = [3] * (m - 2) + [2, 1]
|
399 |
-
else:
|
400 |
-
orders = [3] * (m - 1) + [nfe % 3]
|
401 |
-
|
402 |
-
for i in range(len(orders)):
|
403 |
-
eps_cache = {}
|
404 |
-
t, t_next = ts[i], ts[i + 1]
|
405 |
-
if eta:
|
406 |
-
sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta)
|
407 |
-
t_next_ = torch.minimum(t_end, self.t(sd))
|
408 |
-
su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5
|
409 |
-
else:
|
410 |
-
t_next_, su = t_next, 0.
|
411 |
-
|
412 |
-
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
|
413 |
-
denoised = x - self.sigma(t) * eps
|
414 |
-
if self.info_callback is not None:
|
415 |
-
self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised})
|
416 |
-
|
417 |
-
if orders[i] == 1:
|
418 |
-
x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache)
|
419 |
-
elif orders[i] == 2:
|
420 |
-
x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache)
|
421 |
-
else:
|
422 |
-
x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache)
|
423 |
-
|
424 |
-
x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next))
|
425 |
-
|
426 |
-
return x
|
427 |
-
|
428 |
-
def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
|
429 |
-
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
430 |
-
if order not in {2, 3}:
|
431 |
-
raise ValueError('order should be 2 or 3')
|
432 |
-
forward = t_end > t_start
|
433 |
-
if not forward and eta:
|
434 |
-
raise ValueError('eta must be 0 for reverse sampling')
|
435 |
-
h_init = abs(h_init) * (1 if forward else -1)
|
436 |
-
atol = torch.tensor(atol)
|
437 |
-
rtol = torch.tensor(rtol)
|
438 |
-
s = t_start
|
439 |
-
x_prev = x
|
440 |
-
accept = True
|
441 |
-
pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety)
|
442 |
-
info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0}
|
443 |
-
|
444 |
-
while s < t_end - 1e-5 if forward else s > t_end + 1e-5:
|
445 |
-
eps_cache = {}
|
446 |
-
t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h)
|
447 |
-
if eta:
|
448 |
-
sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta)
|
449 |
-
t_ = torch.minimum(t_end, self.t(sd))
|
450 |
-
su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5
|
451 |
-
else:
|
452 |
-
t_, su = t, 0.
|
453 |
-
|
454 |
-
eps, eps_cache = self.eps(eps_cache, 'eps', x, s)
|
455 |
-
denoised = x - self.sigma(s) * eps
|
456 |
-
|
457 |
-
if order == 2:
|
458 |
-
x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache)
|
459 |
-
x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache)
|
460 |
-
else:
|
461 |
-
x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache)
|
462 |
-
x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache)
|
463 |
-
delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs()))
|
464 |
-
error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5
|
465 |
-
accept = pid.propose_step(error)
|
466 |
-
if accept:
|
467 |
-
x_prev = x_low
|
468 |
-
x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t))
|
469 |
-
s = t
|
470 |
-
info['n_accept'] += 1
|
471 |
-
else:
|
472 |
-
info['n_reject'] += 1
|
473 |
-
info['nfe'] += order
|
474 |
-
info['steps'] += 1
|
475 |
-
|
476 |
-
if self.info_callback is not None:
|
477 |
-
self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info})
|
478 |
-
|
479 |
-
return x, info
|
480 |
-
|
481 |
-
|
482 |
-
@torch.no_grad()
|
483 |
-
def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None):
|
484 |
-
"""DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927."""
|
485 |
-
if sigma_min <= 0 or sigma_max <= 0:
|
486 |
-
raise ValueError('sigma_min and sigma_max must not be 0')
|
487 |
-
with tqdm(total=n, disable=disable) as pbar:
|
488 |
-
dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
|
489 |
-
if callback is not None:
|
490 |
-
dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
|
491 |
-
return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler)
|
492 |
-
|
493 |
-
|
494 |
-
@torch.no_grad()
|
495 |
-
def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False):
|
496 |
-
"""DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927."""
|
497 |
-
if sigma_min <= 0 or sigma_max <= 0:
|
498 |
-
raise ValueError('sigma_min and sigma_max must not be 0')
|
499 |
-
with tqdm(disable=disable) as pbar:
|
500 |
-
dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
|
501 |
-
if callback is not None:
|
502 |
-
dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
|
503 |
-
x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler)
|
504 |
-
if return_info:
|
505 |
-
return x, info
|
506 |
-
return x
|
507 |
-
|
508 |
-
|
509 |
-
@torch.no_grad()
|
510 |
-
def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
511 |
-
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
512 |
-
extra_args = {} if extra_args is None else extra_args
|
513 |
-
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
514 |
-
s_in = x.new_ones([x.shape[0]])
|
515 |
-
sigma_fn = lambda t: t.neg().exp()
|
516 |
-
t_fn = lambda sigma: sigma.log().neg()
|
517 |
-
|
518 |
-
for i in trange(len(sigmas) - 1, disable=disable):
|
519 |
-
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
520 |
-
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
521 |
-
if callback is not None:
|
522 |
-
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
523 |
-
if sigma_down == 0:
|
524 |
-
# Euler method
|
525 |
-
d = to_d(x, sigmas[i], denoised)
|
526 |
-
dt = sigma_down - sigmas[i]
|
527 |
-
x = x + d * dt
|
528 |
-
else:
|
529 |
-
# DPM-Solver++(2S)
|
530 |
-
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
|
531 |
-
r = 1 / 2
|
532 |
-
h = t_next - t
|
533 |
-
s = t + r * h
|
534 |
-
x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised
|
535 |
-
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
536 |
-
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2
|
537 |
-
# Noise addition
|
538 |
-
if sigmas[i + 1] > 0:
|
539 |
-
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
540 |
-
return x
|
541 |
-
|
542 |
-
|
543 |
-
@torch.no_grad()
|
544 |
-
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
545 |
-
"""DPM-Solver++ (stochastic)."""
|
546 |
-
if len(sigmas) <= 1:
|
547 |
-
return x
|
548 |
-
|
549 |
-
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
550 |
-
seed = extra_args.get("seed", None)
|
551 |
-
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
552 |
-
extra_args = {} if extra_args is None else extra_args
|
553 |
-
s_in = x.new_ones([x.shape[0]])
|
554 |
-
sigma_fn = lambda t: t.neg().exp()
|
555 |
-
t_fn = lambda sigma: sigma.log().neg()
|
556 |
-
|
557 |
-
for i in trange(len(sigmas) - 1, disable=disable):
|
558 |
-
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
559 |
-
if callback is not None:
|
560 |
-
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
561 |
-
if sigmas[i + 1] == 0:
|
562 |
-
# Euler method
|
563 |
-
d = to_d(x, sigmas[i], denoised)
|
564 |
-
dt = sigmas[i + 1] - sigmas[i]
|
565 |
-
x = x + d * dt
|
566 |
-
else:
|
567 |
-
# DPM-Solver++
|
568 |
-
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
569 |
-
h = t_next - t
|
570 |
-
s = t + h * r
|
571 |
-
fac = 1 / (2 * r)
|
572 |
-
|
573 |
-
# Step 1
|
574 |
-
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
|
575 |
-
s_ = t_fn(sd)
|
576 |
-
x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
|
577 |
-
x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
|
578 |
-
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
579 |
-
|
580 |
-
# Step 2
|
581 |
-
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
|
582 |
-
t_next_ = t_fn(sd)
|
583 |
-
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
584 |
-
x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
|
585 |
-
x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
|
586 |
-
return x
|
587 |
-
|
588 |
-
|
589 |
-
@torch.no_grad()
|
590 |
-
def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
591 |
-
"""DPM-Solver++(2M)."""
|
592 |
-
extra_args = {} if extra_args is None else extra_args
|
593 |
-
s_in = x.new_ones([x.shape[0]])
|
594 |
-
sigma_fn = lambda t: t.neg().exp()
|
595 |
-
t_fn = lambda sigma: sigma.log().neg()
|
596 |
-
old_denoised = None
|
597 |
-
|
598 |
-
for i in trange(len(sigmas) - 1, disable=disable):
|
599 |
-
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
600 |
-
if callback is not None:
|
601 |
-
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
602 |
-
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
603 |
-
h = t_next - t
|
604 |
-
if old_denoised is None or sigmas[i + 1] == 0:
|
605 |
-
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
|
606 |
-
else:
|
607 |
-
h_last = t - t_fn(sigmas[i - 1])
|
608 |
-
r = h_last / h
|
609 |
-
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
|
610 |
-
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
|
611 |
-
old_denoised = denoised
|
612 |
-
return x
|
613 |
-
|
614 |
-
@torch.no_grad()
|
615 |
-
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
616 |
-
"""DPM-Solver++(2M) SDE."""
|
617 |
-
if len(sigmas) <= 1:
|
618 |
-
return x
|
619 |
-
|
620 |
-
if solver_type not in {'heun', 'midpoint'}:
|
621 |
-
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
622 |
-
|
623 |
-
seed = extra_args.get("seed", None)
|
624 |
-
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
625 |
-
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
626 |
-
extra_args = {} if extra_args is None else extra_args
|
627 |
-
s_in = x.new_ones([x.shape[0]])
|
628 |
-
|
629 |
-
old_denoised = None
|
630 |
-
h_last = None
|
631 |
-
h = None
|
632 |
-
|
633 |
-
for i in trange(len(sigmas) - 1, disable=disable):
|
634 |
-
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
635 |
-
if callback is not None:
|
636 |
-
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
637 |
-
if sigmas[i + 1] == 0:
|
638 |
-
# Denoising step
|
639 |
-
x = denoised
|
640 |
-
else:
|
641 |
-
# DPM-Solver++(2M) SDE
|
642 |
-
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
643 |
-
h = s - t
|
644 |
-
eta_h = eta * h
|
645 |
-
|
646 |
-
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
|
647 |
-
|
648 |
-
if old_denoised is not None:
|
649 |
-
r = h_last / h
|
650 |
-
if solver_type == 'heun':
|
651 |
-
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
|
652 |
-
elif solver_type == 'midpoint':
|
653 |
-
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
|
654 |
-
|
655 |
-
if eta:
|
656 |
-
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
|
657 |
-
|
658 |
-
old_denoised = denoised
|
659 |
-
h_last = h
|
660 |
-
return x
|
661 |
-
|
662 |
-
@torch.no_grad()
|
663 |
-
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
664 |
-
"""DPM-Solver++(3M) SDE."""
|
665 |
-
|
666 |
-
if len(sigmas) <= 1:
|
667 |
-
return x
|
668 |
-
|
669 |
-
seed = extra_args.get("seed", None)
|
670 |
-
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
671 |
-
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
672 |
-
extra_args = {} if extra_args is None else extra_args
|
673 |
-
s_in = x.new_ones([x.shape[0]])
|
674 |
-
|
675 |
-
denoised_1, denoised_2 = None, None
|
676 |
-
h, h_1, h_2 = None, None, None
|
677 |
-
|
678 |
-
for i in trange(len(sigmas) - 1, disable=disable):
|
679 |
-
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
680 |
-
if callback is not None:
|
681 |
-
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
682 |
-
if sigmas[i + 1] == 0:
|
683 |
-
# Denoising step
|
684 |
-
x = denoised
|
685 |
-
else:
|
686 |
-
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
687 |
-
h = s - t
|
688 |
-
h_eta = h * (eta + 1)
|
689 |
-
|
690 |
-
x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised
|
691 |
-
|
692 |
-
if h_2 is not None:
|
693 |
-
r0 = h_1 / h
|
694 |
-
r1 = h_2 / h
|
695 |
-
d1_0 = (denoised - denoised_1) / r0
|
696 |
-
d1_1 = (denoised_1 - denoised_2) / r1
|
697 |
-
d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1)
|
698 |
-
d2 = (d1_0 - d1_1) / (r0 + r1)
|
699 |
-
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
700 |
-
phi_3 = phi_2 / h_eta - 0.5
|
701 |
-
x = x + phi_2 * d1 - phi_3 * d2
|
702 |
-
elif h_1 is not None:
|
703 |
-
r = h_1 / h
|
704 |
-
d = (denoised - denoised_1) / r
|
705 |
-
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
706 |
-
x = x + phi_2 * d
|
707 |
-
|
708 |
-
if eta:
|
709 |
-
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
|
710 |
-
|
711 |
-
denoised_1, denoised_2 = denoised, denoised_1
|
712 |
-
h_1, h_2 = h, h_1
|
713 |
-
return x
|
714 |
-
|
715 |
-
@torch.no_grad()
|
716 |
-
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
717 |
-
if len(sigmas) <= 1:
|
718 |
-
return x
|
719 |
-
|
720 |
-
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
721 |
-
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
722 |
-
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
723 |
-
|
724 |
-
@torch.no_grad()
|
725 |
-
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
726 |
-
if len(sigmas) <= 1:
|
727 |
-
return x
|
728 |
-
|
729 |
-
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
730 |
-
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
731 |
-
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
732 |
-
|
733 |
-
@torch.no_grad()
|
734 |
-
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
735 |
-
if len(sigmas) <= 1:
|
736 |
-
return x
|
737 |
-
|
738 |
-
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
739 |
-
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
740 |
-
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
|
741 |
-
|
742 |
-
|
743 |
-
def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
|
744 |
-
alpha_cumprod = 1 / ((sigma * sigma) + 1)
|
745 |
-
alpha_cumprod_prev = 1 / ((sigma_prev * sigma_prev) + 1)
|
746 |
-
alpha = (alpha_cumprod / alpha_cumprod_prev)
|
747 |
-
|
748 |
-
mu = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise / (1 - alpha_cumprod).sqrt())
|
749 |
-
if sigma_prev > 0:
|
750 |
-
mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
|
751 |
-
return mu
|
752 |
-
|
753 |
-
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
|
754 |
-
extra_args = {} if extra_args is None else extra_args
|
755 |
-
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
756 |
-
s_in = x.new_ones([x.shape[0]])
|
757 |
-
|
758 |
-
for i in trange(len(sigmas) - 1, disable=disable):
|
759 |
-
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
760 |
-
if callback is not None:
|
761 |
-
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
762 |
-
x = step_function(x / torch.sqrt(1.0 + sigmas[i] ** 2.0), sigmas[i], sigmas[i + 1], (x - denoised) / sigmas[i], noise_sampler)
|
763 |
-
if sigmas[i + 1] != 0:
|
764 |
-
x *= torch.sqrt(1.0 + sigmas[i + 1] ** 2.0)
|
765 |
-
return x
|
766 |
-
|
767 |
-
|
768 |
-
@torch.no_grad()
|
769 |
-
def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
770 |
-
return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
|
771 |
-
|
772 |
-
@torch.no_grad()
|
773 |
-
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
774 |
-
extra_args = {} if extra_args is None else extra_args
|
775 |
-
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
776 |
-
s_in = x.new_ones([x.shape[0]])
|
777 |
-
for i in trange(len(sigmas) - 1, disable=disable):
|
778 |
-
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
779 |
-
if callback is not None:
|
780 |
-
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
781 |
-
|
782 |
-
x = denoised
|
783 |
-
if sigmas[i + 1] > 0:
|
784 |
-
x = model.inner_model.inner_model.model_sampling.noise_scaling(sigmas[i + 1], noise_sampler(sigmas[i], sigmas[i + 1]), x)
|
785 |
-
return x
|
786 |
-
|
787 |
-
|
788 |
-
|
789 |
-
@torch.no_grad()
|
790 |
-
def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
791 |
-
# From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/
|
792 |
-
extra_args = {} if extra_args is None else extra_args
|
793 |
-
s_in = x.new_ones([x.shape[0]])
|
794 |
-
s_end = sigmas[-1]
|
795 |
-
for i in trange(len(sigmas) - 1, disable=disable):
|
796 |
-
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
797 |
-
eps = torch.randn_like(x) * s_noise
|
798 |
-
sigma_hat = sigmas[i] * (gamma + 1)
|
799 |
-
if gamma > 0:
|
800 |
-
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
801 |
-
denoised = model(x, sigma_hat * s_in, **extra_args)
|
802 |
-
d = to_d(x, sigma_hat, denoised)
|
803 |
-
if callback is not None:
|
804 |
-
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
805 |
-
dt = sigmas[i + 1] - sigma_hat
|
806 |
-
if sigmas[i + 1] == s_end:
|
807 |
-
# Euler method
|
808 |
-
x = x + d * dt
|
809 |
-
elif sigmas[i + 2] == s_end:
|
810 |
-
|
811 |
-
# Heun's method
|
812 |
-
x_2 = x + d * dt
|
813 |
-
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
|
814 |
-
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
815 |
-
|
816 |
-
w = 2 * sigmas[0]
|
817 |
-
w2 = sigmas[i+1]/w
|
818 |
-
w1 = 1 - w2
|
819 |
-
|
820 |
-
d_prime = d * w1 + d_2 * w2
|
821 |
-
|
822 |
-
|
823 |
-
x = x + d_prime * dt
|
824 |
-
|
825 |
-
else:
|
826 |
-
# Heun++
|
827 |
-
x_2 = x + d * dt
|
828 |
-
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
|
829 |
-
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
830 |
-
dt_2 = sigmas[i + 2] - sigmas[i + 1]
|
831 |
-
|
832 |
-
x_3 = x_2 + d_2 * dt_2
|
833 |
-
denoised_3 = model(x_3, sigmas[i + 2] * s_in, **extra_args)
|
834 |
-
d_3 = to_d(x_3, sigmas[i + 2], denoised_3)
|
835 |
-
|
836 |
-
w = 3 * sigmas[0]
|
837 |
-
w2 = sigmas[i + 1] / w
|
838 |
-
w3 = sigmas[i + 2] / w
|
839 |
-
w1 = 1 - w2 - w3
|
840 |
-
|
841 |
-
d_prime = w1 * d + w2 * d_2 + w3 * d_3
|
842 |
-
x = x + d_prime * dt
|
843 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/comfy/k_diffusion/utils.py
DELETED
@@ -1,313 +0,0 @@
|
|
1 |
-
from contextlib import contextmanager
|
2 |
-
import hashlib
|
3 |
-
import math
|
4 |
-
from pathlib import Path
|
5 |
-
import shutil
|
6 |
-
import urllib
|
7 |
-
import warnings
|
8 |
-
|
9 |
-
from PIL import Image
|
10 |
-
import torch
|
11 |
-
from torch import nn, optim
|
12 |
-
from torch.utils import data
|
13 |
-
|
14 |
-
|
15 |
-
def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):
|
16 |
-
"""Apply passed in transforms for HuggingFace Datasets."""
|
17 |
-
images = [transform(image.convert(mode)) for image in examples[image_key]]
|
18 |
-
return {image_key: images}
|
19 |
-
|
20 |
-
|
21 |
-
def append_dims(x, target_dims):
|
22 |
-
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
23 |
-
dims_to_append = target_dims - x.ndim
|
24 |
-
if dims_to_append < 0:
|
25 |
-
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
|
26 |
-
expanded = x[(...,) + (None,) * dims_to_append]
|
27 |
-
# MPS will get inf values if it tries to index into the new axes, but detaching fixes this.
|
28 |
-
# https://github.com/pytorch/pytorch/issues/84364
|
29 |
-
return expanded.detach().clone() if expanded.device.type == 'mps' else expanded
|
30 |
-
|
31 |
-
|
32 |
-
def n_params(module):
|
33 |
-
"""Returns the number of trainable parameters in a module."""
|
34 |
-
return sum(p.numel() for p in module.parameters())
|
35 |
-
|
36 |
-
|
37 |
-
def download_file(path, url, digest=None):
|
38 |
-
"""Downloads a file if it does not exist, optionally checking its SHA-256 hash."""
|
39 |
-
path = Path(path)
|
40 |
-
path.parent.mkdir(parents=True, exist_ok=True)
|
41 |
-
if not path.exists():
|
42 |
-
with urllib.request.urlopen(url) as response, open(path, 'wb') as f:
|
43 |
-
shutil.copyfileobj(response, f)
|
44 |
-
if digest is not None:
|
45 |
-
file_digest = hashlib.sha256(open(path, 'rb').read()).hexdigest()
|
46 |
-
if digest != file_digest:
|
47 |
-
raise OSError(f'hash of {path} (url: {url}) failed to validate')
|
48 |
-
return path
|
49 |
-
|
50 |
-
|
51 |
-
@contextmanager
|
52 |
-
def train_mode(model, mode=True):
|
53 |
-
"""A context manager that places a model into training mode and restores
|
54 |
-
the previous mode on exit."""
|
55 |
-
modes = [module.training for module in model.modules()]
|
56 |
-
try:
|
57 |
-
yield model.train(mode)
|
58 |
-
finally:
|
59 |
-
for i, module in enumerate(model.modules()):
|
60 |
-
module.training = modes[i]
|
61 |
-
|
62 |
-
|
63 |
-
def eval_mode(model):
|
64 |
-
"""A context manager that places a model into evaluation mode and restores
|
65 |
-
the previous mode on exit."""
|
66 |
-
return train_mode(model, False)
|
67 |
-
|
68 |
-
|
69 |
-
@torch.no_grad()
|
70 |
-
def ema_update(model, averaged_model, decay):
|
71 |
-
"""Incorporates updated model parameters into an exponential moving averaged
|
72 |
-
version of a model. It should be called after each optimizer step."""
|
73 |
-
model_params = dict(model.named_parameters())
|
74 |
-
averaged_params = dict(averaged_model.named_parameters())
|
75 |
-
assert model_params.keys() == averaged_params.keys()
|
76 |
-
|
77 |
-
for name, param in model_params.items():
|
78 |
-
averaged_params[name].mul_(decay).add_(param, alpha=1 - decay)
|
79 |
-
|
80 |
-
model_buffers = dict(model.named_buffers())
|
81 |
-
averaged_buffers = dict(averaged_model.named_buffers())
|
82 |
-
assert model_buffers.keys() == averaged_buffers.keys()
|
83 |
-
|
84 |
-
for name, buf in model_buffers.items():
|
85 |
-
averaged_buffers[name].copy_(buf)
|
86 |
-
|
87 |
-
|
88 |
-
class EMAWarmup:
|
89 |
-
"""Implements an EMA warmup using an inverse decay schedule.
|
90 |
-
If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
|
91 |
-
good values for models you plan to train for a million or more steps (reaches decay
|
92 |
-
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
|
93 |
-
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
|
94 |
-
215.4k steps).
|
95 |
-
Args:
|
96 |
-
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
97 |
-
power (float): Exponential factor of EMA warmup. Default: 1.
|
98 |
-
min_value (float): The minimum EMA decay rate. Default: 0.
|
99 |
-
max_value (float): The maximum EMA decay rate. Default: 1.
|
100 |
-
start_at (int): The epoch to start averaging at. Default: 0.
|
101 |
-
last_epoch (int): The index of last epoch. Default: 0.
|
102 |
-
"""
|
103 |
-
|
104 |
-
def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0,
|
105 |
-
last_epoch=0):
|
106 |
-
self.inv_gamma = inv_gamma
|
107 |
-
self.power = power
|
108 |
-
self.min_value = min_value
|
109 |
-
self.max_value = max_value
|
110 |
-
self.start_at = start_at
|
111 |
-
self.last_epoch = last_epoch
|
112 |
-
|
113 |
-
def state_dict(self):
|
114 |
-
"""Returns the state of the class as a :class:`dict`."""
|
115 |
-
return dict(self.__dict__.items())
|
116 |
-
|
117 |
-
def load_state_dict(self, state_dict):
|
118 |
-
"""Loads the class's state.
|
119 |
-
Args:
|
120 |
-
state_dict (dict): scaler state. Should be an object returned
|
121 |
-
from a call to :meth:`state_dict`.
|
122 |
-
"""
|
123 |
-
self.__dict__.update(state_dict)
|
124 |
-
|
125 |
-
def get_value(self):
|
126 |
-
"""Gets the current EMA decay rate."""
|
127 |
-
epoch = max(0, self.last_epoch - self.start_at)
|
128 |
-
value = 1 - (1 + epoch / self.inv_gamma) ** -self.power
|
129 |
-
return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value))
|
130 |
-
|
131 |
-
def step(self):
|
132 |
-
"""Updates the step count."""
|
133 |
-
self.last_epoch += 1
|
134 |
-
|
135 |
-
|
136 |
-
class InverseLR(optim.lr_scheduler._LRScheduler):
|
137 |
-
"""Implements an inverse decay learning rate schedule with an optional exponential
|
138 |
-
warmup. When last_epoch=-1, sets initial lr as lr.
|
139 |
-
inv_gamma is the number of steps/epochs required for the learning rate to decay to
|
140 |
-
(1 / 2)**power of its original value.
|
141 |
-
Args:
|
142 |
-
optimizer (Optimizer): Wrapped optimizer.
|
143 |
-
inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1.
|
144 |
-
power (float): Exponential factor of learning rate decay. Default: 1.
|
145 |
-
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
|
146 |
-
Default: 0.
|
147 |
-
min_lr (float): The minimum learning rate. Default: 0.
|
148 |
-
last_epoch (int): The index of last epoch. Default: -1.
|
149 |
-
verbose (bool): If ``True``, prints a message to stdout for
|
150 |
-
each update. Default: ``False``.
|
151 |
-
"""
|
152 |
-
|
153 |
-
def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0.,
|
154 |
-
last_epoch=-1, verbose=False):
|
155 |
-
self.inv_gamma = inv_gamma
|
156 |
-
self.power = power
|
157 |
-
if not 0. <= warmup < 1:
|
158 |
-
raise ValueError('Invalid value for warmup')
|
159 |
-
self.warmup = warmup
|
160 |
-
self.min_lr = min_lr
|
161 |
-
super().__init__(optimizer, last_epoch, verbose)
|
162 |
-
|
163 |
-
def get_lr(self):
|
164 |
-
if not self._get_lr_called_within_step:
|
165 |
-
warnings.warn("To get the last learning rate computed by the scheduler, "
|
166 |
-
"please use `get_last_lr()`.")
|
167 |
-
|
168 |
-
return self._get_closed_form_lr()
|
169 |
-
|
170 |
-
def _get_closed_form_lr(self):
|
171 |
-
warmup = 1 - self.warmup ** (self.last_epoch + 1)
|
172 |
-
lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power
|
173 |
-
return [warmup * max(self.min_lr, base_lr * lr_mult)
|
174 |
-
for base_lr in self.base_lrs]
|
175 |
-
|
176 |
-
|
177 |
-
class ExponentialLR(optim.lr_scheduler._LRScheduler):
|
178 |
-
"""Implements an exponential learning rate schedule with an optional exponential
|
179 |
-
warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate
|
180 |
-
continuously by decay (default 0.5) every num_steps steps.
|
181 |
-
Args:
|
182 |
-
optimizer (Optimizer): Wrapped optimizer.
|
183 |
-
num_steps (float): The number of steps to decay the learning rate by decay in.
|
184 |
-
decay (float): The factor by which to decay the learning rate every num_steps
|
185 |
-
steps. Default: 0.5.
|
186 |
-
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
|
187 |
-
Default: 0.
|
188 |
-
min_lr (float): The minimum learning rate. Default: 0.
|
189 |
-
last_epoch (int): The index of last epoch. Default: -1.
|
190 |
-
verbose (bool): If ``True``, prints a message to stdout for
|
191 |
-
each update. Default: ``False``.
|
192 |
-
"""
|
193 |
-
|
194 |
-
def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0.,
|
195 |
-
last_epoch=-1, verbose=False):
|
196 |
-
self.num_steps = num_steps
|
197 |
-
self.decay = decay
|
198 |
-
if not 0. <= warmup < 1:
|
199 |
-
raise ValueError('Invalid value for warmup')
|
200 |
-
self.warmup = warmup
|
201 |
-
self.min_lr = min_lr
|
202 |
-
super().__init__(optimizer, last_epoch, verbose)
|
203 |
-
|
204 |
-
def get_lr(self):
|
205 |
-
if not self._get_lr_called_within_step:
|
206 |
-
warnings.warn("To get the last learning rate computed by the scheduler, "
|
207 |
-
"please use `get_last_lr()`.")
|
208 |
-
|
209 |
-
return self._get_closed_form_lr()
|
210 |
-
|
211 |
-
def _get_closed_form_lr(self):
|
212 |
-
warmup = 1 - self.warmup ** (self.last_epoch + 1)
|
213 |
-
lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch
|
214 |
-
return [warmup * max(self.min_lr, base_lr * lr_mult)
|
215 |
-
for base_lr in self.base_lrs]
|
216 |
-
|
217 |
-
|
218 |
-
def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
|
219 |
-
"""Draws samples from an lognormal distribution."""
|
220 |
-
return (torch.randn(shape, device=device, dtype=dtype) * scale + loc).exp()
|
221 |
-
|
222 |
-
|
223 |
-
def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
|
224 |
-
"""Draws samples from an optionally truncated log-logistic distribution."""
|
225 |
-
min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64)
|
226 |
-
max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64)
|
227 |
-
min_cdf = min_value.log().sub(loc).div(scale).sigmoid()
|
228 |
-
max_cdf = max_value.log().sub(loc).div(scale).sigmoid()
|
229 |
-
u = torch.rand(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf
|
230 |
-
return u.logit().mul(scale).add(loc).exp().to(dtype)
|
231 |
-
|
232 |
-
|
233 |
-
def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32):
|
234 |
-
"""Draws samples from an log-uniform distribution."""
|
235 |
-
min_value = math.log(min_value)
|
236 |
-
max_value = math.log(max_value)
|
237 |
-
return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp()
|
238 |
-
|
239 |
-
|
240 |
-
def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
|
241 |
-
"""Draws samples from a truncated v-diffusion training timestep distribution."""
|
242 |
-
min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi
|
243 |
-
max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi
|
244 |
-
u = torch.rand(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf
|
245 |
-
return torch.tan(u * math.pi / 2) * sigma_data
|
246 |
-
|
247 |
-
|
248 |
-
def rand_split_log_normal(shape, loc, scale_1, scale_2, device='cpu', dtype=torch.float32):
|
249 |
-
"""Draws samples from a split lognormal distribution."""
|
250 |
-
n = torch.randn(shape, device=device, dtype=dtype).abs()
|
251 |
-
u = torch.rand(shape, device=device, dtype=dtype)
|
252 |
-
n_left = n * -scale_1 + loc
|
253 |
-
n_right = n * scale_2 + loc
|
254 |
-
ratio = scale_1 / (scale_1 + scale_2)
|
255 |
-
return torch.where(u < ratio, n_left, n_right).exp()
|
256 |
-
|
257 |
-
|
258 |
-
class FolderOfImages(data.Dataset):
|
259 |
-
"""Recursively finds all images in a directory. It does not support
|
260 |
-
classes/targets."""
|
261 |
-
|
262 |
-
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'}
|
263 |
-
|
264 |
-
def __init__(self, root, transform=None):
|
265 |
-
super().__init__()
|
266 |
-
self.root = Path(root)
|
267 |
-
self.transform = nn.Identity() if transform is None else transform
|
268 |
-
self.paths = sorted(path for path in self.root.rglob('*') if path.suffix.lower() in self.IMG_EXTENSIONS)
|
269 |
-
|
270 |
-
def __repr__(self):
|
271 |
-
return f'FolderOfImages(root="{self.root}", len: {len(self)})'
|
272 |
-
|
273 |
-
def __len__(self):
|
274 |
-
return len(self.paths)
|
275 |
-
|
276 |
-
def __getitem__(self, key):
|
277 |
-
path = self.paths[key]
|
278 |
-
with open(path, 'rb') as f:
|
279 |
-
image = Image.open(f).convert('RGB')
|
280 |
-
image = self.transform(image)
|
281 |
-
return image,
|
282 |
-
|
283 |
-
|
284 |
-
class CSVLogger:
|
285 |
-
def __init__(self, filename, columns):
|
286 |
-
self.filename = Path(filename)
|
287 |
-
self.columns = columns
|
288 |
-
if self.filename.exists():
|
289 |
-
self.file = open(self.filename, 'a')
|
290 |
-
else:
|
291 |
-
self.file = open(self.filename, 'w')
|
292 |
-
self.write(*self.columns)
|
293 |
-
|
294 |
-
def write(self, *args):
|
295 |
-
print(*args, sep=',', file=self.file, flush=True)
|
296 |
-
|
297 |
-
|
298 |
-
@contextmanager
|
299 |
-
def tf32_mode(cudnn=None, matmul=None):
|
300 |
-
"""A context manager that sets whether TF32 is allowed on cuDNN or matmul."""
|
301 |
-
cudnn_old = torch.backends.cudnn.allow_tf32
|
302 |
-
matmul_old = torch.backends.cuda.matmul.allow_tf32
|
303 |
-
try:
|
304 |
-
if cudnn is not None:
|
305 |
-
torch.backends.cudnn.allow_tf32 = cudnn
|
306 |
-
if matmul is not None:
|
307 |
-
torch.backends.cuda.matmul.allow_tf32 = matmul
|
308 |
-
yield
|
309 |
-
finally:
|
310 |
-
if cudnn is not None:
|
311 |
-
torch.backends.cudnn.allow_tf32 = cudnn_old
|
312 |
-
if matmul is not None:
|
313 |
-
torch.backends.cuda.matmul.allow_tf32 = matmul_old
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/comfy/latent_formats.py
DELETED
@@ -1,141 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
class LatentFormat:
|
4 |
-
scale_factor = 1.0
|
5 |
-
latent_channels = 4
|
6 |
-
latent_rgb_factors = None
|
7 |
-
taesd_decoder_name = None
|
8 |
-
|
9 |
-
def process_in(self, latent):
|
10 |
-
return latent * self.scale_factor
|
11 |
-
|
12 |
-
def process_out(self, latent):
|
13 |
-
return latent / self.scale_factor
|
14 |
-
|
15 |
-
class SD15(LatentFormat):
|
16 |
-
def __init__(self, scale_factor=0.18215):
|
17 |
-
self.scale_factor = scale_factor
|
18 |
-
self.latent_rgb_factors = [
|
19 |
-
# R G B
|
20 |
-
[ 0.3512, 0.2297, 0.3227],
|
21 |
-
[ 0.3250, 0.4974, 0.2350],
|
22 |
-
[-0.2829, 0.1762, 0.2721],
|
23 |
-
[-0.2120, -0.2616, -0.7177]
|
24 |
-
]
|
25 |
-
self.taesd_decoder_name = "taesd_decoder"
|
26 |
-
|
27 |
-
class SDXL(LatentFormat):
|
28 |
-
scale_factor = 0.13025
|
29 |
-
|
30 |
-
def __init__(self):
|
31 |
-
self.latent_rgb_factors = [
|
32 |
-
# R G B
|
33 |
-
[ 0.3920, 0.4054, 0.4549],
|
34 |
-
[-0.2634, -0.0196, 0.0653],
|
35 |
-
[ 0.0568, 0.1687, -0.0755],
|
36 |
-
[-0.3112, -0.2359, -0.2076]
|
37 |
-
]
|
38 |
-
self.taesd_decoder_name = "taesdxl_decoder"
|
39 |
-
|
40 |
-
class SDXL_Playground_2_5(LatentFormat):
|
41 |
-
def __init__(self):
|
42 |
-
self.scale_factor = 0.5
|
43 |
-
self.latents_mean = torch.tensor([-1.6574, 1.886, -1.383, 2.5155]).view(1, 4, 1, 1)
|
44 |
-
self.latents_std = torch.tensor([8.4927, 5.9022, 6.5498, 5.2299]).view(1, 4, 1, 1)
|
45 |
-
|
46 |
-
self.latent_rgb_factors = [
|
47 |
-
# R G B
|
48 |
-
[ 0.3920, 0.4054, 0.4549],
|
49 |
-
[-0.2634, -0.0196, 0.0653],
|
50 |
-
[ 0.0568, 0.1687, -0.0755],
|
51 |
-
[-0.3112, -0.2359, -0.2076]
|
52 |
-
]
|
53 |
-
self.taesd_decoder_name = "taesdxl_decoder"
|
54 |
-
|
55 |
-
def process_in(self, latent):
|
56 |
-
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
57 |
-
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
58 |
-
return (latent - latents_mean) * self.scale_factor / latents_std
|
59 |
-
|
60 |
-
def process_out(self, latent):
|
61 |
-
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
62 |
-
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
63 |
-
return latent * latents_std / self.scale_factor + latents_mean
|
64 |
-
|
65 |
-
|
66 |
-
class SD_X4(LatentFormat):
|
67 |
-
def __init__(self):
|
68 |
-
self.scale_factor = 0.08333
|
69 |
-
self.latent_rgb_factors = [
|
70 |
-
[-0.2340, -0.3863, -0.3257],
|
71 |
-
[ 0.0994, 0.0885, -0.0908],
|
72 |
-
[-0.2833, -0.2349, -0.3741],
|
73 |
-
[ 0.2523, -0.0055, -0.1651]
|
74 |
-
]
|
75 |
-
|
76 |
-
class SC_Prior(LatentFormat):
|
77 |
-
latent_channels = 16
|
78 |
-
def __init__(self):
|
79 |
-
self.scale_factor = 1.0
|
80 |
-
self.latent_rgb_factors = [
|
81 |
-
[-0.0326, -0.0204, -0.0127],
|
82 |
-
[-0.1592, -0.0427, 0.0216],
|
83 |
-
[ 0.0873, 0.0638, -0.0020],
|
84 |
-
[-0.0602, 0.0442, 0.1304],
|
85 |
-
[ 0.0800, -0.0313, -0.1796],
|
86 |
-
[-0.0810, -0.0638, -0.1581],
|
87 |
-
[ 0.1791, 0.1180, 0.0967],
|
88 |
-
[ 0.0740, 0.1416, 0.0432],
|
89 |
-
[-0.1745, -0.1888, -0.1373],
|
90 |
-
[ 0.2412, 0.1577, 0.0928],
|
91 |
-
[ 0.1908, 0.0998, 0.0682],
|
92 |
-
[ 0.0209, 0.0365, -0.0092],
|
93 |
-
[ 0.0448, -0.0650, -0.1728],
|
94 |
-
[-0.1658, -0.1045, -0.1308],
|
95 |
-
[ 0.0542, 0.1545, 0.1325],
|
96 |
-
[-0.0352, -0.1672, -0.2541]
|
97 |
-
]
|
98 |
-
|
99 |
-
class SC_B(LatentFormat):
|
100 |
-
def __init__(self):
|
101 |
-
self.scale_factor = 1.0 / 0.43
|
102 |
-
self.latent_rgb_factors = [
|
103 |
-
[ 0.1121, 0.2006, 0.1023],
|
104 |
-
[-0.2093, -0.0222, -0.0195],
|
105 |
-
[-0.3087, -0.1535, 0.0366],
|
106 |
-
[ 0.0290, -0.1574, -0.4078]
|
107 |
-
]
|
108 |
-
|
109 |
-
class SD3(LatentFormat):
|
110 |
-
latent_channels = 16
|
111 |
-
def __init__(self):
|
112 |
-
self.scale_factor = 1.5305
|
113 |
-
self.shift_factor = 0.0609
|
114 |
-
self.latent_rgb_factors = [
|
115 |
-
[-0.0645, 0.0177, 0.1052],
|
116 |
-
[ 0.0028, 0.0312, 0.0650],
|
117 |
-
[ 0.1848, 0.0762, 0.0360],
|
118 |
-
[ 0.0944, 0.0360, 0.0889],
|
119 |
-
[ 0.0897, 0.0506, -0.0364],
|
120 |
-
[-0.0020, 0.1203, 0.0284],
|
121 |
-
[ 0.0855, 0.0118, 0.0283],
|
122 |
-
[-0.0539, 0.0658, 0.1047],
|
123 |
-
[-0.0057, 0.0116, 0.0700],
|
124 |
-
[-0.0412, 0.0281, -0.0039],
|
125 |
-
[ 0.1106, 0.1171, 0.1220],
|
126 |
-
[-0.0248, 0.0682, -0.0481],
|
127 |
-
[ 0.0815, 0.0846, 0.1207],
|
128 |
-
[-0.0120, -0.0055, -0.0867],
|
129 |
-
[-0.0749, -0.0634, -0.0456],
|
130 |
-
[-0.1418, -0.1457, -0.1259]
|
131 |
-
]
|
132 |
-
self.taesd_decoder_name = "taesd3_decoder"
|
133 |
-
|
134 |
-
def process_in(self, latent):
|
135 |
-
return (latent - self.shift_factor) * self.scale_factor
|
136 |
-
|
137 |
-
def process_out(self, latent):
|
138 |
-
return (latent / self.scale_factor) + self.shift_factor
|
139 |
-
|
140 |
-
class StableAudio1(LatentFormat):
|
141 |
-
latent_channels = 64
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/comfy/ldm/.DS_Store
DELETED
Binary file (6.15 kB)
|
|
MagicQuill/comfy/ldm/__pycache__/util.cpython-310.pyc
DELETED
Binary file (6.19 kB)
|
|
MagicQuill/comfy/ldm/audio/__pycache__/autoencoder.cpython-310.pyc
DELETED
Binary file (8.08 kB)
|
|
MagicQuill/comfy/ldm/audio/__pycache__/dit.cpython-310.pyc
DELETED
Binary file (18.7 kB)
|
|
MagicQuill/comfy/ldm/audio/__pycache__/embedders.cpython-310.pyc
DELETED
Binary file (4.34 kB)
|
|
MagicQuill/comfy/ldm/audio/autoencoder.py
DELETED
@@ -1,282 +0,0 @@
|
|
1 |
-
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from torch import nn
|
5 |
-
from typing import Literal, Dict, Any
|
6 |
-
import math
|
7 |
-
import comfy.ops
|
8 |
-
ops = comfy.ops.disable_weight_init
|
9 |
-
|
10 |
-
def vae_sample(mean, scale):
|
11 |
-
stdev = nn.functional.softplus(scale) + 1e-4
|
12 |
-
var = stdev * stdev
|
13 |
-
logvar = torch.log(var)
|
14 |
-
latents = torch.randn_like(mean) * stdev + mean
|
15 |
-
|
16 |
-
kl = (mean * mean + var - logvar - 1).sum(1).mean()
|
17 |
-
|
18 |
-
return latents, kl
|
19 |
-
|
20 |
-
class VAEBottleneck(nn.Module):
|
21 |
-
def __init__(self):
|
22 |
-
super().__init__()
|
23 |
-
self.is_discrete = False
|
24 |
-
|
25 |
-
def encode(self, x, return_info=False, **kwargs):
|
26 |
-
info = {}
|
27 |
-
|
28 |
-
mean, scale = x.chunk(2, dim=1)
|
29 |
-
|
30 |
-
x, kl = vae_sample(mean, scale)
|
31 |
-
|
32 |
-
info["kl"] = kl
|
33 |
-
|
34 |
-
if return_info:
|
35 |
-
return x, info
|
36 |
-
else:
|
37 |
-
return x
|
38 |
-
|
39 |
-
def decode(self, x):
|
40 |
-
return x
|
41 |
-
|
42 |
-
|
43 |
-
def snake_beta(x, alpha, beta):
|
44 |
-
return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
|
45 |
-
|
46 |
-
# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
|
47 |
-
class SnakeBeta(nn.Module):
|
48 |
-
|
49 |
-
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
|
50 |
-
super(SnakeBeta, self).__init__()
|
51 |
-
self.in_features = in_features
|
52 |
-
|
53 |
-
# initialize alpha
|
54 |
-
self.alpha_logscale = alpha_logscale
|
55 |
-
if self.alpha_logscale: # log scale alphas initialized to zeros
|
56 |
-
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
|
57 |
-
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
|
58 |
-
else: # linear scale alphas initialized to ones
|
59 |
-
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
|
60 |
-
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
|
61 |
-
|
62 |
-
# self.alpha.requires_grad = alpha_trainable
|
63 |
-
# self.beta.requires_grad = alpha_trainable
|
64 |
-
|
65 |
-
self.no_div_by_zero = 0.000000001
|
66 |
-
|
67 |
-
def forward(self, x):
|
68 |
-
alpha = self.alpha.unsqueeze(0).unsqueeze(-1).to(x.device) # line up with x to [B, C, T]
|
69 |
-
beta = self.beta.unsqueeze(0).unsqueeze(-1).to(x.device)
|
70 |
-
if self.alpha_logscale:
|
71 |
-
alpha = torch.exp(alpha)
|
72 |
-
beta = torch.exp(beta)
|
73 |
-
x = snake_beta(x, alpha, beta)
|
74 |
-
|
75 |
-
return x
|
76 |
-
|
77 |
-
def WNConv1d(*args, **kwargs):
|
78 |
-
try:
|
79 |
-
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
|
80 |
-
except:
|
81 |
-
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older
|
82 |
-
|
83 |
-
def WNConvTranspose1d(*args, **kwargs):
|
84 |
-
try:
|
85 |
-
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
|
86 |
-
except:
|
87 |
-
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older
|
88 |
-
|
89 |
-
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
90 |
-
if activation == "elu":
|
91 |
-
act = torch.nn.ELU()
|
92 |
-
elif activation == "snake":
|
93 |
-
act = SnakeBeta(channels)
|
94 |
-
elif activation == "none":
|
95 |
-
act = torch.nn.Identity()
|
96 |
-
else:
|
97 |
-
raise ValueError(f"Unknown activation {activation}")
|
98 |
-
|
99 |
-
if antialias:
|
100 |
-
act = Activation1d(act)
|
101 |
-
|
102 |
-
return act
|
103 |
-
|
104 |
-
|
105 |
-
class ResidualUnit(nn.Module):
|
106 |
-
def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
|
107 |
-
super().__init__()
|
108 |
-
|
109 |
-
self.dilation = dilation
|
110 |
-
|
111 |
-
padding = (dilation * (7-1)) // 2
|
112 |
-
|
113 |
-
self.layers = nn.Sequential(
|
114 |
-
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
115 |
-
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
116 |
-
kernel_size=7, dilation=dilation, padding=padding),
|
117 |
-
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
118 |
-
WNConv1d(in_channels=out_channels, out_channels=out_channels,
|
119 |
-
kernel_size=1)
|
120 |
-
)
|
121 |
-
|
122 |
-
def forward(self, x):
|
123 |
-
res = x
|
124 |
-
|
125 |
-
#x = checkpoint(self.layers, x)
|
126 |
-
x = self.layers(x)
|
127 |
-
|
128 |
-
return x + res
|
129 |
-
|
130 |
-
class EncoderBlock(nn.Module):
|
131 |
-
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
|
132 |
-
super().__init__()
|
133 |
-
|
134 |
-
self.layers = nn.Sequential(
|
135 |
-
ResidualUnit(in_channels=in_channels,
|
136 |
-
out_channels=in_channels, dilation=1, use_snake=use_snake),
|
137 |
-
ResidualUnit(in_channels=in_channels,
|
138 |
-
out_channels=in_channels, dilation=3, use_snake=use_snake),
|
139 |
-
ResidualUnit(in_channels=in_channels,
|
140 |
-
out_channels=in_channels, dilation=9, use_snake=use_snake),
|
141 |
-
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
142 |
-
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
143 |
-
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
|
144 |
-
)
|
145 |
-
|
146 |
-
def forward(self, x):
|
147 |
-
return self.layers(x)
|
148 |
-
|
149 |
-
class DecoderBlock(nn.Module):
|
150 |
-
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
|
151 |
-
super().__init__()
|
152 |
-
|
153 |
-
if use_nearest_upsample:
|
154 |
-
upsample_layer = nn.Sequential(
|
155 |
-
nn.Upsample(scale_factor=stride, mode="nearest"),
|
156 |
-
WNConv1d(in_channels=in_channels,
|
157 |
-
out_channels=out_channels,
|
158 |
-
kernel_size=2*stride,
|
159 |
-
stride=1,
|
160 |
-
bias=False,
|
161 |
-
padding='same')
|
162 |
-
)
|
163 |
-
else:
|
164 |
-
upsample_layer = WNConvTranspose1d(in_channels=in_channels,
|
165 |
-
out_channels=out_channels,
|
166 |
-
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
|
167 |
-
|
168 |
-
self.layers = nn.Sequential(
|
169 |
-
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
170 |
-
upsample_layer,
|
171 |
-
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
172 |
-
dilation=1, use_snake=use_snake),
|
173 |
-
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
174 |
-
dilation=3, use_snake=use_snake),
|
175 |
-
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
176 |
-
dilation=9, use_snake=use_snake),
|
177 |
-
)
|
178 |
-
|
179 |
-
def forward(self, x):
|
180 |
-
return self.layers(x)
|
181 |
-
|
182 |
-
class OobleckEncoder(nn.Module):
|
183 |
-
def __init__(self,
|
184 |
-
in_channels=2,
|
185 |
-
channels=128,
|
186 |
-
latent_dim=32,
|
187 |
-
c_mults = [1, 2, 4, 8],
|
188 |
-
strides = [2, 4, 8, 8],
|
189 |
-
use_snake=False,
|
190 |
-
antialias_activation=False
|
191 |
-
):
|
192 |
-
super().__init__()
|
193 |
-
|
194 |
-
c_mults = [1] + c_mults
|
195 |
-
|
196 |
-
self.depth = len(c_mults)
|
197 |
-
|
198 |
-
layers = [
|
199 |
-
WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
|
200 |
-
]
|
201 |
-
|
202 |
-
for i in range(self.depth-1):
|
203 |
-
layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
|
204 |
-
|
205 |
-
layers += [
|
206 |
-
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
|
207 |
-
WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
|
208 |
-
]
|
209 |
-
|
210 |
-
self.layers = nn.Sequential(*layers)
|
211 |
-
|
212 |
-
def forward(self, x):
|
213 |
-
return self.layers(x)
|
214 |
-
|
215 |
-
|
216 |
-
class OobleckDecoder(nn.Module):
|
217 |
-
def __init__(self,
|
218 |
-
out_channels=2,
|
219 |
-
channels=128,
|
220 |
-
latent_dim=32,
|
221 |
-
c_mults = [1, 2, 4, 8],
|
222 |
-
strides = [2, 4, 8, 8],
|
223 |
-
use_snake=False,
|
224 |
-
antialias_activation=False,
|
225 |
-
use_nearest_upsample=False,
|
226 |
-
final_tanh=True):
|
227 |
-
super().__init__()
|
228 |
-
|
229 |
-
c_mults = [1] + c_mults
|
230 |
-
|
231 |
-
self.depth = len(c_mults)
|
232 |
-
|
233 |
-
layers = [
|
234 |
-
WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
|
235 |
-
]
|
236 |
-
|
237 |
-
for i in range(self.depth-1, 0, -1):
|
238 |
-
layers += [DecoderBlock(
|
239 |
-
in_channels=c_mults[i]*channels,
|
240 |
-
out_channels=c_mults[i-1]*channels,
|
241 |
-
stride=strides[i-1],
|
242 |
-
use_snake=use_snake,
|
243 |
-
antialias_activation=antialias_activation,
|
244 |
-
use_nearest_upsample=use_nearest_upsample
|
245 |
-
)
|
246 |
-
]
|
247 |
-
|
248 |
-
layers += [
|
249 |
-
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
|
250 |
-
WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
|
251 |
-
nn.Tanh() if final_tanh else nn.Identity()
|
252 |
-
]
|
253 |
-
|
254 |
-
self.layers = nn.Sequential(*layers)
|
255 |
-
|
256 |
-
def forward(self, x):
|
257 |
-
return self.layers(x)
|
258 |
-
|
259 |
-
|
260 |
-
class AudioOobleckVAE(nn.Module):
|
261 |
-
def __init__(self,
|
262 |
-
in_channels=2,
|
263 |
-
channels=128,
|
264 |
-
latent_dim=64,
|
265 |
-
c_mults = [1, 2, 4, 8, 16],
|
266 |
-
strides = [2, 4, 4, 8, 8],
|
267 |
-
use_snake=True,
|
268 |
-
antialias_activation=False,
|
269 |
-
use_nearest_upsample=False,
|
270 |
-
final_tanh=False):
|
271 |
-
super().__init__()
|
272 |
-
self.encoder = OobleckEncoder(in_channels, channels, latent_dim * 2, c_mults, strides, use_snake, antialias_activation)
|
273 |
-
self.decoder = OobleckDecoder(in_channels, channels, latent_dim, c_mults, strides, use_snake, antialias_activation,
|
274 |
-
use_nearest_upsample=use_nearest_upsample, final_tanh=final_tanh)
|
275 |
-
self.bottleneck = VAEBottleneck()
|
276 |
-
|
277 |
-
def encode(self, x):
|
278 |
-
return self.bottleneck.encode(self.encoder(x))
|
279 |
-
|
280 |
-
def decode(self, x):
|
281 |
-
return self.decoder(self.bottleneck.decode(x))
|
282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/comfy/ldm/audio/dit.py
DELETED
@@ -1,888 +0,0 @@
|
|
1 |
-
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
|
2 |
-
|
3 |
-
from comfy.ldm.modules.attention import optimized_attention
|
4 |
-
import typing as tp
|
5 |
-
|
6 |
-
import torch
|
7 |
-
|
8 |
-
from einops import rearrange
|
9 |
-
from torch import nn
|
10 |
-
from torch.nn import functional as F
|
11 |
-
import math
|
12 |
-
|
13 |
-
class FourierFeatures(nn.Module):
|
14 |
-
def __init__(self, in_features, out_features, std=1., dtype=None, device=None):
|
15 |
-
super().__init__()
|
16 |
-
assert out_features % 2 == 0
|
17 |
-
self.weight = nn.Parameter(torch.empty(
|
18 |
-
[out_features // 2, in_features], dtype=dtype, device=device))
|
19 |
-
|
20 |
-
def forward(self, input):
|
21 |
-
f = 2 * math.pi * input @ self.weight.T.to(dtype=input.dtype, device=input.device)
|
22 |
-
return torch.cat([f.cos(), f.sin()], dim=-1)
|
23 |
-
|
24 |
-
# norms
|
25 |
-
class LayerNorm(nn.Module):
|
26 |
-
def __init__(self, dim, bias=False, fix_scale=False, dtype=None, device=None):
|
27 |
-
"""
|
28 |
-
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
|
29 |
-
"""
|
30 |
-
super().__init__()
|
31 |
-
|
32 |
-
self.gamma = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
|
33 |
-
|
34 |
-
if bias:
|
35 |
-
self.beta = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
|
36 |
-
else:
|
37 |
-
self.beta = None
|
38 |
-
|
39 |
-
def forward(self, x):
|
40 |
-
beta = self.beta
|
41 |
-
if self.beta is not None:
|
42 |
-
beta = beta.to(dtype=x.dtype, device=x.device)
|
43 |
-
return F.layer_norm(x, x.shape[-1:], weight=self.gamma.to(dtype=x.dtype, device=x.device), bias=beta)
|
44 |
-
|
45 |
-
class GLU(nn.Module):
|
46 |
-
def __init__(
|
47 |
-
self,
|
48 |
-
dim_in,
|
49 |
-
dim_out,
|
50 |
-
activation,
|
51 |
-
use_conv = False,
|
52 |
-
conv_kernel_size = 3,
|
53 |
-
dtype=None,
|
54 |
-
device=None,
|
55 |
-
operations=None,
|
56 |
-
):
|
57 |
-
super().__init__()
|
58 |
-
self.act = activation
|
59 |
-
self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2), dtype=dtype, device=device)
|
60 |
-
self.use_conv = use_conv
|
61 |
-
|
62 |
-
def forward(self, x):
|
63 |
-
if self.use_conv:
|
64 |
-
x = rearrange(x, 'b n d -> b d n')
|
65 |
-
x = self.proj(x)
|
66 |
-
x = rearrange(x, 'b d n -> b n d')
|
67 |
-
else:
|
68 |
-
x = self.proj(x)
|
69 |
-
|
70 |
-
x, gate = x.chunk(2, dim = -1)
|
71 |
-
return x * self.act(gate)
|
72 |
-
|
73 |
-
class AbsolutePositionalEmbedding(nn.Module):
|
74 |
-
def __init__(self, dim, max_seq_len):
|
75 |
-
super().__init__()
|
76 |
-
self.scale = dim ** -0.5
|
77 |
-
self.max_seq_len = max_seq_len
|
78 |
-
self.emb = nn.Embedding(max_seq_len, dim)
|
79 |
-
|
80 |
-
def forward(self, x, pos = None, seq_start_pos = None):
|
81 |
-
seq_len, device = x.shape[1], x.device
|
82 |
-
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
|
83 |
-
|
84 |
-
if pos is None:
|
85 |
-
pos = torch.arange(seq_len, device = device)
|
86 |
-
|
87 |
-
if seq_start_pos is not None:
|
88 |
-
pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
|
89 |
-
|
90 |
-
pos_emb = self.emb(pos)
|
91 |
-
pos_emb = pos_emb * self.scale
|
92 |
-
return pos_emb
|
93 |
-
|
94 |
-
class ScaledSinusoidalEmbedding(nn.Module):
|
95 |
-
def __init__(self, dim, theta = 10000):
|
96 |
-
super().__init__()
|
97 |
-
assert (dim % 2) == 0, 'dimension must be divisible by 2'
|
98 |
-
self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
|
99 |
-
|
100 |
-
half_dim = dim // 2
|
101 |
-
freq_seq = torch.arange(half_dim).float() / half_dim
|
102 |
-
inv_freq = theta ** -freq_seq
|
103 |
-
self.register_buffer('inv_freq', inv_freq, persistent = False)
|
104 |
-
|
105 |
-
def forward(self, x, pos = None, seq_start_pos = None):
|
106 |
-
seq_len, device = x.shape[1], x.device
|
107 |
-
|
108 |
-
if pos is None:
|
109 |
-
pos = torch.arange(seq_len, device = device)
|
110 |
-
|
111 |
-
if seq_start_pos is not None:
|
112 |
-
pos = pos - seq_start_pos[..., None]
|
113 |
-
|
114 |
-
emb = torch.einsum('i, j -> i j', pos, self.inv_freq)
|
115 |
-
emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
|
116 |
-
return emb * self.scale
|
117 |
-
|
118 |
-
class RotaryEmbedding(nn.Module):
|
119 |
-
def __init__(
|
120 |
-
self,
|
121 |
-
dim,
|
122 |
-
use_xpos = False,
|
123 |
-
scale_base = 512,
|
124 |
-
interpolation_factor = 1.,
|
125 |
-
base = 10000,
|
126 |
-
base_rescale_factor = 1.
|
127 |
-
):
|
128 |
-
super().__init__()
|
129 |
-
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
130 |
-
# has some connection to NTK literature
|
131 |
-
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
132 |
-
base *= base_rescale_factor ** (dim / (dim - 2))
|
133 |
-
|
134 |
-
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
135 |
-
self.register_buffer('inv_freq', inv_freq)
|
136 |
-
|
137 |
-
assert interpolation_factor >= 1.
|
138 |
-
self.interpolation_factor = interpolation_factor
|
139 |
-
|
140 |
-
if not use_xpos:
|
141 |
-
self.register_buffer('scale', None)
|
142 |
-
return
|
143 |
-
|
144 |
-
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
145 |
-
|
146 |
-
self.scale_base = scale_base
|
147 |
-
self.register_buffer('scale', scale)
|
148 |
-
|
149 |
-
def forward_from_seq_len(self, seq_len, device, dtype):
|
150 |
-
# device = self.inv_freq.device
|
151 |
-
|
152 |
-
t = torch.arange(seq_len, device=device, dtype=dtype)
|
153 |
-
return self.forward(t)
|
154 |
-
|
155 |
-
def forward(self, t):
|
156 |
-
# device = self.inv_freq.device
|
157 |
-
device = t.device
|
158 |
-
dtype = t.dtype
|
159 |
-
|
160 |
-
# t = t.to(torch.float32)
|
161 |
-
|
162 |
-
t = t / self.interpolation_factor
|
163 |
-
|
164 |
-
freqs = torch.einsum('i , j -> i j', t, self.inv_freq.to(dtype=dtype, device=device))
|
165 |
-
freqs = torch.cat((freqs, freqs), dim = -1)
|
166 |
-
|
167 |
-
if self.scale is None:
|
168 |
-
return freqs, 1.
|
169 |
-
|
170 |
-
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
|
171 |
-
scale = self.scale.to(dtype=dtype, device=device) ** rearrange(power, 'n -> n 1')
|
172 |
-
scale = torch.cat((scale, scale), dim = -1)
|
173 |
-
|
174 |
-
return freqs, scale
|
175 |
-
|
176 |
-
def rotate_half(x):
|
177 |
-
x = rearrange(x, '... (j d) -> ... j d', j = 2)
|
178 |
-
x1, x2 = x.unbind(dim = -2)
|
179 |
-
return torch.cat((-x2, x1), dim = -1)
|
180 |
-
|
181 |
-
def apply_rotary_pos_emb(t, freqs, scale = 1):
|
182 |
-
out_dtype = t.dtype
|
183 |
-
|
184 |
-
# cast to float32 if necessary for numerical stability
|
185 |
-
dtype = t.dtype #reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
|
186 |
-
rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
|
187 |
-
freqs, t = freqs.to(dtype), t.to(dtype)
|
188 |
-
freqs = freqs[-seq_len:, :]
|
189 |
-
|
190 |
-
if t.ndim == 4 and freqs.ndim == 3:
|
191 |
-
freqs = rearrange(freqs, 'b n d -> b 1 n d')
|
192 |
-
|
193 |
-
# partial rotary embeddings, Wang et al. GPT-J
|
194 |
-
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
|
195 |
-
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
196 |
-
|
197 |
-
t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
|
198 |
-
|
199 |
-
return torch.cat((t, t_unrotated), dim = -1)
|
200 |
-
|
201 |
-
class FeedForward(nn.Module):
|
202 |
-
def __init__(
|
203 |
-
self,
|
204 |
-
dim,
|
205 |
-
dim_out = None,
|
206 |
-
mult = 4,
|
207 |
-
no_bias = False,
|
208 |
-
glu = True,
|
209 |
-
use_conv = False,
|
210 |
-
conv_kernel_size = 3,
|
211 |
-
zero_init_output = True,
|
212 |
-
dtype=None,
|
213 |
-
device=None,
|
214 |
-
operations=None,
|
215 |
-
):
|
216 |
-
super().__init__()
|
217 |
-
inner_dim = int(dim * mult)
|
218 |
-
|
219 |
-
# Default to SwiGLU
|
220 |
-
|
221 |
-
activation = nn.SiLU()
|
222 |
-
|
223 |
-
dim_out = dim if dim_out is None else dim_out
|
224 |
-
|
225 |
-
if glu:
|
226 |
-
linear_in = GLU(dim, inner_dim, activation, dtype=dtype, device=device, operations=operations)
|
227 |
-
else:
|
228 |
-
linear_in = nn.Sequential(
|
229 |
-
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
230 |
-
operations.Linear(dim, inner_dim, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device),
|
231 |
-
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
232 |
-
activation
|
233 |
-
)
|
234 |
-
|
235 |
-
linear_out = operations.Linear(inner_dim, dim_out, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device)
|
236 |
-
|
237 |
-
# # init last linear layer to 0
|
238 |
-
# if zero_init_output:
|
239 |
-
# nn.init.zeros_(linear_out.weight)
|
240 |
-
# if not no_bias:
|
241 |
-
# nn.init.zeros_(linear_out.bias)
|
242 |
-
|
243 |
-
|
244 |
-
self.ff = nn.Sequential(
|
245 |
-
linear_in,
|
246 |
-
Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
|
247 |
-
linear_out,
|
248 |
-
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
249 |
-
)
|
250 |
-
|
251 |
-
def forward(self, x):
|
252 |
-
return self.ff(x)
|
253 |
-
|
254 |
-
class Attention(nn.Module):
|
255 |
-
def __init__(
|
256 |
-
self,
|
257 |
-
dim,
|
258 |
-
dim_heads = 64,
|
259 |
-
dim_context = None,
|
260 |
-
causal = False,
|
261 |
-
zero_init_output=True,
|
262 |
-
qk_norm = False,
|
263 |
-
natten_kernel_size = None,
|
264 |
-
dtype=None,
|
265 |
-
device=None,
|
266 |
-
operations=None,
|
267 |
-
):
|
268 |
-
super().__init__()
|
269 |
-
self.dim = dim
|
270 |
-
self.dim_heads = dim_heads
|
271 |
-
self.causal = causal
|
272 |
-
|
273 |
-
dim_kv = dim_context if dim_context is not None else dim
|
274 |
-
|
275 |
-
self.num_heads = dim // dim_heads
|
276 |
-
self.kv_heads = dim_kv // dim_heads
|
277 |
-
|
278 |
-
if dim_context is not None:
|
279 |
-
self.to_q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
280 |
-
self.to_kv = operations.Linear(dim_kv, dim_kv * 2, bias=False, dtype=dtype, device=device)
|
281 |
-
else:
|
282 |
-
self.to_qkv = operations.Linear(dim, dim * 3, bias=False, dtype=dtype, device=device)
|
283 |
-
|
284 |
-
self.to_out = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
285 |
-
|
286 |
-
# if zero_init_output:
|
287 |
-
# nn.init.zeros_(self.to_out.weight)
|
288 |
-
|
289 |
-
self.qk_norm = qk_norm
|
290 |
-
|
291 |
-
|
292 |
-
def forward(
|
293 |
-
self,
|
294 |
-
x,
|
295 |
-
context = None,
|
296 |
-
mask = None,
|
297 |
-
context_mask = None,
|
298 |
-
rotary_pos_emb = None,
|
299 |
-
causal = None
|
300 |
-
):
|
301 |
-
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
|
302 |
-
|
303 |
-
kv_input = context if has_context else x
|
304 |
-
|
305 |
-
if hasattr(self, 'to_q'):
|
306 |
-
# Use separate linear projections for q and k/v
|
307 |
-
q = self.to_q(x)
|
308 |
-
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
|
309 |
-
|
310 |
-
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
311 |
-
|
312 |
-
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
|
313 |
-
else:
|
314 |
-
# Use fused linear projection
|
315 |
-
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
316 |
-
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
317 |
-
|
318 |
-
# Normalize q and k for cosine sim attention
|
319 |
-
if self.qk_norm:
|
320 |
-
q = F.normalize(q, dim=-1)
|
321 |
-
k = F.normalize(k, dim=-1)
|
322 |
-
|
323 |
-
if rotary_pos_emb is not None and not has_context:
|
324 |
-
freqs, _ = rotary_pos_emb
|
325 |
-
|
326 |
-
q_dtype = q.dtype
|
327 |
-
k_dtype = k.dtype
|
328 |
-
|
329 |
-
q = q.to(torch.float32)
|
330 |
-
k = k.to(torch.float32)
|
331 |
-
freqs = freqs.to(torch.float32)
|
332 |
-
|
333 |
-
q = apply_rotary_pos_emb(q, freqs)
|
334 |
-
k = apply_rotary_pos_emb(k, freqs)
|
335 |
-
|
336 |
-
q = q.to(q_dtype)
|
337 |
-
k = k.to(k_dtype)
|
338 |
-
|
339 |
-
input_mask = context_mask
|
340 |
-
|
341 |
-
if input_mask is None and not has_context:
|
342 |
-
input_mask = mask
|
343 |
-
|
344 |
-
# determine masking
|
345 |
-
masks = []
|
346 |
-
final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
|
347 |
-
|
348 |
-
if input_mask is not None:
|
349 |
-
input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
|
350 |
-
masks.append(~input_mask)
|
351 |
-
|
352 |
-
# Other masks will be added here later
|
353 |
-
|
354 |
-
if len(masks) > 0:
|
355 |
-
final_attn_mask = ~or_reduce(masks)
|
356 |
-
|
357 |
-
n, device = q.shape[-2], q.device
|
358 |
-
|
359 |
-
causal = self.causal if causal is None else causal
|
360 |
-
|
361 |
-
if n == 1 and causal:
|
362 |
-
causal = False
|
363 |
-
|
364 |
-
if h != kv_h:
|
365 |
-
# Repeat interleave kv_heads to match q_heads
|
366 |
-
heads_per_kv_head = h // kv_h
|
367 |
-
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
|
368 |
-
|
369 |
-
out = optimized_attention(q, k, v, h, skip_reshape=True)
|
370 |
-
out = self.to_out(out)
|
371 |
-
|
372 |
-
if mask is not None:
|
373 |
-
mask = rearrange(mask, 'b n -> b n 1')
|
374 |
-
out = out.masked_fill(~mask, 0.)
|
375 |
-
|
376 |
-
return out
|
377 |
-
|
378 |
-
class ConformerModule(nn.Module):
|
379 |
-
def __init__(
|
380 |
-
self,
|
381 |
-
dim,
|
382 |
-
norm_kwargs = {},
|
383 |
-
):
|
384 |
-
|
385 |
-
super().__init__()
|
386 |
-
|
387 |
-
self.dim = dim
|
388 |
-
|
389 |
-
self.in_norm = LayerNorm(dim, **norm_kwargs)
|
390 |
-
self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
391 |
-
self.glu = GLU(dim, dim, nn.SiLU())
|
392 |
-
self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
|
393 |
-
self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
|
394 |
-
self.swish = nn.SiLU()
|
395 |
-
self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
396 |
-
|
397 |
-
def forward(self, x):
|
398 |
-
x = self.in_norm(x)
|
399 |
-
x = rearrange(x, 'b n d -> b d n')
|
400 |
-
x = self.pointwise_conv(x)
|
401 |
-
x = rearrange(x, 'b d n -> b n d')
|
402 |
-
x = self.glu(x)
|
403 |
-
x = rearrange(x, 'b n d -> b d n')
|
404 |
-
x = self.depthwise_conv(x)
|
405 |
-
x = rearrange(x, 'b d n -> b n d')
|
406 |
-
x = self.mid_norm(x)
|
407 |
-
x = self.swish(x)
|
408 |
-
x = rearrange(x, 'b n d -> b d n')
|
409 |
-
x = self.pointwise_conv_2(x)
|
410 |
-
x = rearrange(x, 'b d n -> b n d')
|
411 |
-
|
412 |
-
return x
|
413 |
-
|
414 |
-
class TransformerBlock(nn.Module):
|
415 |
-
def __init__(
|
416 |
-
self,
|
417 |
-
dim,
|
418 |
-
dim_heads = 64,
|
419 |
-
cross_attend = False,
|
420 |
-
dim_context = None,
|
421 |
-
global_cond_dim = None,
|
422 |
-
causal = False,
|
423 |
-
zero_init_branch_outputs = True,
|
424 |
-
conformer = False,
|
425 |
-
layer_ix = -1,
|
426 |
-
remove_norms = False,
|
427 |
-
attn_kwargs = {},
|
428 |
-
ff_kwargs = {},
|
429 |
-
norm_kwargs = {},
|
430 |
-
dtype=None,
|
431 |
-
device=None,
|
432 |
-
operations=None,
|
433 |
-
):
|
434 |
-
|
435 |
-
super().__init__()
|
436 |
-
self.dim = dim
|
437 |
-
self.dim_heads = dim_heads
|
438 |
-
self.cross_attend = cross_attend
|
439 |
-
self.dim_context = dim_context
|
440 |
-
self.causal = causal
|
441 |
-
|
442 |
-
self.pre_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
|
443 |
-
|
444 |
-
self.self_attn = Attention(
|
445 |
-
dim,
|
446 |
-
dim_heads = dim_heads,
|
447 |
-
causal = causal,
|
448 |
-
zero_init_output=zero_init_branch_outputs,
|
449 |
-
dtype=dtype,
|
450 |
-
device=device,
|
451 |
-
operations=operations,
|
452 |
-
**attn_kwargs
|
453 |
-
)
|
454 |
-
|
455 |
-
if cross_attend:
|
456 |
-
self.cross_attend_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
|
457 |
-
self.cross_attn = Attention(
|
458 |
-
dim,
|
459 |
-
dim_heads = dim_heads,
|
460 |
-
dim_context=dim_context,
|
461 |
-
causal = causal,
|
462 |
-
zero_init_output=zero_init_branch_outputs,
|
463 |
-
dtype=dtype,
|
464 |
-
device=device,
|
465 |
-
operations=operations,
|
466 |
-
**attn_kwargs
|
467 |
-
)
|
468 |
-
|
469 |
-
self.ff_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
|
470 |
-
self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, dtype=dtype, device=device, operations=operations,**ff_kwargs)
|
471 |
-
|
472 |
-
self.layer_ix = layer_ix
|
473 |
-
|
474 |
-
self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None
|
475 |
-
|
476 |
-
self.global_cond_dim = global_cond_dim
|
477 |
-
|
478 |
-
if global_cond_dim is not None:
|
479 |
-
self.to_scale_shift_gate = nn.Sequential(
|
480 |
-
nn.SiLU(),
|
481 |
-
nn.Linear(global_cond_dim, dim * 6, bias=False)
|
482 |
-
)
|
483 |
-
|
484 |
-
nn.init.zeros_(self.to_scale_shift_gate[1].weight)
|
485 |
-
#nn.init.zeros_(self.to_scale_shift_gate_self[1].bias)
|
486 |
-
|
487 |
-
def forward(
|
488 |
-
self,
|
489 |
-
x,
|
490 |
-
context = None,
|
491 |
-
global_cond=None,
|
492 |
-
mask = None,
|
493 |
-
context_mask = None,
|
494 |
-
rotary_pos_emb = None
|
495 |
-
):
|
496 |
-
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
|
497 |
-
|
498 |
-
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1)
|
499 |
-
|
500 |
-
# self-attention with adaLN
|
501 |
-
residual = x
|
502 |
-
x = self.pre_norm(x)
|
503 |
-
x = x * (1 + scale_self) + shift_self
|
504 |
-
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
|
505 |
-
x = x * torch.sigmoid(1 - gate_self)
|
506 |
-
x = x + residual
|
507 |
-
|
508 |
-
if context is not None:
|
509 |
-
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
|
510 |
-
|
511 |
-
if self.conformer is not None:
|
512 |
-
x = x + self.conformer(x)
|
513 |
-
|
514 |
-
# feedforward with adaLN
|
515 |
-
residual = x
|
516 |
-
x = self.ff_norm(x)
|
517 |
-
x = x * (1 + scale_ff) + shift_ff
|
518 |
-
x = self.ff(x)
|
519 |
-
x = x * torch.sigmoid(1 - gate_ff)
|
520 |
-
x = x + residual
|
521 |
-
|
522 |
-
else:
|
523 |
-
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
|
524 |
-
|
525 |
-
if context is not None:
|
526 |
-
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
|
527 |
-
|
528 |
-
if self.conformer is not None:
|
529 |
-
x = x + self.conformer(x)
|
530 |
-
|
531 |
-
x = x + self.ff(self.ff_norm(x))
|
532 |
-
|
533 |
-
return x
|
534 |
-
|
535 |
-
class ContinuousTransformer(nn.Module):
|
536 |
-
def __init__(
|
537 |
-
self,
|
538 |
-
dim,
|
539 |
-
depth,
|
540 |
-
*,
|
541 |
-
dim_in = None,
|
542 |
-
dim_out = None,
|
543 |
-
dim_heads = 64,
|
544 |
-
cross_attend=False,
|
545 |
-
cond_token_dim=None,
|
546 |
-
global_cond_dim=None,
|
547 |
-
causal=False,
|
548 |
-
rotary_pos_emb=True,
|
549 |
-
zero_init_branch_outputs=True,
|
550 |
-
conformer=False,
|
551 |
-
use_sinusoidal_emb=False,
|
552 |
-
use_abs_pos_emb=False,
|
553 |
-
abs_pos_emb_max_length=10000,
|
554 |
-
dtype=None,
|
555 |
-
device=None,
|
556 |
-
operations=None,
|
557 |
-
**kwargs
|
558 |
-
):
|
559 |
-
|
560 |
-
super().__init__()
|
561 |
-
|
562 |
-
self.dim = dim
|
563 |
-
self.depth = depth
|
564 |
-
self.causal = causal
|
565 |
-
self.layers = nn.ModuleList([])
|
566 |
-
|
567 |
-
self.project_in = operations.Linear(dim_in, dim, bias=False, dtype=dtype, device=device) if dim_in is not None else nn.Identity()
|
568 |
-
self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity()
|
569 |
-
|
570 |
-
if rotary_pos_emb:
|
571 |
-
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
|
572 |
-
else:
|
573 |
-
self.rotary_pos_emb = None
|
574 |
-
|
575 |
-
self.use_sinusoidal_emb = use_sinusoidal_emb
|
576 |
-
if use_sinusoidal_emb:
|
577 |
-
self.pos_emb = ScaledSinusoidalEmbedding(dim)
|
578 |
-
|
579 |
-
self.use_abs_pos_emb = use_abs_pos_emb
|
580 |
-
if use_abs_pos_emb:
|
581 |
-
self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length)
|
582 |
-
|
583 |
-
for i in range(depth):
|
584 |
-
self.layers.append(
|
585 |
-
TransformerBlock(
|
586 |
-
dim,
|
587 |
-
dim_heads = dim_heads,
|
588 |
-
cross_attend = cross_attend,
|
589 |
-
dim_context = cond_token_dim,
|
590 |
-
global_cond_dim = global_cond_dim,
|
591 |
-
causal = causal,
|
592 |
-
zero_init_branch_outputs = zero_init_branch_outputs,
|
593 |
-
conformer=conformer,
|
594 |
-
layer_ix=i,
|
595 |
-
dtype=dtype,
|
596 |
-
device=device,
|
597 |
-
operations=operations,
|
598 |
-
**kwargs
|
599 |
-
)
|
600 |
-
)
|
601 |
-
|
602 |
-
def forward(
|
603 |
-
self,
|
604 |
-
x,
|
605 |
-
mask = None,
|
606 |
-
prepend_embeds = None,
|
607 |
-
prepend_mask = None,
|
608 |
-
global_cond = None,
|
609 |
-
return_info = False,
|
610 |
-
**kwargs
|
611 |
-
):
|
612 |
-
batch, seq, device = *x.shape[:2], x.device
|
613 |
-
|
614 |
-
info = {
|
615 |
-
"hidden_states": [],
|
616 |
-
}
|
617 |
-
|
618 |
-
x = self.project_in(x)
|
619 |
-
|
620 |
-
if prepend_embeds is not None:
|
621 |
-
prepend_length, prepend_dim = prepend_embeds.shape[1:]
|
622 |
-
|
623 |
-
assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
|
624 |
-
|
625 |
-
x = torch.cat((prepend_embeds, x), dim = -2)
|
626 |
-
|
627 |
-
if prepend_mask is not None or mask is not None:
|
628 |
-
mask = mask if mask is not None else torch.ones((batch, seq), device = device, dtype = torch.bool)
|
629 |
-
prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device = device, dtype = torch.bool)
|
630 |
-
|
631 |
-
mask = torch.cat((prepend_mask, mask), dim = -1)
|
632 |
-
|
633 |
-
# Attention layers
|
634 |
-
|
635 |
-
if self.rotary_pos_emb is not None:
|
636 |
-
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device)
|
637 |
-
else:
|
638 |
-
rotary_pos_emb = None
|
639 |
-
|
640 |
-
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
|
641 |
-
x = x + self.pos_emb(x)
|
642 |
-
|
643 |
-
# Iterate over the transformer layers
|
644 |
-
for layer in self.layers:
|
645 |
-
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
646 |
-
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
647 |
-
|
648 |
-
if return_info:
|
649 |
-
info["hidden_states"].append(x)
|
650 |
-
|
651 |
-
x = self.project_out(x)
|
652 |
-
|
653 |
-
if return_info:
|
654 |
-
return x, info
|
655 |
-
|
656 |
-
return x
|
657 |
-
|
658 |
-
class AudioDiffusionTransformer(nn.Module):
|
659 |
-
def __init__(self,
|
660 |
-
io_channels=64,
|
661 |
-
patch_size=1,
|
662 |
-
embed_dim=1536,
|
663 |
-
cond_token_dim=768,
|
664 |
-
project_cond_tokens=False,
|
665 |
-
global_cond_dim=1536,
|
666 |
-
project_global_cond=True,
|
667 |
-
input_concat_dim=0,
|
668 |
-
prepend_cond_dim=0,
|
669 |
-
depth=24,
|
670 |
-
num_heads=24,
|
671 |
-
transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer",
|
672 |
-
global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
|
673 |
-
audio_model="",
|
674 |
-
dtype=None,
|
675 |
-
device=None,
|
676 |
-
operations=None,
|
677 |
-
**kwargs):
|
678 |
-
|
679 |
-
super().__init__()
|
680 |
-
|
681 |
-
self.dtype = dtype
|
682 |
-
self.cond_token_dim = cond_token_dim
|
683 |
-
|
684 |
-
# Timestep embeddings
|
685 |
-
timestep_features_dim = 256
|
686 |
-
|
687 |
-
self.timestep_features = FourierFeatures(1, timestep_features_dim, dtype=dtype, device=device)
|
688 |
-
|
689 |
-
self.to_timestep_embed = nn.Sequential(
|
690 |
-
operations.Linear(timestep_features_dim, embed_dim, bias=True, dtype=dtype, device=device),
|
691 |
-
nn.SiLU(),
|
692 |
-
operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device),
|
693 |
-
)
|
694 |
-
|
695 |
-
if cond_token_dim > 0:
|
696 |
-
# Conditioning tokens
|
697 |
-
|
698 |
-
cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
|
699 |
-
self.to_cond_embed = nn.Sequential(
|
700 |
-
operations.Linear(cond_token_dim, cond_embed_dim, bias=False, dtype=dtype, device=device),
|
701 |
-
nn.SiLU(),
|
702 |
-
operations.Linear(cond_embed_dim, cond_embed_dim, bias=False, dtype=dtype, device=device)
|
703 |
-
)
|
704 |
-
else:
|
705 |
-
cond_embed_dim = 0
|
706 |
-
|
707 |
-
if global_cond_dim > 0:
|
708 |
-
# Global conditioning
|
709 |
-
global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
|
710 |
-
self.to_global_embed = nn.Sequential(
|
711 |
-
operations.Linear(global_cond_dim, global_embed_dim, bias=False, dtype=dtype, device=device),
|
712 |
-
nn.SiLU(),
|
713 |
-
operations.Linear(global_embed_dim, global_embed_dim, bias=False, dtype=dtype, device=device)
|
714 |
-
)
|
715 |
-
|
716 |
-
if prepend_cond_dim > 0:
|
717 |
-
# Prepend conditioning
|
718 |
-
self.to_prepend_embed = nn.Sequential(
|
719 |
-
operations.Linear(prepend_cond_dim, embed_dim, bias=False, dtype=dtype, device=device),
|
720 |
-
nn.SiLU(),
|
721 |
-
operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
|
722 |
-
)
|
723 |
-
|
724 |
-
self.input_concat_dim = input_concat_dim
|
725 |
-
|
726 |
-
dim_in = io_channels + self.input_concat_dim
|
727 |
-
|
728 |
-
self.patch_size = patch_size
|
729 |
-
|
730 |
-
# Transformer
|
731 |
-
|
732 |
-
self.transformer_type = transformer_type
|
733 |
-
|
734 |
-
self.global_cond_type = global_cond_type
|
735 |
-
|
736 |
-
if self.transformer_type == "continuous_transformer":
|
737 |
-
|
738 |
-
global_dim = None
|
739 |
-
|
740 |
-
if self.global_cond_type == "adaLN":
|
741 |
-
# The global conditioning is projected to the embed_dim already at this point
|
742 |
-
global_dim = embed_dim
|
743 |
-
|
744 |
-
self.transformer = ContinuousTransformer(
|
745 |
-
dim=embed_dim,
|
746 |
-
depth=depth,
|
747 |
-
dim_heads=embed_dim // num_heads,
|
748 |
-
dim_in=dim_in * patch_size,
|
749 |
-
dim_out=io_channels * patch_size,
|
750 |
-
cross_attend = cond_token_dim > 0,
|
751 |
-
cond_token_dim = cond_embed_dim,
|
752 |
-
global_cond_dim=global_dim,
|
753 |
-
dtype=dtype,
|
754 |
-
device=device,
|
755 |
-
operations=operations,
|
756 |
-
**kwargs
|
757 |
-
)
|
758 |
-
else:
|
759 |
-
raise ValueError(f"Unknown transformer type: {self.transformer_type}")
|
760 |
-
|
761 |
-
self.preprocess_conv = operations.Conv1d(dim_in, dim_in, 1, bias=False, dtype=dtype, device=device)
|
762 |
-
self.postprocess_conv = operations.Conv1d(io_channels, io_channels, 1, bias=False, dtype=dtype, device=device)
|
763 |
-
|
764 |
-
def _forward(
|
765 |
-
self,
|
766 |
-
x,
|
767 |
-
t,
|
768 |
-
mask=None,
|
769 |
-
cross_attn_cond=None,
|
770 |
-
cross_attn_cond_mask=None,
|
771 |
-
input_concat_cond=None,
|
772 |
-
global_embed=None,
|
773 |
-
prepend_cond=None,
|
774 |
-
prepend_cond_mask=None,
|
775 |
-
return_info=False,
|
776 |
-
**kwargs):
|
777 |
-
|
778 |
-
if cross_attn_cond is not None:
|
779 |
-
cross_attn_cond = self.to_cond_embed(cross_attn_cond)
|
780 |
-
|
781 |
-
if global_embed is not None:
|
782 |
-
# Project the global conditioning to the embedding dimension
|
783 |
-
global_embed = self.to_global_embed(global_embed)
|
784 |
-
|
785 |
-
prepend_inputs = None
|
786 |
-
prepend_mask = None
|
787 |
-
prepend_length = 0
|
788 |
-
if prepend_cond is not None:
|
789 |
-
# Project the prepend conditioning to the embedding dimension
|
790 |
-
prepend_cond = self.to_prepend_embed(prepend_cond)
|
791 |
-
|
792 |
-
prepend_inputs = prepend_cond
|
793 |
-
if prepend_cond_mask is not None:
|
794 |
-
prepend_mask = prepend_cond_mask
|
795 |
-
|
796 |
-
if input_concat_cond is not None:
|
797 |
-
|
798 |
-
# Interpolate input_concat_cond to the same length as x
|
799 |
-
if input_concat_cond.shape[2] != x.shape[2]:
|
800 |
-
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
|
801 |
-
|
802 |
-
x = torch.cat([x, input_concat_cond], dim=1)
|
803 |
-
|
804 |
-
# Get the batch of timestep embeddings
|
805 |
-
timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None]).to(x.dtype)) # (b, embed_dim)
|
806 |
-
|
807 |
-
# Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
|
808 |
-
if global_embed is not None:
|
809 |
-
global_embed = global_embed + timestep_embed
|
810 |
-
else:
|
811 |
-
global_embed = timestep_embed
|
812 |
-
|
813 |
-
# Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
|
814 |
-
if self.global_cond_type == "prepend":
|
815 |
-
if prepend_inputs is None:
|
816 |
-
# Prepend inputs are just the global embed, and the mask is all ones
|
817 |
-
prepend_inputs = global_embed.unsqueeze(1)
|
818 |
-
prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
|
819 |
-
else:
|
820 |
-
# Prepend inputs are the prepend conditioning + the global embed
|
821 |
-
prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
|
822 |
-
prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1)
|
823 |
-
|
824 |
-
prepend_length = prepend_inputs.shape[1]
|
825 |
-
|
826 |
-
x = self.preprocess_conv(x) + x
|
827 |
-
|
828 |
-
x = rearrange(x, "b c t -> b t c")
|
829 |
-
|
830 |
-
extra_args = {}
|
831 |
-
|
832 |
-
if self.global_cond_type == "adaLN":
|
833 |
-
extra_args["global_cond"] = global_embed
|
834 |
-
|
835 |
-
if self.patch_size > 1:
|
836 |
-
x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
|
837 |
-
|
838 |
-
if self.transformer_type == "x-transformers":
|
839 |
-
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs)
|
840 |
-
elif self.transformer_type == "continuous_transformer":
|
841 |
-
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs)
|
842 |
-
|
843 |
-
if return_info:
|
844 |
-
output, info = output
|
845 |
-
elif self.transformer_type == "mm_transformer":
|
846 |
-
output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, **extra_args, **kwargs)
|
847 |
-
|
848 |
-
output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:]
|
849 |
-
|
850 |
-
if self.patch_size > 1:
|
851 |
-
output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
|
852 |
-
|
853 |
-
output = self.postprocess_conv(output) + output
|
854 |
-
|
855 |
-
if return_info:
|
856 |
-
return output, info
|
857 |
-
|
858 |
-
return output
|
859 |
-
|
860 |
-
def forward(
|
861 |
-
self,
|
862 |
-
x,
|
863 |
-
timestep,
|
864 |
-
context=None,
|
865 |
-
context_mask=None,
|
866 |
-
input_concat_cond=None,
|
867 |
-
global_embed=None,
|
868 |
-
negative_global_embed=None,
|
869 |
-
prepend_cond=None,
|
870 |
-
prepend_cond_mask=None,
|
871 |
-
mask=None,
|
872 |
-
return_info=False,
|
873 |
-
control=None,
|
874 |
-
transformer_options={},
|
875 |
-
**kwargs):
|
876 |
-
return self._forward(
|
877 |
-
x,
|
878 |
-
timestep,
|
879 |
-
cross_attn_cond=context,
|
880 |
-
cross_attn_cond_mask=context_mask,
|
881 |
-
input_concat_cond=input_concat_cond,
|
882 |
-
global_embed=global_embed,
|
883 |
-
prepend_cond=prepend_cond,
|
884 |
-
prepend_cond_mask=prepend_cond_mask,
|
885 |
-
mask=mask,
|
886 |
-
return_info=return_info,
|
887 |
-
**kwargs
|
888 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/comfy/ldm/audio/embedders.py
DELETED
@@ -1,108 +0,0 @@
|
|
1 |
-
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import torch.nn as nn
|
5 |
-
from torch import Tensor, einsum
|
6 |
-
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
|
7 |
-
from einops import rearrange
|
8 |
-
import math
|
9 |
-
import comfy.ops
|
10 |
-
|
11 |
-
class LearnedPositionalEmbedding(nn.Module):
|
12 |
-
"""Used for continuous time"""
|
13 |
-
|
14 |
-
def __init__(self, dim: int):
|
15 |
-
super().__init__()
|
16 |
-
assert (dim % 2) == 0
|
17 |
-
half_dim = dim // 2
|
18 |
-
self.weights = nn.Parameter(torch.empty(half_dim))
|
19 |
-
|
20 |
-
def forward(self, x: Tensor) -> Tensor:
|
21 |
-
x = rearrange(x, "b -> b 1")
|
22 |
-
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi
|
23 |
-
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
|
24 |
-
fouriered = torch.cat((x, fouriered), dim=-1)
|
25 |
-
return fouriered
|
26 |
-
|
27 |
-
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
|
28 |
-
return nn.Sequential(
|
29 |
-
LearnedPositionalEmbedding(dim),
|
30 |
-
comfy.ops.manual_cast.Linear(in_features=dim + 1, out_features=out_features),
|
31 |
-
)
|
32 |
-
|
33 |
-
|
34 |
-
class NumberEmbedder(nn.Module):
|
35 |
-
def __init__(
|
36 |
-
self,
|
37 |
-
features: int,
|
38 |
-
dim: int = 256,
|
39 |
-
):
|
40 |
-
super().__init__()
|
41 |
-
self.features = features
|
42 |
-
self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
|
43 |
-
|
44 |
-
def forward(self, x: Union[List[float], Tensor]) -> Tensor:
|
45 |
-
if not torch.is_tensor(x):
|
46 |
-
device = next(self.embedding.parameters()).device
|
47 |
-
x = torch.tensor(x, device=device)
|
48 |
-
assert isinstance(x, Tensor)
|
49 |
-
shape = x.shape
|
50 |
-
x = rearrange(x, "... -> (...)")
|
51 |
-
embedding = self.embedding(x)
|
52 |
-
x = embedding.view(*shape, self.features)
|
53 |
-
return x # type: ignore
|
54 |
-
|
55 |
-
|
56 |
-
class Conditioner(nn.Module):
|
57 |
-
def __init__(
|
58 |
-
self,
|
59 |
-
dim: int,
|
60 |
-
output_dim: int,
|
61 |
-
project_out: bool = False
|
62 |
-
):
|
63 |
-
|
64 |
-
super().__init__()
|
65 |
-
|
66 |
-
self.dim = dim
|
67 |
-
self.output_dim = output_dim
|
68 |
-
self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity()
|
69 |
-
|
70 |
-
def forward(self, x):
|
71 |
-
raise NotImplementedError()
|
72 |
-
|
73 |
-
class NumberConditioner(Conditioner):
|
74 |
-
'''
|
75 |
-
Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings
|
76 |
-
'''
|
77 |
-
def __init__(self,
|
78 |
-
output_dim: int,
|
79 |
-
min_val: float=0,
|
80 |
-
max_val: float=1
|
81 |
-
):
|
82 |
-
super().__init__(output_dim, output_dim)
|
83 |
-
|
84 |
-
self.min_val = min_val
|
85 |
-
self.max_val = max_val
|
86 |
-
|
87 |
-
self.embedder = NumberEmbedder(features=output_dim)
|
88 |
-
|
89 |
-
def forward(self, floats, device=None):
|
90 |
-
# Cast the inputs to floats
|
91 |
-
floats = [float(x) for x in floats]
|
92 |
-
|
93 |
-
if device is None:
|
94 |
-
device = next(self.embedder.parameters()).device
|
95 |
-
|
96 |
-
floats = torch.tensor(floats).to(device)
|
97 |
-
|
98 |
-
floats = floats.clamp(self.min_val, self.max_val)
|
99 |
-
|
100 |
-
normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val)
|
101 |
-
|
102 |
-
# Cast floats to same type as embedder
|
103 |
-
embedder_dtype = next(self.embedder.parameters()).dtype
|
104 |
-
normalized_floats = normalized_floats.to(embedder_dtype)
|
105 |
-
|
106 |
-
float_embeds = self.embedder(normalized_floats).unsqueeze(1)
|
107 |
-
|
108 |
-
return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/comfy/ldm/cascade/__pycache__/common.cpython-310.pyc
DELETED
Binary file (7.69 kB)
|
|
MagicQuill/comfy/ldm/cascade/__pycache__/controlnet.cpython-310.pyc
DELETED
Binary file (3.77 kB)
|
|
MagicQuill/comfy/ldm/cascade/__pycache__/stage_a.cpython-310.pyc
DELETED
Binary file (9.41 kB)
|
|
MagicQuill/comfy/ldm/cascade/__pycache__/stage_b.cpython-310.pyc
DELETED
Binary file (7.77 kB)
|
|
MagicQuill/comfy/ldm/cascade/__pycache__/stage_c.cpython-310.pyc
DELETED
Binary file (8.58 kB)
|
|
MagicQuill/comfy/ldm/cascade/__pycache__/stage_c_coder.cpython-310.pyc
DELETED
Binary file (3.5 kB)
|
|
MagicQuill/comfy/ldm/cascade/common.py
DELETED
@@ -1,161 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
This file is part of ComfyUI.
|
3 |
-
Copyright (C) 2024 Stability AI
|
4 |
-
|
5 |
-
This program is free software: you can redistribute it and/or modify
|
6 |
-
it under the terms of the GNU General Public License as published by
|
7 |
-
the Free Software Foundation, either version 3 of the License, or
|
8 |
-
(at your option) any later version.
|
9 |
-
|
10 |
-
This program is distributed in the hope that it will be useful,
|
11 |
-
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
12 |
-
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
13 |
-
GNU General Public License for more details.
|
14 |
-
|
15 |
-
You should have received a copy of the GNU General Public License
|
16 |
-
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
17 |
-
"""
|
18 |
-
|
19 |
-
import torch
|
20 |
-
import torch.nn as nn
|
21 |
-
from comfy.ldm.modules.attention import optimized_attention
|
22 |
-
|
23 |
-
class Linear(torch.nn.Linear):
|
24 |
-
def reset_parameters(self):
|
25 |
-
return None
|
26 |
-
|
27 |
-
class Conv2d(torch.nn.Conv2d):
|
28 |
-
def reset_parameters(self):
|
29 |
-
return None
|
30 |
-
|
31 |
-
class OptimizedAttention(nn.Module):
|
32 |
-
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
|
33 |
-
super().__init__()
|
34 |
-
self.heads = nhead
|
35 |
-
|
36 |
-
self.to_q = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
37 |
-
self.to_k = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
38 |
-
self.to_v = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
39 |
-
|
40 |
-
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
41 |
-
|
42 |
-
def forward(self, q, k, v):
|
43 |
-
q = self.to_q(q)
|
44 |
-
k = self.to_k(k)
|
45 |
-
v = self.to_v(v)
|
46 |
-
|
47 |
-
out = optimized_attention(q, k, v, self.heads)
|
48 |
-
|
49 |
-
return self.out_proj(out)
|
50 |
-
|
51 |
-
class Attention2D(nn.Module):
|
52 |
-
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
|
53 |
-
super().__init__()
|
54 |
-
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
|
55 |
-
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
|
56 |
-
|
57 |
-
def forward(self, x, kv, self_attn=False):
|
58 |
-
orig_shape = x.shape
|
59 |
-
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
|
60 |
-
if self_attn:
|
61 |
-
kv = torch.cat([x, kv], dim=1)
|
62 |
-
# x = self.attn(x, kv, kv, need_weights=False)[0]
|
63 |
-
x = self.attn(x, kv, kv)
|
64 |
-
x = x.permute(0, 2, 1).view(*orig_shape)
|
65 |
-
return x
|
66 |
-
|
67 |
-
|
68 |
-
def LayerNorm2d_op(operations):
|
69 |
-
class LayerNorm2d(operations.LayerNorm):
|
70 |
-
def __init__(self, *args, **kwargs):
|
71 |
-
super().__init__(*args, **kwargs)
|
72 |
-
|
73 |
-
def forward(self, x):
|
74 |
-
return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
75 |
-
return LayerNorm2d
|
76 |
-
|
77 |
-
class GlobalResponseNorm(nn.Module):
|
78 |
-
"from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
|
79 |
-
def __init__(self, dim, dtype=None, device=None):
|
80 |
-
super().__init__()
|
81 |
-
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
|
82 |
-
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
|
83 |
-
|
84 |
-
def forward(self, x):
|
85 |
-
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
86 |
-
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
87 |
-
return self.gamma.to(device=x.device, dtype=x.dtype) * (x * Nx) + self.beta.to(device=x.device, dtype=x.dtype) + x
|
88 |
-
|
89 |
-
|
90 |
-
class ResBlock(nn.Module):
|
91 |
-
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0, dtype=None, device=None, operations=None): # , num_heads=4, expansion=2):
|
92 |
-
super().__init__()
|
93 |
-
self.depthwise = operations.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c, dtype=dtype, device=device)
|
94 |
-
# self.depthwise = SAMBlock(c, num_heads, expansion)
|
95 |
-
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
96 |
-
self.channelwise = nn.Sequential(
|
97 |
-
operations.Linear(c + c_skip, c * 4, dtype=dtype, device=device),
|
98 |
-
nn.GELU(),
|
99 |
-
GlobalResponseNorm(c * 4, dtype=dtype, device=device),
|
100 |
-
nn.Dropout(dropout),
|
101 |
-
operations.Linear(c * 4, c, dtype=dtype, device=device)
|
102 |
-
)
|
103 |
-
|
104 |
-
def forward(self, x, x_skip=None):
|
105 |
-
x_res = x
|
106 |
-
x = self.norm(self.depthwise(x))
|
107 |
-
if x_skip is not None:
|
108 |
-
x = torch.cat([x, x_skip], dim=1)
|
109 |
-
x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
110 |
-
return x + x_res
|
111 |
-
|
112 |
-
|
113 |
-
class AttnBlock(nn.Module):
|
114 |
-
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, dtype=None, device=None, operations=None):
|
115 |
-
super().__init__()
|
116 |
-
self.self_attn = self_attn
|
117 |
-
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
118 |
-
self.attention = Attention2D(c, nhead, dropout, dtype=dtype, device=device, operations=operations)
|
119 |
-
self.kv_mapper = nn.Sequential(
|
120 |
-
nn.SiLU(),
|
121 |
-
operations.Linear(c_cond, c, dtype=dtype, device=device)
|
122 |
-
)
|
123 |
-
|
124 |
-
def forward(self, x, kv):
|
125 |
-
kv = self.kv_mapper(kv)
|
126 |
-
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
|
127 |
-
return x
|
128 |
-
|
129 |
-
|
130 |
-
class FeedForwardBlock(nn.Module):
|
131 |
-
def __init__(self, c, dropout=0.0, dtype=None, device=None, operations=None):
|
132 |
-
super().__init__()
|
133 |
-
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
134 |
-
self.channelwise = nn.Sequential(
|
135 |
-
operations.Linear(c, c * 4, dtype=dtype, device=device),
|
136 |
-
nn.GELU(),
|
137 |
-
GlobalResponseNorm(c * 4, dtype=dtype, device=device),
|
138 |
-
nn.Dropout(dropout),
|
139 |
-
operations.Linear(c * 4, c, dtype=dtype, device=device)
|
140 |
-
)
|
141 |
-
|
142 |
-
def forward(self, x):
|
143 |
-
x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
144 |
-
return x
|
145 |
-
|
146 |
-
|
147 |
-
class TimestepBlock(nn.Module):
|
148 |
-
def __init__(self, c, c_timestep, conds=['sca'], dtype=None, device=None, operations=None):
|
149 |
-
super().__init__()
|
150 |
-
self.mapper = operations.Linear(c_timestep, c * 2, dtype=dtype, device=device)
|
151 |
-
self.conds = conds
|
152 |
-
for cname in conds:
|
153 |
-
setattr(self, f"mapper_{cname}", operations.Linear(c_timestep, c * 2, dtype=dtype, device=device))
|
154 |
-
|
155 |
-
def forward(self, x, t):
|
156 |
-
t = t.chunk(len(self.conds) + 1, dim=1)
|
157 |
-
a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
|
158 |
-
for i, c in enumerate(self.conds):
|
159 |
-
ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
|
160 |
-
a, b = a + ac, b + bc
|
161 |
-
return x * (1 + a) + b
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/comfy/ldm/cascade/controlnet.py
DELETED
@@ -1,93 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
This file is part of ComfyUI.
|
3 |
-
Copyright (C) 2024 Stability AI
|
4 |
-
|
5 |
-
This program is free software: you can redistribute it and/or modify
|
6 |
-
it under the terms of the GNU General Public License as published by
|
7 |
-
the Free Software Foundation, either version 3 of the License, or
|
8 |
-
(at your option) any later version.
|
9 |
-
|
10 |
-
This program is distributed in the hope that it will be useful,
|
11 |
-
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
12 |
-
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
13 |
-
GNU General Public License for more details.
|
14 |
-
|
15 |
-
You should have received a copy of the GNU General Public License
|
16 |
-
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
17 |
-
"""
|
18 |
-
|
19 |
-
import torch
|
20 |
-
import torchvision
|
21 |
-
from torch import nn
|
22 |
-
from .common import LayerNorm2d_op
|
23 |
-
|
24 |
-
|
25 |
-
class CNetResBlock(nn.Module):
|
26 |
-
def __init__(self, c, dtype=None, device=None, operations=None):
|
27 |
-
super().__init__()
|
28 |
-
self.blocks = nn.Sequential(
|
29 |
-
LayerNorm2d_op(operations)(c, dtype=dtype, device=device),
|
30 |
-
nn.GELU(),
|
31 |
-
operations.Conv2d(c, c, kernel_size=3, padding=1),
|
32 |
-
LayerNorm2d_op(operations)(c, dtype=dtype, device=device),
|
33 |
-
nn.GELU(),
|
34 |
-
operations.Conv2d(c, c, kernel_size=3, padding=1),
|
35 |
-
)
|
36 |
-
|
37 |
-
def forward(self, x):
|
38 |
-
return x + self.blocks(x)
|
39 |
-
|
40 |
-
|
41 |
-
class ControlNet(nn.Module):
|
42 |
-
def __init__(self, c_in=3, c_proj=2048, proj_blocks=None, bottleneck_mode=None, dtype=None, device=None, operations=nn):
|
43 |
-
super().__init__()
|
44 |
-
if bottleneck_mode is None:
|
45 |
-
bottleneck_mode = 'effnet'
|
46 |
-
self.proj_blocks = proj_blocks
|
47 |
-
if bottleneck_mode == 'effnet':
|
48 |
-
embd_channels = 1280
|
49 |
-
self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
|
50 |
-
if c_in != 3:
|
51 |
-
in_weights = self.backbone[0][0].weight.data
|
52 |
-
self.backbone[0][0] = operations.Conv2d(c_in, 24, kernel_size=3, stride=2, bias=False, dtype=dtype, device=device)
|
53 |
-
if c_in > 3:
|
54 |
-
# nn.init.constant_(self.backbone[0][0].weight, 0)
|
55 |
-
self.backbone[0][0].weight.data[:, :3] = in_weights[:, :3].clone()
|
56 |
-
else:
|
57 |
-
self.backbone[0][0].weight.data = in_weights[:, :c_in].clone()
|
58 |
-
elif bottleneck_mode == 'simple':
|
59 |
-
embd_channels = c_in
|
60 |
-
self.backbone = nn.Sequential(
|
61 |
-
operations.Conv2d(embd_channels, embd_channels * 4, kernel_size=3, padding=1, dtype=dtype, device=device),
|
62 |
-
nn.LeakyReLU(0.2, inplace=True),
|
63 |
-
operations.Conv2d(embd_channels * 4, embd_channels, kernel_size=3, padding=1, dtype=dtype, device=device),
|
64 |
-
)
|
65 |
-
elif bottleneck_mode == 'large':
|
66 |
-
self.backbone = nn.Sequential(
|
67 |
-
operations.Conv2d(c_in, 4096 * 4, kernel_size=1, dtype=dtype, device=device),
|
68 |
-
nn.LeakyReLU(0.2, inplace=True),
|
69 |
-
operations.Conv2d(4096 * 4, 1024, kernel_size=1, dtype=dtype, device=device),
|
70 |
-
*[CNetResBlock(1024, dtype=dtype, device=device, operations=operations) for _ in range(8)],
|
71 |
-
operations.Conv2d(1024, 1280, kernel_size=1, dtype=dtype, device=device),
|
72 |
-
)
|
73 |
-
embd_channels = 1280
|
74 |
-
else:
|
75 |
-
raise ValueError(f'Unknown bottleneck mode: {bottleneck_mode}')
|
76 |
-
self.projections = nn.ModuleList()
|
77 |
-
for _ in range(len(proj_blocks)):
|
78 |
-
self.projections.append(nn.Sequential(
|
79 |
-
operations.Conv2d(embd_channels, embd_channels, kernel_size=1, bias=False, dtype=dtype, device=device),
|
80 |
-
nn.LeakyReLU(0.2, inplace=True),
|
81 |
-
operations.Conv2d(embd_channels, c_proj, kernel_size=1, bias=False, dtype=dtype, device=device),
|
82 |
-
))
|
83 |
-
# nn.init.constant_(self.projections[-1][-1].weight, 0) # zero output projection
|
84 |
-
self.xl = False
|
85 |
-
self.input_channels = c_in
|
86 |
-
self.unshuffle_amount = 8
|
87 |
-
|
88 |
-
def forward(self, x):
|
89 |
-
x = self.backbone(x)
|
90 |
-
proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)]
|
91 |
-
for i, idx in enumerate(self.proj_blocks):
|
92 |
-
proj_outputs[idx] = self.projections[i](x)
|
93 |
-
return proj_outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MagicQuill/comfy/ldm/cascade/stage_a.py
DELETED
@@ -1,255 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
This file is part of ComfyUI.
|
3 |
-
Copyright (C) 2024 Stability AI
|
4 |
-
|
5 |
-
This program is free software: you can redistribute it and/or modify
|
6 |
-
it under the terms of the GNU General Public License as published by
|
7 |
-
the Free Software Foundation, either version 3 of the License, or
|
8 |
-
(at your option) any later version.
|
9 |
-
|
10 |
-
This program is distributed in the hope that it will be useful,
|
11 |
-
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
12 |
-
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
13 |
-
GNU General Public License for more details.
|
14 |
-
|
15 |
-
You should have received a copy of the GNU General Public License
|
16 |
-
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
17 |
-
"""
|
18 |
-
|
19 |
-
import torch
|
20 |
-
from torch import nn
|
21 |
-
from torch.autograd import Function
|
22 |
-
|
23 |
-
class vector_quantize(Function):
|
24 |
-
@staticmethod
|
25 |
-
def forward(ctx, x, codebook):
|
26 |
-
with torch.no_grad():
|
27 |
-
codebook_sqr = torch.sum(codebook ** 2, dim=1)
|
28 |
-
x_sqr = torch.sum(x ** 2, dim=1, keepdim=True)
|
29 |
-
|
30 |
-
dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0)
|
31 |
-
_, indices = dist.min(dim=1)
|
32 |
-
|
33 |
-
ctx.save_for_backward(indices, codebook)
|
34 |
-
ctx.mark_non_differentiable(indices)
|
35 |
-
|
36 |
-
nn = torch.index_select(codebook, 0, indices)
|
37 |
-
return nn, indices
|
38 |
-
|
39 |
-
@staticmethod
|
40 |
-
def backward(ctx, grad_output, grad_indices):
|
41 |
-
grad_inputs, grad_codebook = None, None
|
42 |
-
|
43 |
-
if ctx.needs_input_grad[0]:
|
44 |
-
grad_inputs = grad_output.clone()
|
45 |
-
if ctx.needs_input_grad[1]:
|
46 |
-
# Gradient wrt. the codebook
|
47 |
-
indices, codebook = ctx.saved_tensors
|
48 |
-
|
49 |
-
grad_codebook = torch.zeros_like(codebook)
|
50 |
-
grad_codebook.index_add_(0, indices, grad_output)
|
51 |
-
|
52 |
-
return (grad_inputs, grad_codebook)
|
53 |
-
|
54 |
-
|
55 |
-
class VectorQuantize(nn.Module):
|
56 |
-
def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False):
|
57 |
-
"""
|
58 |
-
Takes an input of variable size (as long as the last dimension matches the embedding size).
|
59 |
-
Returns one tensor containing the nearest neigbour embeddings to each of the inputs,
|
60 |
-
with the same size as the input, vq and commitment components for the loss as a touple
|
61 |
-
in the second output and the indices of the quantized vectors in the third:
|
62 |
-
quantized, (vq_loss, commit_loss), indices
|
63 |
-
"""
|
64 |
-
super(VectorQuantize, self).__init__()
|
65 |
-
|
66 |
-
self.codebook = nn.Embedding(k, embedding_size)
|
67 |
-
self.codebook.weight.data.uniform_(-1./k, 1./k)
|
68 |
-
self.vq = vector_quantize.apply
|
69 |
-
|
70 |
-
self.ema_decay = ema_decay
|
71 |
-
self.ema_loss = ema_loss
|
72 |
-
if ema_loss:
|
73 |
-
self.register_buffer('ema_element_count', torch.ones(k))
|
74 |
-
self.register_buffer('ema_weight_sum', torch.zeros_like(self.codebook.weight))
|
75 |
-
|
76 |
-
def _laplace_smoothing(self, x, epsilon):
|
77 |
-
n = torch.sum(x)
|
78 |
-
return ((x + epsilon) / (n + x.size(0) * epsilon) * n)
|
79 |
-
|
80 |
-
def _updateEMA(self, z_e_x, indices):
|
81 |
-
mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
|
82 |
-
elem_count = mask.sum(dim=0)
|
83 |
-
weight_sum = torch.mm(mask.t(), z_e_x)
|
84 |
-
|
85 |
-
self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count)
|
86 |
-
self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5)
|
87 |
-
self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum)
|
88 |
-
|
89 |
-
self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
|
90 |
-
|
91 |
-
def idx2vq(self, idx, dim=-1):
|
92 |
-
q_idx = self.codebook(idx)
|
93 |
-
if dim != -1:
|
94 |
-
q_idx = q_idx.movedim(-1, dim)
|
95 |
-
return q_idx
|
96 |
-
|
97 |
-
def forward(self, x, get_losses=True, dim=-1):
|
98 |
-
if dim != -1:
|
99 |
-
x = x.movedim(dim, -1)
|
100 |
-
z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x
|
101 |
-
z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach())
|
102 |
-
vq_loss, commit_loss = None, None
|
103 |
-
if self.ema_loss and self.training:
|
104 |
-
self._updateEMA(z_e_x.detach(), indices.detach())
|
105 |
-
# pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss
|
106 |
-
z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices)
|
107 |
-
if get_losses:
|
108 |
-
vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean()
|
109 |
-
commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean()
|
110 |
-
|
111 |
-
z_q_x = z_q_x.view(x.shape)
|
112 |
-
if dim != -1:
|
113 |
-
z_q_x = z_q_x.movedim(-1, dim)
|
114 |
-
return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1])
|
115 |
-
|
116 |
-
|
117 |
-
class ResBlock(nn.Module):
|
118 |
-
def __init__(self, c, c_hidden):
|
119 |
-
super().__init__()
|
120 |
-
# depthwise/attention
|
121 |
-
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
122 |
-
self.depthwise = nn.Sequential(
|
123 |
-
nn.ReplicationPad2d(1),
|
124 |
-
nn.Conv2d(c, c, kernel_size=3, groups=c)
|
125 |
-
)
|
126 |
-
|
127 |
-
# channelwise
|
128 |
-
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
129 |
-
self.channelwise = nn.Sequential(
|
130 |
-
nn.Linear(c, c_hidden),
|
131 |
-
nn.GELU(),
|
132 |
-
nn.Linear(c_hidden, c),
|
133 |
-
)
|
134 |
-
|
135 |
-
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
|
136 |
-
|
137 |
-
# Init weights
|
138 |
-
def _basic_init(module):
|
139 |
-
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
140 |
-
torch.nn.init.xavier_uniform_(module.weight)
|
141 |
-
if module.bias is not None:
|
142 |
-
nn.init.constant_(module.bias, 0)
|
143 |
-
|
144 |
-
self.apply(_basic_init)
|
145 |
-
|
146 |
-
def _norm(self, x, norm):
|
147 |
-
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
148 |
-
|
149 |
-
def forward(self, x):
|
150 |
-
mods = self.gammas
|
151 |
-
|
152 |
-
x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1]
|
153 |
-
try:
|
154 |
-
x = x + self.depthwise(x_temp) * mods[2]
|
155 |
-
except: #operation not implemented for bf16
|
156 |
-
x_temp = self.depthwise[0](x_temp.float()).to(x.dtype)
|
157 |
-
x = x + self.depthwise[1](x_temp) * mods[2]
|
158 |
-
|
159 |
-
x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4]
|
160 |
-
x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
|
161 |
-
|
162 |
-
return x
|
163 |
-
|
164 |
-
|
165 |
-
class StageA(nn.Module):
|
166 |
-
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192):
|
167 |
-
super().__init__()
|
168 |
-
self.c_latent = c_latent
|
169 |
-
c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
|
170 |
-
|
171 |
-
# Encoder blocks
|
172 |
-
self.in_block = nn.Sequential(
|
173 |
-
nn.PixelUnshuffle(2),
|
174 |
-
nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)
|
175 |
-
)
|
176 |
-
down_blocks = []
|
177 |
-
for i in range(levels):
|
178 |
-
if i > 0:
|
179 |
-
down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
|
180 |
-
block = ResBlock(c_levels[i], c_levels[i] * 4)
|
181 |
-
down_blocks.append(block)
|
182 |
-
down_blocks.append(nn.Sequential(
|
183 |
-
nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
|
184 |
-
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
|
185 |
-
))
|
186 |
-
self.down_blocks = nn.Sequential(*down_blocks)
|
187 |
-
self.down_blocks[0]
|
188 |
-
|
189 |
-
self.codebook_size = codebook_size
|
190 |
-
self.vquantizer = VectorQuantize(c_latent, k=codebook_size)
|
191 |
-
|
192 |
-
# Decoder blocks
|
193 |
-
up_blocks = [nn.Sequential(
|
194 |
-
nn.Conv2d(c_latent, c_levels[-1], kernel_size=1)
|
195 |
-
)]
|
196 |
-
for i in range(levels):
|
197 |
-
for j in range(bottleneck_blocks if i == 0 else 1):
|
198 |
-
block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4)
|
199 |
-
up_blocks.append(block)
|
200 |
-
if i < levels - 1:
|
201 |
-
up_blocks.append(
|
202 |
-
nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
|
203 |
-
padding=1))
|
204 |
-
self.up_blocks = nn.Sequential(*up_blocks)
|
205 |
-
self.out_block = nn.Sequential(
|
206 |
-
nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
|
207 |
-
nn.PixelShuffle(2),
|
208 |
-
)
|
209 |
-
|
210 |
-
def encode(self, x, quantize=False):
|
211 |
-
x = self.in_block(x)
|
212 |
-
x = self.down_blocks(x)
|
213 |
-
if quantize:
|
214 |
-
qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
|
215 |
-
return qe, x, indices, vq_loss + commit_loss * 0.25
|
216 |
-
else:
|
217 |
-
return x
|
218 |
-
|
219 |
-
def decode(self, x):
|
220 |
-
x = self.up_blocks(x)
|
221 |
-
x = self.out_block(x)
|
222 |
-
return x
|
223 |
-
|
224 |
-
def forward(self, x, quantize=False):
|
225 |
-
qe, x, _, vq_loss = self.encode(x, quantize)
|
226 |
-
x = self.decode(qe)
|
227 |
-
return x, vq_loss
|
228 |
-
|
229 |
-
|
230 |
-
class Discriminator(nn.Module):
|
231 |
-
def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6):
|
232 |
-
super().__init__()
|
233 |
-
d = max(depth - 3, 3)
|
234 |
-
layers = [
|
235 |
-
nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
|
236 |
-
nn.LeakyReLU(0.2),
|
237 |
-
]
|
238 |
-
for i in range(depth - 1):
|
239 |
-
c_in = c_hidden // (2 ** max((d - i), 0))
|
240 |
-
c_out = c_hidden // (2 ** max((d - 1 - i), 0))
|
241 |
-
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
|
242 |
-
layers.append(nn.InstanceNorm2d(c_out))
|
243 |
-
layers.append(nn.LeakyReLU(0.2))
|
244 |
-
self.encoder = nn.Sequential(*layers)
|
245 |
-
self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
|
246 |
-
self.logits = nn.Sigmoid()
|
247 |
-
|
248 |
-
def forward(self, x, cond=None):
|
249 |
-
x = self.encoder(x)
|
250 |
-
if cond is not None:
|
251 |
-
cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1))
|
252 |
-
x = torch.cat([x, cond], dim=1)
|
253 |
-
x = self.shuffle(x)
|
254 |
-
x = self.logits(x)
|
255 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|