from typing import Any, Dict, Optional, Tuple import torch import torch.nn.functional as F from diffusers.models.embeddings import TimestepEmbedding, Timesteps from torch import nn def zero_module(module): # Zero out the parameters of a module and return it. for p in module.parameters(): p.detach().zero_() return module class FP32LayerNorm(nn.LayerNorm): def forward(self, inputs: torch.Tensor) -> torch.Tensor: origin_dtype = inputs.dtype if hasattr(self, 'weight') and self.weight is not None: return F.layer_norm( inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps ).to(origin_dtype) else: return F.layer_norm( inputs.float(), self.normalized_shape, None, None, self.eps ).to(origin_dtype) class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): """ For PixArt-Alpha. Reference: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 """ def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False): super().__init__() self.outdim = size_emb_dim self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.use_additional_conditions = use_additional_conditions if use_additional_conditions: self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) self.resolution_embedder.linear_2 = zero_module(self.resolution_embedder.linear_2) self.aspect_ratio_embedder.linear_2 = zero_module(self.aspect_ratio_embedder.linear_2) def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) if self.use_additional_conditions: resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype) resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1) aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype) aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1) conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1) else: conditioning = timesteps_emb return conditioning class AdaLayerNormSingle(nn.Module): r""" Norm layer adaptive layer norm single (adaLN-single). As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). Parameters: embedding_dim (`int`): The size of each embedding vector. use_additional_conditions (`bool`): To use additional conditions for normalization or not. """ def __init__(self, embedding_dim: int, use_additional_conditions: bool = False): super().__init__() self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions ) self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) def forward( self, timestep: torch.Tensor, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, batch_size: Optional[int] = None, hidden_dtype: Optional[torch.dtype] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # No modulation happening here. embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) return self.linear(self.silu(embedded_timestep)), embedded_timestep