|
|
|
""" |
|
This module contains the implementation of mutual self-attention, |
|
which is a type of attention mechanism used in deep learning models. |
|
The module includes several classes and functions related to attention mechanisms, |
|
such as BasicTransformerBlock and TemporalBasicTransformerBlock. |
|
The main purpose of this module is to provide a comprehensive attention mechanism for various tasks in deep learning, |
|
such as image and video processing, natural language processing, and so on. |
|
""" |
|
|
|
from typing import Any, Dict, Optional |
|
|
|
import torch |
|
from einops import rearrange |
|
|
|
from .attention import BasicTransformerBlock, TemporalBasicTransformerBlock |
|
|
|
|
|
def torch_dfs(model: torch.nn.Module): |
|
""" |
|
Perform a depth-first search (DFS) traversal on a PyTorch model's neural network architecture. |
|
|
|
This function recursively traverses all the children modules of a given PyTorch model and returns a list |
|
containing all the modules in the model's architecture. The DFS approach starts with the input model and |
|
explores its children modules depth-wise before backtracking and exploring other branches. |
|
|
|
Args: |
|
model (torch.nn.Module): The root module of the neural network to traverse. |
|
|
|
Returns: |
|
list: A list of all the modules in the model's architecture. |
|
""" |
|
result = [model] |
|
for child in model.children(): |
|
result += torch_dfs(child) |
|
return result |
|
|
|
|
|
class ReferenceAttentionControl: |
|
""" |
|
This class is used to control the reference attention mechanism in a neural network model. |
|
It is responsible for managing the guidance and fusion blocks, and modifying the self-attention |
|
and group normalization mechanisms. The class also provides methods for registering reference hooks |
|
and updating/clearing the internal state of the attention control object. |
|
|
|
Attributes: |
|
unet: The UNet model associated with this attention control object. |
|
mode: The operating mode of the attention control object, either 'write' or 'read'. |
|
do_classifier_free_guidance: Whether to use classifier-free guidance in the attention mechanism. |
|
attention_auto_machine_weight: The weight assigned to the attention auto-machine. |
|
gn_auto_machine_weight: The weight assigned to the group normalization auto-machine. |
|
style_fidelity: The style fidelity parameter for the attention mechanism. |
|
reference_attn: Whether to use reference attention in the model. |
|
reference_adain: Whether to use reference AdaIN in the model. |
|
fusion_blocks: The type of fusion blocks to use in the model ('midup', 'late', or 'nofusion'). |
|
batch_size: The batch size used for processing video frames. |
|
|
|
Methods: |
|
register_reference_hooks: Registers the reference hooks for the attention control object. |
|
hacked_basic_transformer_inner_forward: The modified inner forward method for the basic transformer block. |
|
update: Updates the internal state of the attention control object using the provided writer and dtype. |
|
clear: Clears the internal state of the attention control object. |
|
""" |
|
def __init__( |
|
self, |
|
unet, |
|
mode="write", |
|
do_classifier_free_guidance=False, |
|
attention_auto_machine_weight=float("inf"), |
|
gn_auto_machine_weight=1.0, |
|
style_fidelity=1.0, |
|
reference_attn=True, |
|
reference_adain=False, |
|
fusion_blocks="midup", |
|
batch_size=1, |
|
) -> None: |
|
""" |
|
Initializes the ReferenceAttentionControl class. |
|
|
|
Args: |
|
unet (torch.nn.Module): The UNet model. |
|
mode (str, optional): The mode of operation. Defaults to "write". |
|
do_classifier_free_guidance (bool, optional): Whether to do classifier-free guidance. Defaults to False. |
|
attention_auto_machine_weight (float, optional): The weight for attention auto-machine. Defaults to infinity. |
|
gn_auto_machine_weight (float, optional): The weight for group-norm auto-machine. Defaults to 1.0. |
|
style_fidelity (float, optional): The style fidelity. Defaults to 1.0. |
|
reference_attn (bool, optional): Whether to use reference attention. Defaults to True. |
|
reference_adain (bool, optional): Whether to use reference AdaIN. Defaults to False. |
|
fusion_blocks (str, optional): The fusion blocks to use. Defaults to "midup". |
|
batch_size (int, optional): The batch size. Defaults to 1. |
|
|
|
Raises: |
|
ValueError: If the mode is not recognized. |
|
ValueError: If the fusion blocks are not recognized. |
|
""" |
|
|
|
self.unet = unet |
|
assert mode in ["read", "write"] |
|
assert fusion_blocks in ["midup", "full"] |
|
self.reference_attn = reference_attn |
|
self.reference_adain = reference_adain |
|
self.fusion_blocks = fusion_blocks |
|
self.register_reference_hooks( |
|
mode, |
|
do_classifier_free_guidance, |
|
attention_auto_machine_weight, |
|
gn_auto_machine_weight, |
|
style_fidelity, |
|
reference_attn, |
|
reference_adain, |
|
fusion_blocks, |
|
batch_size=batch_size, |
|
) |
|
|
|
def register_reference_hooks( |
|
self, |
|
mode, |
|
do_classifier_free_guidance, |
|
_attention_auto_machine_weight, |
|
_gn_auto_machine_weight, |
|
_style_fidelity, |
|
_reference_attn, |
|
_reference_adain, |
|
_dtype=torch.float16, |
|
batch_size=1, |
|
num_images_per_prompt=1, |
|
device=torch.device("cpu"), |
|
_fusion_blocks="midup", |
|
): |
|
""" |
|
Registers reference hooks for the model. |
|
|
|
This function is responsible for registering reference hooks in the model, |
|
which are used to modify the attention mechanism and group normalization layers. |
|
It takes various parameters as input, such as mode, |
|
do_classifier_free_guidance, _attention_auto_machine_weight, _gn_auto_machine_weight, _style_fidelity, |
|
_reference_attn, _reference_adain, _dtype, batch_size, num_images_per_prompt, device, and _fusion_blocks. |
|
|
|
Args: |
|
self: Reference to the instance of the class. |
|
mode: The mode of operation for the reference hooks. |
|
do_classifier_free_guidance: A boolean flag indicating whether to use classifier-free guidance. |
|
_attention_auto_machine_weight: The weight for the attention auto-machine. |
|
_gn_auto_machine_weight: The weight for the group normalization auto-machine. |
|
_style_fidelity: The style fidelity for the reference hooks. |
|
_reference_attn: A boolean flag indicating whether to use reference attention. |
|
_reference_adain: A boolean flag indicating whether to use reference AdaIN. |
|
_dtype: The data type for the reference hooks. |
|
batch_size: The batch size for the reference hooks. |
|
num_images_per_prompt: The number of images per prompt for the reference hooks. |
|
device: The device for the reference hooks. |
|
_fusion_blocks: The fusion blocks for the reference hooks. |
|
|
|
Returns: |
|
None |
|
""" |
|
MODE = mode |
|
if do_classifier_free_guidance: |
|
uc_mask = ( |
|
torch.Tensor( |
|
[1] * batch_size * num_images_per_prompt * 16 |
|
+ [0] * batch_size * num_images_per_prompt * 16 |
|
) |
|
.to(device) |
|
.bool() |
|
) |
|
else: |
|
uc_mask = ( |
|
torch.Tensor([0] * batch_size * num_images_per_prompt * 2) |
|
.to(device) |
|
.bool() |
|
) |
|
|
|
def hacked_basic_transformer_inner_forward( |
|
self, |
|
hidden_states: torch.FloatTensor, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
timestep: Optional[torch.LongTensor] = None, |
|
cross_attention_kwargs: Dict[str, Any] = None, |
|
class_labels: Optional[torch.LongTensor] = None, |
|
video_length=None, |
|
): |
|
gate_msa = None |
|
shift_mlp = None |
|
scale_mlp = None |
|
gate_mlp = None |
|
|
|
if self.use_ada_layer_norm: |
|
norm_hidden_states = self.norm1(hidden_states, timestep) |
|
elif self.use_ada_layer_norm_zero: |
|
( |
|
norm_hidden_states, |
|
gate_msa, |
|
shift_mlp, |
|
scale_mlp, |
|
gate_mlp, |
|
) = self.norm1( |
|
hidden_states, |
|
timestep, |
|
class_labels, |
|
hidden_dtype=hidden_states.dtype, |
|
) |
|
else: |
|
norm_hidden_states = self.norm1(hidden_states) |
|
|
|
|
|
|
|
cross_attention_kwargs = ( |
|
cross_attention_kwargs if cross_attention_kwargs is not None else {} |
|
) |
|
if self.only_cross_attention: |
|
attn_output = self.attn1( |
|
norm_hidden_states, |
|
encoder_hidden_states=( |
|
encoder_hidden_states if self.only_cross_attention else None |
|
), |
|
attention_mask=attention_mask, |
|
**cross_attention_kwargs, |
|
) |
|
else: |
|
if MODE == "write": |
|
self.bank.append(norm_hidden_states.clone()) |
|
attn_output = self.attn1( |
|
norm_hidden_states, |
|
encoder_hidden_states=( |
|
encoder_hidden_states if self.only_cross_attention else None |
|
), |
|
attention_mask=attention_mask, |
|
**cross_attention_kwargs, |
|
) |
|
if MODE == "read": |
|
|
|
bank_fea = [ |
|
rearrange( |
|
rearrange( |
|
d, |
|
"(b s) l c -> b s l c", |
|
b=norm_hidden_states.shape[0] // video_length, |
|
)[:, 0, :, :] |
|
|
|
.repeat(1, video_length, 1, 1), |
|
"b t l c -> (b t) l c", |
|
) |
|
for d in self.bank |
|
] |
|
motion_frames_fea = [rearrange( |
|
d, |
|
"(b s) l c -> b s l c", |
|
b=norm_hidden_states.shape[0] // video_length, |
|
)[:, 1:, :, :] for d in self.bank] |
|
modify_norm_hidden_states = torch.cat( |
|
[norm_hidden_states] + bank_fea, dim=1 |
|
) |
|
hidden_states_uc = ( |
|
self.attn1( |
|
norm_hidden_states, |
|
encoder_hidden_states=modify_norm_hidden_states, |
|
attention_mask=attention_mask, |
|
) |
|
+ hidden_states |
|
) |
|
if do_classifier_free_guidance: |
|
hidden_states_c = hidden_states_uc.clone() |
|
_uc_mask = uc_mask.clone() |
|
if hidden_states.shape[0] != _uc_mask.shape[0]: |
|
_uc_mask = ( |
|
torch.Tensor( |
|
[1] * (hidden_states.shape[0] // 2) |
|
+ [0] * (hidden_states.shape[0] // 2) |
|
) |
|
.to(device) |
|
.bool() |
|
) |
|
hidden_states_c[_uc_mask] = ( |
|
self.attn1( |
|
norm_hidden_states[_uc_mask], |
|
encoder_hidden_states=norm_hidden_states[_uc_mask], |
|
attention_mask=attention_mask, |
|
) |
|
+ hidden_states[_uc_mask] |
|
) |
|
hidden_states = hidden_states_c.clone() |
|
else: |
|
hidden_states = hidden_states_uc |
|
|
|
|
|
if self.attn2 is not None: |
|
|
|
norm_hidden_states = ( |
|
self.norm2(hidden_states, timestep) |
|
if self.use_ada_layer_norm |
|
else self.norm2(hidden_states) |
|
) |
|
hidden_states = ( |
|
self.attn2( |
|
norm_hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
attention_mask=attention_mask, |
|
) |
|
+ hidden_states |
|
) |
|
|
|
|
|
hidden_states = self.ff(self.norm3( |
|
hidden_states)) + hidden_states |
|
|
|
|
|
if self.unet_use_temporal_attention: |
|
d = hidden_states.shape[1] |
|
hidden_states = rearrange( |
|
hidden_states, "(b f) d c -> (b d) f c", f=video_length |
|
) |
|
norm_hidden_states = ( |
|
self.norm_temp(hidden_states, timestep) |
|
if self.use_ada_layer_norm |
|
else self.norm_temp(hidden_states) |
|
) |
|
hidden_states = ( |
|
self.attn_temp(norm_hidden_states) + hidden_states |
|
) |
|
hidden_states = rearrange( |
|
hidden_states, "(b d) f c -> (b f) d c", d=d |
|
) |
|
|
|
return hidden_states, motion_frames_fea |
|
|
|
if self.use_ada_layer_norm_zero: |
|
attn_output = gate_msa.unsqueeze(1) * attn_output |
|
hidden_states = attn_output + hidden_states |
|
|
|
if self.attn2 is not None: |
|
norm_hidden_states = ( |
|
self.norm2(hidden_states, timestep) |
|
if self.use_ada_layer_norm |
|
else self.norm2(hidden_states) |
|
) |
|
|
|
|
|
tmp = norm_hidden_states.shape[0] // encoder_hidden_states.shape[0] |
|
attn_output = self.attn2( |
|
norm_hidden_states, |
|
|
|
encoder_hidden_states=encoder_hidden_states.repeat( |
|
tmp, 1, 1), |
|
attention_mask=encoder_attention_mask, |
|
**cross_attention_kwargs, |
|
) |
|
hidden_states = attn_output + hidden_states |
|
|
|
|
|
norm_hidden_states = self.norm3(hidden_states) |
|
|
|
if self.use_ada_layer_norm_zero: |
|
norm_hidden_states = ( |
|
norm_hidden_states * |
|
(1 + scale_mlp[:, None]) + shift_mlp[:, None] |
|
) |
|
|
|
ff_output = self.ff(norm_hidden_states) |
|
|
|
if self.use_ada_layer_norm_zero: |
|
ff_output = gate_mlp.unsqueeze(1) * ff_output |
|
|
|
hidden_states = ff_output + hidden_states |
|
|
|
return hidden_states |
|
|
|
if self.reference_attn: |
|
if self.fusion_blocks == "midup": |
|
attn_modules = [ |
|
module |
|
for module in ( |
|
torch_dfs(self.unet.mid_block) + |
|
torch_dfs(self.unet.up_blocks) |
|
) |
|
if isinstance(module, (BasicTransformerBlock, TemporalBasicTransformerBlock)) |
|
] |
|
elif self.fusion_blocks == "full": |
|
attn_modules = [ |
|
module |
|
for module in torch_dfs(self.unet) |
|
if isinstance(module, (BasicTransformerBlock, TemporalBasicTransformerBlock)) |
|
] |
|
attn_modules = sorted( |
|
attn_modules, key=lambda x: -x.norm1.normalized_shape[0] |
|
) |
|
|
|
for i, module in enumerate(attn_modules): |
|
module._original_inner_forward = module.forward |
|
if isinstance(module, BasicTransformerBlock): |
|
module.forward = hacked_basic_transformer_inner_forward.__get__( |
|
module, |
|
BasicTransformerBlock) |
|
if isinstance(module, TemporalBasicTransformerBlock): |
|
module.forward = hacked_basic_transformer_inner_forward.__get__( |
|
module, |
|
TemporalBasicTransformerBlock) |
|
|
|
module.bank = [] |
|
module.attn_weight = float(i) / float(len(attn_modules)) |
|
|
|
def update(self, writer, dtype=torch.float16): |
|
""" |
|
Update the model's parameters. |
|
|
|
Args: |
|
writer (torch.nn.Module): The model's writer object. |
|
dtype (torch.dtype, optional): The data type to be used for the update. Defaults to torch.float16. |
|
|
|
Returns: |
|
None. |
|
""" |
|
if self.reference_attn: |
|
if self.fusion_blocks == "midup": |
|
reader_attn_modules = [ |
|
module |
|
for module in ( |
|
torch_dfs(self.unet.mid_block) + |
|
torch_dfs(self.unet.up_blocks) |
|
) |
|
if isinstance(module, TemporalBasicTransformerBlock) |
|
] |
|
writer_attn_modules = [ |
|
module |
|
for module in ( |
|
torch_dfs(writer.unet.mid_block) |
|
+ torch_dfs(writer.unet.up_blocks) |
|
) |
|
if isinstance(module, BasicTransformerBlock) |
|
] |
|
elif self.fusion_blocks == "full": |
|
reader_attn_modules = [ |
|
module |
|
for module in torch_dfs(self.unet) |
|
if isinstance(module, TemporalBasicTransformerBlock) |
|
] |
|
writer_attn_modules = [ |
|
module |
|
for module in torch_dfs(writer.unet) |
|
if isinstance(module, BasicTransformerBlock) |
|
] |
|
|
|
assert len(reader_attn_modules) == len(writer_attn_modules) |
|
reader_attn_modules = sorted( |
|
reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] |
|
) |
|
writer_attn_modules = sorted( |
|
writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] |
|
) |
|
for r, w in zip(reader_attn_modules, writer_attn_modules): |
|
r.bank = [v.clone().to(dtype) for v in w.bank] |
|
|
|
|
|
def clear(self): |
|
""" |
|
Clears the attention bank of all reader attention modules. |
|
|
|
This method is used when the `reference_attn` attribute is set to `True`. |
|
It clears the attention bank of all reader attention modules inside the UNet |
|
model based on the selected `fusion_blocks` mode. |
|
|
|
If `fusion_blocks` is set to "midup", it searches for reader attention modules |
|
in both the mid block and up blocks of the UNet model. If `fusion_blocks` is set |
|
to "full", it searches for reader attention modules in the entire UNet model. |
|
|
|
It sorts the reader attention modules by the number of neurons in their |
|
`norm1.normalized_shape[0]` attribute in descending order. This sorting ensures |
|
that the modules with more neurons are cleared first. |
|
|
|
Finally, it iterates through the sorted list of reader attention modules and |
|
calls the `clear()` method on each module's `bank` attribute to clear the |
|
attention bank. |
|
""" |
|
if self.reference_attn: |
|
if self.fusion_blocks == "midup": |
|
reader_attn_modules = [ |
|
module |
|
for module in ( |
|
torch_dfs(self.unet.mid_block) + |
|
torch_dfs(self.unet.up_blocks) |
|
) |
|
if isinstance(module, (BasicTransformerBlock, TemporalBasicTransformerBlock)) |
|
] |
|
elif self.fusion_blocks == "full": |
|
reader_attn_modules = [ |
|
module |
|
for module in torch_dfs(self.unet) |
|
if isinstance(module, (BasicTransformerBlock, TemporalBasicTransformerBlock)) |
|
] |
|
reader_attn_modules = sorted( |
|
reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] |
|
) |
|
for r in reader_attn_modules: |
|
r.bank.clear() |
|
|