bill-jiang's picture
Init
4409449
raw
history blame
No virus
3.19 kB
from typing import List
import torch
from torch import Tensor
from torchmetrics import Metric
from .utils import *
# motion reconstruction metric
class MRMetrics(Metric):
def __init__(self,
njoints,
jointstype: str = "mmm",
force_in_meter: bool = True,
align_root: bool = True,
dist_sync_on_step=True,
**kwargs):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.name = 'Motion Reconstructions'
self.jointstype = jointstype
self.align_root = align_root
self.force_in_meter = force_in_meter
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("count_seq",
default=torch.tensor(0),
dist_reduce_fx="sum")
self.add_state("MPJPE",
default=torch.tensor([0.0]),
dist_reduce_fx="sum")
self.add_state("PAMPJPE",
default=torch.tensor([0.0]),
dist_reduce_fx="sum")
self.add_state("ACCEL",
default=torch.tensor([0.0]),
dist_reduce_fx="sum")
# todo
# self.add_state("ROOT", default=torch.tensor([0.0]), dist_reduce_fx="sum")
self.MR_metrics = ["MPJPE", "PAMPJPE", "ACCEL"]
# All metric
self.metrics = self.MR_metrics
def compute(self, sanity_flag):
if self.force_in_meter:
# different jointstypes have different scale factors
# if self.jointstype == 'mmm':
# factor = 1000.0
# elif self.jointstype == 'humanml3d':
# factor = 1000.0 * 0.75 / 480
factor = 1000.0
else:
factor = 1.0
count = self.count
count_seq = self.count_seq
mr_metrics = {}
mr_metrics["MPJPE"] = self.MPJPE / count * factor
mr_metrics["PAMPJPE"] = self.PAMPJPE / count * factor
# accel error: joints_gt[:-2] - 2 * joints_gt[1:-1] + joints_gt[2:]
# n-2 for each sequences
mr_metrics["ACCEL"] = self.ACCEL / (count - 2 * count_seq) * factor
# Reset
self.reset()
return mr_metrics
def update(self, joints_rst: Tensor, joints_ref: Tensor,
lengths: List[int]):
assert joints_rst.shape == joints_ref.shape
assert joints_rst.dim() == 4
# (bs, seq, njoint=22, 3)
self.count += sum(lengths)
self.count_seq += len(lengths)
# avoid cuda error of DDP in pampjpe
rst = joints_rst.detach().cpu()
ref = joints_ref.detach().cpu()
# align root joints index
if self.align_root and self.jointstype in ['mmm', 'humanml3d']:
align_inds = [0]
else:
align_inds = None
for i in range(len(lengths)):
self.MPJPE += torch.sum(
calc_mpjpe(rst[i], ref[i], align_inds=align_inds))
self.PAMPJPE += torch.sum(calc_pampjpe(rst[i], ref[i]))
self.ACCEL += torch.sum(calc_accel(rst[i], ref[i]))