# Copyright 2024 NVIDIA CORPORATION & AFFILIATES # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 # This file is modified from https://github.com/PixArt-alpha/PixArt-sigma import torch import torch.nn as nn from timm.models.layers import DropPath from diffusion.model.nets.basic_modules import DWMlp, MBConvPreGLU, Mlp from diffusion.model.nets.fastlinear.modules import TritonLiteMLA from diffusion.model.nets.sana_blocks import Attention, FlashAttention, MultiHeadCrossAttention, t2i_modulate class SanaMSPABlock(nn.Module): """ A Sana block with adaptive layer norm zero (adaLN-Zero) conditioning. reference VIT-22B https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L224 """ def __init__( self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0.0, input_size=None, sampling=None, sr_ratio=1, qk_norm=False, attn_type="flash", ffn_type="mlp", mlp_acts=("silu", "silu", None), **block_kwargs, ): super().__init__() self.hidden_size = hidden_size self.norm1 = nn.LayerNorm(hidden_size * 3, elementwise_affine=False, eps=1e-6) if attn_type == "flash": # flash self attention self.attn = FlashAttention( hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio, qk_norm=qk_norm, **block_kwargs, ) print("currently not support parallel attn") exit() elif attn_type == "linear": # linear self attention # TODO: Here the num_heads set to 36 for tmp used self_num_heads = hidden_size // 32 # self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8) self.attn = SlimLiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8) elif attn_type == "triton_linear": # linear self attention with triton kernel fusion self_num_heads = hidden_size // 32 self.attn = TritonLiteMLA(hidden_size, num_heads=self_num_heads, eps=1e-8) print("currently not support parallel attn") exit() elif attn_type == "vanilla": # vanilla self attention self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True) print("currently not support parallel attn") exit() else: raise ValueError(f"{attn_type} type is not defined.") self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs) self.norm2 = nn.LayerNorm(int(hidden_size * mlp_ratio * 2), elementwise_affine=False, eps=1e-6) if ffn_type == "dwmlp": approx_gelu = lambda: nn.GELU(approximate="tanh") self.mlp = DWMlp( in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0 ) print("currently not support parallel attn") exit() elif ffn_type == "glumbconv": self.mlp = SlimGLUMBConv( in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), use_bias=(True, True, False), norm=(None, None, None), act=mlp_acts, ) elif ffn_type == "mlp": approx_gelu = lambda: nn.GELU(approximate="tanh") self.mlp = Mlp( in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0 ) print("currently not support parallel attn") exit() elif ffn_type == "mbconvpreglu": self.mlp = MBConvPreGLU( in_dim=hidden_size, out_dim=hidden_size, mid_dim=int(hidden_size * mlp_ratio), use_bias=(True, True, False), norm=None, act=("silu", "silu", None), ) print("currently not support parallel attn") exit() else: raise ValueError(f"{ffn_type} type is not defined.") self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5) # parallel layers self.mlp_ratio = mlp_ratio self.in_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.in_proj = nn.Linear(hidden_size, (hidden_size * 3 + int(hidden_size * mlp_ratio * 2))) self.in_split = [hidden_size * 3] + [int(hidden_size * mlp_ratio * 2)] def forward(self, x, y, t, mask=None, HW=None, **kwargs): B, N, C = x.shape shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( self.scale_shift_table[None] + t.reshape(B, 6, -1) ).chunk(6, dim=1) # original Attention code # x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW)) # x = x + self.cross_attn(x, y, mask) # x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp), HW=HW)) # combine GLUMBConv fc1 & qkv projections # x_1 = self.in_norm(x) # x_1 = self.in_proj(x_1) x_1 = self.in_proj(self.in_norm(x)) qkv, x_mlp = torch.split(x_1, self.in_split, dim=-1) qkv = t2i_modulate(self.norm1(qkv), shift_msa.repeat(1, 1, 3), scale_msa.repeat(1, 1, 3)) x_mlp = t2i_modulate( self.norm2(x_mlp), shift_mlp.repeat(1, 1, int(self.mlp_ratio * 2)), scale_mlp.repeat(1, 1, int(self.mlp_ratio * 2)), ) # qkv = self.norm1(qkv) # x_mlp = self.norm2(x_mlp) # branch 1 x_attn = gate_msa * self.attn(qkv, HW=HW) x_attn = x_attn + self.cross_attn(x_attn, y, mask) # branch 2 x_mlp = gate_mlp * self.mlp(x_mlp, HW=HW) # Add residual w/ drop path & layer scale applied x = x + self.drop_path(x_attn + x_mlp) return x class SanaMSPABlock(nn.Module): """ A Sana block with adaptive layer norm zero (adaLN-Zero) conditioning. reference VIT-22B https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L224 """ def __init__( self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0.0, input_size=None, sampling=None, sr_ratio=1, qk_norm=False, attn_type="flash", ffn_type="mlp", mlp_acts=("silu", "silu", None), **block_kwargs, ): super().__init__() self.hidden_size = hidden_size self.norm1 = nn.LayerNorm(hidden_size * 3, elementwise_affine=False, eps=1e-6) if attn_type == "flash": # flash self attention self.attn = FlashAttention( hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio, qk_norm=qk_norm, **block_kwargs, ) print("currently not support parallel attn") exit() elif attn_type == "linear": # linear self attention # TODO: Here the num_heads set to 36 for tmp used self_num_heads = hidden_size // 32 # self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8) self.attn = SlimLiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8) elif attn_type == "triton_linear": # linear self attention with triton kernel fusion self_num_heads = hidden_size // 32 self.attn = TritonLiteMLA(hidden_size, num_heads=self_num_heads, eps=1e-8) print("currently not support parallel attn") exit() elif attn_type == "vanilla": # vanilla self attention self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True) print("currently not support parallel attn") exit() else: raise ValueError(f"{attn_type} type is not defined.") self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs) self.norm2 = nn.LayerNorm(int(hidden_size * mlp_ratio * 2), elementwise_affine=False, eps=1e-6) if ffn_type == "dwmlp": approx_gelu = lambda: nn.GELU(approximate="tanh") self.mlp = DWMlp( in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0 ) print("currently not support parallel attn") exit() elif ffn_type == "glumbconv": self.mlp = SlimGLUMBConv( in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), use_bias=(True, True, False), norm=(None, None, None), act=mlp_acts, ) elif ffn_type == "mlp": approx_gelu = lambda: nn.GELU(approximate="tanh") self.mlp = Mlp( in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0 ) print("currently not support parallel attn") exit() elif ffn_type == "mbconvpreglu": self.mlp = MBConvPreGLU( in_dim=hidden_size, out_dim=hidden_size, mid_dim=int(hidden_size * mlp_ratio), use_bias=(True, True, False), norm=None, act=("silu", "silu", None), ) print("currently not support parallel attn") exit() else: raise ValueError(f"{ffn_type} type is not defined.") self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5) # parallel layers self.mlp_ratio = mlp_ratio self.in_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.in_proj = nn.Linear(hidden_size, (hidden_size * 3 + int(hidden_size * mlp_ratio * 2))) self.in_split = [hidden_size * 3] + [int(hidden_size * mlp_ratio * 2)] def forward(self, x, y, t, mask=None, HW=None, **kwargs): B, N, C = x.shape shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( self.scale_shift_table[None] + t.reshape(B, 6, -1) ).chunk(6, dim=1) x_1 = self.in_proj(self.in_norm(x)) qkv, x_mlp = torch.split(x_1, self.in_split, dim=-1) qkv = t2i_modulate(self.norm1(qkv), shift_msa.repeat(1, 1, 3), scale_msa.repeat(1, 1, 3)) x_mlp = t2i_modulate( self.norm2(x_mlp), shift_mlp.repeat(1, 1, int(self.mlp_ratio * 2)), scale_mlp.repeat(1, 1, int(self.mlp_ratio * 2)), ) # branch 1 x_attn = gate_msa * self.attn(qkv, HW=HW) x_attn = x_attn + self.cross_attn(x_attn, y, mask) # branch 2 x_mlp = gate_mlp * self.mlp(x_mlp, HW=HW) # Add residual w/ drop path & layer scale applied x = x + self.drop_path(x_attn + x_mlp) return x