bill-jiang's picture
Init
4409449
raw
history blame
No virus
18.4 kB
import numpy as np
import os
import random
import torch
import time
from mGPT.config import instantiate_from_config
from os.path import join as pjoin
from mGPT.losses.mgpt import GPTLosses
from mGPT.models.base import BaseModel
from .base import BaseModel
import json
import mGPT.render.matplot.plot_3d_global as plot_3d
class MotionGPT(BaseModel):
"""
Stage 1 Motion Tokenizer
Stage 2 Motion-language pretrian
Stage 3 Motion-language instruction tuning
"""
def __init__(self,
cfg,
datamodule,
lm,
motion_vae,
codebook_size=512,
stage='vae',
debug=True,
condition='text',
task='t2m',
metrics_dict=['TM2TMetrics'],
**kwargs):
self.save_hyperparameters(ignore='datamodule', logger=False)
self.datamodule = datamodule
super().__init__()
# Instantiate motion tokenizer
if motion_vae != None:
self.vae = instantiate_from_config(motion_vae)
# Instantiate motion-language model
self.lm = instantiate_from_config(lm)
# Freeze the motion tokenizer for lm training
if 'lm' in self.hparams.stage:
self.vae.training = False
for p in self.vae.parameters():
p.requires_grad = False
# Instantiate the losses
self._losses = torch.nn.ModuleDict({
split: GPTLosses(cfg, self.hparams.stage, self.datamodule.njoints)
for split in ["losses_train", "losses_test", "losses_val"]
})
# Data transform
self.feats2joints = datamodule.feats2joints
# Count codebook frequency
self.codePred = []
self.codeFrequency = torch.zeros((self.hparams.codebook_size, ))
def forward(self, batch, task="t2m"):
texts = batch["text"]
lengths_ref = batch["length"]
# Forward
# texts = ['Generate motion: ' + text for text in texts]
outputs, output_texts = self.lm.generate_direct(texts, do_sample=True)
# Motion Decode
feats_rst_lst = []
lengths = []
max_len = 0
for i in range(len(texts)):
if task == "pred":
motion = self.vae.decode(
torch.cat((batch["motion"][i], outputs[i])))
elif task in ["t2m", "m2t", "inbetween"]:
motion = self.vae.decode(outputs[i])
# motion = self.datamodule.denormalize(motion)
lengths.append(motion.shape[1])
else:
raise NotImplementedError
if motion.shape[1] > max_len:
max_len = motion.shape[1]
if task in ["t2m", "m2t", "pred"]:
feats_rst_lst.append(motion)
elif task == "inbetween":
motion = torch.cat(
(batch["motion_heading"][i][None],
motion[:, lengths_ref[i] // 4:lengths_ref[i] // 4 * 3,
...], batch["motion_tailing"][i][None]),
dim=1)
feats_rst_lst.append(motion)
feats_rst = torch.zeros(
(len(feats_rst_lst), max_len, motion.shape[-1])).to(self.device)
# padding and concat
for i in range(len(feats_rst_lst)):
feats_rst[i, :feats_rst_lst[i].shape[1], ...] = feats_rst_lst[i]
# Recover joints for evaluation
joints_rst = self.feats2joints(feats_rst)
# return set
outputs = {
"texts": output_texts,
"feats": feats_rst,
"joints": joints_rst,
"length": lengths
}
return outputs
def train_lm_forward(self, batch):
tokens_ref = batch["motion"]
texts = batch["text"]
lengths = batch["length"]
tasks = batch["tasks"]
all_captions = batch['all_captions']
if self.hparams.condition == 'caption':
texts = [random.choice(all_captions[i]) for i in range(len(texts))]
# LLM Forward
outputs = self.lm(texts, tokens_ref, lengths, tasks)
# outputs = self.t2m_gpt.generate(texts)
return {'outputs': outputs}
@torch.no_grad()
def val_t2m_forward(self, batch):
feats_ref = batch["motion"]
texts = batch["text"]
lengths = batch["length"]
tasks = None
if self.trainer.datamodule.is_mm:
texts = texts * self.hparams.cfg.METRIC.MM_NUM_REPEATS
feats_ref = feats_ref.repeat_interleave(
self.hparams.cfg.METRIC.MM_NUM_REPEATS, dim=0)
lengths = lengths * self.hparams.cfg.METRIC.MM_NUM_REPEATS
instructions = pjoin(self.datamodule.hparams.data_root,
'template_instructions.json')
instructions = json.load(open(instructions, 'r'))
tasks = [instructions["Text-to-Motion"]["caption"]] * len(texts)
if self.hparams.condition == 'caption':
tasks = [{
'input': ['<Caption_Placeholder>'],
'output': ['']
}] * len(texts)
if self.hparams.cfg.DATASET.TASK_PATH:
instructions = pjoin(self.hparams.cfg.DATASET.TASK_PATH)
instructions = json.load(open(instructions, 'r'))
tasks = [instructions["Text-to-Motion"]["t2m"]] * len(texts)
min_len = lengths.copy()
# Forward
outputs = self.lm.generate_conditional(texts,
lengths=lengths,
stage='test',
tasks=tasks)
# Motion Decode
feats_rst = torch.zeros_like(feats_ref)
for i in range(len(texts)):
outputs[i] = torch.clamp(outputs[i],
0,
self.hparams.codebook_size - 1,
out=None)
if len(outputs[i]) > 1:
motion = self.vae.decode(outputs[i])
else:
motion = torch.zeros_like(feats_ref[i:i + 1, ...])
min_len[i] = min(motion.shape[1], lengths[i])
# Cut Motion
feats_rst[i:i + 1, :min_len[i], ...] = motion[:, :lengths[i]]
# Recover joints for evaluation
joints_ref = self.feats2joints(feats_ref)
joints_rst = self.feats2joints(feats_rst)
# Renorm for evaluation
feats_ref = self.datamodule.renorm4t2m(feats_ref)
feats_rst = self.datamodule.renorm4t2m(feats_rst)
# return set
rs_set = {
"m_ref": feats_ref,
"m_rst": feats_rst,
"joints_ref": joints_ref,
"joints_rst": joints_rst,
"length": min_len
# "length": lengths
}
return rs_set
@torch.no_grad()
def val_m2t_forward(self, batch):
self.hparams.metrics_dict = []
feats_ref = batch["motion"]
texts = batch["text"]
lengths = batch["length"]
all_captions = batch['all_captions']
# Motion Encode
motion_tokens = []
lengths_tokens = []
for i in range(len(feats_ref)):
motion_token, _ = self.vae.encode(feats_ref[i:i + 1])
motion_tokens.append(motion_token[0])
lengths_tokens.append(motion_token.shape[1])
# Forward
outputs = self.lm.generate_conditional(motion_tokens=motion_tokens,
lengths=lengths_tokens,
task="m2t",
stage='test')
# return set
rs_set = {
"m_ref": feats_ref,
"t_ref": all_captions,
# "t_ref": texts,
"t_pred": outputs,
"length": lengths
}
return rs_set
@torch.no_grad()
def val_m2m_forward(self, batch, task="pred"):
feats_ref = batch["motion"]
lengths = batch["length"]
# Motion Encode
motion_tokens = []
lengths_tokens = []
for i in range(len(feats_ref)):
motion_token, _ = self.vae.encode(feats_ref[i:i + 1])
motion_tokens.append(motion_token[0])
# Forward
outputs = self.lm.generate_conditional(motion_tokens=motion_tokens,
lengths=lengths,
task=task,
stage='test')
# Motion Decode
feats_rst = torch.zeros_like(feats_ref)
min_len = lengths.copy()
for i in range(len(lengths)):
outputs[i] = torch.clamp(outputs[i],
0,
self.hparams.codebook_size - 1,
out=None)
if len(outputs[i]) > 1:
motion = self.vae.decode(outputs[i])
else:
motion = torch.zeros_like(feats_ref[i:i + 1, ...])
min_len[i] = min(motion.shape[1], lengths[i])
# Cut Motion
feats_rst[i:i + 1, :min_len[i], ...] = motion[:, :lengths[i]]
# Recover joints for evaluation
joints_ref = self.feats2joints(feats_ref)
joints_rst = self.feats2joints(feats_rst)
# Renorm for evaluation
feats_ref = self.datamodule.renorm4t2m(feats_ref)
feats_rst = self.datamodule.renorm4t2m(feats_rst)
# return set
rs_set = {
"m_ref": feats_ref,
"m_rst": feats_rst,
"joints_ref": joints_ref,
"joints_rst": joints_rst,
"length": min_len
# "length": lengths
}
return rs_set
def train_vae_forward(self, batch):
# batch detach
feats_ref = batch["motion"]
joints_ref = self.feats2joints(feats_ref)
# motion encode & decode
feats_rst, loss_commit, perplexity = self.vae(feats_ref)
joints_rst = self.feats2joints(feats_rst)
# return set
rs_set = {
"m_ref": feats_ref,
"joints_ref": joints_ref,
"m_rst": feats_rst,
"joints_rst": joints_rst,
"loss_commit": loss_commit,
"perplexity": perplexity,
}
return rs_set
@torch.no_grad()
def val_vae_forward(self, batch, split="train"):
# Detach batch
feats_ref = batch["motion"]
lengths = batch["length"]
# Repeat for multimodal evaluation
if self.trainer.datamodule.is_mm:
feats_ref = feats_ref.repeat_interleave(
self.hparams.cfg.METRIC.MM_NUM_REPEATS, dim=0)
lengths = lengths * self.hparams.cfg.METRIC.MM_NUM_REPEATS
# Motion encode & decode
feats_rst = torch.zeros_like(feats_ref)
for i in range(len(feats_ref)):
if lengths[i] == 0:
continue
feats_pred, _, _ = self.vae(feats_ref[i:i + 1, :lengths[i]])
feats_rst[i:i + 1, :feats_pred.shape[1], :] = feats_pred
code_pred, _ = self.vae.encode(feats_ref[i:i + 1, :lengths[i]])
# codeFre_pred = torch.bincount(code_pred[0],
# minlength=self.hparams.codebook_size).to(
# self.codeFrequency.device)
# self.codePred.append(code_pred[0])
# self.codeFrequency += codeFre_pred
# np.save('../memData/results/codeFrequency.npy',
# self.codeFrequency.cpu().numpy())
# Recover joints for evaluation
joints_ref = self.feats2joints(feats_ref)
joints_rst = self.feats2joints(feats_rst)
# Renorm for evaluation
feats_ref = self.datamodule.renorm4t2m(feats_ref)
feats_rst = self.datamodule.renorm4t2m(feats_rst)
# Return set
rs_set = {
"m_ref": feats_ref,
"joints_ref": joints_ref,
"m_rst": feats_rst,
"joints_rst": joints_rst,
"length": lengths,
}
return rs_set
def allsplit_step(self, split: str, batch, batch_idx):
# Compute the losses
loss = None
if self.hparams.stage == "vae" and split in ["train", "val"]:
rs_set = self.train_vae_forward(batch)
loss = self._losses['losses_' + split].update(rs_set)
elif self.hparams.stage in ["lm_instruct", "lm_pretrain"
] and split in ["train"]:
rs_set = self.train_lm_forward(batch)
loss = self._losses['losses_' + split].update(rs_set)
elif self.hparams.stage == 'lm_rl' and split in ['train']:
rs_set = self.train_rl_forward(batch)
loss = None
# Compute the metrics
if split in ["val", "test"]:
if self.hparams.stage == "vae":
rs_set = self.val_vae_forward(batch, split)
elif self.hparams.stage in ["lm_instruct", "lm_pretrain", "lm_rl"]:
if self.hparams.task == "t2m":
rs_set = self.val_t2m_forward(batch)
elif self.hparams.task == "m2t":
rs_set = self.val_m2t_forward(batch)
elif self.hparams.task in ["m2m", "pred", "inbetween"]:
rs_set = self.val_m2m_forward(batch, self.hparams.task)
if self.hparams.task not in ["m2t"]:
# MultiModality evaluation sperately
if self.trainer.datamodule.is_mm:
metrics_dicts = ['MMMetrics']
else:
metrics_dicts = self.hparams.metrics_dict
if self.hparams.task not in ['pred', 'inbetween']:
metrics_dicts.remove('PredMetrics')
for metric in metrics_dicts:
lengths = batch['length']
if metric == "TemosMetric":
getattr(self.metrics,
metric).update(rs_set["joints_rst"],
rs_set["joints_ref"], lengths)
elif metric == "TM2TMetrics":
if self.hparams.stage in [
"lm_instruct", "lm_pretrain", "lm_rl"
]:
word_embs = batch['word_embs']
pos_ohot = batch['pos_ohot']
text_lengths = batch['text_len']
if self.trainer.datamodule.is_mm:
word_embs = word_embs.repeat_interleave(
self.hparams.cfg.METRIC.MM_NUM_REPEATS,
dim=0)
pos_ohot = pos_ohot.repeat_interleave(
self.hparams.cfg.METRIC.MM_NUM_REPEATS,
dim=0)
text_lengths = text_lengths.repeat_interleave(
self.hparams.cfg.METRIC.MM_NUM_REPEATS,
dim=0)
else:
word_embs = None
pos_ohot = None
text_lengths = None
getattr(self.metrics, metric).update(
feats_ref=rs_set["m_ref"],
feats_rst=rs_set["m_rst"],
lengths_ref=lengths,
lengths_rst=rs_set['length'],
word_embs=word_embs,
pos_ohot=pos_ohot,
text_lengths=text_lengths,
)
elif metric == "UncondMetrics":
getattr(self.metrics, metric).update(
recmotion_embeddings=rs_set["lat_rm"],
gtmotion_embeddings=rs_set["lat_m"],
lengths=lengths,
)
elif metric == "MRMetrics":
getattr(self.metrics,
metric).update(rs_set["joints_rst"],
rs_set["joints_ref"], lengths)
elif metric == "PredMetrics":
getattr(self.metrics,
metric).update(rs_set["joints_rst"],
rs_set["joints_ref"], lengths)
elif metric == "MMMetrics":
# pass
getattr(self.metrics,
metric).update(rs_set["m_rst"],
rs_set['length'])
else:
raise TypeError(f"Not support this metric {metric}")
elif self.hparams.task == "m2t" and self.hparams.stage in [
"lm_instruct", "lm_pretrain", "lm_rl"
]:
self.hparams.metrics_dict = metrics_dicts = ['M2TMetrics']
for metric in metrics_dicts:
if metric == "M2TMetrics":
getattr(self.metrics, metric).update(
feats_ref=rs_set["m_ref"],
pred_texts=rs_set["t_pred"],
gt_texts=batch["all_captions"],
lengths=rs_set['length'],
word_embs=batch["word_embs"],
pos_ohot=batch["pos_ohot"],
text_lengths=batch["text_len"],
)
# return forward output rather than loss during test
if split in ["test"]:
if self.hparams.task == "t2m":
return rs_set["joints_rst"], rs_set["length"], rs_set[
"joints_ref"]
# pass
elif self.hparams.task == "m2t":
return rs_set["t_pred"], batch["length"]
# return batch["length"]
return loss