robin-courant's picture
Add app
f7a5cb1 verified
raw
history blame
39.1 kB
import torch
import torch.nn as nn
from torch import Tensor
import numpy as np
from einops import rearrange
from typing import Optional, List
from torchtyping import TensorType
from einops._torch_specific import allow_ops_in_compiled_graph # requires einops>=0.6.1
allow_ops_in_compiled_graph()
batch_size, num_cond_feats = None, None
class FusedMLP(nn.Sequential):
def __init__(
self,
dim_model: int,
dropout: float,
activation: nn.Module,
hidden_layer_multiplier: int = 4,
bias: bool = True,
):
super().__init__(
nn.Linear(dim_model, dim_model * hidden_layer_multiplier, bias=bias),
activation(),
nn.Dropout(dropout),
nn.Linear(dim_model * hidden_layer_multiplier, dim_model, bias=bias),
)
def _cast_if_autocast_enabled(tensor):
if torch.is_autocast_enabled():
if tensor.device.type == "cuda":
dtype = torch.get_autocast_gpu_dtype()
elif tensor.device.type == "cpu":
dtype = torch.get_autocast_cpu_dtype()
else:
raise NotImplementedError()
return tensor.to(dtype=dtype)
return tensor
class LayerNorm16Bits(torch.nn.LayerNorm):
"""
16-bit friendly version of torch.nn.LayerNorm
"""
def __init__(
self,
normalized_shape,
eps=1e-06,
elementwise_affine=True,
device=None,
dtype=None,
):
super().__init__(
normalized_shape=normalized_shape,
eps=eps,
elementwise_affine=elementwise_affine,
device=device,
dtype=dtype,
)
def forward(self, x):
module_device = x.device
downcast_x = _cast_if_autocast_enabled(x)
downcast_weight = (
_cast_if_autocast_enabled(self.weight)
if self.weight is not None
else self.weight
)
downcast_bias = (
_cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
)
with torch.autocast(enabled=False, device_type=module_device.type):
return nn.functional.layer_norm(
downcast_x,
self.normalized_shape,
downcast_weight,
downcast_bias,
self.eps,
)
class StochatichDepth(nn.Module):
def __init__(self, p: float):
super().__init__()
self.survival_prob = 1.0 - p
def forward(self, x: Tensor) -> Tensor:
if self.training and self.survival_prob < 1:
mask = (
torch.empty(x.shape[0], 1, 1, device=x.device).uniform_()
+ self.survival_prob
)
mask = mask.floor()
if self.survival_prob > 0:
mask = mask / self.survival_prob
return x * mask
else:
return x
class CrossAttentionOp(nn.Module):
def __init__(
self, attention_dim, num_heads, dim_q, dim_kv, use_biases=True, is_sa=False
):
super().__init__()
self.dim_q = dim_q
self.dim_kv = dim_kv
self.attention_dim = attention_dim
self.num_heads = num_heads
self.use_biases = use_biases
self.is_sa = is_sa
if self.is_sa:
self.qkv = nn.Linear(dim_q, attention_dim * 3, bias=use_biases)
else:
self.q = nn.Linear(dim_q, attention_dim, bias=use_biases)
self.kv = nn.Linear(dim_kv, attention_dim * 2, bias=use_biases)
self.out = nn.Linear(attention_dim, dim_q, bias=use_biases)
def forward(self, x_to, x_from=None, attention_mask=None):
if x_from is None:
x_from = x_to
if self.is_sa:
q, k, v = self.qkv(x_to).chunk(3, dim=-1)
else:
q = self.q(x_to)
k, v = self.kv(x_from).chunk(2, dim=-1)
q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads)
k = rearrange(k, "b n (h d) -> b h n d", h=self.num_heads)
v = rearrange(v, "b n (h d) -> b h n d", h=self.num_heads)
if attention_mask is not None:
attention_mask = attention_mask.unsqueeze(1)
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attention_mask
)
x = rearrange(x, "b h n d -> b n (h d)")
x = self.out(x)
return x
class CrossAttentionBlock(nn.Module):
def __init__(
self,
dim_q: int,
dim_kv: int,
num_heads: int,
attention_dim: int = 0,
mlp_multiplier: int = 4,
dropout: float = 0.0,
stochastic_depth: float = 0.0,
use_biases: bool = True,
retrieve_attention_scores: bool = False,
use_layernorm16: bool = True,
):
super().__init__()
layer_norm = (
nn.LayerNorm
if not use_layernorm16 or retrieve_attention_scores
else LayerNorm16Bits
)
self.retrieve_attention_scores = retrieve_attention_scores
self.initial_to_ln = layer_norm(dim_q, eps=1e-6)
attention_dim = min(dim_q, dim_kv) if attention_dim == 0 else attention_dim
self.ca = CrossAttentionOp(
attention_dim, num_heads, dim_q, dim_kv, is_sa=False, use_biases=use_biases
)
self.ca_stochastic_depth = StochatichDepth(stochastic_depth)
self.middle_ln = layer_norm(dim_q, eps=1e-6)
self.ffn = FusedMLP(
dim_model=dim_q,
dropout=dropout,
activation=nn.GELU,
hidden_layer_multiplier=mlp_multiplier,
bias=use_biases,
)
self.ffn_stochastic_depth = StochatichDepth(stochastic_depth)
def forward(
self,
to_tokens: Tensor,
from_tokens: Tensor,
to_token_mask: Optional[Tensor] = None,
from_token_mask: Optional[Tensor] = None,
) -> Tensor:
if to_token_mask is None and from_token_mask is None:
attention_mask = None
else:
if to_token_mask is None:
to_token_mask = torch.ones(
to_tokens.shape[0],
to_tokens.shape[1],
dtype=torch.bool,
device=to_tokens.device,
)
if from_token_mask is None:
from_token_mask = torch.ones(
from_tokens.shape[0],
from_tokens.shape[1],
dtype=torch.bool,
device=from_tokens.device,
)
attention_mask = from_token_mask.unsqueeze(1) * to_token_mask.unsqueeze(2)
attention_output = self.ca(
self.initial_to_ln(to_tokens),
from_tokens,
attention_mask=attention_mask,
)
to_tokens = to_tokens + self.ca_stochastic_depth(attention_output)
to_tokens = to_tokens + self.ffn_stochastic_depth(
self.ffn(self.middle_ln(to_tokens))
)
return to_tokens
class SelfAttentionBlock(nn.Module):
def __init__(
self,
dim_qkv: int,
num_heads: int,
attention_dim: int = 0,
mlp_multiplier: int = 4,
dropout: float = 0.0,
stochastic_depth: float = 0.0,
use_biases: bool = True,
use_layer_scale: bool = False,
layer_scale_value: float = 0.0,
use_layernorm16: bool = True,
):
super().__init__()
layer_norm = LayerNorm16Bits if use_layernorm16 else nn.LayerNorm
self.initial_ln = layer_norm(dim_qkv, eps=1e-6)
attention_dim = dim_qkv if attention_dim == 0 else attention_dim
self.sa = CrossAttentionOp(
attention_dim,
num_heads,
dim_qkv,
dim_qkv,
is_sa=True,
use_biases=use_biases,
)
self.sa_stochastic_depth = StochatichDepth(stochastic_depth)
self.middle_ln = layer_norm(dim_qkv, eps=1e-6)
self.ffn = FusedMLP(
dim_model=dim_qkv,
dropout=dropout,
activation=nn.GELU,
hidden_layer_multiplier=mlp_multiplier,
bias=use_biases,
)
self.ffn_stochastic_depth = StochatichDepth(stochastic_depth)
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(
torch.ones(dim_qkv) * layer_scale_value, requires_grad=True
)
self.layer_scale_2 = nn.Parameter(
torch.ones(dim_qkv) * layer_scale_value, requires_grad=True
)
def forward(
self,
tokens: torch.Tensor,
token_mask: Optional[torch.Tensor] = None,
):
if token_mask is None:
attention_mask = None
else:
attention_mask = token_mask.unsqueeze(1) * torch.ones(
tokens.shape[0],
tokens.shape[1],
1,
dtype=torch.bool,
device=tokens.device,
)
attention_output = self.sa(
self.initial_ln(tokens),
attention_mask=attention_mask,
)
if self.use_layer_scale:
tokens = tokens + self.sa_stochastic_depth(
self.layer_scale_1 * attention_output
)
tokens = tokens + self.ffn_stochastic_depth(
self.layer_scale_2 * self.ffn(self.middle_ln(tokens))
)
else:
tokens = tokens + self.sa_stochastic_depth(attention_output)
tokens = tokens + self.ffn_stochastic_depth(
self.ffn(self.middle_ln(tokens))
)
return tokens
class AdaLNSABlock(nn.Module):
def __init__(
self,
dim_qkv: int,
dim_cond: int,
num_heads: int,
attention_dim: int = 0,
mlp_multiplier: int = 4,
dropout: float = 0.0,
stochastic_depth: float = 0.0,
use_biases: bool = True,
use_layer_scale: bool = False,
layer_scale_value: float = 0.1,
use_layernorm16: bool = True,
):
super().__init__()
layer_norm = LayerNorm16Bits if use_layernorm16 else nn.LayerNorm
self.initial_ln = layer_norm(dim_qkv, eps=1e-6, elementwise_affine=False)
attention_dim = dim_qkv if attention_dim == 0 else attention_dim
self.adaln_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(dim_cond, dim_qkv * 6, bias=use_biases),
)
# Zero init
nn.init.zeros_(self.adaln_modulation[1].weight)
nn.init.zeros_(self.adaln_modulation[1].bias)
self.sa = CrossAttentionOp(
attention_dim,
num_heads,
dim_qkv,
dim_qkv,
is_sa=True,
use_biases=use_biases,
)
self.sa_stochastic_depth = StochatichDepth(stochastic_depth)
self.middle_ln = layer_norm(dim_qkv, eps=1e-6, elementwise_affine=False)
self.ffn = FusedMLP(
dim_model=dim_qkv,
dropout=dropout,
activation=nn.GELU,
hidden_layer_multiplier=mlp_multiplier,
bias=use_biases,
)
self.ffn_stochastic_depth = StochatichDepth(stochastic_depth)
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(
torch.ones(dim_qkv) * layer_scale_value, requires_grad=True
)
self.layer_scale_2 = nn.Parameter(
torch.ones(dim_qkv) * layer_scale_value, requires_grad=True
)
def forward(
self,
tokens: torch.Tensor,
cond: torch.Tensor,
token_mask: Optional[torch.Tensor] = None,
):
if token_mask is None:
attention_mask = None
else:
attention_mask = token_mask.unsqueeze(1) * torch.ones(
tokens.shape[0],
tokens.shape[1],
1,
dtype=torch.bool,
device=tokens.device,
)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.adaln_modulation(cond).chunk(6, dim=-1)
)
attention_output = self.sa(
modulate_shift_and_scale(self.initial_ln(tokens), shift_msa, scale_msa),
attention_mask=attention_mask,
)
if self.use_layer_scale:
tokens = tokens + self.sa_stochastic_depth(
gate_msa.unsqueeze(1) * self.layer_scale_1 * attention_output
)
tokens = tokens + self.ffn_stochastic_depth(
gate_mlp.unsqueeze(1)
* self.layer_scale_2
* self.ffn(
modulate_shift_and_scale(
self.middle_ln(tokens), shift_mlp, scale_mlp
)
)
)
else:
tokens = tokens + gate_msa.unsqueeze(1) * self.sa_stochastic_depth(
attention_output
)
tokens = tokens + self.ffn_stochastic_depth(
gate_mlp.unsqueeze(1)
* self.ffn(
modulate_shift_and_scale(
self.middle_ln(tokens), shift_mlp, scale_mlp
)
)
)
return tokens
class CrossAttentionSABlock(nn.Module):
def __init__(
self,
dim_qkv: int,
dim_cond: int,
num_heads: int,
attention_dim: int = 0,
mlp_multiplier: int = 4,
dropout: float = 0.0,
stochastic_depth: float = 0.0,
use_biases: bool = True,
use_layer_scale: bool = False,
layer_scale_value: float = 0.0,
use_layernorm16: bool = True,
):
super().__init__()
layer_norm = LayerNorm16Bits if use_layernorm16 else nn.LayerNorm
attention_dim = dim_qkv if attention_dim == 0 else attention_dim
self.ca = CrossAttentionOp(
attention_dim,
num_heads,
dim_qkv,
dim_cond,
is_sa=False,
use_biases=use_biases,
)
self.ca_stochastic_depth = StochatichDepth(stochastic_depth)
self.ca_ln = layer_norm(dim_qkv, eps=1e-6)
self.initial_ln = layer_norm(dim_qkv, eps=1e-6)
attention_dim = dim_qkv if attention_dim == 0 else attention_dim
self.sa = CrossAttentionOp(
attention_dim,
num_heads,
dim_qkv,
dim_qkv,
is_sa=True,
use_biases=use_biases,
)
self.sa_stochastic_depth = StochatichDepth(stochastic_depth)
self.middle_ln = layer_norm(dim_qkv, eps=1e-6)
self.ffn = FusedMLP(
dim_model=dim_qkv,
dropout=dropout,
activation=nn.GELU,
hidden_layer_multiplier=mlp_multiplier,
bias=use_biases,
)
self.ffn_stochastic_depth = StochatichDepth(stochastic_depth)
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(
torch.ones(dim_qkv) * layer_scale_value, requires_grad=True
)
self.layer_scale_2 = nn.Parameter(
torch.ones(dim_qkv) * layer_scale_value, requires_grad=True
)
def forward(
self,
tokens: torch.Tensor,
cond: torch.Tensor,
token_mask: Optional[torch.Tensor] = None,
cond_mask: Optional[torch.Tensor] = None,
):
if cond_mask is None:
cond_attention_mask = None
else:
cond_attention_mask = torch.ones(
cond.shape[0],
1,
cond.shape[1],
dtype=torch.bool,
device=tokens.device,
) * token_mask.unsqueeze(2)
if token_mask is None:
attention_mask = None
else:
attention_mask = token_mask.unsqueeze(1) * torch.ones(
tokens.shape[0],
tokens.shape[1],
1,
dtype=torch.bool,
device=tokens.device,
)
ca_output = self.ca(
self.ca_ln(tokens),
cond,
attention_mask=cond_attention_mask,
)
ca_output = torch.nan_to_num(
ca_output, nan=0.0, posinf=0.0, neginf=0.0
) # Needed as some tokens get attention from no token so Nan
tokens = tokens + self.ca_stochastic_depth(ca_output)
attention_output = self.sa(
self.initial_ln(tokens),
attention_mask=attention_mask,
)
if self.use_layer_scale:
tokens = tokens + self.sa_stochastic_depth(
self.layer_scale_1 * attention_output
)
tokens = tokens + self.ffn_stochastic_depth(
self.layer_scale_2 * self.ffn(self.middle_ln(tokens))
)
else:
tokens = tokens + self.sa_stochastic_depth(attention_output)
tokens = tokens + self.ffn_stochastic_depth(
self.ffn(self.middle_ln(tokens))
)
return tokens
class CAAdaLNSABlock(nn.Module):
def __init__(
self,
dim_qkv: int,
dim_cond: int,
num_heads: int,
attention_dim: int = 0,
mlp_multiplier: int = 4,
dropout: float = 0.0,
stochastic_depth: float = 0.0,
use_biases: bool = True,
use_layer_scale: bool = False,
layer_scale_value: float = 0.1,
use_layernorm16: bool = True,
):
super().__init__()
layer_norm = LayerNorm16Bits if use_layernorm16 else nn.LayerNorm
self.ca = CrossAttentionOp(
attention_dim,
num_heads,
dim_qkv,
dim_cond,
is_sa=False,
use_biases=use_biases,
)
self.ca_stochastic_depth = StochatichDepth(stochastic_depth)
self.ca_ln = layer_norm(dim_qkv, eps=1e-6)
self.initial_ln = layer_norm(dim_qkv, eps=1e-6)
attention_dim = dim_qkv if attention_dim == 0 else attention_dim
self.adaln_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(dim_cond, dim_qkv * 6, bias=use_biases),
)
# Zero init
nn.init.zeros_(self.adaln_modulation[1].weight)
nn.init.zeros_(self.adaln_modulation[1].bias)
self.sa = CrossAttentionOp(
attention_dim,
num_heads,
dim_qkv,
dim_qkv,
is_sa=True,
use_biases=use_biases,
)
self.sa_stochastic_depth = StochatichDepth(stochastic_depth)
self.middle_ln = layer_norm(dim_qkv, eps=1e-6)
self.ffn = FusedMLP(
dim_model=dim_qkv,
dropout=dropout,
activation=nn.GELU,
hidden_layer_multiplier=mlp_multiplier,
bias=use_biases,
)
self.ffn_stochastic_depth = StochatichDepth(stochastic_depth)
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(
torch.ones(dim_qkv) * layer_scale_value, requires_grad=True
)
self.layer_scale_2 = nn.Parameter(
torch.ones(dim_qkv) * layer_scale_value, requires_grad=True
)
def forward(
self,
tokens: torch.Tensor,
cond_1: torch.Tensor,
cond_2: torch.Tensor,
cond_1_mask: Optional[torch.Tensor] = None,
token_mask: Optional[torch.Tensor] = None,
):
if token_mask is None and cond_1_mask is None:
cond_attention_mask = None
elif token_mask is None:
cond_attention_mask = cond_1_mask.unsqueeze(1) * torch.ones(
cond_1.shape[0],
cond_1.shape[1],
1,
dtype=torch.bool,
device=cond_1.device,
)
elif cond_1_mask is None:
cond_attention_mask = torch.ones(
tokens.shape[0],
1,
tokens.shape[1],
dtype=torch.bool,
device=tokens.device,
) * token_mask.unsqueeze(2)
else:
cond_attention_mask = cond_1_mask.unsqueeze(1) * token_mask.unsqueeze(2)
if token_mask is None:
attention_mask = None
else:
attention_mask = token_mask.unsqueeze(1) * torch.ones(
tokens.shape[0],
tokens.shape[1],
1,
dtype=torch.bool,
device=tokens.device,
)
ca_output = self.ca(
self.ca_ln(tokens),
cond_1,
attention_mask=cond_attention_mask,
)
ca_output = torch.nan_to_num(ca_output, nan=0.0, posinf=0.0, neginf=0.0)
tokens = tokens + self.ca_stochastic_depth(ca_output)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.adaln_modulation(cond_2).chunk(6, dim=-1)
)
attention_output = self.sa(
modulate_shift_and_scale(self.initial_ln(tokens), shift_msa, scale_msa),
attention_mask=attention_mask,
)
if self.use_layer_scale:
tokens = tokens + self.sa_stochastic_depth(
gate_msa.unsqueeze(1) * self.layer_scale_1 * attention_output
)
tokens = tokens + self.ffn_stochastic_depth(
gate_mlp.unsqueeze(1)
* self.layer_scale_2
* self.ffn(
modulate_shift_and_scale(
self.middle_ln(tokens), shift_mlp, scale_mlp
)
)
)
else:
tokens = tokens + gate_msa.unsqueeze(1) * self.sa_stochastic_depth(
attention_output
)
tokens = tokens + self.ffn_stochastic_depth(
gate_mlp.unsqueeze(1)
* self.ffn(
modulate_shift_and_scale(
self.middle_ln(tokens), shift_mlp, scale_mlp
)
)
)
return tokens
class PositionalEmbedding(nn.Module):
"""
Taken from https://github.com/NVlabs/edm
"""
def __init__(self, num_channels, max_positions=10000, endpoint=False):
super().__init__()
self.num_channels = num_channels
self.max_positions = max_positions
self.endpoint = endpoint
freqs = torch.arange(start=0, end=self.num_channels // 2, dtype=torch.float32)
freqs = 2 * freqs / self.num_channels
freqs = (1 / self.max_positions) ** freqs
self.register_buffer("freqs", freqs)
def forward(self, x):
x = torch.outer(x, self.freqs)
out = torch.cat([x.cos(), x.sin()], dim=1)
return out.to(x.dtype)
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.0, max_len=10000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x):
# not used in the final model
x = x + self.pe[:, : x.shape[1], :]
return self.dropout(x)
class TimeEmbedder(nn.Module):
def __init__(
self,
dim: int,
time_scaling: float,
expansion: int = 4,
):
super().__init__()
self.encode_time = PositionalEmbedding(num_channels=dim, endpoint=True)
self.time_scaling = time_scaling
self.map_time = nn.Sequential(
nn.Linear(dim, dim * expansion),
nn.SiLU(),
nn.Linear(dim * expansion, dim * expansion),
)
def forward(self, t: Tensor) -> Tensor:
time = self.encode_time(t * self.time_scaling)
time_mean = time.mean(dim=-1, keepdim=True)
time_std = time.std(dim=-1, keepdim=True)
time = (time - time_mean) / time_std
return self.map_time(time)
def modulate_shift_and_scale(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor:
return x * (1 + scale).unsqueeze(1) + shift.unsqueeze(1)
# ------------------------------------------------------------------------------------- #
class BaseDirector(nn.Module):
def __init__(
self,
name: str,
num_feats: int,
num_cond_feats: int,
num_cams: int,
latent_dim: int,
mlp_multiplier: int,
num_layers: int,
num_heads: int,
dropout: float,
stochastic_depth: float,
label_dropout: float,
num_rawfeats: int,
clip_sequential: bool = False,
cond_sequential: bool = False,
device: str = "cuda",
**kwargs,
):
super().__init__()
self.name = name
self.label_dropout = label_dropout
self.num_rawfeats = num_rawfeats
self.num_feats = num_feats
self.num_cams = num_cams
self.clip_sequential = clip_sequential
self.cond_sequential = cond_sequential
self.use_layernorm16 = device == "cuda"
self.input_projection = nn.Sequential(
nn.Linear(num_feats, latent_dim),
PositionalEncoding(latent_dim),
)
self.time_embedding = TimeEmbedder(latent_dim // 4, time_scaling=1000)
self.init_conds_mappings(num_cond_feats, latent_dim)
self.init_backbone(
num_layers, latent_dim, mlp_multiplier, num_heads, dropout, stochastic_depth
)
self.init_output_projection(num_feats, latent_dim)
def forward(
self,
x: Tensor,
timesteps: Tensor,
y: List[Tensor] = None,
mask: Tensor = None,
) -> Tensor:
mask = mask.logical_not() if mask is not None else None
x = rearrange(x, "b c n -> b n c")
x = self.input_projection(x)
t = self.time_embedding(timesteps)
if y is not None:
y = self.mask_cond(y)
y = self.cond_mapping(y, mask, t)
x = self.backbone(x, y, mask)
x = self.output_projection(x, y)
return rearrange(x, "b n c -> b c n")
def init_conds_mappings(self, num_cond_feats, latent_dim):
raise NotImplementedError(
"This method should be implemented in the derived class"
)
def init_backbone(self):
raise NotImplementedError(
"This method should be implemented in the derived class"
)
def cond_mapping(self, cond: List[Tensor], mask: Tensor, t: Tensor) -> Tensor:
raise NotImplementedError(
"This method should be implemented in the derived class"
)
def backbone(self, x: Tensor, y: Tensor, mask: Tensor) -> Tensor:
raise NotImplementedError(
"This method should be implemented in the derived class"
)
def mask_cond(
self, cond: List[TensorType["batch_size", "num_cond_feats"]]
) -> TensorType["batch_size", "num_cond_feats"]:
bs = cond[0].shape[0]
if self.training and self.label_dropout > 0.0:
# 1-> use null_cond, 0-> use real cond
prob = torch.ones(bs, device=cond[0].device) * self.label_dropout
masked_cond = []
common_mask = torch.bernoulli(prob) # Common to all modalities
for _cond in cond:
modality_mask = torch.bernoulli(prob) # Modality only
mask = torch.clip(common_mask + modality_mask, 0, 1)
mask = mask.view(bs, 1, 1) if _cond.dim() == 3 else mask.view(bs, 1)
masked_cond.append(_cond * (1.0 - mask))
return masked_cond
else:
return cond
def init_output_projection(self, num_feats, latent_dim):
raise NotImplementedError(
"This method should be implemented in the derived class"
)
def output_projection(self, x: Tensor, y: Tensor) -> Tensor:
raise NotImplementedError(
"This method should be implemented in the derived class"
)
class AdaLNDirector(BaseDirector):
def __init__(
self,
name: str,
num_feats: int,
num_cond_feats: int,
num_cams: int,
latent_dim: int,
mlp_multiplier: int,
num_layers: int,
num_heads: int,
dropout: float,
stochastic_depth: float,
label_dropout: float,
num_rawfeats: int,
clip_sequential: bool = False,
cond_sequential: bool = False,
device: str = "cuda",
**kwargs,
):
super().__init__(
name=name,
num_feats=num_feats,
num_cond_feats=num_cond_feats,
num_cams=num_cams,
latent_dim=latent_dim,
mlp_multiplier=mlp_multiplier,
num_layers=num_layers,
num_heads=num_heads,
dropout=dropout,
stochastic_depth=stochastic_depth,
label_dropout=label_dropout,
num_rawfeats=num_rawfeats,
clip_sequential=clip_sequential,
cond_sequential=cond_sequential,
device=device,
)
assert not (clip_sequential and cond_sequential)
def init_conds_mappings(self, num_cond_feats, latent_dim):
self.joint_cond_projection = nn.Linear(sum(num_cond_feats), latent_dim)
def cond_mapping(self, cond: List[Tensor], mask: Tensor, t: Tensor) -> Tensor:
c_emb = torch.cat(cond, dim=-1)
return self.joint_cond_projection(c_emb) + t
def init_backbone(
self,
num_layers,
latent_dim,
mlp_multiplier,
num_heads,
dropout,
stochastic_depth,
):
self.backbone_module = nn.ModuleList(
[
AdaLNSABlock(
dim_qkv=latent_dim,
dim_cond=latent_dim,
num_heads=num_heads,
mlp_multiplier=mlp_multiplier,
dropout=dropout,
stochastic_depth=stochastic_depth,
use_layernorm16=self.use_layernorm16,
)
for _ in range(num_layers)
]
)
def backbone(self, x: Tensor, y: Tensor, mask: Tensor) -> Tensor:
for block in self.backbone_module:
x = block(x, y, mask)
return x
def init_output_projection(self, num_feats, latent_dim):
layer_norm = LayerNorm16Bits if self.use_layernorm16 else nn.LayerNorm
self.final_norm = layer_norm(latent_dim, eps=1e-6, elementwise_affine=False)
self.final_linear = nn.Linear(latent_dim, num_feats, bias=True)
self.final_adaln = nn.Sequential(
nn.SiLU(),
nn.Linear(latent_dim, latent_dim * 2, bias=True),
)
# Zero init
nn.init.zeros_(self.final_adaln[1].weight)
nn.init.zeros_(self.final_adaln[1].bias)
def output_projection(self, x: Tensor, y: Tensor) -> Tensor:
shift, scale = self.final_adaln(y).chunk(2, dim=-1)
x = modulate_shift_and_scale(self.final_norm(x), shift, scale)
return self.final_linear(x)
class CrossAttentionDirector(BaseDirector):
def __init__(
self,
name: str,
num_feats: int,
num_cond_feats: int,
num_cams: int,
latent_dim: int,
mlp_multiplier: int,
num_layers: int,
num_heads: int,
dropout: float,
stochastic_depth: float,
label_dropout: float,
num_rawfeats: int,
num_text_registers: int,
clip_sequential: bool = True,
cond_sequential: bool = True,
device: str = "cuda",
**kwargs,
):
self.num_text_registers = num_text_registers
self.num_heads = num_heads
self.dropout = dropout
self.mlp_multiplier = mlp_multiplier
self.stochastic_depth = stochastic_depth
super().__init__(
name=name,
num_feats=num_feats,
num_cond_feats=num_cond_feats,
num_cams=num_cams,
latent_dim=latent_dim,
mlp_multiplier=mlp_multiplier,
num_layers=num_layers,
num_heads=num_heads,
dropout=dropout,
stochastic_depth=stochastic_depth,
label_dropout=label_dropout,
num_rawfeats=num_rawfeats,
clip_sequential=clip_sequential,
cond_sequential=cond_sequential,
device=device,
)
assert clip_sequential and cond_sequential
def init_conds_mappings(self, num_cond_feats, latent_dim):
self.cond_projection = nn.ModuleList(
[nn.Linear(num_cond_feat, latent_dim) for num_cond_feat in num_cond_feats]
)
self.cond_registers = nn.Parameter(
torch.randn(self.num_text_registers, latent_dim), requires_grad=True
)
nn.init.trunc_normal_(self.cond_registers, std=0.02, a=-2 * 0.02, b=2 * 0.02)
self.cond_sa = nn.ModuleList(
[
SelfAttentionBlock(
dim_qkv=latent_dim,
num_heads=self.num_heads,
mlp_multiplier=self.mlp_multiplier,
dropout=self.dropout,
stochastic_depth=self.stochastic_depth,
use_layernorm16=self.use_layernorm16,
)
for _ in range(2)
]
)
self.cond_positional_embedding = PositionalEncoding(latent_dim, max_len=10000)
def cond_mapping(self, cond: List[Tensor], mask: Tensor, t: Tensor) -> Tensor:
batch_size = cond[0].shape[0]
cond_emb = [
cond_proj(rearrange(c, "b c n -> b n c"))
for cond_proj, c in zip(self.cond_projection, cond)
]
cond_emb = [
self.cond_registers.unsqueeze(0).expand(batch_size, -1, -1),
t.unsqueeze(1),
] + cond_emb
cond_emb = torch.cat(cond_emb, dim=1)
cond_emb = self.cond_positional_embedding(cond_emb)
for block in self.cond_sa:
cond_emb = block(cond_emb)
return cond_emb
def init_backbone(
self,
num_layers,
latent_dim,
mlp_multiplier,
num_heads,
dropout,
stochastic_depth,
):
self.backbone_module = nn.ModuleList(
[
CrossAttentionSABlock(
dim_qkv=latent_dim,
dim_cond=latent_dim,
num_heads=num_heads,
mlp_multiplier=mlp_multiplier,
dropout=dropout,
stochastic_depth=stochastic_depth,
use_layernorm16=self.use_layernorm16,
)
for _ in range(num_layers)
]
)
def backbone(self, x: Tensor, y: Tensor, mask: Tensor) -> Tensor:
for block in self.backbone_module:
x = block(x, y, mask, None)
return x
def init_output_projection(self, num_feats, latent_dim):
layer_norm = LayerNorm16Bits if self.use_layernorm16 else nn.LayerNorm
self.final_norm = layer_norm(latent_dim, eps=1e-6)
self.final_linear = nn.Linear(latent_dim, num_feats, bias=True)
def output_projection(self, x: Tensor, y: Tensor) -> Tensor:
return self.final_linear(self.final_norm(x))
class InContextDirector(BaseDirector):
def __init__(
self,
name: str,
num_feats: int,
num_cond_feats: int,
num_cams: int,
latent_dim: int,
mlp_multiplier: int,
num_layers: int,
num_heads: int,
dropout: float,
stochastic_depth: float,
label_dropout: float,
num_rawfeats: int,
clip_sequential: bool = False,
cond_sequential: bool = False,
device: str = "cuda",
**kwargs,
):
super().__init__(
name=name,
num_feats=num_feats,
num_cond_feats=num_cond_feats,
num_cams=num_cams,
latent_dim=latent_dim,
mlp_multiplier=mlp_multiplier,
num_layers=num_layers,
num_heads=num_heads,
dropout=dropout,
stochastic_depth=stochastic_depth,
label_dropout=label_dropout,
num_rawfeats=num_rawfeats,
clip_sequential=clip_sequential,
cond_sequential=cond_sequential,
device=device,
)
def init_conds_mappings(self, num_cond_feats, latent_dim):
self.cond_projection = nn.ModuleList(
[nn.Linear(num_cond_feat, latent_dim) for num_cond_feat in num_cond_feats]
)
def cond_mapping(self, cond: List[Tensor], mask: Tensor, t: Tensor) -> Tensor:
for i in range(len(cond)):
if cond[i].dim() == 3:
cond[i] = rearrange(cond[i], "b c n -> b n c")
cond_emb = [cond_proj(c) for cond_proj, c in zip(self.cond_projection, cond)]
cond_emb = [c.unsqueeze(1) if c.dim() == 2 else cond_emb for c in cond_emb]
cond_emb = torch.cat([t.unsqueeze(1)] + cond_emb, dim=1)
return cond_emb
def init_backbone(
self,
num_layers,
latent_dim,
mlp_multiplier,
num_heads,
dropout,
stochastic_depth,
):
self.backbone_module = nn.ModuleList(
[
SelfAttentionBlock(
dim_qkv=latent_dim,
num_heads=num_heads,
mlp_multiplier=mlp_multiplier,
dropout=dropout,
stochastic_depth=stochastic_depth,
use_layernorm16=self.use_layernorm16,
)
for _ in range(num_layers)
]
)
def backbone(self, x: Tensor, y: Tensor, mask: Tensor) -> Tensor:
bs, n_y, _ = y.shape
mask = torch.cat([torch.ones(bs, n_y, device=y.device), mask], dim=1)
x = torch.cat([y, x], dim=1)
for block in self.backbone_module:
x = block(x, mask)
return x
def init_output_projection(self, num_feats, latent_dim):
layer_norm = LayerNorm16Bits if self.use_layernorm16 else nn.LayerNorm
self.final_norm = layer_norm(latent_dim, eps=1e-6)
self.final_linear = nn.Linear(latent_dim, num_feats, bias=True)
def output_projection(self, x: Tensor, y: Tensor) -> Tensor:
num_y = y.shape[1]
x = x[:, num_y:]
return self.final_linear(self.final_norm(x))