|
from typing import List |
|
|
|
import torch |
|
from torch import Tensor |
|
from torchmetrics import Metric |
|
from torchmetrics.functional import pairwise_euclidean_distance |
|
from .utils import * |
|
import os |
|
from mGPT.config import instantiate_from_config |
|
|
|
class MMMetrics(Metric): |
|
full_state_update = True |
|
|
|
def __init__(self, cfg, dataname='humanml3d', mm_num_times=10, dist_sync_on_step=True, **kwargs): |
|
super().__init__(dist_sync_on_step=dist_sync_on_step) |
|
|
|
self.name = "MultiModality scores" |
|
self.cfg = cfg |
|
self.dataname = dataname |
|
self.mm_num_times = mm_num_times |
|
|
|
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") |
|
self.add_state("count_seq", |
|
default=torch.tensor(0), |
|
dist_reduce_fx="sum") |
|
|
|
self.metrics = ["MultiModality"] |
|
self.add_state("MultiModality", |
|
default=torch.tensor(0.), |
|
dist_reduce_fx="sum") |
|
|
|
|
|
self.add_state("mm_motion_embeddings", default=[], dist_reduce_fx=None) |
|
|
|
|
|
self._get_t2m_evaluator(cfg) |
|
|
|
def _get_t2m_evaluator(self, cfg): |
|
""" |
|
load T2M text encoder and motion encoder for evaluating |
|
""" |
|
|
|
self.t2m_textencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_textencoder) |
|
self.t2m_moveencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_moveencoder) |
|
self.t2m_motionencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_motionencoder) |
|
|
|
|
|
if self.dataname == "kit": |
|
dataname = "kit" |
|
else: |
|
dataname = "t2m" |
|
t2m_checkpoint = torch.load(os.path.join( |
|
cfg.METRIC.TM2T.t2m_path, dataname, |
|
"text_mot_match/model/finest.tar"), |
|
map_location="cpu") |
|
|
|
self.t2m_textencoder.load_state_dict(t2m_checkpoint["text_encoder"]) |
|
self.t2m_moveencoder.load_state_dict( |
|
t2m_checkpoint["movement_encoder"]) |
|
self.t2m_motionencoder.load_state_dict( |
|
t2m_checkpoint["motion_encoder"]) |
|
|
|
|
|
self.t2m_textencoder.eval() |
|
self.t2m_moveencoder.eval() |
|
self.t2m_motionencoder.eval() |
|
for p in self.t2m_textencoder.parameters(): |
|
p.requires_grad = False |
|
for p in self.t2m_moveencoder.parameters(): |
|
p.requires_grad = False |
|
for p in self.t2m_motionencoder.parameters(): |
|
p.requires_grad = False |
|
|
|
def compute(self, sanity_flag): |
|
count = self.count.item() |
|
count_seq = self.count_seq.item() |
|
|
|
|
|
metrics = {metric: getattr(self, metric) for metric in self.metrics} |
|
|
|
|
|
if sanity_flag: |
|
return metrics |
|
|
|
|
|
all_mm_motions = torch.cat(self.mm_motion_embeddings, |
|
axis=0).cpu().numpy() |
|
metrics['MultiModality'] = calculate_multimodality_np( |
|
all_mm_motions, self.mm_num_times) |
|
|
|
|
|
self.reset() |
|
|
|
return {**metrics} |
|
|
|
def update( |
|
self, |
|
feats_rst: Tensor, |
|
lengths_rst: List[int], |
|
): |
|
self.count += sum(lengths_rst) |
|
self.count_seq += len(lengths_rst) |
|
|
|
align_idx = np.argsort(lengths_rst)[::-1].copy() |
|
feats_rst = feats_rst[align_idx] |
|
lengths_rst = np.array(lengths_rst)[align_idx] |
|
recmotion_embeddings = self.get_motion_embeddings( |
|
feats_rst, lengths_rst) |
|
cache = [0] * len(lengths_rst) |
|
for i in range(len(lengths_rst)): |
|
cache[align_idx[i]] = recmotion_embeddings[i:i + 1] |
|
|
|
mm_motion_embeddings = torch.cat(cache, axis=0).unsqueeze(0) |
|
|
|
|
|
|
|
self.mm_motion_embeddings.append(mm_motion_embeddings) |
|
|
|
def get_motion_embeddings(self, feats: Tensor, lengths: List[int]): |
|
m_lens = torch.tensor(lengths) |
|
m_lens = torch.div(m_lens, |
|
self.cfg.DATASET.HUMANML3D.UNIT_LEN, |
|
rounding_mode="floor") |
|
|
|
mov = self.t2m_moveencoder(feats[..., :-4]).detach() |
|
emb = self.t2m_motionencoder(mov, m_lens) |
|
|
|
|
|
return torch.flatten(emb, start_dim=1).detach() |
|
|