import numpy as np import os import random import torch import time from mGPT.config import instantiate_from_config from os.path import join as pjoin from mGPT.losses.mgpt import GPTLosses from mGPT.models.base import BaseModel from .base import BaseModel import json import mGPT.render.matplot.plot_3d_global as plot_3d class MotionGPT(BaseModel): """ Stage 1 Motion Tokenizer Stage 2 Motion-language pretrian Stage 3 Motion-language instruction tuning """ def __init__(self, cfg, datamodule, lm, motion_vae, codebook_size=512, stage='vae', debug=True, condition='text', task='t2m', metrics_dict=['TM2TMetrics'], **kwargs): self.save_hyperparameters(ignore='datamodule', logger=False) self.datamodule = datamodule super().__init__() # Instantiate motion tokenizer if motion_vae != None: self.vae = instantiate_from_config(motion_vae) # Instantiate motion-language model self.lm = instantiate_from_config(lm) # Freeze the motion tokenizer for lm training if 'lm' in self.hparams.stage: self.vae.training = False for p in self.vae.parameters(): p.requires_grad = False # Instantiate the losses self._losses = torch.nn.ModuleDict({ split: GPTLosses(cfg, self.hparams.stage, self.datamodule.njoints) for split in ["losses_train", "losses_test", "losses_val"] }) # Data transform self.feats2joints = datamodule.feats2joints # Count codebook frequency self.codePred = [] self.codeFrequency = torch.zeros((self.hparams.codebook_size, )) def forward(self, batch, task="t2m"): texts = batch["text"] lengths_ref = batch["length"] # Forward # texts = ['Generate motion: ' + text for text in texts] outputs, output_texts = self.lm.generate_direct(texts, do_sample=True) # Motion Decode feats_rst_lst = [] lengths = [] max_len = 0 for i in range(len(texts)): if task == "pred": motion = self.vae.decode( torch.cat((batch["motion"][i], outputs[i]))) elif task in ["t2m", "m2t", "inbetween"]: motion = self.vae.decode(outputs[i]) # motion = self.datamodule.denormalize(motion) lengths.append(motion.shape[1]) else: raise NotImplementedError if motion.shape[1] > max_len: max_len = motion.shape[1] if task in ["t2m", "m2t", "pred"]: feats_rst_lst.append(motion) elif task == "inbetween": motion = torch.cat( (batch["motion_heading"][i][None], motion[:, lengths_ref[i] // 4:lengths_ref[i] // 4 * 3, ...], batch["motion_tailing"][i][None]), dim=1) feats_rst_lst.append(motion) feats_rst = torch.zeros( (len(feats_rst_lst), max_len, motion.shape[-1])).to(self.device) # padding and concat for i in range(len(feats_rst_lst)): feats_rst[i, :feats_rst_lst[i].shape[1], ...] = feats_rst_lst[i] # Recover joints for evaluation joints_rst = self.feats2joints(feats_rst) # return set outputs = { "texts": output_texts, "feats": feats_rst, "joints": joints_rst, "length": lengths } return outputs def train_lm_forward(self, batch): tokens_ref = batch["motion"] texts = batch["text"] lengths = batch["length"] tasks = batch["tasks"] all_captions = batch['all_captions'] if self.hparams.condition == 'caption': texts = [random.choice(all_captions[i]) for i in range(len(texts))] # LLM Forward outputs = self.lm(texts, tokens_ref, lengths, tasks) # outputs = self.t2m_gpt.generate(texts) return {'outputs': outputs} @torch.no_grad() def val_t2m_forward(self, batch): feats_ref = batch["motion"] texts = batch["text"] lengths = batch["length"] tasks = None if self.trainer.datamodule.is_mm: texts = texts * self.hparams.cfg.METRIC.MM_NUM_REPEATS feats_ref = feats_ref.repeat_interleave( self.hparams.cfg.METRIC.MM_NUM_REPEATS, dim=0) lengths = lengths * self.hparams.cfg.METRIC.MM_NUM_REPEATS instructions = pjoin(self.datamodule.hparams.data_root, 'template_instructions.json') instructions = json.load(open(instructions, 'r')) tasks = [instructions["Text-to-Motion"]["caption"]] * len(texts) if self.hparams.condition == 'caption': tasks = [{ 'input': [''], 'output': [''] }] * len(texts) if self.hparams.cfg.DATASET.TASK_PATH: instructions = pjoin(self.hparams.cfg.DATASET.TASK_PATH) instructions = json.load(open(instructions, 'r')) tasks = [instructions["Text-to-Motion"]["t2m"]] * len(texts) min_len = lengths.copy() # Forward outputs = self.lm.generate_conditional(texts, lengths=lengths, stage='test', tasks=tasks) # Motion Decode feats_rst = torch.zeros_like(feats_ref) for i in range(len(texts)): outputs[i] = torch.clamp(outputs[i], 0, self.hparams.codebook_size - 1, out=None) if len(outputs[i]) > 1: motion = self.vae.decode(outputs[i]) else: motion = torch.zeros_like(feats_ref[i:i + 1, ...]) min_len[i] = min(motion.shape[1], lengths[i]) # Cut Motion feats_rst[i:i + 1, :min_len[i], ...] = motion[:, :lengths[i]] # Recover joints for evaluation joints_ref = self.feats2joints(feats_ref) joints_rst = self.feats2joints(feats_rst) # Renorm for evaluation feats_ref = self.datamodule.renorm4t2m(feats_ref) feats_rst = self.datamodule.renorm4t2m(feats_rst) # return set rs_set = { "m_ref": feats_ref, "m_rst": feats_rst, "joints_ref": joints_ref, "joints_rst": joints_rst, "length": min_len # "length": lengths } return rs_set @torch.no_grad() def val_m2t_forward(self, batch): self.hparams.metrics_dict = [] feats_ref = batch["motion"] texts = batch["text"] lengths = batch["length"] all_captions = batch['all_captions'] # Motion Encode motion_tokens = [] lengths_tokens = [] for i in range(len(feats_ref)): motion_token, _ = self.vae.encode(feats_ref[i:i + 1]) motion_tokens.append(motion_token[0]) lengths_tokens.append(motion_token.shape[1]) # Forward outputs = self.lm.generate_conditional(motion_tokens=motion_tokens, lengths=lengths_tokens, task="m2t", stage='test') # return set rs_set = { "m_ref": feats_ref, "t_ref": all_captions, # "t_ref": texts, "t_pred": outputs, "length": lengths } return rs_set @torch.no_grad() def val_m2m_forward(self, batch, task="pred"): feats_ref = batch["motion"] lengths = batch["length"] # Motion Encode motion_tokens = [] lengths_tokens = [] for i in range(len(feats_ref)): motion_token, _ = self.vae.encode(feats_ref[i:i + 1]) motion_tokens.append(motion_token[0]) # Forward outputs = self.lm.generate_conditional(motion_tokens=motion_tokens, lengths=lengths, task=task, stage='test') # Motion Decode feats_rst = torch.zeros_like(feats_ref) min_len = lengths.copy() for i in range(len(lengths)): outputs[i] = torch.clamp(outputs[i], 0, self.hparams.codebook_size - 1, out=None) if len(outputs[i]) > 1: motion = self.vae.decode(outputs[i]) else: motion = torch.zeros_like(feats_ref[i:i + 1, ...]) min_len[i] = min(motion.shape[1], lengths[i]) # Cut Motion feats_rst[i:i + 1, :min_len[i], ...] = motion[:, :lengths[i]] # Recover joints for evaluation joints_ref = self.feats2joints(feats_ref) joints_rst = self.feats2joints(feats_rst) # Renorm for evaluation feats_ref = self.datamodule.renorm4t2m(feats_ref) feats_rst = self.datamodule.renorm4t2m(feats_rst) # return set rs_set = { "m_ref": feats_ref, "m_rst": feats_rst, "joints_ref": joints_ref, "joints_rst": joints_rst, "length": min_len # "length": lengths } return rs_set def train_vae_forward(self, batch): # batch detach feats_ref = batch["motion"] joints_ref = self.feats2joints(feats_ref) # motion encode & decode feats_rst, loss_commit, perplexity = self.vae(feats_ref) joints_rst = self.feats2joints(feats_rst) # return set rs_set = { "m_ref": feats_ref, "joints_ref": joints_ref, "m_rst": feats_rst, "joints_rst": joints_rst, "loss_commit": loss_commit, "perplexity": perplexity, } return rs_set @torch.no_grad() def val_vae_forward(self, batch, split="train"): # Detach batch feats_ref = batch["motion"] lengths = batch["length"] # Repeat for multimodal evaluation if self.trainer.datamodule.is_mm: feats_ref = feats_ref.repeat_interleave( self.hparams.cfg.METRIC.MM_NUM_REPEATS, dim=0) lengths = lengths * self.hparams.cfg.METRIC.MM_NUM_REPEATS # Motion encode & decode feats_rst = torch.zeros_like(feats_ref) for i in range(len(feats_ref)): if lengths[i] == 0: continue feats_pred, _, _ = self.vae(feats_ref[i:i + 1, :lengths[i]]) feats_rst[i:i + 1, :feats_pred.shape[1], :] = feats_pred code_pred, _ = self.vae.encode(feats_ref[i:i + 1, :lengths[i]]) # codeFre_pred = torch.bincount(code_pred[0], # minlength=self.hparams.codebook_size).to( # self.codeFrequency.device) # self.codePred.append(code_pred[0]) # self.codeFrequency += codeFre_pred # np.save('../memData/results/codeFrequency.npy', # self.codeFrequency.cpu().numpy()) # Recover joints for evaluation joints_ref = self.feats2joints(feats_ref) joints_rst = self.feats2joints(feats_rst) # Renorm for evaluation feats_ref = self.datamodule.renorm4t2m(feats_ref) feats_rst = self.datamodule.renorm4t2m(feats_rst) # Return set rs_set = { "m_ref": feats_ref, "joints_ref": joints_ref, "m_rst": feats_rst, "joints_rst": joints_rst, "length": lengths, } return rs_set def allsplit_step(self, split: str, batch, batch_idx): # Compute the losses loss = None if self.hparams.stage == "vae" and split in ["train", "val"]: rs_set = self.train_vae_forward(batch) loss = self._losses['losses_' + split].update(rs_set) elif self.hparams.stage in ["lm_instruct", "lm_pretrain" ] and split in ["train"]: rs_set = self.train_lm_forward(batch) loss = self._losses['losses_' + split].update(rs_set) elif self.hparams.stage == 'lm_rl' and split in ['train']: rs_set = self.train_rl_forward(batch) loss = None # Compute the metrics if split in ["val", "test"]: if self.hparams.stage == "vae": rs_set = self.val_vae_forward(batch, split) elif self.hparams.stage in ["lm_instruct", "lm_pretrain", "lm_rl"]: if self.hparams.task == "t2m": rs_set = self.val_t2m_forward(batch) elif self.hparams.task == "m2t": rs_set = self.val_m2t_forward(batch) elif self.hparams.task in ["m2m", "pred", "inbetween"]: rs_set = self.val_m2m_forward(batch, self.hparams.task) if self.hparams.task not in ["m2t"]: # MultiModality evaluation sperately if self.trainer.datamodule.is_mm: metrics_dicts = ['MMMetrics'] else: metrics_dicts = self.hparams.metrics_dict if self.hparams.task not in ['pred', 'inbetween']: metrics_dicts.remove('PredMetrics') for metric in metrics_dicts: lengths = batch['length'] if metric == "TemosMetric": getattr(self.metrics, metric).update(rs_set["joints_rst"], rs_set["joints_ref"], lengths) elif metric == "TM2TMetrics": if self.hparams.stage in [ "lm_instruct", "lm_pretrain", "lm_rl" ]: word_embs = batch['word_embs'] pos_ohot = batch['pos_ohot'] text_lengths = batch['text_len'] if self.trainer.datamodule.is_mm: word_embs = word_embs.repeat_interleave( self.hparams.cfg.METRIC.MM_NUM_REPEATS, dim=0) pos_ohot = pos_ohot.repeat_interleave( self.hparams.cfg.METRIC.MM_NUM_REPEATS, dim=0) text_lengths = text_lengths.repeat_interleave( self.hparams.cfg.METRIC.MM_NUM_REPEATS, dim=0) else: word_embs = None pos_ohot = None text_lengths = None getattr(self.metrics, metric).update( feats_ref=rs_set["m_ref"], feats_rst=rs_set["m_rst"], lengths_ref=lengths, lengths_rst=rs_set['length'], word_embs=word_embs, pos_ohot=pos_ohot, text_lengths=text_lengths, ) elif metric == "UncondMetrics": getattr(self.metrics, metric).update( recmotion_embeddings=rs_set["lat_rm"], gtmotion_embeddings=rs_set["lat_m"], lengths=lengths, ) elif metric == "MRMetrics": getattr(self.metrics, metric).update(rs_set["joints_rst"], rs_set["joints_ref"], lengths) elif metric == "PredMetrics": getattr(self.metrics, metric).update(rs_set["joints_rst"], rs_set["joints_ref"], lengths) elif metric == "MMMetrics": # pass getattr(self.metrics, metric).update(rs_set["m_rst"], rs_set['length']) else: raise TypeError(f"Not support this metric {metric}") elif self.hparams.task == "m2t" and self.hparams.stage in [ "lm_instruct", "lm_pretrain", "lm_rl" ]: self.hparams.metrics_dict = metrics_dicts = ['M2TMetrics'] for metric in metrics_dicts: if metric == "M2TMetrics": getattr(self.metrics, metric).update( feats_ref=rs_set["m_ref"], pred_texts=rs_set["t_pred"], gt_texts=batch["all_captions"], lengths=rs_set['length'], word_embs=batch["word_embs"], pos_ohot=batch["pos_ohot"], text_lengths=batch["text_len"], ) # return forward output rather than loss during test if split in ["test"]: if self.hparams.task == "t2m": return rs_set["joints_rst"], rs_set["length"], rs_set[ "joints_ref"] # pass elif self.hparams.task == "m2t": return rs_set["t_pred"], batch["length"] # return batch["length"] return loss