|
from torch import Tensor, nn |
|
from os.path import join as pjoin |
|
from .mr import MRMetrics |
|
from .t2m import TM2TMetrics |
|
from .mm import MMMetrics |
|
from .m2t import M2TMetrics |
|
from .m2m import PredMetrics |
|
|
|
|
|
class BaseMetrics(nn.Module): |
|
def __init__(self, cfg, datamodule, debug, **kwargs) -> None: |
|
super().__init__() |
|
|
|
njoints = datamodule.njoints |
|
|
|
data_name = datamodule.name |
|
if data_name in ["humanml3d", "kit"]: |
|
self.TM2TMetrics = TM2TMetrics( |
|
cfg=cfg, |
|
dataname=data_name, |
|
diversity_times=30 if debug else cfg.METRIC.DIVERSITY_TIMES, |
|
dist_sync_on_step=cfg.METRIC.DIST_SYNC_ON_STEP, |
|
) |
|
self.M2TMetrics = M2TMetrics( |
|
cfg=cfg, |
|
w_vectorizer=datamodule.hparams.w_vectorizer, |
|
diversity_times=30 if debug else cfg.METRIC.DIVERSITY_TIMES, |
|
dist_sync_on_step=cfg.METRIC.DIST_SYNC_ON_STEP) |
|
self.MMMetrics = MMMetrics( |
|
cfg=cfg, |
|
mm_num_times=cfg.METRIC.MM_NUM_TIMES, |
|
dist_sync_on_step=cfg.METRIC.DIST_SYNC_ON_STEP, |
|
) |
|
|
|
self.MRMetrics = MRMetrics( |
|
njoints=njoints, |
|
jointstype=cfg.DATASET.JOINT_TYPE, |
|
dist_sync_on_step=cfg.METRIC.DIST_SYNC_ON_STEP, |
|
) |
|
self.PredMetrics = PredMetrics( |
|
cfg=cfg, |
|
njoints=njoints, |
|
jointstype=cfg.DATASET.JOINT_TYPE, |
|
dist_sync_on_step=cfg.METRIC.DIST_SYNC_ON_STEP, |
|
task=cfg.model.params.task, |
|
) |
|
|