|
from typing import Any, Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from torch.utils import checkpoint |
|
|
|
from comfy.ldm.modules.diffusionmodules.mmdit import ( |
|
Mlp, |
|
TimestepEmbedder, |
|
PatchEmbed, |
|
RMSNorm, |
|
) |
|
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding |
|
from .poolers import AttentionPool |
|
|
|
import comfy.latent_formats |
|
from .models import HunYuanDiTBlock, calc_rope |
|
|
|
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop |
|
|
|
|
|
class HunYuanControlNet(nn.Module): |
|
""" |
|
HunYuanDiT: Diffusion model with a Transformer backbone. |
|
|
|
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers. |
|
|
|
Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline. |
|
|
|
Parameters |
|
---------- |
|
args: argparse.Namespace |
|
The arguments parsed by argparse. |
|
input_size: tuple |
|
The size of the input image. |
|
patch_size: int |
|
The size of the patch. |
|
in_channels: int |
|
The number of input channels. |
|
hidden_size: int |
|
The hidden size of the transformer backbone. |
|
depth: int |
|
The number of transformer blocks. |
|
num_heads: int |
|
The number of attention heads. |
|
mlp_ratio: float |
|
The ratio of the hidden size of the MLP in the transformer block. |
|
log_fn: callable |
|
The logging function. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_size: tuple = 128, |
|
patch_size: int = 2, |
|
in_channels: int = 4, |
|
hidden_size: int = 1408, |
|
depth: int = 40, |
|
num_heads: int = 16, |
|
mlp_ratio: float = 4.3637, |
|
text_states_dim=1024, |
|
text_states_dim_t5=2048, |
|
text_len=77, |
|
text_len_t5=256, |
|
qk_norm=True, |
|
size_cond=False, |
|
use_style_cond=False, |
|
learn_sigma=True, |
|
norm="layer", |
|
log_fn: callable = print, |
|
attn_precision=None, |
|
dtype=None, |
|
device=None, |
|
operations=None, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
self.log_fn = log_fn |
|
self.depth = depth |
|
self.learn_sigma = learn_sigma |
|
self.in_channels = in_channels |
|
self.out_channels = in_channels * 2 if learn_sigma else in_channels |
|
self.patch_size = patch_size |
|
self.num_heads = num_heads |
|
self.hidden_size = hidden_size |
|
self.text_states_dim = text_states_dim |
|
self.text_states_dim_t5 = text_states_dim_t5 |
|
self.text_len = text_len |
|
self.text_len_t5 = text_len_t5 |
|
self.size_cond = size_cond |
|
self.use_style_cond = use_style_cond |
|
self.norm = norm |
|
self.dtype = dtype |
|
self.latent_format = comfy.latent_formats.SDXL |
|
|
|
self.mlp_t5 = nn.Sequential( |
|
nn.Linear( |
|
self.text_states_dim_t5, |
|
self.text_states_dim_t5 * 4, |
|
bias=True, |
|
dtype=dtype, |
|
device=device, |
|
), |
|
nn.SiLU(), |
|
nn.Linear( |
|
self.text_states_dim_t5 * 4, |
|
self.text_states_dim, |
|
bias=True, |
|
dtype=dtype, |
|
device=device, |
|
), |
|
) |
|
|
|
self.text_embedding_padding = nn.Parameter( |
|
torch.randn( |
|
self.text_len + self.text_len_t5, |
|
self.text_states_dim, |
|
dtype=dtype, |
|
device=device, |
|
) |
|
) |
|
|
|
|
|
pooler_out_dim = 1024 |
|
self.pooler = AttentionPool( |
|
self.text_len_t5, |
|
self.text_states_dim_t5, |
|
num_heads=8, |
|
output_dim=pooler_out_dim, |
|
dtype=dtype, |
|
device=device, |
|
operations=operations, |
|
) |
|
|
|
|
|
self.extra_in_dim = pooler_out_dim |
|
|
|
if self.size_cond: |
|
|
|
self.extra_in_dim += 6 * 256 |
|
|
|
if self.use_style_cond: |
|
|
|
self.style_embedder = nn.Embedding( |
|
1, hidden_size, dtype=dtype, device=device |
|
) |
|
self.extra_in_dim += hidden_size |
|
|
|
|
|
self.x_embedder = PatchEmbed( |
|
input_size, |
|
patch_size, |
|
in_channels, |
|
hidden_size, |
|
dtype=dtype, |
|
device=device, |
|
operations=operations, |
|
) |
|
self.t_embedder = TimestepEmbedder( |
|
hidden_size, dtype=dtype, device=device, operations=operations |
|
) |
|
self.extra_embedder = nn.Sequential( |
|
operations.Linear( |
|
self.extra_in_dim, hidden_size * 4, dtype=dtype, device=device |
|
), |
|
nn.SiLU(), |
|
operations.Linear( |
|
hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device |
|
), |
|
) |
|
|
|
|
|
num_patches = self.x_embedder.num_patches |
|
|
|
|
|
self.blocks = nn.ModuleList( |
|
[ |
|
HunYuanDiTBlock( |
|
hidden_size=hidden_size, |
|
c_emb_size=hidden_size, |
|
num_heads=num_heads, |
|
mlp_ratio=mlp_ratio, |
|
text_states_dim=self.text_states_dim, |
|
qk_norm=qk_norm, |
|
norm_type=self.norm, |
|
skip=False, |
|
attn_precision=attn_precision, |
|
dtype=dtype, |
|
device=device, |
|
operations=operations, |
|
) |
|
for _ in range(19) |
|
] |
|
) |
|
|
|
|
|
self.before_proj = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) |
|
|
|
|
|
|
|
self.after_proj_list = nn.ModuleList( |
|
[ |
|
|
|
operations.Linear( |
|
self.hidden_size, self.hidden_size, dtype=dtype, device=device |
|
) |
|
for _ in range(len(self.blocks)) |
|
] |
|
) |
|
|
|
def forward( |
|
self, |
|
x, |
|
hint, |
|
timesteps, |
|
context, |
|
text_embedding_mask=None, |
|
encoder_hidden_states_t5=None, |
|
text_embedding_mask_t5=None, |
|
image_meta_size=None, |
|
style=None, |
|
return_dict=False, |
|
**kwarg, |
|
): |
|
""" |
|
Forward pass of the encoder. |
|
|
|
Parameters |
|
---------- |
|
x: torch.Tensor |
|
(B, D, H, W) |
|
t: torch.Tensor |
|
(B) |
|
encoder_hidden_states: torch.Tensor |
|
CLIP text embedding, (B, L_clip, D) |
|
text_embedding_mask: torch.Tensor |
|
CLIP text embedding mask, (B, L_clip) |
|
encoder_hidden_states_t5: torch.Tensor |
|
T5 text embedding, (B, L_t5, D) |
|
text_embedding_mask_t5: torch.Tensor |
|
T5 text embedding mask, (B, L_t5) |
|
image_meta_size: torch.Tensor |
|
(B, 6) |
|
style: torch.Tensor |
|
(B) |
|
cos_cis_img: torch.Tensor |
|
sin_cis_img: torch.Tensor |
|
return_dict: bool |
|
Whether to return a dictionary. |
|
""" |
|
condition = hint |
|
if condition.shape[0] == 1: |
|
condition = torch.repeat_interleave(condition, x.shape[0], dim=0) |
|
|
|
text_states = context |
|
text_states_t5 = encoder_hidden_states_t5 |
|
text_states_mask = text_embedding_mask.bool() |
|
text_states_t5_mask = text_embedding_mask_t5.bool() |
|
b_t5, l_t5, c_t5 = text_states_t5.shape |
|
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1) |
|
|
|
padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states) |
|
|
|
text_states[:, -self.text_len :] = torch.where( |
|
text_states_mask[:, -self.text_len :].unsqueeze(2), |
|
text_states[:, -self.text_len :], |
|
padding[: self.text_len], |
|
) |
|
text_states_t5[:, -self.text_len_t5 :] = torch.where( |
|
text_states_t5_mask[:, -self.text_len_t5 :].unsqueeze(2), |
|
text_states_t5[:, -self.text_len_t5 :], |
|
padding[self.text_len :], |
|
) |
|
|
|
text_states = torch.cat([text_states, text_states_t5], dim=1) |
|
|
|
|
|
|
|
|
|
|
|
freqs_cis_img = calc_rope( |
|
x, self.patch_size, self.hidden_size // self.num_heads |
|
) |
|
|
|
|
|
t = self.t_embedder(timesteps, dtype=self.dtype) |
|
x = self.x_embedder(x) |
|
|
|
|
|
|
|
extra_vec = self.pooler(encoder_hidden_states_t5) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if style is not None: |
|
style_embedding = self.style_embedder(style) |
|
extra_vec = torch.cat([extra_vec, style_embedding], dim=1) |
|
|
|
|
|
c = t + self.extra_embedder(extra_vec) |
|
|
|
|
|
condition = self.x_embedder(condition) |
|
|
|
|
|
controls = [] |
|
x = x + self.before_proj(condition) |
|
for layer, block in enumerate(self.blocks): |
|
x = block(x, c, text_states, freqs_cis_img) |
|
controls.append(self.after_proj_list[layer](x)) |
|
|
|
return {"output": controls} |
|
|