OminiControl / src /transformer.py
Yuanshi's picture
add all
6ed1db6
raw
history blame
9.88 kB
import torch
from diffusers.pipelines import FluxPipeline
from typing import List, Union, Optional, Dict, Any, Callable
from .block import block_forward, single_block_forward
from .lora_controller import enable_lora
from diffusers.models.transformers.transformer_flux import (
FluxTransformer2DModel,
Transformer2DModelOutput,
USE_PEFT_BACKEND,
is_torch_version,
scale_lora_layers,
unscale_lora_layers,
logger,
)
import numpy as np
def prepare_params(
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
return_dict: bool = True,
**kwargs: dict,
):
return (
hidden_states,
encoder_hidden_states,
pooled_projections,
timestep,
img_ids,
txt_ids,
guidance,
joint_attention_kwargs,
controlnet_block_samples,
controlnet_single_block_samples,
return_dict,
)
def tranformer_forward(
transformer: FluxTransformer2DModel,
condition_latents: torch.Tensor,
condition_ids: torch.Tensor,
condition_type_ids: torch.Tensor,
model_config: Optional[Dict[str, Any]] = {},
return_conditional_latents: bool = False,
c_t=0,
**params: dict,
):
self = transformer
use_condition = condition_latents is not None
use_condition_in_single_blocks = model_config.get(
"use_condition_in_single_blocks", True
)
# if return_conditional_latents is True, use_condition and use_condition_in_single_blocks must be True
assert not return_conditional_latents or (
use_condition and use_condition_in_single_blocks
), "`return_conditional_latents` is True, `use_condition` and `use_condition_in_single_blocks` must be True"
(
hidden_states,
encoder_hidden_states,
pooled_projections,
timestep,
img_ids,
txt_ids,
guidance,
joint_attention_kwargs,
controlnet_block_samples,
controlnet_single_block_samples,
return_dict,
) = prepare_params(**params)
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if (
joint_attention_kwargs is not None
and joint_attention_kwargs.get("scale", None) is not None
):
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)):
hidden_states = self.x_embedder(hidden_states)
condition_latents = self.x_embedder(condition_latents) if use_condition else None
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
cond_temb = (
self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections)
if guidance is None
else self.time_text_embed(
torch.ones_like(timestep) * c_t * 1000, guidance, pooled_projections
)
)
if hasattr(self, "cond_type_embed") and condition_type_ids is not None:
cond_type_proj = self.time_text_embed.time_proj(condition_type_ids[0])
cond_type_emb = self.cond_type_embed(cond_type_proj.to(dtype=cond_temb.dtype))
cond_temb = cond_temb + cond_type_emb
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
if use_condition:
cond_ids = condition_ids
cond_rotary_emb = self.pos_embed(cond_ids)
# hidden_states = torch.cat([hidden_states, condition_latents], dim=1)
for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = (
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
)
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states, condition_latents = block_forward(
block,
model_config=model_config,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
condition_latents=condition_latents if use_condition else None,
temb=temb,
cond_temb=cond_temb if use_condition else None,
cond_rotary_emb=cond_rotary_emb if use_condition else None,
image_rotary_emb=image_rotary_emb,
)
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(
controlnet_block_samples
)
interval_control = int(np.ceil(interval_control))
hidden_states = (
hidden_states
+ controlnet_block_samples[index_block // interval_control]
)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = (
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
result = single_block_forward(
block,
model_config=model_config,
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
**(
{
"condition_latents": condition_latents,
"cond_temb": cond_temb,
"cond_rotary_emb": cond_rotary_emb,
}
if use_condition_in_single_blocks and use_condition
else {}
),
)
if use_condition_in_single_blocks and use_condition:
hidden_states, condition_latents = result
else:
hidden_states = result
# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(
controlnet_single_block_samples
)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if return_conditional_latents:
condition_latents = (
self.norm_out(condition_latents, cond_temb) if use_condition else None
)
condition_output = self.proj_out(condition_latents) if use_condition else None
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (
(output,) if not return_conditional_latents else (output, condition_output)
)
return Transformer2DModelOutput(sample=output)