|
import torch |
|
from typing import Dict, Optional |
|
import comfy.ldm.modules.diffusionmodules.mmdit |
|
|
|
class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT): |
|
def __init__( |
|
self, |
|
num_blocks = None, |
|
control_latent_channels = None, |
|
dtype = None, |
|
device = None, |
|
operations = None, |
|
**kwargs, |
|
): |
|
super().__init__(dtype=dtype, device=device, operations=operations, final_layer=False, num_blocks=num_blocks, **kwargs) |
|
|
|
self.controlnet_blocks = torch.nn.ModuleList([]) |
|
for _ in range(len(self.joint_blocks)): |
|
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype)) |
|
|
|
if control_latent_channels is None: |
|
control_latent_channels = self.in_channels |
|
|
|
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed( |
|
None, |
|
self.patch_size, |
|
control_latent_channels, |
|
self.hidden_size, |
|
bias=True, |
|
strict_img_size=False, |
|
dtype=dtype, |
|
device=device, |
|
operations=operations |
|
) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
timesteps: torch.Tensor, |
|
y: Optional[torch.Tensor] = None, |
|
context: Optional[torch.Tensor] = None, |
|
hint = None, |
|
) -> torch.Tensor: |
|
|
|
|
|
y = torch.zeros_like(y) |
|
|
|
if self.context_processor is not None: |
|
context = self.context_processor(context) |
|
|
|
hw = x.shape[-2:] |
|
x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device) |
|
x += self.pos_embed_input(hint) |
|
|
|
c = self.t_embedder(timesteps, dtype=x.dtype) |
|
if y is not None and self.y_embedder is not None: |
|
y = self.y_embedder(y) |
|
c = c + y |
|
|
|
if context is not None: |
|
context = self.context_embedder(context) |
|
|
|
output = [] |
|
|
|
blocks = len(self.joint_blocks) |
|
for i in range(blocks): |
|
context, x = self.joint_blocks[i]( |
|
context, |
|
x, |
|
c=c, |
|
use_checkpoint=self.use_checkpoint, |
|
) |
|
|
|
out = self.controlnet_blocks[i](x) |
|
count = self.depth // blocks |
|
if i == blocks - 1: |
|
count -= 1 |
|
for j in range(count): |
|
output.append(out) |
|
|
|
return {"output": output} |
|
|