Stable-X commited on
Commit
44168ee
1 Parent(s): ec00fa5

Upload controlnetvae.py

Browse files
Files changed (1) hide show
  1. controlnet/controlnetvae.py +250 -0
controlnet/controlnetvae.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ from torch import nn
19
+ from torch.nn import functional as F
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.models.attention_processor import (
25
+ ADDED_KV_ATTENTION_PROCESSORS,
26
+ CROSS_ATTENTION_PROCESSORS,
27
+ AttentionProcessor,
28
+ AttnAddedKVProcessor,
29
+ AttnProcessor,
30
+ )
31
+ from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+ from diffusers.models.unets.unet_2d_blocks import (
34
+ CrossAttnDownBlock2D,
35
+ DownBlock2D,
36
+ UNetMidBlock2D,
37
+ UNetMidBlock2DCrossAttn,
38
+ get_down_block,
39
+ )
40
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
41
+ from diffusers.models.controlnet import ControlNetOutput
42
+ from diffusers.models import ControlNetModel
43
+
44
+ import pdb
45
+
46
+
47
+ class ControlNetVAEModel(ControlNetModel):
48
+ def forward(
49
+ self,
50
+ sample: torch.Tensor,
51
+ timestep: Union[torch.Tensor, float, int],
52
+ encoder_hidden_states: torch.Tensor,
53
+ controlnet_cond: torch.Tensor = None,
54
+ conditioning_scale: float = 1.0,
55
+ class_labels: Optional[torch.Tensor] = None,
56
+ timestep_cond: Optional[torch.Tensor] = None,
57
+ attention_mask: Optional[torch.Tensor] = None,
58
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
59
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
60
+ guess_mode: bool = False,
61
+ return_dict: bool = True,
62
+ ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
63
+ """
64
+ The [`ControlNetVAEModel`] forward method.
65
+
66
+ Args:
67
+ sample (`torch.Tensor`):
68
+ The noisy input tensor.
69
+ timestep (`Union[torch.Tensor, float, int]`):
70
+ The number of timesteps to denoise an input.
71
+ encoder_hidden_states (`torch.Tensor`):
72
+ The encoder hidden states.
73
+ controlnet_cond (`torch.Tensor`):
74
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
75
+ conditioning_scale (`float`, defaults to `1.0`):
76
+ The scale factor for ControlNet outputs.
77
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
78
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
79
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
80
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
81
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
82
+ embeddings.
83
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
84
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
85
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
86
+ negative values to the attention scores corresponding to "discard" tokens.
87
+ added_cond_kwargs (`dict`):
88
+ Additional conditions for the Stable Diffusion XL UNet.
89
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
90
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
91
+ guess_mode (`bool`, defaults to `False`):
92
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
93
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
94
+ return_dict (`bool`, defaults to `True`):
95
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
96
+
97
+ Returns:
98
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
99
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
100
+ returned where the first element is the sample tensor.
101
+ """
102
+ # check channel order
103
+
104
+
105
+ channel_order = self.config.controlnet_conditioning_channel_order
106
+
107
+ if channel_order == "rgb":
108
+ # in rgb order by default
109
+ ...
110
+ elif channel_order == "bgr":
111
+ controlnet_cond = torch.flip(controlnet_cond, dims=[1])
112
+ else:
113
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
114
+
115
+ # prepare attention_mask
116
+ if attention_mask is not None:
117
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
118
+ attention_mask = attention_mask.unsqueeze(1)
119
+
120
+ # 1. time
121
+ timesteps = timestep
122
+ if not torch.is_tensor(timesteps):
123
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
124
+ # This would be a good case for the `match` statement (Python 3.10+)
125
+ is_mps = sample.device.type == "mps"
126
+ if isinstance(timestep, float):
127
+ dtype = torch.float32 if is_mps else torch.float64
128
+ else:
129
+ dtype = torch.int32 if is_mps else torch.int64
130
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
131
+ elif len(timesteps.shape) == 0:
132
+ timesteps = timesteps[None].to(sample.device)
133
+
134
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
135
+ timesteps = timesteps.expand(sample.shape[0])
136
+
137
+ t_emb = self.time_proj(timesteps)
138
+
139
+ # timesteps does not contain any weights and will always return f32 tensors
140
+ # but time_embedding might actually be running in fp16. so we need to cast here.
141
+ # there might be better ways to encapsulate this.
142
+ t_emb = t_emb.to(dtype=sample.dtype)
143
+
144
+ emb = self.time_embedding(t_emb, timestep_cond)
145
+ aug_emb = None
146
+
147
+ if self.class_embedding is not None:
148
+ if class_labels is None:
149
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
150
+
151
+ if self.config.class_embed_type == "timestep":
152
+ class_labels = self.time_proj(class_labels)
153
+
154
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
155
+ emb = emb + class_emb
156
+
157
+ if self.config.addition_embed_type is not None:
158
+ if self.config.addition_embed_type == "text":
159
+ aug_emb = self.add_embedding(encoder_hidden_states)
160
+
161
+ elif self.config.addition_embed_type == "text_time":
162
+ if "text_embeds" not in added_cond_kwargs:
163
+ raise ValueError(
164
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
165
+ )
166
+ text_embeds = added_cond_kwargs.get("text_embeds")
167
+ if "time_ids" not in added_cond_kwargs:
168
+ raise ValueError(
169
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
170
+ )
171
+ time_ids = added_cond_kwargs.get("time_ids")
172
+ time_embeds = self.add_time_proj(time_ids.flatten())
173
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
174
+
175
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
176
+ add_embeds = add_embeds.to(emb.dtype)
177
+ aug_emb = self.add_embedding(add_embeds)
178
+
179
+
180
+ emb = emb + aug_emb if aug_emb is not None else emb
181
+ # 2. pre-process
182
+ sample = self.conv_in(sample)
183
+
184
+ # 3. down
185
+ down_block_res_samples = (sample,)
186
+ for downsample_block in self.down_blocks:
187
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
188
+ sample, res_samples = downsample_block(
189
+ hidden_states=sample,
190
+ temb=emb,
191
+ encoder_hidden_states=encoder_hidden_states,
192
+ attention_mask=attention_mask,
193
+ cross_attention_kwargs=cross_attention_kwargs,
194
+ )
195
+ else:
196
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
197
+
198
+ down_block_res_samples += res_samples
199
+
200
+ # 4. mid
201
+ if self.mid_block is not None:
202
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
203
+ sample = self.mid_block(
204
+ sample,
205
+ emb,
206
+ encoder_hidden_states=encoder_hidden_states,
207
+ attention_mask=attention_mask,
208
+ cross_attention_kwargs=cross_attention_kwargs,
209
+ )
210
+ else:
211
+ sample = self.mid_block(sample, emb)
212
+
213
+ # 5. Control net blocks
214
+
215
+ controlnet_down_block_res_samples = ()
216
+
217
+ # NOTE that controlnet downblock is zeroconv, we discard
218
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
219
+ down_block_res_sample = down_block_res_sample
220
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
221
+
222
+ down_block_res_samples = controlnet_down_block_res_samples
223
+
224
+ mid_block_res_sample = sample
225
+
226
+ # 6. scaling
227
+ if guess_mode and not self.config.global_pool_conditions:
228
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
229
+ scales = scales * conditioning_scale
230
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
231
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
232
+ else:
233
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
234
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
235
+
236
+ if self.config.global_pool_conditions:
237
+ down_block_res_samples = [
238
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
239
+ ]
240
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
241
+
242
+ if not return_dict:
243
+ return (down_block_res_samples, mid_block_res_sample)
244
+
245
+ return ControlNetOutput(
246
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
247
+ )
248
+
249
+
250
+