Spaces:
Running
on
Zero
Running
on
Zero
"""This file contains the model definition of TiTok. | |
Copyright (2024) Bytedance Ltd. and/or its affiliates | |
Licensed under the Apache License, Version 2.0 (the "License"); | |
you may not use this file except in compliance with the License. | |
You may obtain a copy of the License at | |
http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software | |
distributed under the License is distributed on an "AS IS" BASIS, | |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
See the License for the specific language governing permissions and | |
limitations under the License. | |
Reference: | |
https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py | |
https://github.com/facebookresearch/DiT/blob/main/models.py | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from modeling.modules import BaseModel | |
from functools import partial | |
from timm.layers import Mlp | |
from typing import Optional | |
import numpy as np | |
import random | |
# util function | |
def build_causal_mask(seq_length): | |
mask = torch.empty(seq_length, seq_length) | |
mask.fill_(float("-inf")) | |
mask.triu_(1) # zero out the lower diagonal | |
return mask | |
# weight init | |
def init_weights(module): | |
if (isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d) or | |
isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv3d)): | |
module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=0.02) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
elif isinstance(module, nn.Embedding): | |
module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=0.02) | |
elif isinstance(module, nn.LayerNorm): | |
if module.bias is not None: | |
module.bias.data.zero_() | |
if module.weight is not None: | |
module.weight.data.fill_(1.0) | |
# attention layer with KV cache supported | |
class Attention(nn.Module): | |
def __init__( | |
self, | |
dim: int, | |
num_heads: int = 8, | |
qkv_bias: bool = False, | |
qk_norm: bool = False, | |
attn_drop: float = 0., | |
proj_drop: float = 0., | |
norm_layer: nn.Module = nn.LayerNorm, | |
) -> None: | |
super().__init__() | |
assert dim % num_heads == 0, 'dim should be divisible by num_heads' | |
self.dim = dim | |
self.num_heads = num_heads | |
self.head_dim = dim // num_heads | |
self.scale = self.head_dim ** -0.5 | |
self.fused_attn = True | |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() | |
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
self.kv_cache = False | |
self.k_cache = None | |
self.v_cache = None | |
def reset_kv_cache(self): | |
self.k_cache = None | |
self.v_cache = None | |
def forward(self, x: torch.Tensor, attn_mask=None) -> torch.Tensor: | |
B, N, C = x.shape | |
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) | |
q, k, v = qkv.unbind(0) | |
q, k = self.q_norm(q), self.k_norm(k) | |
if self.kv_cache: | |
if self.k_cache is None and self.v_cache is None: | |
k_cache = k | |
v_cache = v | |
else: | |
assert N in [1, 2], f"x.shape {x.shape}" | |
k_cache = torch.cat([self.k_cache, k], dim=-2) | |
v_cache = torch.cat([self.v_cache, v], dim=-2) | |
self.k_cache = k_cache | |
self.v_cache = v_cache | |
k = k_cache | |
v = v_cache | |
x = F.scaled_dot_product_attention( | |
q, k, v, attn_mask=attn_mask, | |
dropout_p=self.attn_drop.p if self.training else 0., | |
) | |
x = x.transpose(1, 2).reshape(B, N, C) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
def modulate(x, shift, scale): | |
return x * (1 + scale) + shift | |
class FinalLayer(nn.Module): | |
def __init__(self, dim, norm_layer): | |
super().__init__() | |
self.norm_final = norm_layer(dim, elementwise_affine=False) | |
self.adaLN_modulation = nn.Sequential( | |
nn.SiLU(), nn.Linear(dim, 2*dim) | |
) | |
def forward(self, x, c): | |
scale, shift = self.adaLN_modulation(c).chunk(2, dim=-1) | |
x = modulate(self.norm_final(x), shift, scale) | |
return x | |
# basic transformer block | |
class Block(nn.Module): | |
def __init__( | |
self, | |
dim: int, | |
num_heads: int, | |
mlp_ratio: float = 4., | |
qkv_bias: bool = False, | |
qk_norm: bool = False, | |
proj_drop: float = 0., | |
attn_drop: float = 0., | |
act_layer: nn.Module = nn.GELU, | |
norm_layer: nn.Module = nn.LayerNorm, | |
mlp_layer: nn.Module = Mlp, | |
) -> None: | |
super().__init__() | |
self.norm1 = norm_layer(dim) | |
self.attn = Attention( | |
dim=dim, | |
num_heads=num_heads, | |
qkv_bias=qkv_bias, | |
qk_norm=qk_norm, | |
attn_drop=attn_drop, | |
proj_drop=proj_drop, | |
norm_layer=norm_layer, | |
) | |
self.norm2 = norm_layer(dim) | |
self.mlp = mlp_layer( | |
in_features=dim, | |
hidden_features=int(dim * mlp_ratio), | |
act_layer=act_layer, | |
drop=proj_drop, | |
) | |
self.adaLN_modulation = nn.Sequential( | |
nn.SiLU(), | |
nn.Linear(dim, 6 * dim, bias=True) | |
) | |
def forward(self, x: torch.Tensor, attn_mask=None, c = None) -> torch.Tensor: | |
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) | |
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), attn_mask=attn_mask) | |
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) | |
return x | |
class RAR(BaseModel): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
# parse the configs | |
embed_dim = config.model.generator.hidden_size | |
depth = config.model.generator.num_hidden_layers | |
num_heads = config.model.generator.num_attention_heads | |
intermediate_size = config.model.generator.intermediate_size | |
mlp_ratio = intermediate_size / embed_dim | |
image_seq_len = config.model.generator.image_seq_len | |
target_codebook_size = config.model.vq_model.codebook_size | |
condition_num_classes = config.model.generator.condition_num_classes | |
norm_layer=partial(nn.LayerNorm, eps=1e-6) | |
dropout_rate = config.model.generator.dropout | |
attn_dropout_rate = config.model.generator.attn_drop | |
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
self.blocks = nn.ModuleList([ | |
Block( | |
dim=embed_dim, | |
num_heads=num_heads, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=True, | |
qk_norm=True, | |
proj_drop=dropout_rate, | |
attn_drop=attn_dropout_rate, | |
norm_layer=norm_layer) | |
for i in range(depth)]) | |
self.embeddings = nn.Embedding( | |
target_codebook_size + 1 + condition_num_classes + 1, embed_dim) | |
self.pos_embed = nn.init.trunc_normal_( | |
nn.Parameter(torch.zeros(1, image_seq_len + 1024, embed_dim)), 0., 0.02) | |
self.target_aware_pos_embed = nn.init.trunc_normal_( | |
nn.Parameter(torch.zeros(1, image_seq_len + 1024, embed_dim)), 0., 0.02) | |
# number of steps == image_seq_len | |
self.timesteps_embeddings = nn.init.trunc_normal_( | |
nn.Parameter(torch.zeros(1, image_seq_len + 100, embed_dim)), 0., 0.02) | |
self.adaln_before_head = FinalLayer(embed_dim, norm_layer=norm_layer) | |
self.lm_head = nn.Linear(embed_dim, | |
target_codebook_size, bias=True) | |
self.condition_num_classes = condition_num_classes | |
self.image_seq_len = image_seq_len | |
self.target_codebook_size = target_codebook_size | |
self.none_condition_id = self.condition_num_classes + self.target_codebook_size + 1 | |
self.apply(init_weights) | |
attn_mask = build_causal_mask(self.image_seq_len + 1024) # include condition | |
self.register_buffer('attn_mask', attn_mask, persistent=False) | |
self.use_checkpoint = config.model.generator.get("use_checkpoint", False) | |
# init for adaln-zero. | |
nn.init.constant_(self.adaln_before_head.adaLN_modulation[-1].weight, 0) | |
nn.init.constant_(self.adaln_before_head.adaLN_modulation[-1].bias, 0) | |
for block in self.blocks: | |
nn.init.constant_(block.adaLN_modulation[-1].weight, 0) | |
nn.init.constant_(block.adaLN_modulation[-1].bias, 0) | |
self.random_ratio = 0.0 | |
def enable_kv_cache(self): | |
for block in self.blocks: | |
block.attn.kv_cache = True | |
block.attn.reset_kv_cache() | |
def disable_kv_cache(self): | |
for block in self.blocks: | |
block.attn.kv_cache = False | |
block.attn.reset_kv_cache() | |
def sample_orders(self, x): | |
batch_size = x.shape[0] | |
shuffled_orders = [] | |
for _ in range(batch_size): | |
if random.random() < self.random_ratio: | |
# random order | |
shuffled_orders.append(torch.randperm(self.image_seq_len, device=x.device)) | |
else: | |
# raster order | |
shuffled_orders.append(torch.arange(self.image_seq_len, device=x.device)) | |
shuffled_orders = torch.stack(shuffled_orders) | |
return shuffled_orders.to(x.device) | |
def set_random_ratio(self, new_ratio): | |
self.random_ratio = new_ratio | |
def get_raster_orders(self, x): | |
batch_size = x.shape[0] | |
shuffled_orders = torch.stack([torch.arange(self.image_seq_len, device=x.device) for _ in range(batch_size)]) | |
return shuffled_orders | |
def shuffle(self, x, orders): | |
batch_size, seq_len = x.shape[:2] | |
batch_indices = torch.arange(batch_size).unsqueeze(1).expand(-1, seq_len) | |
shuffled_x = x[batch_indices, orders] | |
return shuffled_x | |
def unshuffle(self, shuffled_x, orders): | |
# Unshuffle the tensor based on the original orders | |
batch_size, seq_len = shuffled_x.shape[:2] | |
batch_indices = torch.arange(batch_size).unsqueeze(1).expand(-1, seq_len) | |
unshuffled_x = torch.zeros_like(shuffled_x) | |
unshuffled_x[batch_indices, orders] = shuffled_x | |
return unshuffled_x | |
def preprocess_condition(self, condition, cond_drop_prob=0.0): | |
# Set class condition to None condition | |
drop_label_mask = torch.rand_like(condition, dtype=torch.float) < cond_drop_prob | |
condition = condition + self.target_codebook_size + 1 # [0, 999] -> [codebook_size + 1, codebook_size + 999] | |
condition[drop_label_mask] = self.none_condition_id | |
return condition | |
def get_none_condition(self, | |
condition | |
): | |
return torch.full_like(condition, self.none_condition_id) | |
def forward(self, input_ids, condition, return_labels=False): | |
orders = self.sample_orders(input_ids) | |
return self.forward_fn(input_ids, condition, return_labels, orders) | |
def forward_fn(self, input_ids, condition, | |
return_labels=False, | |
orders=None, | |
is_sampling=False): | |
# TODO: optimize the inference time where the computation of pos_embed etc can be shared across sampling steps. | |
# Token space: | |
# [0, codebook_size - 1] : those are the learned quantized image tokens | |
# codebook_size : the mask token used to mask image tokens | |
# [codebook_size + 1, codebook_size + nclass] : the imagenet class tokens | |
# codebook_size + 1 + nclass : the class drop label | |
if orders is None: | |
orders = self.get_raster_orders(input_ids) | |
labels = input_ids.clone() | |
# prepend condition token | |
input_ids = torch.cat([condition.view(condition.shape[0], -1), | |
input_ids.view(input_ids.shape[0], -1),], dim=1) | |
embeddings = self.embeddings(input_ids) | |
condition_token = embeddings[:, 0] | |
# prepare positional embeddings. | |
# shuffle pos embed | |
pos_embed = self.pos_embed.repeat(input_ids.shape[0], 1, 1) | |
# cls_token, condition, the permute does not impact these prefix tokens. | |
prefix = 2 | |
pos_embed_prefix = pos_embed[:, :prefix] | |
pos_embed_postfix = self.shuffle(pos_embed[:, prefix:prefix+self.image_seq_len], orders) | |
# prepare target-aware positional embeddings. | |
target_aware_pos_embed = self.target_aware_pos_embed.repeat(input_ids.shape[0], 1, 1) | |
# target_aware_pos_embed_prefix = target_aware_pos_embed[:, :prefix] | |
target_aware_pos_embed_postfix = self.shuffle(target_aware_pos_embed[:, prefix:prefix+self.image_seq_len], orders) | |
if not is_sampling: | |
# shuffle labels | |
labels = self.shuffle(labels, orders) | |
# randomized permutation: during training, we need to shuffle the input_ids's order but not for sampling | |
embeddings = torch.cat([embeddings[:, :1], self.shuffle(embeddings[:, 1:], orders)], dim=1) | |
x = embeddings | |
# prepend the cls token | |
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) | |
x = torch.cat((cls_tokens, x), dim=1) | |
# add original pos embed | |
x = x + torch.cat([pos_embed_prefix, pos_embed_postfix], dim=1)[:, :x.shape[1]] | |
# add target-aware pos embed | |
target_aware_pos_embed = torch.cat( | |
[torch.zeros_like(x[:, :prefix-1]), target_aware_pos_embed_postfix, torch.zeros_like(x[:, -1:])], dim=1 | |
) | |
x = x + target_aware_pos_embed[:, :x.shape[1]] | |
# causal attention masking | |
attn_mask = self.attn_mask[:x.shape[1], :x.shape[1]] | |
# seperate condition token for each step, at generation, we start from 1 to seq len | |
condition_token = condition_token.unsqueeze(1) + self.timesteps_embeddings[:, :x.shape[1]] | |
if self.blocks[0].attn.kv_cache: | |
if self.blocks[0].attn.k_cache is not None and self.blocks[0].attn.v_cache is not None: | |
# only need to process the last token | |
x = x[:, -1:] | |
attn_mask = None | |
# only keep the last condition | |
condition_token = condition_token[:, -1:] | |
for idx, blk in enumerate(self.blocks): | |
if self.use_checkpoint: | |
x = torch.utils.checkpoint.checkpoint( | |
blk.forward, x, attn_mask, condition_token, use_reentrant=False) | |
else: | |
x = blk(x, attn_mask=attn_mask, c=condition_token) | |
if not self.blocks[0].attn.kv_cache: | |
# remove cls token | |
x = x[:, prefix - 1:] | |
condition_token = condition_token[:, prefix - 1:] | |
x = self.adaln_before_head(x, condition_token) | |
x = self.lm_head(x) | |
if return_labels: | |
return x, labels | |
return x | |
def generate(self, | |
condition, | |
guidance_scale, | |
randomize_temperature, | |
guidance_scale_pow, | |
kv_cache=True, | |
**kwargs): | |
condition = self.preprocess_condition( | |
condition, cond_drop_prob=0.0) | |
device = condition.device | |
num_samples = condition.shape[0] | |
ids = torch.full((num_samples, 0), -1, device=device) | |
cfg_scale = 0. | |
if kv_cache: | |
self.enable_kv_cache() | |
orders = None | |
cfg_orders = None | |
for step in range(self.image_seq_len): | |
# ref: https://github.com/sail-sg/MDT/blob/441d6a1d49781dbca22b708bbd9ed81e9e3bdee4/masked_diffusion/models.py#L513C13-L513C23 | |
scale_pow = torch.ones((1), device=device) * guidance_scale_pow | |
scale_step = (1 - torch.cos( | |
((step / self.image_seq_len) ** scale_pow) * torch.pi)) * 1/2 | |
cfg_scale = (guidance_scale - 1) * scale_step + 1 | |
if guidance_scale != 0: | |
logits = self.forward_fn( | |
torch.cat([ids, ids], dim=0), | |
torch.cat([condition, self.get_none_condition(condition)], dim=0), | |
orders=cfg_orders, is_sampling=True) | |
cond_logits, uncond_logits = logits[:num_samples], logits[num_samples:] | |
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale | |
else: | |
logits = self.forward_fn( | |
ids, condition, orders=orders, is_sampling=True | |
) | |
# keep the logit of last token | |
logits = logits[:, -1] | |
logits = logits / randomize_temperature | |
probs = F.softmax(logits, dim=-1) | |
sampled = torch.multinomial(probs, num_samples=1) | |
ids = torch.cat((ids, sampled), dim = -1) | |
self.disable_kv_cache() | |
return ids | |