MoMask / models /mask_transformer /transformer.py
MeYourHint's picture
first demo version
c0eac48
raw
history blame
43.8 kB
import torch
import torch.nn as nn
import numpy as np
# from networks.layers import *
import torch.nn.functional as F
import clip
from einops import rearrange, repeat
import math
from random import random
from tqdm.auto import tqdm
from typing import Callable, Optional, List, Dict
from copy import deepcopy
from functools import partial
from models.mask_transformer.tools import *
from torch.distributions.categorical import Categorical
class InputProcess(nn.Module):
def __init__(self, input_feats, latent_dim):
super().__init__()
self.input_feats = input_feats
self.latent_dim = latent_dim
self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim)
def forward(self, x):
# [bs, ntokens, input_feats]
x = x.permute((1, 0, 2)) # [seqen, bs, input_feats]
# print(x.shape)
x = self.poseEmbedding(x) # [seqlen, bs, d]
return x
class PositionalEncoding(nn.Module):
#Borrow from MDM, the same as above, but add dropout, exponential may improve precision
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1) #[max_len, 1, d_model]
self.register_buffer('pe', pe)
def forward(self, x):
# not used in the final model
x = x + self.pe[:x.shape[0], :]
return self.dropout(x)
class OutputProcess_Bert(nn.Module):
def __init__(self, out_feats, latent_dim):
super().__init__()
self.dense = nn.Linear(latent_dim, latent_dim)
self.transform_act_fn = F.gelu
self.LayerNorm = nn.LayerNorm(latent_dim, eps=1e-12)
self.poseFinal = nn.Linear(latent_dim, out_feats) #Bias!
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
output = self.poseFinal(hidden_states) # [seqlen, bs, out_feats]
output = output.permute(1, 2, 0) # [bs, c, seqlen]
return output
class OutputProcess(nn.Module):
def __init__(self, out_feats, latent_dim):
super().__init__()
self.dense = nn.Linear(latent_dim, latent_dim)
self.transform_act_fn = F.gelu
self.LayerNorm = nn.LayerNorm(latent_dim, eps=1e-12)
self.poseFinal = nn.Linear(latent_dim, out_feats) #Bias!
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
output = self.poseFinal(hidden_states) # [seqlen, bs, out_feats]
output = output.permute(1, 2, 0) # [bs, e, seqlen]
return output
class MaskTransformer(nn.Module):
def __init__(self, code_dim, cond_mode, latent_dim=256, ff_size=1024, num_layers=8,
num_heads=4, dropout=0.1, clip_dim=512, cond_drop_prob=0.1,
clip_version=None, opt=None, **kargs):
super(MaskTransformer, self).__init__()
print(f'latent_dim: {latent_dim}, ff_size: {ff_size}, nlayers: {num_layers}, nheads: {num_heads}, dropout: {dropout}')
self.code_dim = code_dim
self.latent_dim = latent_dim
self.clip_dim = clip_dim
self.dropout = dropout
self.opt = opt
self.cond_mode = cond_mode
self.cond_drop_prob = cond_drop_prob
if self.cond_mode == 'action':
assert 'num_actions' in kargs
self.num_actions = kargs.get('num_actions', 1)
'''
Preparing Networks
'''
self.input_process = InputProcess(self.code_dim, self.latent_dim)
self.position_enc = PositionalEncoding(self.latent_dim, self.dropout)
seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,
nhead=num_heads,
dim_feedforward=ff_size,
dropout=dropout,
activation='gelu')
self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer,
num_layers=num_layers)
self.encode_action = partial(F.one_hot, num_classes=self.num_actions)
# if self.cond_mode != 'no_cond':
if self.cond_mode == 'text':
self.cond_emb = nn.Linear(self.clip_dim, self.latent_dim)
elif self.cond_mode == 'action':
self.cond_emb = nn.Linear(self.num_actions, self.latent_dim)
elif self.cond_mode == 'uncond':
self.cond_emb = nn.Identity()
else:
raise KeyError("Unsupported condition mode!!!")
_num_tokens = opt.num_tokens + 2 # two dummy tokens, one for masking, one for padding
self.mask_id = opt.num_tokens
self.pad_id = opt.num_tokens + 1
self.output_process = OutputProcess_Bert(out_feats=opt.num_tokens, latent_dim=latent_dim)
self.token_emb = nn.Embedding(_num_tokens, self.code_dim)
self.apply(self.__init_weights)
'''
Preparing frozen weights
'''
if self.cond_mode == 'text':
print('Loading CLIP...')
self.clip_version = clip_version
self.clip_model = self.load_and_freeze_clip(clip_version)
self.noise_schedule = cosine_schedule
def load_and_freeze_token_emb(self, codebook):
'''
:param codebook: (c, d)
:return:
'''
assert self.training, 'Only necessary in training mode'
c, d = codebook.shape
self.token_emb.weight = nn.Parameter(torch.cat([codebook, torch.zeros(size=(2, d), device=codebook.device)], dim=0)) #add two dummy tokens, 0 vectors
self.token_emb.requires_grad_(False)
# self.token_emb.weight.requires_grad = False
# self.token_emb_ready = True
print("Token embedding initialized!")
def __init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def parameters_wo_clip(self):
return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')]
def load_and_freeze_clip(self, clip_version):
clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
jit=False) # Must set jit=False for training
# Cannot run on cpu
clip.model.convert_weights(
clip_model) # Actually this line is unnecessary since clip by default already on float16
# Date 0707: It's necessary, only unecessary when load directly to gpu. Disable if need to run on cpu
# Freeze CLIP weights
clip_model.eval()
for p in clip_model.parameters():
p.requires_grad = False
return clip_model
def encode_text(self, raw_text):
device = next(self.parameters()).device
text = clip.tokenize(raw_text, truncate=True).to(device)
feat_clip_text = self.clip_model.encode_text(text).float()
return feat_clip_text
def mask_cond(self, cond, force_mask=False):
bs, d = cond.shape
if force_mask:
return torch.zeros_like(cond)
elif self.training and self.cond_drop_prob > 0.:
mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_drop_prob).view(bs, 1)
return cond * (1. - mask)
else:
return cond
def trans_forward(self, motion_ids, cond, padding_mask, force_mask=False):
'''
:param motion_ids: (b, seqlen)
:padding_mask: (b, seqlen), all pad positions are TRUE else FALSE
:param cond: (b, embed_dim) for text, (b, num_actions) for action
:param force_mask: boolean
:return:
-logits: (b, num_token, seqlen)
'''
cond = self.mask_cond(cond, force_mask=force_mask)
# print(motion_ids.shape)
x = self.token_emb(motion_ids)
# print(x.shape)
# (b, seqlen, d) -> (seqlen, b, latent_dim)
x = self.input_process(x)
cond = self.cond_emb(cond).unsqueeze(0) #(1, b, latent_dim)
x = self.position_enc(x)
xseq = torch.cat([cond, x], dim=0) #(seqlen+1, b, latent_dim)
padding_mask = torch.cat([torch.zeros_like(padding_mask[:, 0:1]), padding_mask], dim=1) #(b, seqlen+1)
# print(xseq.shape, padding_mask.shape)
# print(padding_mask.shape, xseq.shape)
output = self.seqTransEncoder(xseq, src_key_padding_mask=padding_mask)[1:] #(seqlen, b, e)
logits = self.output_process(output) #(seqlen, b, e) -> (b, ntoken, seqlen)
return logits
def forward(self, ids, y, m_lens):
'''
:param ids: (b, n)
:param y: raw text for cond_mode=text, (b, ) for cond_mode=action
:m_lens: (b,)
:return:
'''
bs, ntokens = ids.shape
device = ids.device
# Positions that are PADDED are ALL FALSE
non_pad_mask = lengths_to_mask(m_lens, ntokens) #(b, n)
ids = torch.where(non_pad_mask, ids, self.pad_id)
force_mask = False
if self.cond_mode == 'text':
with torch.no_grad():
cond_vector = self.encode_text(y)
elif self.cond_mode == 'action':
cond_vector = self.enc_action(y).to(device).float()
elif self.cond_mode == 'uncond':
cond_vector = torch.zeros(bs, self.latent_dim).float().to(device)
force_mask = True
else:
raise NotImplementedError("Unsupported condition mode!!!")
'''
Prepare mask
'''
rand_time = uniform((bs,), device=device)
rand_mask_probs = self.noise_schedule(rand_time)
num_token_masked = (ntokens * rand_mask_probs).round().clamp(min=1)
batch_randperm = torch.rand((bs, ntokens), device=device).argsort(dim=-1)
# Positions to be MASKED are ALL TRUE
mask = batch_randperm < num_token_masked.unsqueeze(-1)
# Positions to be MASKED must also be NON-PADDED
mask &= non_pad_mask
# Note this is our training target, not input
labels = torch.where(mask, ids, self.mask_id)
x_ids = ids.clone()
# Further Apply Bert Masking Scheme
# Step 1: 10% replace with an incorrect token
mask_rid = get_mask_subset_prob(mask, 0.1)
rand_id = torch.randint_like(x_ids, high=self.opt.num_tokens)
x_ids = torch.where(mask_rid, rand_id, x_ids)
# Step 2: 90% x 10% replace with correct token, and 90% x 88% replace with mask token
mask_mid = get_mask_subset_prob(mask & ~mask_rid, 0.88)
# mask_mid = mask
x_ids = torch.where(mask_mid, self.mask_id, x_ids)
logits = self.trans_forward(x_ids, cond_vector, ~non_pad_mask, force_mask)
ce_loss, pred_id, acc = cal_performance(logits, labels, ignore_index=self.mask_id)
return ce_loss, pred_id, acc
def forward_with_cond_scale(self,
motion_ids,
cond_vector,
padding_mask,
cond_scale=3,
force_mask=False):
# bs = motion_ids.shape[0]
# if cond_scale == 1:
if force_mask:
return self.trans_forward(motion_ids, cond_vector, padding_mask, force_mask=True)
logits = self.trans_forward(motion_ids, cond_vector, padding_mask)
if cond_scale == 1:
return logits
aux_logits = self.trans_forward(motion_ids, cond_vector, padding_mask, force_mask=True)
scaled_logits = aux_logits + (logits - aux_logits) * cond_scale
return scaled_logits
@torch.no_grad()
@eval_decorator
def generate(self,
conds,
m_lens,
timesteps: int,
cond_scale: int,
temperature=1,
topk_filter_thres=0.9,
gsample=False,
force_mask=False
):
# print(self.opt.num_quantizers)
# assert len(timesteps) >= len(cond_scales) == self.opt.num_quantizers
device = next(self.parameters()).device
seq_len = max(m_lens)
batch_size = len(m_lens)
if self.cond_mode == 'text':
with torch.no_grad():
cond_vector = self.encode_text(conds)
elif self.cond_mode == 'action':
cond_vector = self.enc_action(conds).to(device)
elif self.cond_mode == 'uncond':
cond_vector = torch.zeros(batch_size, self.latent_dim).float().to(device)
else:
raise NotImplementedError("Unsupported condition mode!!!")
padding_mask = ~lengths_to_mask(m_lens, seq_len)
# print(padding_mask.shape, )
# Start from all tokens being masked
ids = torch.where(padding_mask, self.pad_id, self.mask_id)
scores = torch.where(padding_mask, 1e5, 0.)
starting_temperature = temperature
for timestep, steps_until_x0 in zip(torch.linspace(0, 1, timesteps, device=device), reversed(range(timesteps))):
# 0 < timestep < 1
rand_mask_prob = self.noise_schedule(timestep) # Tensor
'''
Maskout, and cope with variable length
'''
# fix: the ratio regarding lengths, instead of seq_len
num_token_masked = torch.round(rand_mask_prob * m_lens).clamp(min=1) # (b, )
# select num_token_masked tokens with lowest scores to be masked
sorted_indices = scores.argsort(
dim=1) # (b, k), sorted_indices[i, j] = the index of j-th lowest element in scores on dim=1
ranks = sorted_indices.argsort(dim=1) # (b, k), rank[i, j] = the rank (0: lowest) of scores[i, j] on dim=1
is_mask = (ranks < num_token_masked.unsqueeze(-1))
ids = torch.where(is_mask, self.mask_id, ids)
'''
Preparing input
'''
# (b, num_token, seqlen)
logits = self.forward_with_cond_scale(ids, cond_vector=cond_vector,
padding_mask=padding_mask,
cond_scale=cond_scale,
force_mask=force_mask)
logits = logits.permute(0, 2, 1) # (b, seqlen, ntoken)
# print(logits.shape, self.opt.num_tokens)
# clean low prob token
filtered_logits = top_k(logits, topk_filter_thres, dim=-1)
'''
Update ids
'''
# if force_mask:
temperature = starting_temperature
# else:
# temperature = starting_temperature * (steps_until_x0 / timesteps)
# temperature = max(temperature, 1e-4)
# print(filtered_logits.shape)
# temperature is annealed, gradually reducing temperature as well as randomness
if gsample: # use gumbel_softmax sampling
# print("1111")
pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) # (b, seqlen)
else: # use multinomial sampling
# print("2222")
probs = F.softmax(filtered_logits, dim=-1) # (b, seqlen, ntoken)
# print(temperature, starting_temperature, steps_until_x0, timesteps)
# print(probs / temperature)
pred_ids = Categorical(probs / temperature).sample() # (b, seqlen)
# print(pred_ids.max(), pred_ids.min())
# if pred_ids.
ids = torch.where(is_mask, pred_ids, ids)
'''
Updating scores
'''
probs_without_temperature = logits.softmax(dim=-1) # (b, seqlen, ntoken)
scores = probs_without_temperature.gather(2, pred_ids.unsqueeze(dim=-1)) # (b, seqlen, 1)
scores = scores.squeeze(-1) # (b, seqlen)
# We do not want to re-mask the previously kept tokens, or pad tokens
scores = scores.masked_fill(~is_mask, 1e5)
ids = torch.where(padding_mask, -1, ids)
# print("Final", ids.max(), ids.min())
return ids
@torch.no_grad()
@eval_decorator
def edit(self,
conds,
tokens,
m_lens,
timesteps: int,
cond_scale: int,
temperature=1,
topk_filter_thres=0.9,
gsample=False,
force_mask=False,
edit_mask=None,
padding_mask=None,
):
assert edit_mask.shape == tokens.shape if edit_mask is not None else True
device = next(self.parameters()).device
seq_len = tokens.shape[1]
if self.cond_mode == 'text':
with torch.no_grad():
cond_vector = self.encode_text(conds)
elif self.cond_mode == 'action':
cond_vector = self.enc_action(conds).to(device)
elif self.cond_mode == 'uncond':
cond_vector = torch.zeros(1, self.latent_dim).float().to(device)
else:
raise NotImplementedError("Unsupported condition mode!!!")
if padding_mask == None:
padding_mask = ~lengths_to_mask(m_lens, seq_len)
# Start from all tokens being masked
if edit_mask == None:
mask_free = True
ids = torch.where(padding_mask, self.pad_id, tokens)
edit_mask = torch.ones_like(padding_mask)
edit_mask = edit_mask & ~padding_mask
edit_len = edit_mask.sum(dim=-1)
scores = torch.where(edit_mask, 0., 1e5)
else:
mask_free = False
edit_mask = edit_mask & ~padding_mask
edit_len = edit_mask.sum(dim=-1)
ids = torch.where(edit_mask, self.mask_id, tokens)
scores = torch.where(edit_mask, 0., 1e5)
starting_temperature = temperature
for timestep, steps_until_x0 in zip(torch.linspace(0, 1, timesteps, device=device), reversed(range(timesteps))):
# 0 < timestep < 1
rand_mask_prob = 0.16 if mask_free else self.noise_schedule(timestep) # Tensor
'''
Maskout, and cope with variable length
'''
# fix: the ratio regarding lengths, instead of seq_len
num_token_masked = torch.round(rand_mask_prob * edit_len).clamp(min=1) # (b, )
# select num_token_masked tokens with lowest scores to be masked
sorted_indices = scores.argsort(
dim=1) # (b, k), sorted_indices[i, j] = the index of j-th lowest element in scores on dim=1
ranks = sorted_indices.argsort(dim=1) # (b, k), rank[i, j] = the rank (0: lowest) of scores[i, j] on dim=1
is_mask = (ranks < num_token_masked.unsqueeze(-1))
# is_mask = (torch.rand_like(scores) < 0.8) * ~padding_mask if mask_free else is_mask
ids = torch.where(is_mask, self.mask_id, ids)
'''
Preparing input
'''
# (b, num_token, seqlen)
logits = self.forward_with_cond_scale(ids, cond_vector=cond_vector,
padding_mask=padding_mask,
cond_scale=cond_scale,
force_mask=force_mask)
logits = logits.permute(0, 2, 1) # (b, seqlen, ntoken)
# print(logits.shape, self.opt.num_tokens)
# clean low prob token
filtered_logits = top_k(logits, topk_filter_thres, dim=-1)
'''
Update ids
'''
# if force_mask:
temperature = starting_temperature
# else:
# temperature = starting_temperature * (steps_until_x0 / timesteps)
# temperature = max(temperature, 1e-4)
# print(filtered_logits.shape)
# temperature is annealed, gradually reducing temperature as well as randomness
if gsample: # use gumbel_softmax sampling
# print("1111")
pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) # (b, seqlen)
else: # use multinomial sampling
# print("2222")
probs = F.softmax(filtered_logits, dim=-1) # (b, seqlen, ntoken)
# print(temperature, starting_temperature, steps_until_x0, timesteps)
# print(probs / temperature)
pred_ids = Categorical(probs / temperature).sample() # (b, seqlen)
# print(pred_ids.max(), pred_ids.min())
# if pred_ids.
ids = torch.where(is_mask, pred_ids, ids)
'''
Updating scores
'''
probs_without_temperature = logits.softmax(dim=-1) # (b, seqlen, ntoken)
scores = probs_without_temperature.gather(2, pred_ids.unsqueeze(dim=-1)) # (b, seqlen, 1)
scores = scores.squeeze(-1) # (b, seqlen)
# We do not want to re-mask the previously kept tokens, or pad tokens
scores = scores.masked_fill(~edit_mask, 1e5) if mask_free else scores.masked_fill(~is_mask, 1e5)
ids = torch.where(padding_mask, -1, ids)
# print("Final", ids.max(), ids.min())
return ids
@torch.no_grad()
@eval_decorator
def edit_beta(self,
conds,
conds_og,
tokens,
m_lens,
cond_scale: int,
force_mask=False,
):
device = next(self.parameters()).device
seq_len = tokens.shape[1]
if self.cond_mode == 'text':
with torch.no_grad():
cond_vector = self.encode_text(conds)
if conds_og is not None:
cond_vector_og = self.encode_text(conds_og)
else:
cond_vector_og = None
elif self.cond_mode == 'action':
cond_vector = self.enc_action(conds).to(device)
if conds_og is not None:
cond_vector_og = self.enc_action(conds_og).to(device)
else:
cond_vector_og = None
else:
raise NotImplementedError("Unsupported condition mode!!!")
padding_mask = ~lengths_to_mask(m_lens, seq_len)
# Start from all tokens being masked
ids = torch.where(padding_mask, self.pad_id, tokens) # Do not mask anything
'''
Preparing input
'''
# (b, num_token, seqlen)
logits = self.forward_with_cond_scale(ids,
cond_vector=cond_vector,
cond_vector_neg=cond_vector_og,
padding_mask=padding_mask,
cond_scale=cond_scale,
force_mask=force_mask)
logits = logits.permute(0, 2, 1) # (b, seqlen, ntoken)
'''
Updating scores
'''
probs_without_temperature = logits.softmax(dim=-1) # (b, seqlen, ntoken)
tokens[tokens == -1] = 0 # just to get through an error when index = -1 using gather
og_tokens_scores = probs_without_temperature.gather(2, tokens.unsqueeze(dim=-1)) # (b, seqlen, 1)
og_tokens_scores = og_tokens_scores.squeeze(-1) # (b, seqlen)
return og_tokens_scores
class ResidualTransformer(nn.Module):
def __init__(self, code_dim, cond_mode, latent_dim=256, ff_size=1024, num_layers=8, cond_drop_prob=0.1,
num_heads=4, dropout=0.1, clip_dim=512, shared_codebook=False, share_weight=False,
clip_version=None, opt=None, **kargs):
super(ResidualTransformer, self).__init__()
print(f'latent_dim: {latent_dim}, ff_size: {ff_size}, nlayers: {num_layers}, nheads: {num_heads}, dropout: {dropout}')
# assert shared_codebook == True, "Only support shared codebook right now!"
self.code_dim = code_dim
self.latent_dim = latent_dim
self.clip_dim = clip_dim
self.dropout = dropout
self.opt = opt
self.cond_mode = cond_mode
# self.cond_drop_prob = cond_drop_prob
if self.cond_mode == 'action':
assert 'num_actions' in kargs
self.num_actions = kargs.get('num_actions', 1)
self.cond_drop_prob = cond_drop_prob
'''
Preparing Networks
'''
self.input_process = InputProcess(self.code_dim, self.latent_dim)
self.position_enc = PositionalEncoding(self.latent_dim, self.dropout)
seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,
nhead=num_heads,
dim_feedforward=ff_size,
dropout=dropout,
activation='gelu')
self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer,
num_layers=num_layers)
self.encode_quant = partial(F.one_hot, num_classes=self.opt.num_quantizers)
self.encode_action = partial(F.one_hot, num_classes=self.num_actions)
self.quant_emb = nn.Linear(self.opt.num_quantizers, self.latent_dim)
# if self.cond_mode != 'no_cond':
if self.cond_mode == 'text':
self.cond_emb = nn.Linear(self.clip_dim, self.latent_dim)
elif self.cond_mode == 'action':
self.cond_emb = nn.Linear(self.num_actions, self.latent_dim)
else:
raise KeyError("Unsupported condition mode!!!")
_num_tokens = opt.num_tokens + 1 # one dummy tokens for padding
self.pad_id = opt.num_tokens
# self.output_process = OutputProcess_Bert(out_feats=opt.num_tokens, latent_dim=latent_dim)
self.output_process = OutputProcess(out_feats=code_dim, latent_dim=latent_dim)
if shared_codebook:
token_embed = nn.Parameter(torch.normal(mean=0, std=0.02, size=(_num_tokens, code_dim)))
self.token_embed_weight = token_embed.expand(opt.num_quantizers-1, _num_tokens, code_dim)
if share_weight:
self.output_proj_weight = self.token_embed_weight
self.output_proj_bias = None
else:
output_proj = nn.Parameter(torch.normal(mean=0, std=0.02, size=(_num_tokens, code_dim)))
output_bias = nn.Parameter(torch.zeros(size=(_num_tokens,)))
# self.output_proj_bias = 0
self.output_proj_weight = output_proj.expand(opt.num_quantizers-1, _num_tokens, code_dim)
self.output_proj_bias = output_bias.expand(opt.num_quantizers-1, _num_tokens)
else:
if share_weight:
self.embed_proj_shared_weight = nn.Parameter(torch.normal(mean=0, std=0.02, size=(opt.num_quantizers - 2, _num_tokens, code_dim)))
self.token_embed_weight_ = nn.Parameter(torch.normal(mean=0, std=0.02, size=(1, _num_tokens, code_dim)))
self.output_proj_weight_ = nn.Parameter(torch.normal(mean=0, std=0.02, size=(1, _num_tokens, code_dim)))
self.output_proj_bias = None
self.registered = False
else:
output_proj_weight = torch.normal(mean=0, std=0.02,
size=(opt.num_quantizers - 1, _num_tokens, code_dim))
self.output_proj_weight = nn.Parameter(output_proj_weight)
self.output_proj_bias = nn.Parameter(torch.zeros(size=(opt.num_quantizers, _num_tokens)))
token_embed_weight = torch.normal(mean=0, std=0.02,
size=(opt.num_quantizers - 1, _num_tokens, code_dim))
self.token_embed_weight = nn.Parameter(token_embed_weight)
self.apply(self.__init_weights)
self.shared_codebook = shared_codebook
self.share_weight = share_weight
if self.cond_mode == 'text':
print('Loading CLIP...')
self.clip_version = clip_version
self.clip_model = self.load_and_freeze_clip(clip_version)
# def
def mask_cond(self, cond, force_mask=False):
bs, d = cond.shape
if force_mask:
return torch.zeros_like(cond)
elif self.training and self.cond_drop_prob > 0.:
mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_drop_prob).view(bs, 1)
return cond * (1. - mask)
else:
return cond
def __init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def parameters_wo_clip(self):
return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')]
def load_and_freeze_clip(self, clip_version):
clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
jit=False) # Must set jit=False for training
# Cannot run on cpu
clip.model.convert_weights(
clip_model) # Actually this line is unnecessary since clip by default already on float16
# Date 0707: It's necessary, only unecessary when load directly to gpu. Disable if need to run on cpu
# Freeze CLIP weights
clip_model.eval()
for p in clip_model.parameters():
p.requires_grad = False
return clip_model
def encode_text(self, raw_text):
device = next(self.parameters()).device
text = clip.tokenize(raw_text, truncate=True).to(device)
feat_clip_text = self.clip_model.encode_text(text).float()
return feat_clip_text
def q_schedule(self, bs, low, high):
noise = uniform((bs,), device=self.opt.device)
schedule = 1 - cosine_schedule(noise)
return torch.round(schedule * (high - low)) + low
def process_embed_proj_weight(self):
if self.share_weight and (not self.shared_codebook):
# if not self.registered:
self.output_proj_weight = torch.cat([self.embed_proj_shared_weight, self.output_proj_weight_], dim=0)
self.token_embed_weight = torch.cat([self.token_embed_weight_, self.embed_proj_shared_weight], dim=0)
# self.registered = True
def output_project(self, logits, qids):
'''
:logits: (bs, code_dim, seqlen)
:qids: (bs)
:return:
-logits (bs, ntoken, seqlen)
'''
# (num_qlayers-1, num_token, code_dim) -> (bs, ntoken, code_dim)
output_proj_weight = self.output_proj_weight[qids]
# (num_qlayers, ntoken) -> (bs, ntoken)
output_proj_bias = None if self.output_proj_bias is None else self.output_proj_bias[qids]
output = torch.einsum('bnc, bcs->bns', output_proj_weight, logits)
if output_proj_bias is not None:
output += output + output_proj_bias.unsqueeze(-1)
return output
def trans_forward(self, motion_codes, qids, cond, padding_mask, force_mask=False):
'''
:param motion_codes: (b, seqlen, d)
:padding_mask: (b, seqlen), all pad positions are TRUE else FALSE
:param qids: (b), quantizer layer ids
:param cond: (b, embed_dim) for text, (b, num_actions) for action
:return:
-logits: (b, num_token, seqlen)
'''
cond = self.mask_cond(cond, force_mask=force_mask)
# (b, seqlen, d) -> (seqlen, b, latent_dim)
x = self.input_process(motion_codes)
# (b, num_quantizer)
q_onehot = self.encode_quant(qids).float().to(x.device)
q_emb = self.quant_emb(q_onehot).unsqueeze(0) # (1, b, latent_dim)
cond = self.cond_emb(cond).unsqueeze(0) # (1, b, latent_dim)
x = self.position_enc(x)
xseq = torch.cat([cond, q_emb, x], dim=0) # (seqlen+2, b, latent_dim)
padding_mask = torch.cat([torch.zeros_like(padding_mask[:, 0:2]), padding_mask], dim=1) # (b, seqlen+2)
output = self.seqTransEncoder(xseq, src_key_padding_mask=padding_mask)[2:] # (seqlen, b, e)
logits = self.output_process(output)
return logits
def forward_with_cond_scale(self,
motion_codes,
q_id,
cond_vector,
padding_mask,
cond_scale=3,
force_mask=False):
bs = motion_codes.shape[0]
# if cond_scale == 1:
qids = torch.full((bs,), q_id, dtype=torch.long, device=motion_codes.device)
if force_mask:
logits = self.trans_forward(motion_codes, qids, cond_vector, padding_mask, force_mask=True)
logits = self.output_project(logits, qids-1)
return logits
logits = self.trans_forward(motion_codes, qids, cond_vector, padding_mask)
logits = self.output_project(logits, qids-1)
if cond_scale == 1:
return logits
aux_logits = self.trans_forward(motion_codes, qids, cond_vector, padding_mask, force_mask=True)
aux_logits = self.output_project(aux_logits, qids-1)
scaled_logits = aux_logits + (logits - aux_logits) * cond_scale
return scaled_logits
def forward(self, all_indices, y, m_lens):
'''
:param all_indices: (b, n, q)
:param y: raw text for cond_mode=text, (b, ) for cond_mode=action
:m_lens: (b,)
:return:
'''
self.process_embed_proj_weight()
bs, ntokens, num_quant_layers = all_indices.shape
device = all_indices.device
# Positions that are PADDED are ALL FALSE
non_pad_mask = lengths_to_mask(m_lens, ntokens) # (b, n)
q_non_pad_mask = repeat(non_pad_mask, 'b n -> b n q', q=num_quant_layers)
all_indices = torch.where(q_non_pad_mask, all_indices, self.pad_id) #(b, n, q)
# randomly sample quantization layers to work on, [1, num_q)
active_q_layers = q_schedule(bs, low=1, high=num_quant_layers, device=device)
# print(self.token_embed_weight.shape, all_indices.shape)
token_embed = repeat(self.token_embed_weight, 'q c d-> b c d q', b=bs)
gather_indices = repeat(all_indices[..., :-1], 'b n q -> b n d q', d=token_embed.shape[2])
# print(token_embed.shape, gather_indices.shape)
all_codes = token_embed.gather(1, gather_indices) # (b, n, d, q-1)
cumsum_codes = torch.cumsum(all_codes, dim=-1) #(b, n, d, q-1)
active_indices = all_indices[torch.arange(bs), :, active_q_layers] # (b, n)
history_sum = cumsum_codes[torch.arange(bs), :, :, active_q_layers - 1]
force_mask = False
if self.cond_mode == 'text':
with torch.no_grad():
cond_vector = self.encode_text(y)
elif self.cond_mode == 'action':
cond_vector = self.enc_action(y).to(device).float()
elif self.cond_mode == 'uncond':
cond_vector = torch.zeros(bs, self.latent_dim).float().to(device)
force_mask = True
else:
raise NotImplementedError("Unsupported condition mode!!!")
logits = self.trans_forward(history_sum, active_q_layers, cond_vector, ~non_pad_mask, force_mask)
logits = self.output_project(logits, active_q_layers-1)
ce_loss, pred_id, acc = cal_performance(logits, active_indices, ignore_index=self.pad_id)
return ce_loss, pred_id, acc
@torch.no_grad()
@eval_decorator
def generate(self,
motion_ids,
conds,
m_lens,
temperature=1,
topk_filter_thres=0.9,
cond_scale=2,
num_res_layers=-1, # If it's -1, use all.
):
# print(self.opt.num_quantizers)
# assert len(timesteps) >= len(cond_scales) == self.opt.num_quantizers
self.process_embed_proj_weight()
device = next(self.parameters()).device
seq_len = motion_ids.shape[1]
batch_size = len(conds)
if self.cond_mode == 'text':
with torch.no_grad():
cond_vector = self.encode_text(conds)
elif self.cond_mode == 'action':
cond_vector = self.enc_action(conds).to(device)
elif self.cond_mode == 'uncond':
cond_vector = torch.zeros(batch_size, self.latent_dim).float().to(device)
else:
raise NotImplementedError("Unsupported condition mode!!!")
# token_embed = repeat(self.token_embed_weight, 'c d -> b c d', b=batch_size)
# gathered_ids = repeat(motion_ids, 'b n -> b n d', d=token_embed.shape[-1])
# history_sum = token_embed.gather(1, gathered_ids)
# print(pa, seq_len)
padding_mask = ~lengths_to_mask(m_lens, seq_len)
# print(padding_mask.shape, motion_ids.shape)
motion_ids = torch.where(padding_mask, self.pad_id, motion_ids)
all_indices = [motion_ids]
history_sum = 0
num_quant_layers = self.opt.num_quantizers if num_res_layers==-1 else num_res_layers+1
for i in range(1, num_quant_layers):
# print(f"--> Working on {i}-th quantizer")
# Start from all tokens being masked
# qids = torch.full((batch_size,), i, dtype=torch.long, device=motion_ids.device)
token_embed = self.token_embed_weight[i-1]
token_embed = repeat(token_embed, 'c d -> b c d', b=batch_size)
gathered_ids = repeat(motion_ids, 'b n -> b n d', d=token_embed.shape[-1])
history_sum += token_embed.gather(1, gathered_ids)
logits = self.forward_with_cond_scale(history_sum, i, cond_vector, padding_mask, cond_scale=cond_scale)
# logits = self.trans_forward(history_sum, qids, cond_vector, padding_mask)
logits = logits.permute(0, 2, 1) # (b, seqlen, ntoken)
# clean low prob token
filtered_logits = top_k(logits, topk_filter_thres, dim=-1)
pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) # (b, seqlen)
# probs = F.softmax(filtered_logits, dim=-1) # (b, seqlen, ntoken)
# # print(temperature, starting_temperature, steps_until_x0, timesteps)
# # print(probs / temperature)
# pred_ids = Categorical(probs / temperature).sample() # (b, seqlen)
ids = torch.where(padding_mask, self.pad_id, pred_ids)
motion_ids = ids
all_indices.append(ids)
all_indices = torch.stack(all_indices, dim=-1)
# padding_mask = repeat(padding_mask, 'b n -> b n q', q=all_indices.shape[-1])
# all_indices = torch.where(padding_mask, -1, all_indices)
all_indices = torch.where(all_indices==self.pad_id, -1, all_indices)
# all_indices = all_indices.masked_fill()
return all_indices
@torch.no_grad()
@eval_decorator
def edit(self,
motion_ids,
conds,
m_lens,
temperature=1,
topk_filter_thres=0.9,
cond_scale=2
):
# print(self.opt.num_quantizers)
# assert len(timesteps) >= len(cond_scales) == self.opt.num_quantizers
self.process_embed_proj_weight()
device = next(self.parameters()).device
seq_len = motion_ids.shape[1]
batch_size = len(conds)
if self.cond_mode == 'text':
with torch.no_grad():
cond_vector = self.encode_text(conds)
elif self.cond_mode == 'action':
cond_vector = self.enc_action(conds).to(device)
elif self.cond_mode == 'uncond':
cond_vector = torch.zeros(batch_size, self.latent_dim).float().to(device)
else:
raise NotImplementedError("Unsupported condition mode!!!")
# token_embed = repeat(self.token_embed_weight, 'c d -> b c d', b=batch_size)
# gathered_ids = repeat(motion_ids, 'b n -> b n d', d=token_embed.shape[-1])
# history_sum = token_embed.gather(1, gathered_ids)
# print(pa, seq_len)
padding_mask = ~lengths_to_mask(m_lens, seq_len)
# print(padding_mask.shape, motion_ids.shape)
motion_ids = torch.where(padding_mask, self.pad_id, motion_ids)
all_indices = [motion_ids]
history_sum = 0
for i in range(1, self.opt.num_quantizers):
# print(f"--> Working on {i}-th quantizer")
# Start from all tokens being masked
# qids = torch.full((batch_size,), i, dtype=torch.long, device=motion_ids.device)
token_embed = self.token_embed_weight[i-1]
token_embed = repeat(token_embed, 'c d -> b c d', b=batch_size)
gathered_ids = repeat(motion_ids, 'b n -> b n d', d=token_embed.shape[-1])
history_sum += token_embed.gather(1, gathered_ids)
logits = self.forward_with_cond_scale(history_sum, i, cond_vector, padding_mask, cond_scale=cond_scale)
# logits = self.trans_forward(history_sum, qids, cond_vector, padding_mask)
logits = logits.permute(0, 2, 1) # (b, seqlen, ntoken)
# clean low prob token
filtered_logits = top_k(logits, topk_filter_thres, dim=-1)
pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) # (b, seqlen)
# probs = F.softmax(filtered_logits, dim=-1) # (b, seqlen, ntoken)
# # print(temperature, starting_temperature, steps_until_x0, timesteps)
# # print(probs / temperature)
# pred_ids = Categorical(probs / temperature).sample() # (b, seqlen)
ids = torch.where(padding_mask, self.pad_id, pred_ids)
motion_ids = ids
all_indices.append(ids)
all_indices = torch.stack(all_indices, dim=-1)
# padding_mask = repeat(padding_mask, 'b n -> b n q', q=all_indices.shape[-1])
# all_indices = torch.where(padding_mask, -1, all_indices)
all_indices = torch.where(all_indices==self.pad_id, -1, all_indices)
# all_indices = all_indices.masked_fill()
return all_indices