liuhuadai's picture
Upload 340 files
6efc863 verified
raw
history blame
20.4 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# GLIDE: https://github.com/openai/glide-text2im
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
# --------------------------------------------------------
from copy import deepcopy
import torch
import torch.nn as nn
import numpy as np
import math
import collections.abc
from itertools import repeat
from ldm.modules.new_attention import PositionEmbedding
from einops import rearrange
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return x
return tuple(repeat(x, 2))
################################################################
# Embedding Layers for Timesteps #
################################################################
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.proj_w = nn.Linear(frequency_embedding_size,frequency_embedding_size,bias=False)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t, w_cond=None):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
if w_cond is not None:
t_freq = t_freq + self.proj_w(w_cond)
t_emb = self.mlp(t_freq)
return t_emb
class Conv1DFinalLayer(nn.Module):
"""
The final layer of CrossAttnDiT.
"""
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.GroupNorm(16,hidden_size)
self.conv1d = nn.Conv1d(hidden_size, out_channels,kernel_size=1)
def forward(self, x): # x:(B,C,T)
x = self.norm_final(x)
x = self.conv1d(x)
return x
class ConditionEmbedder(nn.Module):
def __init__(self, hidden_size, context_dim):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(context_dim, hidden_size, bias=True),
nn.GELU(approximate='tanh'),
nn.Linear(hidden_size, hidden_size, bias=True),
nn.LayerNorm(hidden_size)
)
def forward(self,x):
return self.mlp(x)
from ldm.modules.new_attention import CrossAttention,Conv1dFeedForward,checkpoint,Normalize,zero_module
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., gated_ff=True, checkpoint=True): # 1 self 1 cross or 2 self
super().__init__()
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention,if context is none
self.ff = Conv1dFeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # use as cross attention
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x):
return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint)
def _forward(self, x):# x shape:(B,T,C)
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x)) + x
x = self.ff(self.norm3(x).permute(0,2,1)).permute(0,2,1) + x
return x
class TemporalTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Conv1d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout)
for d in range(depth)]
)
self.proj_out = zero_module(nn.Conv1d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))# initialize with zero
def forward(self, x):# x shape (b,c,t)
# note: if no context is given, cross-attention defaults to self-attention
x_in = x
x = self.norm(x)# group norm
x = self.proj_in(x)# no shape change
x = rearrange(x,'b c t -> b t c')
for block in self.transformer_blocks:
x = block(x)# context shape [b,seq_len=77,context_dim]
x = rearrange(x,'b t c -> b c t')
x = self.proj_out(x)
x = x + x_in
return x
class ConcatDiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
in_channels,
context_dim,
hidden_size=1152,
depth=28,
num_heads=16,
max_len = 1000,
):
super().__init__()
self.in_channels = in_channels # vae dim
self.out_channels = in_channels
self.num_heads = num_heads
kernel_size = 5
self.t_embedder = TimestepEmbedder(hidden_size)
self.c_embedder = ConditionEmbedder(hidden_size,context_dim)
self.proj_in = nn.Conv1d(in_channels,hidden_size,kernel_size=kernel_size,padding=kernel_size//2)
self.pos_emb = PositionEmbedding(num_embeddings=max_len,embedding_dim = hidden_size)
self.blocks = nn.ModuleList([
TemporalTransformer(hidden_size,num_heads,d_head=hidden_size//num_heads,depth=1,context_dim=context_dim) for _ in range(depth)
])
self.final_layer = Conv1DFinalLayer(hidden_size, self.out_channels)
self.initialize_weights()
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear): #
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
def forward(self, x, t, context, w_cond=None):
"""
Forward pass of DiT.
x: (N, C, T) tensor of temporal inputs (latent representations of melspec)
t: (N,) tensor of diffusion timesteps
y: (N,max_tokens_len=77, context_dim)
"""
t = self.t_embedder(t, w_cond=w_cond).unsqueeze(1) # (N,1,hidden_size)
c = self.c_embedder(context) # (N,c_len,hidden_size)
extra_len = c.shape[1] + 1
x = self.proj_in(x)
x = rearrange(x,'b c t -> b t c')
x = torch.concat([t,c,x],dim=1)
x = self.pos_emb(x)
x = rearrange(x,'b t c -> b c t')
for block in self.blocks:
x = block(x) # (N, D, extra_len+T)
x = x[...,extra_len:] # (N,D,T)
x = self.final_layer(x) # (N, out_channels,T)
return x
class ConcatDiT2MLP(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
in_channels,
context_dim,
hidden_size=1152,
depth=28,
num_heads=16,
max_len = 1000,
):
super().__init__()
self.in_channels = in_channels # vae dim
self.out_channels = in_channels
self.num_heads = num_heads
kernel_size = 5
self.t_embedder = TimestepEmbedder(hidden_size)
self.c1_embedder = ConditionEmbedder(hidden_size,context_dim)
self.c2_embedder = ConditionEmbedder(hidden_size,context_dim)
self.proj_in = nn.Conv1d(in_channels,hidden_size,kernel_size=kernel_size,padding=kernel_size//2)
self.pos_emb = PositionEmbedding(num_embeddings=max_len,embedding_dim = hidden_size)
self.blocks = nn.ModuleList([
TemporalTransformer(hidden_size,num_heads,d_head=hidden_size//num_heads,depth=1,context_dim=context_dim) for _ in range(depth)
])
self.final_layer = Conv1DFinalLayer(hidden_size, self.out_channels)
self.initialize_weights()
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear): #
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
def forward(self, x, t, context, w_cond=None):
"""
Forward pass of DiT.
x: (N, C, T) tensor of temporal inputs (latent representations of melspec)
t: (N,) tensor of diffusion timesteps
y: (N,max_tokens_len=77, context_dim)
"""
t = self.t_embedder(t, w_cond=w_cond).unsqueeze(1) # (N,1,hidden_size)
c1,c2 = context.chunk(2,dim=1)
c1 = self.c1_embedder(c1) # (N,c_len,hidden_size)
c2 = self.c2_embedder(c2) # (N,c_len,hidden_size)
c = torch.cat((c1,c2),dim=1)
extra_len = c.shape[1] + 1
x = self.proj_in(x)
x = rearrange(x,'b c t -> b t c')
x = torch.concat([t,c,x],dim=1)
x = self.pos_emb(x)
x = rearrange(x,'b t c -> b c t')
for block in self.blocks:
x = block(x) # (N, D, extra_len+T)
x = x[...,extra_len:] # (N,D,T)
x = self.final_layer(x) # (N, out_channels,T)
return x
class ConcatOrderDiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
in_channels,
context_dim,
hidden_size=1152,
depth=28,
num_heads=16,
max_len = 1000,
):
super().__init__()
self.in_channels = in_channels # vae dim
self.out_channels = in_channels
self.num_heads = num_heads
kernel_size = 5
self.t_embedder = TimestepEmbedder(hidden_size)
self.c_embedder = ConditionEmbedder(hidden_size,context_dim)
self.proj_in = nn.Conv1d(in_channels,hidden_size,kernel_size=kernel_size,padding=kernel_size//2)
self.pos_emb = PositionEmbedding(num_embeddings=max_len,embedding_dim = hidden_size)
self.order_embedding = nn.Embedding(num_embeddings=100,embedding_dim = hidden_size)
self.blocks = nn.ModuleList([
TemporalTransformer(hidden_size,num_heads,d_head=hidden_size//num_heads,depth=1,context_dim=context_dim) for _ in range(depth)
])
self.final_layer = Conv1DFinalLayer(hidden_size, self.out_channels)
self.initialize_weights()
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear): #
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
def add_order_embedding(self,token_emb,token_ids,orders_list):
"""
token_emb: shape (N,max_tokens_len=77, hidden_size)
token_ids: shape (N,max_tokens)
order_list: [N*list]. len(order_list[i]) == objs_num in text[i]
"""
for b,orderl in enumerate(orders_list):
orderl = torch.LongTensor(orderl).to(device=self.order_embedding.weight.device)
order_emb = self.order_embedding(orderl)
obj2index = []
cur_obj = 0
for i in range(token_ids.shape[1]):# max_length
token_id = token_ids[b][i]
if token_id in [101,102,0,1064]: # <start>,<eos>,<pad>,<|> . if another Tokenizer is used, this should be changed
obj2index.append(-1)
if token_id == 1064:
cur_obj += 1
else:
obj2index.append(cur_obj)
for i,order_index in enumerate(obj2index):
if order_index != -1:
token_emb[b][i] += order_emb[order_index]
return token_emb
def forward(self, x, t, context):
"""
Forward pass of DiT.
x: (N, C, T) tensor of temporal inputs (latent representations of melspec)
t: (N,) tensor of diffusion timesteps
context: dict{'token_embedding':(N,max_tokens_len=77, context_dim),'token_ids':tokens:(N,max_tokens_len=77),'orders':orders_list}
"""
token_embedding = context['token_embedding']
token_ids = context['token_ids']
orders = context['orders']
t = self.t_embedder(t).unsqueeze(1) # (N,1,hidden_size)
c = self.c_embedder(token_embedding) # (N,c_len,hidden_size)
c = self.add_order_embedding(c,token_ids,orders)
extra_len = c.shape[1] + 1
x = self.proj_in(x)
x = rearrange(x,'b c t -> b t c')
x = torch.concat([t,c,x],dim=1)
x = self.pos_emb(x)
x = rearrange(x,'b t c -> b c t')
for block in self.blocks:
x = block(x) # (N, D, extra_len+T)
x = x[...,extra_len:] # (N,D,T)
x = self.final_layer(x) # (N, out_channels,T)
return x
class ConcatOrderDiT2(nn.Module):
"""
Diffusion model with a Transformer backbone. concat by token
"""
def __init__(
self,
in_channels,
context_dim,
hidden_size=1152,
depth=28,
num_heads=16,
max_len = 1000,
):
super().__init__()
self.in_channels = in_channels # vae dim
self.out_channels = in_channels
self.num_heads = num_heads
kernel_size = 5
self.t_embedder = TimestepEmbedder(hidden_size)
self.c_embedder = ConditionEmbedder(hidden_size,context_dim)
self.proj_in = nn.Conv1d(in_channels,hidden_size,kernel_size=kernel_size,padding=kernel_size//2)
self.pos_emb = PositionEmbedding(num_embeddings=max_len,embedding_dim = hidden_size)
self.max_objs = 10
self.max_objs_order = 100
self.order_embedding = nn.Embedding(num_embeddings=self.max_objs_order + 1,embedding_dim = hidden_size)
self.blocks = nn.ModuleList([
TemporalTransformer(hidden_size,num_heads,d_head=hidden_size//num_heads,depth=1,context_dim=context_dim) for _ in range(depth)
])
self.final_layer = Conv1DFinalLayer(hidden_size, self.out_channels)
self.initialize_weights()
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear): #
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
def concat_order_embedding(self,token_emb,token_ids,orders_list):
"""
token_emb: shape (N,max_tokens_len=77, hidden_size)
token_ids: shape (N,max_tokens)
order_list: [N*list]. len(order_list[i]) == objs_num in text[i]
return token_emb: shape (N,max_tokens_len+self.max_objs, hidden_size)
"""
bsz,t,c = token_emb.shape
token_emb = list(torch.tensor_split(token_emb,bsz))# token_emb[i] shape (1,t,c)
orders_list = deepcopy(orders_list) # avoid inplace modification
for i in range(bsz):
token_emb[i] = list(torch.tensor_split(token_emb[i].squeeze(0),t))# token_emb[i][j] shape(1,c)
for b,orderl in enumerate(orders_list):
orderl.append(self.max_objs_order)# the last is for pad
orderl = torch.LongTensor(orderl).to(device=self.order_embedding.weight.device)
order_emb = self.order_embedding(orderl)# shape(len(orderl),hidden_size)
order_emb = torch.tensor_split(order_emb,len(orderl))# order_emb[i] shape (1,hidden_size)
obj_insert_index = []
for i in range(token_ids.shape[1]):# max_length
token_id = token_ids[b][i]
if token_id == 1064: # <|> after each word . if another Tokenizer is used, this should be changed
obj_insert_index.append(i+len(obj_insert_index))
for i,index in enumerate(obj_insert_index):
token_emb[b].insert(index,order_emb[i])
#print(f"len1:{len(token_emb[b])}")
for i in range(self.max_objs-len(orderl)+1):
token_emb[b].append(order_emb[-1])# pad to max_tokens_len+self.max_objs
token_emb[b] = torch.concat(token_emb[b])# shape:(max_tokens_len+self.max_objs,hidden_size)
#print(f"tokenemb shape:{token_emb[b].shape}")
token_emb = torch.stack(token_emb)
return token_emb
def forward(self, x, t, context):
"""
Forward pass of DiT.
x: (N, C, T) tensor of temporal inputs (latent representations of melspec)
t: (N,) tensor of diffusion timesteps
context: dict{'token_embedding':(N,max_tokens_len=77, context_dim),'token_ids':tokens:(N,max_tokens_len=77),'orders':orders_list}
"""
token_embedding = context['token_embedding']
token_ids = context['token_ids']
orders = context['orders']
t = self.t_embedder(t).unsqueeze(1) # (N,1,hidden_size)
c = self.c_embedder(token_embedding) # (N,c_len,hidden_size)
c = self.concat_order_embedding(c,token_ids,orders)
extra_len = c.shape[1] + 1
x = self.proj_in(x)
x = rearrange(x,'b c t -> b t c')
x = torch.concat([t,c,x],dim=1)
x = self.pos_emb(x)
x = rearrange(x,'b t c -> b c t')
for block in self.blocks:
x = block(x) # (N, D, extra_len+T)
x = x[...,extra_len:] # (N,D,T)
x = self.final_layer(x) # (N, out_channels,T)
return x