|
import math |
|
import warnings |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
from torch import nn |
|
from torch.nn.functional import scaled_dot_product_attention |
|
|
|
from models.helpers import DropPath |
|
from models.rope import apply_rotary_emb |
|
|
|
try: |
|
from flash_attn.ops.fused_dense import fused_mlp_func |
|
except ImportError: |
|
fused_mlp_func = None |
|
|
|
|
|
__all__ = ["FFN", "SwiGLUFFN", "RMSNorm", "AdaLNSelfCrossAttn", "AdaLNBeforeHead"] |
|
|
|
|
|
try: |
|
from apex.normalization import FusedRMSNorm as RMSNorm |
|
except ImportError: |
|
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") |
|
|
|
class RMSNorm(torch.nn.Module): |
|
def __init__(self, dim: int, eps: float = 1e-6): |
|
""" |
|
Initialize the RMSNorm normalization layer. |
|
|
|
Args: |
|
dim (int): The dimension of the input tensor. |
|
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. |
|
|
|
Attributes: |
|
eps (float): A small value added to the denominator for numerical stability. |
|
weight (nn.Parameter): Learnable scaling parameter. |
|
|
|
""" |
|
super().__init__() |
|
self.eps = eps |
|
self.weight = nn.Parameter(torch.ones(dim)) |
|
|
|
def _norm(self, x): |
|
""" |
|
Apply the RMSNorm normalization to the input tensor. |
|
|
|
Args: |
|
x (torch.Tensor): The input tensor. |
|
|
|
Returns: |
|
torch.Tensor: The normalized tensor. |
|
|
|
""" |
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
|
def forward(self, x): |
|
""" |
|
Forward pass through the RMSNorm layer. |
|
|
|
Args: |
|
x (torch.Tensor): The input tensor. |
|
|
|
Returns: |
|
torch.Tensor: The output tensor after applying RMSNorm. |
|
|
|
""" |
|
output = self._norm(x.float()).type_as(x) |
|
return output * self.weight |
|
|
|
|
|
class FFN(nn.Module): |
|
def __init__( |
|
self, |
|
in_features, |
|
hidden_features=None, |
|
out_features=None, |
|
drop=0.0, |
|
fused_if_available=True, |
|
): |
|
super().__init__() |
|
self.fused_mlp_func = fused_mlp_func if fused_if_available else None |
|
out_features = out_features or in_features |
|
hidden_features = hidden_features or in_features |
|
self.fc1 = nn.Linear(in_features, hidden_features) |
|
self.act = nn.GELU(approximate="tanh") |
|
self.fc2 = nn.Linear(hidden_features, out_features) |
|
self.drop = nn.Dropout(drop, inplace=True) if drop > 0 else nn.Identity() |
|
|
|
def forward(self, x): |
|
if self.fused_mlp_func is not None: |
|
return self.drop( |
|
self.fused_mlp_func( |
|
x=x, |
|
weight1=self.fc1.weight, |
|
weight2=self.fc2.weight, |
|
bias1=self.fc1.bias, |
|
bias2=self.fc2.bias, |
|
activation="gelu_approx", |
|
save_pre_act=self.training, |
|
return_residual=False, |
|
checkpoint_lvl=0, |
|
heuristic=0, |
|
process_group=None, |
|
) |
|
) |
|
else: |
|
return self.drop(self.fc2(self.act(self.fc1(x)))) |
|
|
|
def extra_repr(self) -> str: |
|
return f"fused_mlp_func={self.fused_mlp_func is not None}" |
|
|
|
|
|
class SwiGLUFFN(nn.Module): |
|
def __init__( |
|
self, |
|
dim: int, |
|
ff_mult: float = 8 / 3, |
|
): |
|
""" |
|
Initialize the FeedForward module. |
|
|
|
Args: |
|
dim (int): Input dimension. |
|
ff_mult (float, optional): Custom multiplier for hidden dimension. Defaults to 4. |
|
""" |
|
super().__init__() |
|
hidden_dim = int(dim * ff_mult) |
|
|
|
self.up_proj = nn.Linear(dim, hidden_dim, bias=False) |
|
self.down_proj = nn.Linear(hidden_dim, dim, bias=False) |
|
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) |
|
self.fused_mlp_func = None |
|
self._init() |
|
|
|
def _init(self): |
|
for module in self.modules(): |
|
if isinstance(module, nn.Linear): |
|
nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
def _forward_silu_gating(self, x_gate: torch.Tensor, x_up: torch.Tensor): |
|
return F.silu(x_gate) * x_up |
|
|
|
def forward(self, x: torch.Tensor): |
|
return self.down_proj( |
|
self._forward_silu_gating(self.gate_proj(x), self.up_proj(x)) |
|
) |
|
|
|
def extra_repr(self) -> str: |
|
return f"fused_mlp_func={self.fused_mlp_func is not None}" |
|
|
|
|
|
class CrossAttention(nn.Module): |
|
def __init__( |
|
self, |
|
embed_dim: int = 768, |
|
context_dim: int = 2048, |
|
num_heads: int = 12, |
|
attn_drop: float = 0.0, |
|
proj_drop: float = 0.0, |
|
qk_norm: bool = False, |
|
): |
|
super().__init__() |
|
assert embed_dim % num_heads == 0 |
|
assert attn_drop == 0.0 |
|
|
|
self.num_heads, self.head_dim = ( |
|
num_heads, |
|
embed_dim // num_heads, |
|
) |
|
self.qk_norm = qk_norm |
|
self.scale = 1 / math.sqrt(self.head_dim) |
|
|
|
self.q_norm = nn.LayerNorm(embed_dim, eps=1e-6, elementwise_affine=False) |
|
self.k_norm = nn.LayerNorm(embed_dim, eps=1e-6, elementwise_affine=False) |
|
|
|
self.to_q = nn.Linear(embed_dim, embed_dim, bias=True) |
|
self.to_kv = nn.Linear(context_dim, embed_dim * 2, bias=True) |
|
|
|
self.proj = nn.Linear(embed_dim, embed_dim) |
|
self.proj_drop = ( |
|
nn.Dropout(proj_drop, inplace=True) if proj_drop > 0 else nn.Identity() |
|
) |
|
self.attn_drop = attn_drop |
|
|
|
|
|
self.caching, self.cached_k, self.cached_v = False, None, None |
|
|
|
def kv_caching(self, enable: bool): |
|
self.caching, self.cached_k, self.cached_v = enable, None, None |
|
|
|
def forward(self, x, context, context_attn_bias=None, freqs_cis=None): |
|
B, L, C = x.shape |
|
context_B, context_L, context_C = context.shape |
|
assert B == context_B |
|
|
|
q = self.to_q(x).view(B, L, -1) |
|
if self.qk_norm: |
|
q = self.q_norm(q) |
|
|
|
q = q.view(B, L, self.num_heads, self.head_dim) |
|
q = q.permute(0, 2, 1, 3) |
|
|
|
if self.cached_k is None: |
|
|
|
kv = self.to_kv(context).view(B, context_L, 2, -1) |
|
k, v = kv.permute(2, 0, 1, 3).unbind(dim=0) |
|
|
|
if self.qk_norm: |
|
k = self.k_norm(k) |
|
|
|
k = k.view(B, context_L, self.num_heads, self.head_dim) |
|
k = k.permute(0, 2, 1, 3) |
|
|
|
v = v.view(B, context_L, self.num_heads, self.head_dim) |
|
v = v.permute(0, 2, 1, 3) |
|
|
|
if self.caching: |
|
self.cached_k = k |
|
self.cached_v = v |
|
else: |
|
k = self.cached_k |
|
v = self.cached_v |
|
|
|
if context_attn_bias is not None: |
|
context_attn_bias = rearrange(context_attn_bias, "b j -> b 1 1 j") |
|
|
|
dropout_p = self.attn_drop if self.training else 0.0 |
|
out = ( |
|
scaled_dot_product_attention( |
|
query=q, |
|
key=k, |
|
value=v, |
|
scale=self.scale, |
|
attn_mask=context_attn_bias, |
|
dropout_p=dropout_p, |
|
) |
|
.transpose(1, 2) |
|
.reshape(B, L, C) |
|
) |
|
|
|
return self.proj_drop(self.proj(out)) |
|
|
|
|
|
class SelfAttention(nn.Module): |
|
def __init__( |
|
self, |
|
block_idx: int, |
|
embed_dim: int = 768, |
|
num_heads: int = 12, |
|
attn_drop: float = 0.0, |
|
proj_drop: float = 0.0, |
|
qk_norm: bool = False, |
|
): |
|
super().__init__() |
|
assert embed_dim % num_heads == 0 |
|
self.block_idx, self.num_heads, self.head_dim = ( |
|
block_idx, |
|
num_heads, |
|
embed_dim // num_heads, |
|
) |
|
self.qk_norm = qk_norm |
|
self.scale = 1 / math.sqrt(self.head_dim) |
|
|
|
self.q_norm = nn.LayerNorm(embed_dim, eps=1e-6, elementwise_affine=False) |
|
self.k_norm = nn.LayerNorm(embed_dim, eps=1e-6, elementwise_affine=False) |
|
|
|
self.to_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=True) |
|
self.proj = nn.Linear(embed_dim, embed_dim) |
|
self.proj_drop = ( |
|
nn.Dropout(proj_drop, inplace=True) if proj_drop > 0 else nn.Identity() |
|
) |
|
self.attn_drop = attn_drop |
|
|
|
|
|
self.caching, self.cached_k, self.cached_v = False, None, None |
|
|
|
def kv_caching(self, enable: bool): |
|
self.caching, self.cached_k, self.cached_v = enable, None, None |
|
|
|
|
|
def forward(self, x, attn_bias, freqs_cis: torch.Tensor = None): |
|
B, L, C = x.shape |
|
|
|
qkv = self.to_qkv(x).view(B, L, 3, -1) |
|
q, k, v = qkv.permute(2, 0, 1, 3).unbind(dim=0) |
|
|
|
if self.qk_norm: |
|
q = self.q_norm(q) |
|
k = self.k_norm(k) |
|
|
|
q = q.view(B, L, self.num_heads, self.head_dim) |
|
q = q.permute(0, 2, 1, 3) |
|
k = k.view(B, L, self.num_heads, self.head_dim) |
|
k = k.permute(0, 2, 1, 3) |
|
v = v.view(B, L, self.num_heads, self.head_dim) |
|
v = v.permute(0, 2, 1, 3) |
|
dim_cat = 2 |
|
|
|
if freqs_cis is not None: |
|
q = apply_rotary_emb(q, freqs_cis=freqs_cis) |
|
k = apply_rotary_emb(k, freqs_cis=freqs_cis) |
|
|
|
if self.caching: |
|
if self.cached_k is None: |
|
self.cached_k = k |
|
self.cached_v = v |
|
else: |
|
k = self.cached_k = torch.cat((self.cached_k, k), dim=dim_cat) |
|
v = self.cached_v = torch.cat((self.cached_v, v), dim=dim_cat) |
|
|
|
dropout_p = self.attn_drop if self.training else 0.0 |
|
out = ( |
|
scaled_dot_product_attention( |
|
query=q, |
|
key=k, |
|
value=v, |
|
scale=self.scale, |
|
attn_mask=attn_bias, |
|
dropout_p=dropout_p, |
|
) |
|
.transpose(1, 2) |
|
.reshape(B, L, C) |
|
) |
|
|
|
return self.proj_drop(self.proj(out)) |
|
|
|
def extra_repr(self) -> str: |
|
return f"attn_l2_norm={self.qk_norm}" |
|
|
|
|
|
class AdaLNSelfCrossAttn(nn.Module): |
|
def __init__( |
|
self, |
|
block_idx, |
|
last_drop_p, |
|
embed_dim, |
|
cond_dim, |
|
num_heads, |
|
mlp_ratio=4.0, |
|
drop=0.0, |
|
attn_drop=0.0, |
|
drop_path=0.0, |
|
qk_norm=False, |
|
context_dim=None, |
|
use_swiglu_ffn=False, |
|
norm_eps=1e-6, |
|
use_crop_cond=False, |
|
): |
|
super().__init__() |
|
assert attn_drop == 0.0 |
|
assert qk_norm |
|
|
|
self.block_idx, self.last_drop_p, self.C = block_idx, last_drop_p, embed_dim |
|
self.C, self.D = embed_dim, cond_dim |
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
self.attn = SelfAttention( |
|
block_idx=block_idx, |
|
embed_dim=embed_dim, |
|
num_heads=num_heads, |
|
attn_drop=attn_drop, |
|
proj_drop=drop, |
|
qk_norm=qk_norm, |
|
) |
|
|
|
if context_dim: |
|
self.cross_attn = CrossAttention( |
|
embed_dim=embed_dim, |
|
context_dim=context_dim, |
|
num_heads=num_heads, |
|
attn_drop=attn_drop, |
|
proj_drop=drop, |
|
qk_norm=qk_norm, |
|
) |
|
else: |
|
self.cross_attn = None |
|
|
|
if use_swiglu_ffn: |
|
self.ffn = SwiGLUFFN(dim=embed_dim) |
|
else: |
|
self.ffn = FFN( |
|
in_features=embed_dim, |
|
hidden_features=round(embed_dim * mlp_ratio), |
|
drop=drop, |
|
) |
|
|
|
self.self_attention_norm1 = RMSNorm(embed_dim, eps=norm_eps) |
|
self.self_attention_norm2 = RMSNorm(embed_dim, eps=norm_eps) |
|
self.cross_attention_norm1 = RMSNorm(embed_dim, eps=norm_eps) |
|
self.cross_attention_norm2 = RMSNorm(embed_dim, eps=norm_eps) |
|
|
|
self.ffn_norm1 = RMSNorm(embed_dim, eps=norm_eps) |
|
self.ffn_norm2 = RMSNorm(embed_dim, eps=norm_eps) |
|
|
|
self.attention_y_norm = RMSNorm(context_dim, eps=norm_eps) |
|
|
|
|
|
lin = nn.Linear(cond_dim, 6 * embed_dim) |
|
self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin) |
|
|
|
self.fused_add_norm_fn = None |
|
|
|
self.use_crop_cond = use_crop_cond |
|
if use_crop_cond: |
|
self.crop_cond_scales = nn.Parameter(torch.zeros(1, cond_dim)) |
|
|
|
|
|
def forward( |
|
self, |
|
x, |
|
cond_BD, |
|
attn_bias, |
|
crop_cond=None, |
|
context=None, |
|
context_attn_bias=None, |
|
freqs_cis=None, |
|
): |
|
|
|
if self.use_crop_cond: |
|
assert crop_cond is not None |
|
cond_BD = cond_BD + self.crop_cond_scales * crop_cond |
|
|
|
gamma1, gamma2, scale1, scale2, shift1, shift2 = ( |
|
self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2) |
|
) |
|
x = x + self.self_attention_norm2( |
|
self.attn( |
|
self.self_attention_norm1(x).mul(scale1.add(1)).add(shift1), |
|
attn_bias=attn_bias, |
|
freqs_cis=freqs_cis, |
|
) |
|
).mul(gamma1) |
|
if context is not None: |
|
x = x + self.cross_attention_norm2( |
|
self.cross_attn( |
|
self.cross_attention_norm1(x), |
|
self.attention_y_norm(context), |
|
context_attn_bias=context_attn_bias, |
|
freqs_cis=freqs_cis, |
|
) |
|
) |
|
x = x + self.ffn_norm2( |
|
self.ffn(self.ffn_norm1(x).mul(scale2.add(1)).add(shift2)) |
|
).mul(gamma2) |
|
return x |
|
|
|
|
|
class AdaLNBeforeHead(nn.Module): |
|
def __init__(self, C, D, norm_layer): |
|
super().__init__() |
|
self.C, self.D = C, D |
|
self.ln_wo_grad = norm_layer(C, elementwise_affine=False) |
|
self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), nn.Linear(D, 2 * C)) |
|
|
|
def forward(self, x_BLC: torch.Tensor, cond_BD: torch.Tensor): |
|
scale, shift = self.ada_lin(cond_BD).view(-1, 1, 2, self.C).unbind(2) |
|
return self.ln_wo_grad(x_BLC).mul(scale.add(1)).add_(shift) |
|
|