Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import numpy as np | |
# from networks.layers import * | |
import torch.nn.functional as F | |
import clip | |
from einops import rearrange, repeat | |
import math | |
from random import random | |
from tqdm.auto import tqdm | |
from typing import Callable, Optional, List, Dict | |
from copy import deepcopy | |
from functools import partial | |
from models.mask_transformer.tools import * | |
from torch.distributions.categorical import Categorical | |
class InputProcess(nn.Module): | |
def __init__(self, input_feats, latent_dim): | |
super().__init__() | |
self.input_feats = input_feats | |
self.latent_dim = latent_dim | |
self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim) | |
def forward(self, x): | |
# [bs, ntokens, input_feats] | |
x = x.permute((1, 0, 2)) # [seqen, bs, input_feats] | |
# print(x.shape) | |
x = self.poseEmbedding(x) # [seqlen, bs, d] | |
return x | |
class PositionalEncoding(nn.Module): | |
#Borrow from MDM, the same as above, but add dropout, exponential may improve precision | |
def __init__(self, d_model, dropout=0.1, max_len=5000): | |
super(PositionalEncoding, self).__init__() | |
self.dropout = nn.Dropout(p=dropout) | |
pe = torch.zeros(max_len, d_model) | |
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
pe = pe.unsqueeze(0).transpose(0, 1) #[max_len, 1, d_model] | |
self.register_buffer('pe', pe) | |
def forward(self, x): | |
# not used in the final model | |
x = x + self.pe[:x.shape[0], :] | |
return self.dropout(x) | |
class OutputProcess_Bert(nn.Module): | |
def __init__(self, out_feats, latent_dim): | |
super().__init__() | |
self.dense = nn.Linear(latent_dim, latent_dim) | |
self.transform_act_fn = F.gelu | |
self.LayerNorm = nn.LayerNorm(latent_dim, eps=1e-12) | |
self.poseFinal = nn.Linear(latent_dim, out_feats) #Bias! | |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
hidden_states = self.dense(hidden_states) | |
hidden_states = self.transform_act_fn(hidden_states) | |
hidden_states = self.LayerNorm(hidden_states) | |
output = self.poseFinal(hidden_states) # [seqlen, bs, out_feats] | |
output = output.permute(1, 2, 0) # [bs, c, seqlen] | |
return output | |
class OutputProcess(nn.Module): | |
def __init__(self, out_feats, latent_dim): | |
super().__init__() | |
self.dense = nn.Linear(latent_dim, latent_dim) | |
self.transform_act_fn = F.gelu | |
self.LayerNorm = nn.LayerNorm(latent_dim, eps=1e-12) | |
self.poseFinal = nn.Linear(latent_dim, out_feats) #Bias! | |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
hidden_states = self.dense(hidden_states) | |
hidden_states = self.transform_act_fn(hidden_states) | |
hidden_states = self.LayerNorm(hidden_states) | |
output = self.poseFinal(hidden_states) # [seqlen, bs, out_feats] | |
output = output.permute(1, 2, 0) # [bs, e, seqlen] | |
return output | |
class MaskTransformer(nn.Module): | |
def __init__(self, code_dim, cond_mode, latent_dim=256, ff_size=1024, num_layers=8, | |
num_heads=4, dropout=0.1, clip_dim=512, cond_drop_prob=0.1, | |
clip_version=None, opt=None, **kargs): | |
super(MaskTransformer, self).__init__() | |
print(f'latent_dim: {latent_dim}, ff_size: {ff_size}, nlayers: {num_layers}, nheads: {num_heads}, dropout: {dropout}') | |
self.code_dim = code_dim | |
self.latent_dim = latent_dim | |
self.clip_dim = clip_dim | |
self.dropout = dropout | |
self.opt = opt | |
self.cond_mode = cond_mode | |
self.cond_drop_prob = cond_drop_prob | |
if self.cond_mode == 'action': | |
assert 'num_actions' in kargs | |
self.num_actions = kargs.get('num_actions', 1) | |
''' | |
Preparing Networks | |
''' | |
self.input_process = InputProcess(self.code_dim, self.latent_dim) | |
self.position_enc = PositionalEncoding(self.latent_dim, self.dropout) | |
seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
nhead=num_heads, | |
dim_feedforward=ff_size, | |
dropout=dropout, | |
activation='gelu') | |
self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, | |
num_layers=num_layers) | |
self.encode_action = partial(F.one_hot, num_classes=self.num_actions) | |
# if self.cond_mode != 'no_cond': | |
if self.cond_mode == 'text': | |
self.cond_emb = nn.Linear(self.clip_dim, self.latent_dim) | |
elif self.cond_mode == 'action': | |
self.cond_emb = nn.Linear(self.num_actions, self.latent_dim) | |
elif self.cond_mode == 'uncond': | |
self.cond_emb = nn.Identity() | |
else: | |
raise KeyError("Unsupported condition mode!!!") | |
_num_tokens = opt.num_tokens + 2 # two dummy tokens, one for masking, one for padding | |
self.mask_id = opt.num_tokens | |
self.pad_id = opt.num_tokens + 1 | |
self.output_process = OutputProcess_Bert(out_feats=opt.num_tokens, latent_dim=latent_dim) | |
self.token_emb = nn.Embedding(_num_tokens, self.code_dim) | |
self.apply(self.__init_weights) | |
''' | |
Preparing frozen weights | |
''' | |
if self.cond_mode == 'text': | |
print('Loading CLIP...') | |
self.clip_version = clip_version | |
self.clip_model = self.load_and_freeze_clip(clip_version) | |
self.noise_schedule = cosine_schedule | |
def load_and_freeze_token_emb(self, codebook): | |
''' | |
:param codebook: (c, d) | |
:return: | |
''' | |
assert self.training, 'Only necessary in training mode' | |
c, d = codebook.shape | |
self.token_emb.weight = nn.Parameter(torch.cat([codebook, torch.zeros(size=(2, d), device=codebook.device)], dim=0)) #add two dummy tokens, 0 vectors | |
self.token_emb.requires_grad_(False) | |
# self.token_emb.weight.requires_grad = False | |
# self.token_emb_ready = True | |
print("Token embedding initialized!") | |
def __init_weights(self, module): | |
if isinstance(module, (nn.Linear, nn.Embedding)): | |
module.weight.data.normal_(mean=0.0, std=0.02) | |
if isinstance(module, nn.Linear) and module.bias is not None: | |
module.bias.data.zero_() | |
elif isinstance(module, nn.LayerNorm): | |
module.bias.data.zero_() | |
module.weight.data.fill_(1.0) | |
def parameters_wo_clip(self): | |
return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')] | |
def load_and_freeze_clip(self, clip_version): | |
clip_model, clip_preprocess = clip.load(clip_version, device='cpu', | |
jit=False) # Must set jit=False for training | |
# Cannot run on cpu | |
clip.model.convert_weights( | |
clip_model) # Actually this line is unnecessary since clip by default already on float16 | |
# Date 0707: It's necessary, only unecessary when load directly to gpu. Disable if need to run on cpu | |
# Freeze CLIP weights | |
clip_model.eval() | |
for p in clip_model.parameters(): | |
p.requires_grad = False | |
return clip_model | |
def encode_text(self, raw_text): | |
device = next(self.parameters()).device | |
text = clip.tokenize(raw_text, truncate=True).to(device) | |
feat_clip_text = self.clip_model.encode_text(text).float() | |
return feat_clip_text | |
def mask_cond(self, cond, force_mask=False): | |
bs, d = cond.shape | |
if force_mask: | |
return torch.zeros_like(cond) | |
elif self.training and self.cond_drop_prob > 0.: | |
mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_drop_prob).view(bs, 1) | |
return cond * (1. - mask) | |
else: | |
return cond | |
def trans_forward(self, motion_ids, cond, padding_mask, force_mask=False): | |
''' | |
:param motion_ids: (b, seqlen) | |
:padding_mask: (b, seqlen), all pad positions are TRUE else FALSE | |
:param cond: (b, embed_dim) for text, (b, num_actions) for action | |
:param force_mask: boolean | |
:return: | |
-logits: (b, num_token, seqlen) | |
''' | |
cond = self.mask_cond(cond, force_mask=force_mask) | |
# print(motion_ids.shape) | |
x = self.token_emb(motion_ids) | |
# print(x.shape) | |
# (b, seqlen, d) -> (seqlen, b, latent_dim) | |
x = self.input_process(x) | |
cond = self.cond_emb(cond).unsqueeze(0) #(1, b, latent_dim) | |
x = self.position_enc(x) | |
xseq = torch.cat([cond, x], dim=0) #(seqlen+1, b, latent_dim) | |
padding_mask = torch.cat([torch.zeros_like(padding_mask[:, 0:1]), padding_mask], dim=1) #(b, seqlen+1) | |
# print(xseq.shape, padding_mask.shape) | |
# print(padding_mask.shape, xseq.shape) | |
output = self.seqTransEncoder(xseq, src_key_padding_mask=padding_mask)[1:] #(seqlen, b, e) | |
logits = self.output_process(output) #(seqlen, b, e) -> (b, ntoken, seqlen) | |
return logits | |
def forward(self, ids, y, m_lens): | |
''' | |
:param ids: (b, n) | |
:param y: raw text for cond_mode=text, (b, ) for cond_mode=action | |
:m_lens: (b,) | |
:return: | |
''' | |
bs, ntokens = ids.shape | |
device = ids.device | |
# Positions that are PADDED are ALL FALSE | |
non_pad_mask = lengths_to_mask(m_lens, ntokens) #(b, n) | |
ids = torch.where(non_pad_mask, ids, self.pad_id) | |
force_mask = False | |
if self.cond_mode == 'text': | |
with torch.no_grad(): | |
cond_vector = self.encode_text(y) | |
elif self.cond_mode == 'action': | |
cond_vector = self.enc_action(y).to(device).float() | |
elif self.cond_mode == 'uncond': | |
cond_vector = torch.zeros(bs, self.latent_dim).float().to(device) | |
force_mask = True | |
else: | |
raise NotImplementedError("Unsupported condition mode!!!") | |
''' | |
Prepare mask | |
''' | |
rand_time = uniform((bs,), device=device) | |
rand_mask_probs = self.noise_schedule(rand_time) | |
num_token_masked = (ntokens * rand_mask_probs).round().clamp(min=1) | |
batch_randperm = torch.rand((bs, ntokens), device=device).argsort(dim=-1) | |
# Positions to be MASKED are ALL TRUE | |
mask = batch_randperm < num_token_masked.unsqueeze(-1) | |
# Positions to be MASKED must also be NON-PADDED | |
mask &= non_pad_mask | |
# Note this is our training target, not input | |
labels = torch.where(mask, ids, self.mask_id) | |
x_ids = ids.clone() | |
# Further Apply Bert Masking Scheme | |
# Step 1: 10% replace with an incorrect token | |
mask_rid = get_mask_subset_prob(mask, 0.1) | |
rand_id = torch.randint_like(x_ids, high=self.opt.num_tokens) | |
x_ids = torch.where(mask_rid, rand_id, x_ids) | |
# Step 2: 90% x 10% replace with correct token, and 90% x 88% replace with mask token | |
mask_mid = get_mask_subset_prob(mask & ~mask_rid, 0.88) | |
# mask_mid = mask | |
x_ids = torch.where(mask_mid, self.mask_id, x_ids) | |
logits = self.trans_forward(x_ids, cond_vector, ~non_pad_mask, force_mask) | |
ce_loss, pred_id, acc = cal_performance(logits, labels, ignore_index=self.mask_id) | |
return ce_loss, pred_id, acc | |
def forward_with_cond_scale(self, | |
motion_ids, | |
cond_vector, | |
padding_mask, | |
cond_scale=3, | |
force_mask=False): | |
# bs = motion_ids.shape[0] | |
# if cond_scale == 1: | |
if force_mask: | |
return self.trans_forward(motion_ids, cond_vector, padding_mask, force_mask=True) | |
logits = self.trans_forward(motion_ids, cond_vector, padding_mask) | |
if cond_scale == 1: | |
return logits | |
aux_logits = self.trans_forward(motion_ids, cond_vector, padding_mask, force_mask=True) | |
scaled_logits = aux_logits + (logits - aux_logits) * cond_scale | |
return scaled_logits | |
def generate(self, | |
conds, | |
m_lens, | |
timesteps: int, | |
cond_scale: int, | |
temperature=1, | |
topk_filter_thres=0.9, | |
gsample=False, | |
force_mask=False | |
): | |
# print(self.opt.num_quantizers) | |
# assert len(timesteps) >= len(cond_scales) == self.opt.num_quantizers | |
device = next(self.parameters()).device | |
seq_len = max(m_lens) | |
batch_size = len(m_lens) | |
if self.cond_mode == 'text': | |
with torch.no_grad(): | |
cond_vector = self.encode_text(conds) | |
elif self.cond_mode == 'action': | |
cond_vector = self.enc_action(conds).to(device) | |
elif self.cond_mode == 'uncond': | |
cond_vector = torch.zeros(batch_size, self.latent_dim).float().to(device) | |
else: | |
raise NotImplementedError("Unsupported condition mode!!!") | |
padding_mask = ~lengths_to_mask(m_lens, seq_len) | |
# print(padding_mask.shape, ) | |
# Start from all tokens being masked | |
ids = torch.where(padding_mask, self.pad_id, self.mask_id) | |
scores = torch.where(padding_mask, 1e5, 0.) | |
starting_temperature = temperature | |
for timestep, steps_until_x0 in zip(torch.linspace(0, 1, timesteps, device=device), reversed(range(timesteps))): | |
# 0 < timestep < 1 | |
rand_mask_prob = self.noise_schedule(timestep) # Tensor | |
''' | |
Maskout, and cope with variable length | |
''' | |
# fix: the ratio regarding lengths, instead of seq_len | |
num_token_masked = torch.round(rand_mask_prob * m_lens).clamp(min=1) # (b, ) | |
# select num_token_masked tokens with lowest scores to be masked | |
sorted_indices = scores.argsort( | |
dim=1) # (b, k), sorted_indices[i, j] = the index of j-th lowest element in scores on dim=1 | |
ranks = sorted_indices.argsort(dim=1) # (b, k), rank[i, j] = the rank (0: lowest) of scores[i, j] on dim=1 | |
is_mask = (ranks < num_token_masked.unsqueeze(-1)) | |
ids = torch.where(is_mask, self.mask_id, ids) | |
''' | |
Preparing input | |
''' | |
# (b, num_token, seqlen) | |
logits = self.forward_with_cond_scale(ids, cond_vector=cond_vector, | |
padding_mask=padding_mask, | |
cond_scale=cond_scale, | |
force_mask=force_mask) | |
logits = logits.permute(0, 2, 1) # (b, seqlen, ntoken) | |
# print(logits.shape, self.opt.num_tokens) | |
# clean low prob token | |
filtered_logits = top_k(logits, topk_filter_thres, dim=-1) | |
''' | |
Update ids | |
''' | |
# if force_mask: | |
temperature = starting_temperature | |
# else: | |
# temperature = starting_temperature * (steps_until_x0 / timesteps) | |
# temperature = max(temperature, 1e-4) | |
# print(filtered_logits.shape) | |
# temperature is annealed, gradually reducing temperature as well as randomness | |
if gsample: # use gumbel_softmax sampling | |
# print("1111") | |
pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) # (b, seqlen) | |
else: # use multinomial sampling | |
# print("2222") | |
probs = F.softmax(filtered_logits, dim=-1) # (b, seqlen, ntoken) | |
# print(temperature, starting_temperature, steps_until_x0, timesteps) | |
# print(probs / temperature) | |
pred_ids = Categorical(probs / temperature).sample() # (b, seqlen) | |
# print(pred_ids.max(), pred_ids.min()) | |
# if pred_ids. | |
ids = torch.where(is_mask, pred_ids, ids) | |
''' | |
Updating scores | |
''' | |
probs_without_temperature = logits.softmax(dim=-1) # (b, seqlen, ntoken) | |
scores = probs_without_temperature.gather(2, pred_ids.unsqueeze(dim=-1)) # (b, seqlen, 1) | |
scores = scores.squeeze(-1) # (b, seqlen) | |
# We do not want to re-mask the previously kept tokens, or pad tokens | |
scores = scores.masked_fill(~is_mask, 1e5) | |
ids = torch.where(padding_mask, -1, ids) | |
# print("Final", ids.max(), ids.min()) | |
return ids | |
def edit(self, | |
conds, | |
tokens, | |
m_lens, | |
timesteps: int, | |
cond_scale: int, | |
temperature=1, | |
topk_filter_thres=0.9, | |
gsample=False, | |
force_mask=False, | |
edit_mask=None, | |
padding_mask=None, | |
): | |
assert edit_mask.shape == tokens.shape if edit_mask is not None else True | |
device = next(self.parameters()).device | |
seq_len = tokens.shape[1] | |
if self.cond_mode == 'text': | |
with torch.no_grad(): | |
cond_vector = self.encode_text(conds) | |
elif self.cond_mode == 'action': | |
cond_vector = self.enc_action(conds).to(device) | |
elif self.cond_mode == 'uncond': | |
cond_vector = torch.zeros(1, self.latent_dim).float().to(device) | |
else: | |
raise NotImplementedError("Unsupported condition mode!!!") | |
if padding_mask == None: | |
padding_mask = ~lengths_to_mask(m_lens, seq_len) | |
# Start from all tokens being masked | |
if edit_mask == None: | |
mask_free = True | |
ids = torch.where(padding_mask, self.pad_id, tokens) | |
edit_mask = torch.ones_like(padding_mask) | |
edit_mask = edit_mask & ~padding_mask | |
edit_len = edit_mask.sum(dim=-1) | |
scores = torch.where(edit_mask, 0., 1e5) | |
else: | |
mask_free = False | |
edit_mask = edit_mask & ~padding_mask | |
edit_len = edit_mask.sum(dim=-1) | |
ids = torch.where(edit_mask, self.mask_id, tokens) | |
scores = torch.where(edit_mask, 0., 1e5) | |
starting_temperature = temperature | |
for timestep, steps_until_x0 in zip(torch.linspace(0, 1, timesteps, device=device), reversed(range(timesteps))): | |
# 0 < timestep < 1 | |
rand_mask_prob = 0.16 if mask_free else self.noise_schedule(timestep) # Tensor | |
''' | |
Maskout, and cope with variable length | |
''' | |
# fix: the ratio regarding lengths, instead of seq_len | |
num_token_masked = torch.round(rand_mask_prob * edit_len).clamp(min=1) # (b, ) | |
# select num_token_masked tokens with lowest scores to be masked | |
sorted_indices = scores.argsort( | |
dim=1) # (b, k), sorted_indices[i, j] = the index of j-th lowest element in scores on dim=1 | |
ranks = sorted_indices.argsort(dim=1) # (b, k), rank[i, j] = the rank (0: lowest) of scores[i, j] on dim=1 | |
is_mask = (ranks < num_token_masked.unsqueeze(-1)) | |
# is_mask = (torch.rand_like(scores) < 0.8) * ~padding_mask if mask_free else is_mask | |
ids = torch.where(is_mask, self.mask_id, ids) | |
''' | |
Preparing input | |
''' | |
# (b, num_token, seqlen) | |
logits = self.forward_with_cond_scale(ids, cond_vector=cond_vector, | |
padding_mask=padding_mask, | |
cond_scale=cond_scale, | |
force_mask=force_mask) | |
logits = logits.permute(0, 2, 1) # (b, seqlen, ntoken) | |
# print(logits.shape, self.opt.num_tokens) | |
# clean low prob token | |
filtered_logits = top_k(logits, topk_filter_thres, dim=-1) | |
''' | |
Update ids | |
''' | |
# if force_mask: | |
temperature = starting_temperature | |
# else: | |
# temperature = starting_temperature * (steps_until_x0 / timesteps) | |
# temperature = max(temperature, 1e-4) | |
# print(filtered_logits.shape) | |
# temperature is annealed, gradually reducing temperature as well as randomness | |
if gsample: # use gumbel_softmax sampling | |
# print("1111") | |
pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) # (b, seqlen) | |
else: # use multinomial sampling | |
# print("2222") | |
probs = F.softmax(filtered_logits, dim=-1) # (b, seqlen, ntoken) | |
# print(temperature, starting_temperature, steps_until_x0, timesteps) | |
# print(probs / temperature) | |
pred_ids = Categorical(probs / temperature).sample() # (b, seqlen) | |
# print(pred_ids.max(), pred_ids.min()) | |
# if pred_ids. | |
ids = torch.where(is_mask, pred_ids, ids) | |
''' | |
Updating scores | |
''' | |
probs_without_temperature = logits.softmax(dim=-1) # (b, seqlen, ntoken) | |
scores = probs_without_temperature.gather(2, pred_ids.unsqueeze(dim=-1)) # (b, seqlen, 1) | |
scores = scores.squeeze(-1) # (b, seqlen) | |
# We do not want to re-mask the previously kept tokens, or pad tokens | |
scores = scores.masked_fill(~edit_mask, 1e5) if mask_free else scores.masked_fill(~is_mask, 1e5) | |
ids = torch.where(padding_mask, -1, ids) | |
# print("Final", ids.max(), ids.min()) | |
return ids | |
def edit_beta(self, | |
conds, | |
conds_og, | |
tokens, | |
m_lens, | |
cond_scale: int, | |
force_mask=False, | |
): | |
device = next(self.parameters()).device | |
seq_len = tokens.shape[1] | |
if self.cond_mode == 'text': | |
with torch.no_grad(): | |
cond_vector = self.encode_text(conds) | |
if conds_og is not None: | |
cond_vector_og = self.encode_text(conds_og) | |
else: | |
cond_vector_og = None | |
elif self.cond_mode == 'action': | |
cond_vector = self.enc_action(conds).to(device) | |
if conds_og is not None: | |
cond_vector_og = self.enc_action(conds_og).to(device) | |
else: | |
cond_vector_og = None | |
else: | |
raise NotImplementedError("Unsupported condition mode!!!") | |
padding_mask = ~lengths_to_mask(m_lens, seq_len) | |
# Start from all tokens being masked | |
ids = torch.where(padding_mask, self.pad_id, tokens) # Do not mask anything | |
''' | |
Preparing input | |
''' | |
# (b, num_token, seqlen) | |
logits = self.forward_with_cond_scale(ids, | |
cond_vector=cond_vector, | |
cond_vector_neg=cond_vector_og, | |
padding_mask=padding_mask, | |
cond_scale=cond_scale, | |
force_mask=force_mask) | |
logits = logits.permute(0, 2, 1) # (b, seqlen, ntoken) | |
''' | |
Updating scores | |
''' | |
probs_without_temperature = logits.softmax(dim=-1) # (b, seqlen, ntoken) | |
tokens[tokens == -1] = 0 # just to get through an error when index = -1 using gather | |
og_tokens_scores = probs_without_temperature.gather(2, tokens.unsqueeze(dim=-1)) # (b, seqlen, 1) | |
og_tokens_scores = og_tokens_scores.squeeze(-1) # (b, seqlen) | |
return og_tokens_scores | |
class ResidualTransformer(nn.Module): | |
def __init__(self, code_dim, cond_mode, latent_dim=256, ff_size=1024, num_layers=8, cond_drop_prob=0.1, | |
num_heads=4, dropout=0.1, clip_dim=512, shared_codebook=False, share_weight=False, | |
clip_version=None, opt=None, **kargs): | |
super(ResidualTransformer, self).__init__() | |
print(f'latent_dim: {latent_dim}, ff_size: {ff_size}, nlayers: {num_layers}, nheads: {num_heads}, dropout: {dropout}') | |
# assert shared_codebook == True, "Only support shared codebook right now!" | |
self.code_dim = code_dim | |
self.latent_dim = latent_dim | |
self.clip_dim = clip_dim | |
self.dropout = dropout | |
self.opt = opt | |
self.cond_mode = cond_mode | |
# self.cond_drop_prob = cond_drop_prob | |
if self.cond_mode == 'action': | |
assert 'num_actions' in kargs | |
self.num_actions = kargs.get('num_actions', 1) | |
self.cond_drop_prob = cond_drop_prob | |
''' | |
Preparing Networks | |
''' | |
self.input_process = InputProcess(self.code_dim, self.latent_dim) | |
self.position_enc = PositionalEncoding(self.latent_dim, self.dropout) | |
seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
nhead=num_heads, | |
dim_feedforward=ff_size, | |
dropout=dropout, | |
activation='gelu') | |
self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, | |
num_layers=num_layers) | |
self.encode_quant = partial(F.one_hot, num_classes=self.opt.num_quantizers) | |
self.encode_action = partial(F.one_hot, num_classes=self.num_actions) | |
self.quant_emb = nn.Linear(self.opt.num_quantizers, self.latent_dim) | |
# if self.cond_mode != 'no_cond': | |
if self.cond_mode == 'text': | |
self.cond_emb = nn.Linear(self.clip_dim, self.latent_dim) | |
elif self.cond_mode == 'action': | |
self.cond_emb = nn.Linear(self.num_actions, self.latent_dim) | |
else: | |
raise KeyError("Unsupported condition mode!!!") | |
_num_tokens = opt.num_tokens + 1 # one dummy tokens for padding | |
self.pad_id = opt.num_tokens | |
# self.output_process = OutputProcess_Bert(out_feats=opt.num_tokens, latent_dim=latent_dim) | |
self.output_process = OutputProcess(out_feats=code_dim, latent_dim=latent_dim) | |
if shared_codebook: | |
token_embed = nn.Parameter(torch.normal(mean=0, std=0.02, size=(_num_tokens, code_dim))) | |
self.token_embed_weight = token_embed.expand(opt.num_quantizers-1, _num_tokens, code_dim) | |
if share_weight: | |
self.output_proj_weight = self.token_embed_weight | |
self.output_proj_bias = None | |
else: | |
output_proj = nn.Parameter(torch.normal(mean=0, std=0.02, size=(_num_tokens, code_dim))) | |
output_bias = nn.Parameter(torch.zeros(size=(_num_tokens,))) | |
# self.output_proj_bias = 0 | |
self.output_proj_weight = output_proj.expand(opt.num_quantizers-1, _num_tokens, code_dim) | |
self.output_proj_bias = output_bias.expand(opt.num_quantizers-1, _num_tokens) | |
else: | |
if share_weight: | |
self.embed_proj_shared_weight = nn.Parameter(torch.normal(mean=0, std=0.02, size=(opt.num_quantizers - 2, _num_tokens, code_dim))) | |
self.token_embed_weight_ = nn.Parameter(torch.normal(mean=0, std=0.02, size=(1, _num_tokens, code_dim))) | |
self.output_proj_weight_ = nn.Parameter(torch.normal(mean=0, std=0.02, size=(1, _num_tokens, code_dim))) | |
self.output_proj_bias = None | |
self.registered = False | |
else: | |
output_proj_weight = torch.normal(mean=0, std=0.02, | |
size=(opt.num_quantizers - 1, _num_tokens, code_dim)) | |
self.output_proj_weight = nn.Parameter(output_proj_weight) | |
self.output_proj_bias = nn.Parameter(torch.zeros(size=(opt.num_quantizers, _num_tokens))) | |
token_embed_weight = torch.normal(mean=0, std=0.02, | |
size=(opt.num_quantizers - 1, _num_tokens, code_dim)) | |
self.token_embed_weight = nn.Parameter(token_embed_weight) | |
self.apply(self.__init_weights) | |
self.shared_codebook = shared_codebook | |
self.share_weight = share_weight | |
if self.cond_mode == 'text': | |
print('Loading CLIP...') | |
self.clip_version = clip_version | |
self.clip_model = self.load_and_freeze_clip(clip_version) | |
# def | |
def mask_cond(self, cond, force_mask=False): | |
bs, d = cond.shape | |
if force_mask: | |
return torch.zeros_like(cond) | |
elif self.training and self.cond_drop_prob > 0.: | |
mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_drop_prob).view(bs, 1) | |
return cond * (1. - mask) | |
else: | |
return cond | |
def __init_weights(self, module): | |
if isinstance(module, (nn.Linear, nn.Embedding)): | |
module.weight.data.normal_(mean=0.0, std=0.02) | |
if isinstance(module, nn.Linear) and module.bias is not None: | |
module.bias.data.zero_() | |
elif isinstance(module, nn.LayerNorm): | |
module.bias.data.zero_() | |
module.weight.data.fill_(1.0) | |
def parameters_wo_clip(self): | |
return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')] | |
def load_and_freeze_clip(self, clip_version): | |
clip_model, clip_preprocess = clip.load(clip_version, device='cpu', | |
jit=False) # Must set jit=False for training | |
# Cannot run on cpu | |
clip.model.convert_weights( | |
clip_model) # Actually this line is unnecessary since clip by default already on float16 | |
# Date 0707: It's necessary, only unecessary when load directly to gpu. Disable if need to run on cpu | |
# Freeze CLIP weights | |
clip_model.eval() | |
for p in clip_model.parameters(): | |
p.requires_grad = False | |
return clip_model | |
def encode_text(self, raw_text): | |
device = next(self.parameters()).device | |
text = clip.tokenize(raw_text, truncate=True).to(device) | |
feat_clip_text = self.clip_model.encode_text(text).float() | |
return feat_clip_text | |
def q_schedule(self, bs, low, high): | |
noise = uniform((bs,), device=self.opt.device) | |
schedule = 1 - cosine_schedule(noise) | |
return torch.round(schedule * (high - low)) + low | |
def process_embed_proj_weight(self): | |
if self.share_weight and (not self.shared_codebook): | |
# if not self.registered: | |
self.output_proj_weight = torch.cat([self.embed_proj_shared_weight, self.output_proj_weight_], dim=0) | |
self.token_embed_weight = torch.cat([self.token_embed_weight_, self.embed_proj_shared_weight], dim=0) | |
# self.registered = True | |
def output_project(self, logits, qids): | |
''' | |
:logits: (bs, code_dim, seqlen) | |
:qids: (bs) | |
:return: | |
-logits (bs, ntoken, seqlen) | |
''' | |
# (num_qlayers-1, num_token, code_dim) -> (bs, ntoken, code_dim) | |
output_proj_weight = self.output_proj_weight[qids] | |
# (num_qlayers, ntoken) -> (bs, ntoken) | |
output_proj_bias = None if self.output_proj_bias is None else self.output_proj_bias[qids] | |
output = torch.einsum('bnc, bcs->bns', output_proj_weight, logits) | |
if output_proj_bias is not None: | |
output += output + output_proj_bias.unsqueeze(-1) | |
return output | |
def trans_forward(self, motion_codes, qids, cond, padding_mask, force_mask=False): | |
''' | |
:param motion_codes: (b, seqlen, d) | |
:padding_mask: (b, seqlen), all pad positions are TRUE else FALSE | |
:param qids: (b), quantizer layer ids | |
:param cond: (b, embed_dim) for text, (b, num_actions) for action | |
:return: | |
-logits: (b, num_token, seqlen) | |
''' | |
cond = self.mask_cond(cond, force_mask=force_mask) | |
# (b, seqlen, d) -> (seqlen, b, latent_dim) | |
x = self.input_process(motion_codes) | |
# (b, num_quantizer) | |
q_onehot = self.encode_quant(qids).float().to(x.device) | |
q_emb = self.quant_emb(q_onehot).unsqueeze(0) # (1, b, latent_dim) | |
cond = self.cond_emb(cond).unsqueeze(0) # (1, b, latent_dim) | |
x = self.position_enc(x) | |
xseq = torch.cat([cond, q_emb, x], dim=0) # (seqlen+2, b, latent_dim) | |
padding_mask = torch.cat([torch.zeros_like(padding_mask[:, 0:2]), padding_mask], dim=1) # (b, seqlen+2) | |
output = self.seqTransEncoder(xseq, src_key_padding_mask=padding_mask)[2:] # (seqlen, b, e) | |
logits = self.output_process(output) | |
return logits | |
def forward_with_cond_scale(self, | |
motion_codes, | |
q_id, | |
cond_vector, | |
padding_mask, | |
cond_scale=3, | |
force_mask=False): | |
bs = motion_codes.shape[0] | |
# if cond_scale == 1: | |
qids = torch.full((bs,), q_id, dtype=torch.long, device=motion_codes.device) | |
if force_mask: | |
logits = self.trans_forward(motion_codes, qids, cond_vector, padding_mask, force_mask=True) | |
logits = self.output_project(logits, qids-1) | |
return logits | |
logits = self.trans_forward(motion_codes, qids, cond_vector, padding_mask) | |
logits = self.output_project(logits, qids-1) | |
if cond_scale == 1: | |
return logits | |
aux_logits = self.trans_forward(motion_codes, qids, cond_vector, padding_mask, force_mask=True) | |
aux_logits = self.output_project(aux_logits, qids-1) | |
scaled_logits = aux_logits + (logits - aux_logits) * cond_scale | |
return scaled_logits | |
def forward(self, all_indices, y, m_lens): | |
''' | |
:param all_indices: (b, n, q) | |
:param y: raw text for cond_mode=text, (b, ) for cond_mode=action | |
:m_lens: (b,) | |
:return: | |
''' | |
self.process_embed_proj_weight() | |
bs, ntokens, num_quant_layers = all_indices.shape | |
device = all_indices.device | |
# Positions that are PADDED are ALL FALSE | |
non_pad_mask = lengths_to_mask(m_lens, ntokens) # (b, n) | |
q_non_pad_mask = repeat(non_pad_mask, 'b n -> b n q', q=num_quant_layers) | |
all_indices = torch.where(q_non_pad_mask, all_indices, self.pad_id) #(b, n, q) | |
# randomly sample quantization layers to work on, [1, num_q) | |
active_q_layers = q_schedule(bs, low=1, high=num_quant_layers, device=device) | |
# print(self.token_embed_weight.shape, all_indices.shape) | |
token_embed = repeat(self.token_embed_weight, 'q c d-> b c d q', b=bs) | |
gather_indices = repeat(all_indices[..., :-1], 'b n q -> b n d q', d=token_embed.shape[2]) | |
# print(token_embed.shape, gather_indices.shape) | |
all_codes = token_embed.gather(1, gather_indices) # (b, n, d, q-1) | |
cumsum_codes = torch.cumsum(all_codes, dim=-1) #(b, n, d, q-1) | |
active_indices = all_indices[torch.arange(bs), :, active_q_layers] # (b, n) | |
history_sum = cumsum_codes[torch.arange(bs), :, :, active_q_layers - 1] | |
force_mask = False | |
if self.cond_mode == 'text': | |
with torch.no_grad(): | |
cond_vector = self.encode_text(y) | |
elif self.cond_mode == 'action': | |
cond_vector = self.enc_action(y).to(device).float() | |
elif self.cond_mode == 'uncond': | |
cond_vector = torch.zeros(bs, self.latent_dim).float().to(device) | |
force_mask = True | |
else: | |
raise NotImplementedError("Unsupported condition mode!!!") | |
logits = self.trans_forward(history_sum, active_q_layers, cond_vector, ~non_pad_mask, force_mask) | |
logits = self.output_project(logits, active_q_layers-1) | |
ce_loss, pred_id, acc = cal_performance(logits, active_indices, ignore_index=self.pad_id) | |
return ce_loss, pred_id, acc | |
def generate(self, | |
motion_ids, | |
conds, | |
m_lens, | |
temperature=1, | |
topk_filter_thres=0.9, | |
cond_scale=2, | |
num_res_layers=-1, # If it's -1, use all. | |
): | |
# print(self.opt.num_quantizers) | |
# assert len(timesteps) >= len(cond_scales) == self.opt.num_quantizers | |
self.process_embed_proj_weight() | |
device = next(self.parameters()).device | |
seq_len = motion_ids.shape[1] | |
batch_size = len(conds) | |
if self.cond_mode == 'text': | |
with torch.no_grad(): | |
cond_vector = self.encode_text(conds) | |
elif self.cond_mode == 'action': | |
cond_vector = self.enc_action(conds).to(device) | |
elif self.cond_mode == 'uncond': | |
cond_vector = torch.zeros(batch_size, self.latent_dim).float().to(device) | |
else: | |
raise NotImplementedError("Unsupported condition mode!!!") | |
# token_embed = repeat(self.token_embed_weight, 'c d -> b c d', b=batch_size) | |
# gathered_ids = repeat(motion_ids, 'b n -> b n d', d=token_embed.shape[-1]) | |
# history_sum = token_embed.gather(1, gathered_ids) | |
# print(pa, seq_len) | |
padding_mask = ~lengths_to_mask(m_lens, seq_len) | |
# print(padding_mask.shape, motion_ids.shape) | |
motion_ids = torch.where(padding_mask, self.pad_id, motion_ids) | |
all_indices = [motion_ids] | |
history_sum = 0 | |
num_quant_layers = self.opt.num_quantizers if num_res_layers==-1 else num_res_layers+1 | |
for i in range(1, num_quant_layers): | |
# print(f"--> Working on {i}-th quantizer") | |
# Start from all tokens being masked | |
# qids = torch.full((batch_size,), i, dtype=torch.long, device=motion_ids.device) | |
token_embed = self.token_embed_weight[i-1] | |
token_embed = repeat(token_embed, 'c d -> b c d', b=batch_size) | |
gathered_ids = repeat(motion_ids, 'b n -> b n d', d=token_embed.shape[-1]) | |
history_sum += token_embed.gather(1, gathered_ids) | |
logits = self.forward_with_cond_scale(history_sum, i, cond_vector, padding_mask, cond_scale=cond_scale) | |
# logits = self.trans_forward(history_sum, qids, cond_vector, padding_mask) | |
logits = logits.permute(0, 2, 1) # (b, seqlen, ntoken) | |
# clean low prob token | |
filtered_logits = top_k(logits, topk_filter_thres, dim=-1) | |
pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) # (b, seqlen) | |
# probs = F.softmax(filtered_logits, dim=-1) # (b, seqlen, ntoken) | |
# # print(temperature, starting_temperature, steps_until_x0, timesteps) | |
# # print(probs / temperature) | |
# pred_ids = Categorical(probs / temperature).sample() # (b, seqlen) | |
ids = torch.where(padding_mask, self.pad_id, pred_ids) | |
motion_ids = ids | |
all_indices.append(ids) | |
all_indices = torch.stack(all_indices, dim=-1) | |
# padding_mask = repeat(padding_mask, 'b n -> b n q', q=all_indices.shape[-1]) | |
# all_indices = torch.where(padding_mask, -1, all_indices) | |
all_indices = torch.where(all_indices==self.pad_id, -1, all_indices) | |
# all_indices = all_indices.masked_fill() | |
return all_indices | |
def edit(self, | |
motion_ids, | |
conds, | |
m_lens, | |
temperature=1, | |
topk_filter_thres=0.9, | |
cond_scale=2 | |
): | |
# print(self.opt.num_quantizers) | |
# assert len(timesteps) >= len(cond_scales) == self.opt.num_quantizers | |
self.process_embed_proj_weight() | |
device = next(self.parameters()).device | |
seq_len = motion_ids.shape[1] | |
batch_size = len(conds) | |
if self.cond_mode == 'text': | |
with torch.no_grad(): | |
cond_vector = self.encode_text(conds) | |
elif self.cond_mode == 'action': | |
cond_vector = self.enc_action(conds).to(device) | |
elif self.cond_mode == 'uncond': | |
cond_vector = torch.zeros(batch_size, self.latent_dim).float().to(device) | |
else: | |
raise NotImplementedError("Unsupported condition mode!!!") | |
# token_embed = repeat(self.token_embed_weight, 'c d -> b c d', b=batch_size) | |
# gathered_ids = repeat(motion_ids, 'b n -> b n d', d=token_embed.shape[-1]) | |
# history_sum = token_embed.gather(1, gathered_ids) | |
# print(pa, seq_len) | |
padding_mask = ~lengths_to_mask(m_lens, seq_len) | |
# print(padding_mask.shape, motion_ids.shape) | |
motion_ids = torch.where(padding_mask, self.pad_id, motion_ids) | |
all_indices = [motion_ids] | |
history_sum = 0 | |
for i in range(1, self.opt.num_quantizers): | |
# print(f"--> Working on {i}-th quantizer") | |
# Start from all tokens being masked | |
# qids = torch.full((batch_size,), i, dtype=torch.long, device=motion_ids.device) | |
token_embed = self.token_embed_weight[i-1] | |
token_embed = repeat(token_embed, 'c d -> b c d', b=batch_size) | |
gathered_ids = repeat(motion_ids, 'b n -> b n d', d=token_embed.shape[-1]) | |
history_sum += token_embed.gather(1, gathered_ids) | |
logits = self.forward_with_cond_scale(history_sum, i, cond_vector, padding_mask, cond_scale=cond_scale) | |
# logits = self.trans_forward(history_sum, qids, cond_vector, padding_mask) | |
logits = logits.permute(0, 2, 1) # (b, seqlen, ntoken) | |
# clean low prob token | |
filtered_logits = top_k(logits, topk_filter_thres, dim=-1) | |
pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) # (b, seqlen) | |
# probs = F.softmax(filtered_logits, dim=-1) # (b, seqlen, ntoken) | |
# # print(temperature, starting_temperature, steps_until_x0, timesteps) | |
# # print(probs / temperature) | |
# pred_ids = Categorical(probs / temperature).sample() # (b, seqlen) | |
ids = torch.where(padding_mask, self.pad_id, pred_ids) | |
motion_ids = ids | |
all_indices.append(ids) | |
all_indices = torch.stack(all_indices, dim=-1) | |
# padding_mask = repeat(padding_mask, 'b n -> b n q', q=all_indices.shape[-1]) | |
# all_indices = torch.where(padding_mask, -1, all_indices) | |
all_indices = torch.where(all_indices==self.pad_id, -1, all_indices) | |
# all_indices = all_indices.masked_fill() | |
return all_indices |