MoMask-test / utils /eval_t2m.py
andrewatef's picture
Upload folder using huggingface_hub
823807d verified
raw
history blame
50.7 kB
import os
import clip
import numpy as np
import torch
# from scipy import linalg
from utils.metrics import *
import torch.nn.functional as F
# import visualization.plot_3d_global as plot_3d
from utils.motion_process import recover_from_ric
#
#
# def tensorborad_add_video_xyz(writer, xyz, nb_iter, tag, nb_vis=4, title_batch=None, outname=None):
# xyz = xyz[:1]
# bs, seq = xyz.shape[:2]
# xyz = xyz.reshape(bs, seq, -1, 3)
# plot_xyz = plot_3d.draw_to_batch(xyz.cpu().numpy(), title_batch, outname)
# plot_xyz = np.transpose(plot_xyz, (0, 1, 4, 2, 3))
# writer.add_video(tag, plot_xyz, nb_iter, fps=20)
@torch.no_grad()
def evaluation_res_conv(out_dir, val_loader, res_model, vq_model, writer, ep, best_fid, best_div, best_top1,
best_top2, best_top3, best_matching, eval_wrapper, plot_func,
save_ckpt=True, save_anim=False, draw=True):
vq_model.eval()
res_model.eval()
motion_annotation_list = []
motion_pred_list = []
R_precision_real = 0
R_precision = 0
nb_sample = 0
matching_score_real = 0
matching_score_pred = 0
for batch in val_loader:
# print(len(batch))
word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, token = batch
motion = motion.cuda()
et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, motion, m_length)
bs, seq = motion.shape[0], motion.shape[1]
code_indices, all_codes = vq_model.encode(motion)
# print(all_codes.shape)
if ep == 0:
if len(all_codes.shape) == 4:
pred_codes = all_codes[0]
else:
pred_codes = vq_model.quantizer.dequantize(code_indices)
pred_codes = pred_codes.permute(0, 2, 1)
else:
if len(code_indices.shape) == 3:
pred_codes = res_model(code_indices[..., 0])
else:
pred_codes = res_model(code_indices)
# print(pred_codes.shape)
pred_motions = vq_model.decoder(pred_codes)
et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_motions,
m_length)
motion_pred_list.append(em_pred)
motion_annotation_list.append(em)
temp_R = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True)
temp_match = euclidean_distance_matrix(et.cpu().numpy(), em.cpu().numpy()).trace()
R_precision_real += temp_R
matching_score_real += temp_match
temp_R = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True)
temp_match = euclidean_distance_matrix(et_pred.cpu().numpy(), em_pred.cpu().numpy()).trace()
R_precision += temp_R
matching_score_pred += temp_match
nb_sample += bs
motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()
gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np)
mu, cov = calculate_activation_statistics(motion_pred_np)
diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100)
diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100)
R_precision_real = R_precision_real / nb_sample
R_precision = R_precision / nb_sample
matching_score_real = matching_score_real / nb_sample
matching_score_pred = matching_score_pred / nb_sample
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
msg = "--> \t Eva. Ep %d:, FID. %.4f, Diversity Real. %.4f, Diversity. %.4f, R_precision_real. (%.4f, %.4f, %.4f), R_precision. (%.4f, %.4f, %.4f), matching_score_real. %.4f, matching_score_pred. %.4f"%\
(ep, fid, diversity_real, diversity, R_precision_real[0],R_precision_real[1], R_precision_real[2],
R_precision[0],R_precision[1], R_precision[2], matching_score_real, matching_score_pred )
# logger.info(msg)
print(msg)
if draw:
writer.add_scalar('./Test/FID', fid, ep)
writer.add_scalar('./Test/Diversity', diversity, ep)
writer.add_scalar('./Test/top1', R_precision[0], ep)
writer.add_scalar('./Test/top2', R_precision[1], ep)
writer.add_scalar('./Test/top3', R_precision[2], ep)
writer.add_scalar('./Test/matching_score', matching_score_pred, ep)
if fid < best_fid:
msg = "--> --> \t FID Improved from %.5f to %.5f !!!" % (best_fid, fid)
if draw: print(msg)
best_fid = fid
if save_ckpt:
torch.save({'res_model': res_model.state_dict(), 'ep': ep}, os.path.join(out_dir, 'net_best_fid.tar'))
if abs(diversity_real - diversity) < abs(diversity_real - best_div):
msg = "--> --> \t Diversity Improved from %.5f to %.5f !!!"%(best_div, diversity)
if draw: print(msg)
best_div = diversity
# if save:
# torch.save({'net': net.state_dict()}, os.path.join(out_dir, 'net_best_div.pth'))
if R_precision[0] > best_top1:
msg = "--> --> \t Top1 Improved from %.5f to %.5f !!!" % (best_top1, R_precision[0])
if draw: print(msg)
best_top1 = R_precision[0]
if R_precision[1] > best_top2:
msg = "--> --> \t Top2 Improved from %.5f to %.5f!!!" % (best_top2, R_precision[1])
if draw: print(msg)
best_top2 = R_precision[1]
if R_precision[2] > best_top3:
msg = "--> --> \t Top3 Improved from %.5f to %.5f !!!" % (best_top3, R_precision[2])
if draw: print(msg)
best_top3 = R_precision[2]
if matching_score_pred < best_matching:
msg = f"--> --> \t matching_score Improved from %.5f to %.5f !!!" % (best_matching, matching_score_pred)
if draw: print(msg)
if save_ckpt:
torch.save({'res_model': res_model.state_dict(), 'ep':ep}, os.path.join(out_dir, 'net_best_mm.tar'))
best_matching = matching_score_pred
# if save:
# torch.save({'net': net.state_dict()}, os.path.join(out_dir, 'net_best_matching.pth'))
if save_anim:
rand_idx = torch.randint(bs, (3,))
data = pred_motions[rand_idx].detach().cpu().numpy()
captions = [caption[k] for k in rand_idx]
lengths = m_length[rand_idx].cpu().numpy()
save_dir = os.path.join(out_dir, 'animation', 'E%04d' % ep)
os.makedirs(save_dir, exist_ok=True)
# print(lengths)
plot_func(data, save_dir, captions, lengths)
# if save:
# torch.save({'net': net.state_dict()}, os.path.join(out_dir, 'net_last.pth'))
vq_model.train()
res_model.train()
return best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer
@torch.no_grad()
def evaluation_vqvae(out_dir, val_loader, net, writer, ep, best_fid, best_div, best_top1,
best_top2, best_top3, best_matching, eval_wrapper, save=True, draw=True):
net.eval()
motion_annotation_list = []
motion_pred_list = []
R_precision_real = 0
R_precision = 0
nb_sample = 0
matching_score_real = 0
matching_score_pred = 0
for batch in val_loader:
# print(len(batch))
word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, token = batch
motion = motion.cuda()
et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, motion, m_length)
bs, seq = motion.shape[0], motion.shape[1]
# num_joints = 21 if motion.shape[-1] == 251 else 22
# pred_pose_eval = torch.zeros((bs, seq, motion.shape[-1])).cuda()
pred_pose_eval, loss_commit, perplexity = net(motion)
et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_pose_eval,
m_length)
motion_pred_list.append(em_pred)
motion_annotation_list.append(em)
temp_R = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True)
temp_match = euclidean_distance_matrix(et.cpu().numpy(), em.cpu().numpy()).trace()
R_precision_real += temp_R
matching_score_real += temp_match
temp_R = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True)
temp_match = euclidean_distance_matrix(et_pred.cpu().numpy(), em_pred.cpu().numpy()).trace()
R_precision += temp_R
matching_score_pred += temp_match
nb_sample += bs
motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()
gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np)
mu, cov = calculate_activation_statistics(motion_pred_np)
diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100)
diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100)
R_precision_real = R_precision_real / nb_sample
R_precision = R_precision / nb_sample
matching_score_real = matching_score_real / nb_sample
matching_score_pred = matching_score_pred / nb_sample
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
msg = "--> \t Eva. Ep %d:, FID. %.4f, Diversity Real. %.4f, Diversity. %.4f, R_precision_real. (%.4f, %.4f, %.4f), R_precision. (%.4f, %.4f, %.4f), matching_score_real. %.4f, matching_score_pred. %.4f"%\
(ep, fid, diversity_real, diversity, R_precision_real[0],R_precision_real[1], R_precision_real[2],
R_precision[0],R_precision[1], R_precision[2], matching_score_real, matching_score_pred )
# logger.info(msg)
print(msg)
if draw:
writer.add_scalar('./Test/FID', fid, ep)
writer.add_scalar('./Test/Diversity', diversity, ep)
writer.add_scalar('./Test/top1', R_precision[0], ep)
writer.add_scalar('./Test/top2', R_precision[1], ep)
writer.add_scalar('./Test/top3', R_precision[2], ep)
writer.add_scalar('./Test/matching_score', matching_score_pred, ep)
if fid < best_fid:
msg = "--> --> \t FID Improved from %.5f to %.5f !!!" % (best_fid, fid)
if draw: print(msg)
best_fid = fid
if save:
torch.save({'vq_model': net.state_dict(), 'ep': ep}, os.path.join(out_dir, 'net_best_fid.tar'))
if abs(diversity_real - diversity) < abs(diversity_real - best_div):
msg = "--> --> \t Diversity Improved from %.5f to %.5f !!!"%(best_div, diversity)
if draw: print(msg)
best_div = diversity
# if save:
# torch.save({'net': net.state_dict()}, os.path.join(out_dir, 'net_best_div.pth'))
if R_precision[0] > best_top1:
msg = "--> --> \t Top1 Improved from %.5f to %.5f !!!" % (best_top1, R_precision[0])
if draw: print(msg)
best_top1 = R_precision[0]
# if save:
# torch.save({'vq_model': net.state_dict(), 'ep':ep}, os.path.join(out_dir, 'net_best_top1.tar'))
if R_precision[1] > best_top2:
msg = "--> --> \t Top2 Improved from %.5f to %.5f!!!" % (best_top2, R_precision[1])
if draw: print(msg)
best_top2 = R_precision[1]
if R_precision[2] > best_top3:
msg = "--> --> \t Top3 Improved from %.5f to %.5f !!!" % (best_top3, R_precision[2])
if draw: print(msg)
best_top3 = R_precision[2]
if matching_score_pred < best_matching:
msg = f"--> --> \t matching_score Improved from %.5f to %.5f !!!" % (best_matching, matching_score_pred)
if draw: print(msg)
best_matching = matching_score_pred
if save:
torch.save({'vq_model': net.state_dict(), 'ep': ep}, os.path.join(out_dir, 'net_best_mm.tar'))
# if save:
# torch.save({'net': net.state_dict()}, os.path.join(out_dir, 'net_last.pth'))
net.train()
return best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer
@torch.no_grad()
def evaluation_vqvae_plus_mpjpe(val_loader, net, repeat_id, eval_wrapper, num_joint):
net.eval()
motion_annotation_list = []
motion_pred_list = []
R_precision_real = 0
R_precision = 0
nb_sample = 0
matching_score_real = 0
matching_score_pred = 0
mpjpe = 0
num_poses = 0
for batch in val_loader:
# print(len(batch))
word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, token = batch
motion = motion.cuda()
et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, motion, m_length)
bs, seq = motion.shape[0], motion.shape[1]
# num_joints = 21 if motion.shape[-1] == 251 else 22
# pred_pose_eval = torch.zeros((bs, seq, motion.shape[-1])).cuda()
pred_pose_eval, loss_commit, perplexity = net(motion)
# all_indices,_ = net.encode(motion)
# pred_pose_eval = net.forward_decoder(all_indices[..., :1])
et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_pose_eval,
m_length)
bgt = val_loader.dataset.inv_transform(motion.detach().cpu().numpy())
bpred = val_loader.dataset.inv_transform(pred_pose_eval.detach().cpu().numpy())
for i in range(bs):
gt = recover_from_ric(torch.from_numpy(bgt[i, :m_length[i]]).float(), num_joint)
pred = recover_from_ric(torch.from_numpy(bpred[i, :m_length[i]]).float(), num_joint)
mpjpe += torch.sum(calculate_mpjpe(gt, pred))
# print(calculate_mpjpe(gt, pred).shape, gt.shape, pred.shape)
num_poses += gt.shape[0]
# print(mpjpe, num_poses)
# exit()
motion_pred_list.append(em_pred)
motion_annotation_list.append(em)
temp_R = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True)
temp_match = euclidean_distance_matrix(et.cpu().numpy(), em.cpu().numpy()).trace()
R_precision_real += temp_R
matching_score_real += temp_match
temp_R = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True)
temp_match = euclidean_distance_matrix(et_pred.cpu().numpy(), em_pred.cpu().numpy()).trace()
R_precision += temp_R
matching_score_pred += temp_match
nb_sample += bs
motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()
gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np)
mu, cov = calculate_activation_statistics(motion_pred_np)
diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100)
diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100)
R_precision_real = R_precision_real / nb_sample
R_precision = R_precision / nb_sample
matching_score_real = matching_score_real / nb_sample
matching_score_pred = matching_score_pred / nb_sample
mpjpe = mpjpe / num_poses
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
msg = "--> \t Eva. Re %d:, FID. %.4f, Diversity Real. %.4f, Diversity. %.4f, R_precision_real. (%.4f, %.4f, %.4f), R_precision. (%.4f, %.4f, %.4f), matching_real. %.4f, matching_pred. %.4f, MPJPE. %.4f" % \
(repeat_id, fid, diversity_real, diversity, R_precision_real[0], R_precision_real[1], R_precision_real[2],
R_precision[0], R_precision[1], R_precision[2], matching_score_real, matching_score_pred, mpjpe)
# logger.info(msg)
print(msg)
return fid, diversity, R_precision, matching_score_pred, mpjpe
@torch.no_grad()
def evaluation_vqvae_plus_l1(val_loader, net, repeat_id, eval_wrapper, num_joint):
net.eval()
motion_annotation_list = []
motion_pred_list = []
R_precision_real = 0
R_precision = 0
nb_sample = 0
matching_score_real = 0
matching_score_pred = 0
l1_dist = 0
num_poses = 1
for batch in val_loader:
# print(len(batch))
word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, token = batch
motion = motion.cuda()
et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, motion, m_length)
bs, seq = motion.shape[0], motion.shape[1]
# num_joints = 21 if motion.shape[-1] == 251 else 22
# pred_pose_eval = torch.zeros((bs, seq, motion.shape[-1])).cuda()
pred_pose_eval, loss_commit, perplexity = net(motion)
# all_indices,_ = net.encode(motion)
# pred_pose_eval = net.forward_decoder(all_indices[..., :1])
et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_pose_eval,
m_length)
bgt = val_loader.dataset.inv_transform(motion.detach().cpu().numpy())
bpred = val_loader.dataset.inv_transform(pred_pose_eval.detach().cpu().numpy())
for i in range(bs):
gt = recover_from_ric(torch.from_numpy(bgt[i, :m_length[i]]).float(), num_joint)
pred = recover_from_ric(torch.from_numpy(bpred[i, :m_length[i]]).float(), num_joint)
# gt = motion[i, :m_length[i]]
# pred = pred_pose_eval[i, :m_length[i]]
num_pose = gt.shape[0]
l1_dist += F.l1_loss(gt, pred) * num_pose
num_poses += num_pose
motion_pred_list.append(em_pred)
motion_annotation_list.append(em)
temp_R = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True)
temp_match = euclidean_distance_matrix(et.cpu().numpy(), em.cpu().numpy()).trace()
R_precision_real += temp_R
matching_score_real += temp_match
temp_R = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True)
temp_match = euclidean_distance_matrix(et_pred.cpu().numpy(), em_pred.cpu().numpy()).trace()
R_precision += temp_R
matching_score_pred += temp_match
nb_sample += bs
motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()
gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np)
mu, cov = calculate_activation_statistics(motion_pred_np)
diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100)
diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100)
R_precision_real = R_precision_real / nb_sample
R_precision = R_precision / nb_sample
matching_score_real = matching_score_real / nb_sample
matching_score_pred = matching_score_pred / nb_sample
l1_dist = l1_dist / num_poses
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
msg = "--> \t Eva. Re %d:, FID. %.4f, Diversity Real. %.4f, Diversity. %.4f, R_precision_real. (%.4f, %.4f, %.4f), R_precision. (%.4f, %.4f, %.4f), matching_real. %.4f, matching_pred. %.4f, mae. %.4f"%\
(repeat_id, fid, diversity_real, diversity, R_precision_real[0],R_precision_real[1], R_precision_real[2],
R_precision[0],R_precision[1], R_precision[2], matching_score_real, matching_score_pred, l1_dist)
# logger.info(msg)
print(msg)
return fid, diversity, R_precision, matching_score_pred, l1_dist
@torch.no_grad()
def evaluation_res_plus_l1(val_loader, vq_model, res_model, repeat_id, eval_wrapper, num_joint, do_vq_res=True):
vq_model.eval()
res_model.eval()
motion_annotation_list = []
motion_pred_list = []
R_precision_real = 0
R_precision = 0
nb_sample = 0
matching_score_real = 0
matching_score_pred = 0
l1_dist = 0
num_poses = 1
for batch in val_loader:
# print(len(batch))
word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, token = batch
motion = motion.cuda()
et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, motion, m_length)
bs, seq = motion.shape[0], motion.shape[1]
# num_joints = 21 if motion.shape[-1] == 251 else 22
# pred_pose_eval = torch.zeros((bs, seq, motion.shape[-1])).cuda()
if do_vq_res:
code_ids, all_codes = vq_model.encode(motion)
if len(code_ids.shape) == 3:
pred_vq_codes = res_model(code_ids[..., 0])
else:
pred_vq_codes = res_model(code_ids)
# pred_vq_codes = pred_vq_codes - pred_vq_res + all_codes[1:].sum(0)
pred_pose_eval = vq_model.decoder(pred_vq_codes)
else:
rec_motions, _, _ = vq_model(motion)
pred_pose_eval = res_model(rec_motions) # all_indices,_ = net.encode(motion)
# pred_pose_eval = net.forward_decoder(all_indices[..., :1])
et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_pose_eval,
m_length)
bgt = val_loader.dataset.inv_transform(motion.detach().cpu().numpy())
bpred = val_loader.dataset.inv_transform(pred_pose_eval.detach().cpu().numpy())
for i in range(bs):
gt = recover_from_ric(torch.from_numpy(bgt[i, :m_length[i]]).float(), num_joint)
pred = recover_from_ric(torch.from_numpy(bpred[i, :m_length[i]]).float(), num_joint)
# gt = motion[i, :m_length[i]]
# pred = pred_pose_eval[i, :m_length[i]]
num_pose = gt.shape[0]
l1_dist += F.l1_loss(gt, pred) * num_pose
num_poses += num_pose
motion_pred_list.append(em_pred)
motion_annotation_list.append(em)
temp_R = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True)
temp_match = euclidean_distance_matrix(et.cpu().numpy(), em.cpu().numpy()).trace()
R_precision_real += temp_R
matching_score_real += temp_match
temp_R = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True)
temp_match = euclidean_distance_matrix(et_pred.cpu().numpy(), em_pred.cpu().numpy()).trace()
R_precision += temp_R
matching_score_pred += temp_match
nb_sample += bs
motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()
gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np)
mu, cov = calculate_activation_statistics(motion_pred_np)
diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100)
diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100)
R_precision_real = R_precision_real / nb_sample
R_precision = R_precision / nb_sample
matching_score_real = matching_score_real / nb_sample
matching_score_pred = matching_score_pred / nb_sample
l1_dist = l1_dist / num_poses
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
msg = "--> \t Eva. Re %d:, FID. %.4f, Diversity Real. %.4f, Diversity. %.4f, R_precision_real. (%.4f, %.4f, %.4f), R_precision. (%.4f, %.4f, %.4f), matching_real. %.4f, matching_pred. %.4f, mae. %.4f"%\
(repeat_id, fid, diversity_real, diversity, R_precision_real[0],R_precision_real[1], R_precision_real[2],
R_precision[0],R_precision[1], R_precision[2], matching_score_real, matching_score_pred, l1_dist)
# logger.info(msg)
print(msg)
return fid, diversity, R_precision, matching_score_pred, l1_dist
@torch.no_grad()
def evaluation_mask_transformer(out_dir, val_loader, trans, vq_model, writer, ep, best_fid, best_div,
best_top1, best_top2, best_top3, best_matching, eval_wrapper, plot_func,
save_ckpt=False, save_anim=False):
def save(file_name, ep):
t2m_trans_state_dict = trans.state_dict()
clip_weights = [e for e in t2m_trans_state_dict.keys() if e.startswith('clip_model.')]
for e in clip_weights:
del t2m_trans_state_dict[e]
state = {
't2m_transformer': t2m_trans_state_dict,
# 'opt_t2m_transformer': self.opt_t2m_transformer.state_dict(),
# 'scheduler':self.scheduler.state_dict(),
'ep': ep,
}
torch.save(state, file_name)
trans.eval()
vq_model.eval()
motion_annotation_list = []
motion_pred_list = []
R_precision_real = 0
R_precision = 0
matching_score_real = 0
matching_score_pred = 0
time_steps = 18
if "kit" in out_dir:
cond_scale = 2
else:
cond_scale = 4
# print(num_quantizer)
# assert num_quantizer >= len(time_steps) and num_quantizer >= len(cond_scales)
nb_sample = 0
# for i in range(1):
for batch in val_loader:
word_embeddings, pos_one_hots, clip_text, sent_len, pose, m_length, token = batch
m_length = m_length.cuda()
bs, seq = pose.shape[:2]
# num_joints = 21 if pose.shape[-1] == 251 else 22
# (b, seqlen)
mids = trans.generate(clip_text, m_length//4, time_steps, cond_scale, temperature=1)
# motion_codes = motion_codes.permute(0, 2, 1)
mids.unsqueeze_(-1)
pred_motions = vq_model.forward_decoder(mids)
et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_motions.clone(),
m_length)
pose = pose.cuda().float()
et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pose, m_length)
motion_annotation_list.append(em)
motion_pred_list.append(em_pred)
temp_R = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True)
temp_match = euclidean_distance_matrix(et.cpu().numpy(), em.cpu().numpy()).trace()
R_precision_real += temp_R
matching_score_real += temp_match
temp_R = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True)
temp_match = euclidean_distance_matrix(et_pred.cpu().numpy(), em_pred.cpu().numpy()).trace()
R_precision += temp_R
matching_score_pred += temp_match
nb_sample += bs
motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()
gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np)
mu, cov = calculate_activation_statistics(motion_pred_np)
diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100)
diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100)
R_precision_real = R_precision_real / nb_sample
R_precision = R_precision / nb_sample
matching_score_real = matching_score_real / nb_sample
matching_score_pred = matching_score_pred / nb_sample
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
msg = f"--> \t Eva. Ep {ep} :, FID. {fid:.4f}, Diversity Real. {diversity_real:.4f}, Diversity. {diversity:.4f}, R_precision_real. {R_precision_real}, R_precision. {R_precision}, matching_score_real. {matching_score_real}, matching_score_pred. {matching_score_pred}"
print(msg)
# if draw:
writer.add_scalar('./Test/FID', fid, ep)
writer.add_scalar('./Test/Diversity', diversity, ep)
writer.add_scalar('./Test/top1', R_precision[0], ep)
writer.add_scalar('./Test/top2', R_precision[1], ep)
writer.add_scalar('./Test/top3', R_precision[2], ep)
writer.add_scalar('./Test/matching_score', matching_score_pred, ep)
if fid < best_fid:
msg = f"--> --> \t FID Improved from {best_fid:.5f} to {fid:.5f} !!!"
print(msg)
best_fid, best_ep = fid, ep
if save_ckpt:
save(os.path.join(out_dir, 'model', 'net_best_fid.tar'), ep)
if matching_score_pred < best_matching:
msg = f"--> --> \t matching_score Improved from {best_matching:.5f} to {matching_score_pred:.5f} !!!"
print(msg)
best_matching = matching_score_pred
if abs(diversity_real - diversity) < abs(diversity_real - best_div):
msg = f"--> --> \t Diversity Improved from {best_div:.5f} to {diversity:.5f} !!!"
print(msg)
best_div = diversity
if R_precision[0] > best_top1:
msg = f"--> --> \t Top1 Improved from {best_top1:.4f} to {R_precision[0]:.4f} !!!"
print(msg)
best_top1 = R_precision[0]
if R_precision[1] > best_top2:
msg = f"--> --> \t Top2 Improved from {best_top2:.4f} to {R_precision[1]:.4f} !!!"
print(msg)
best_top2 = R_precision[1]
if R_precision[2] > best_top3:
msg = f"--> --> \t Top3 Improved from {best_top3:.4f} to {R_precision[2]:.4f} !!!"
print(msg)
best_top3 = R_precision[2]
if save_anim:
rand_idx = torch.randint(bs, (3,))
data = pred_motions[rand_idx].detach().cpu().numpy()
captions = [clip_text[k] for k in rand_idx]
lengths = m_length[rand_idx].cpu().numpy()
save_dir = os.path.join(out_dir, 'animation', 'E%04d' % ep)
os.makedirs(save_dir, exist_ok=True)
# print(lengths)
plot_func(data, save_dir, captions, lengths)
return best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer
@torch.no_grad()
def evaluation_res_transformer(out_dir, val_loader, trans, vq_model, writer, ep, best_fid, best_div,
best_top1, best_top2, best_top3, best_matching, eval_wrapper, plot_func,
save_ckpt=False, save_anim=False, cond_scale=2, temperature=1):
def save(file_name, ep):
res_trans_state_dict = trans.state_dict()
clip_weights = [e for e in res_trans_state_dict.keys() if e.startswith('clip_model.')]
for e in clip_weights:
del res_trans_state_dict[e]
state = {
'res_transformer': res_trans_state_dict,
# 'opt_t2m_transformer': self.opt_t2m_transformer.state_dict(),
# 'scheduler':self.scheduler.state_dict(),
'ep': ep,
}
torch.save(state, file_name)
trans.eval()
vq_model.eval()
motion_annotation_list = []
motion_pred_list = []
R_precision_real = 0
R_precision = 0
matching_score_real = 0
matching_score_pred = 0
# print(num_quantizer)
# assert num_quantizer >= len(time_steps) and num_quantizer >= len(cond_scales)
nb_sample = 0
# for i in range(1):
for batch in val_loader:
word_embeddings, pos_one_hots, clip_text, sent_len, pose, m_length, token = batch
m_length = m_length.cuda().long()
pose = pose.cuda().float()
bs, seq = pose.shape[:2]
# num_joints = 21 if pose.shape[-1] == 251 else 22
code_indices, all_codes = vq_model.encode(pose)
# (b, seqlen)
if ep == 0:
pred_ids = code_indices[..., 0:1]
else:
pred_ids = trans.generate(code_indices[..., 0], clip_text, m_length//4,
temperature=temperature, cond_scale=cond_scale)
# pred_codes = trans(code_indices[..., 0], clip_text, m_length//4, force_mask=force_mask)
pred_motions = vq_model.forward_decoder(pred_ids)
et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_motions.clone(),
m_length)
pose = pose.cuda().float()
et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pose, m_length)
motion_annotation_list.append(em)
motion_pred_list.append(em_pred)
temp_R = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True)
temp_match = euclidean_distance_matrix(et.cpu().numpy(), em.cpu().numpy()).trace()
R_precision_real += temp_R
matching_score_real += temp_match
temp_R = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True)
temp_match = euclidean_distance_matrix(et_pred.cpu().numpy(), em_pred.cpu().numpy()).trace()
R_precision += temp_R
matching_score_pred += temp_match
nb_sample += bs
motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()
gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np)
mu, cov = calculate_activation_statistics(motion_pred_np)
diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100)
diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100)
R_precision_real = R_precision_real / nb_sample
R_precision = R_precision / nb_sample
matching_score_real = matching_score_real / nb_sample
matching_score_pred = matching_score_pred / nb_sample
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
msg = f"--> \t Eva. Ep {ep} :, FID. {fid:.4f}, Diversity Real. {diversity_real:.4f}, Diversity. {diversity:.4f}, R_precision_real. {R_precision_real}, R_precision. {R_precision}, matching_score_real. {matching_score_real}, matching_score_pred. {matching_score_pred}"
print(msg)
# if draw:
writer.add_scalar('./Test/FID', fid, ep)
writer.add_scalar('./Test/Diversity', diversity, ep)
writer.add_scalar('./Test/top1', R_precision[0], ep)
writer.add_scalar('./Test/top2', R_precision[1], ep)
writer.add_scalar('./Test/top3', R_precision[2], ep)
writer.add_scalar('./Test/matching_score', matching_score_pred, ep)
if fid < best_fid:
msg = f"--> --> \t FID Improved from {best_fid:.5f} to {fid:.5f} !!!"
print(msg)
best_fid, best_ep = fid, ep
if save_ckpt:
save(os.path.join(out_dir, 'model', 'net_best_fid.tar'), ep)
if matching_score_pred < best_matching:
msg = f"--> --> \t matching_score Improved from {best_matching:.5f} to {matching_score_pred:.5f} !!!"
print(msg)
best_matching = matching_score_pred
if abs(diversity_real - diversity) < abs(diversity_real - best_div):
msg = f"--> --> \t Diversity Improved from {best_div:.5f} to {diversity:.5f} !!!"
print(msg)
best_div = diversity
if R_precision[0] > best_top1:
msg = f"--> --> \t Top1 Improved from {best_top1:.4f} to {R_precision[0]:.4f} !!!"
print(msg)
best_top1 = R_precision[0]
if R_precision[1] > best_top2:
msg = f"--> --> \t Top2 Improved from {best_top2:.4f} to {R_precision[1]:.4f} !!!"
print(msg)
best_top2 = R_precision[1]
if R_precision[2] > best_top3:
msg = f"--> --> \t Top3 Improved from {best_top3:.4f} to {R_precision[2]:.4f} !!!"
print(msg)
best_top3 = R_precision[2]
if save_anim:
rand_idx = torch.randint(bs, (3,))
data = pred_motions[rand_idx].detach().cpu().numpy()
captions = [clip_text[k] for k in rand_idx]
lengths = m_length[rand_idx].cpu().numpy()
save_dir = os.path.join(out_dir, 'animation', 'E%04d' % ep)
os.makedirs(save_dir, exist_ok=True)
# print(lengths)
plot_func(data, save_dir, captions, lengths)
return best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer
@torch.no_grad()
def evaluation_res_transformer_plus_l1(val_loader, vq_model, trans, repeat_id, eval_wrapper, num_joint,
cond_scale=2, temperature=1, topkr=0.9, cal_l1=True):
trans.eval()
vq_model.eval()
motion_annotation_list = []
motion_pred_list = []
R_precision_real = 0
R_precision = 0
matching_score_real = 0
matching_score_pred = 0
# print(num_quantizer)
# assert num_quantizer >= len(time_steps) and num_quantizer >= len(cond_scales)
nb_sample = 0
l1_dist = 0
num_poses = 1
# for i in range(1):
for batch in val_loader:
word_embeddings, pos_one_hots, clip_text, sent_len, pose, m_length, token = batch
m_length = m_length.cuda().long()
pose = pose.cuda().float()
bs, seq = pose.shape[:2]
# num_joints = 21 if pose.shape[-1] == 251 else 22
code_indices, all_codes = vq_model.encode(pose)
# print(code_indices[0:2, :, 1])
pred_ids = trans.generate(code_indices[..., 0], clip_text, m_length//4, topk_filter_thres=topkr,
temperature=temperature, cond_scale=cond_scale)
# pred_codes = trans(code_indices[..., 0], clip_text, m_length//4, force_mask=force_mask)
pred_motions = vq_model.forward_decoder(pred_ids)
if cal_l1:
bgt = val_loader.dataset.inv_transform(pose.detach().cpu().numpy())
bpred = val_loader.dataset.inv_transform(pred_motions.detach().cpu().numpy())
for i in range(bs):
gt = recover_from_ric(torch.from_numpy(bgt[i, :m_length[i]]).float(), num_joint)
pred = recover_from_ric(torch.from_numpy(bpred[i, :m_length[i]]).float(), num_joint)
# gt = motion[i, :m_length[i]]
# pred = pred_pose_eval[i, :m_length[i]]
num_pose = gt.shape[0]
l1_dist += F.l1_loss(gt, pred) * num_pose
num_poses += num_pose
et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_motions.clone(),
m_length)
pose = pose.cuda().float()
et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pose, m_length)
motion_annotation_list.append(em)
motion_pred_list.append(em_pred)
temp_R = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True)
temp_match = euclidean_distance_matrix(et.cpu().numpy(), em.cpu().numpy()).trace()
R_precision_real += temp_R
matching_score_real += temp_match
temp_R = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True)
temp_match = euclidean_distance_matrix(et_pred.cpu().numpy(), em_pred.cpu().numpy()).trace()
R_precision += temp_R
matching_score_pred += temp_match
nb_sample += bs
motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()
gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np)
mu, cov = calculate_activation_statistics(motion_pred_np)
diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100)
diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100)
R_precision_real = R_precision_real / nb_sample
R_precision = R_precision / nb_sample
matching_score_real = matching_score_real / nb_sample
matching_score_pred = matching_score_pred / nb_sample
l1_dist = l1_dist / num_poses
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
msg = "--> \t Eva. Re %d:, FID. %.4f, Diversity Real. %.4f, Diversity. %.4f, R_precision_real. (%.4f, %.4f, %.4f), R_precision. (%.4f, %.4f, %.4f), matching_real. %.4f, matching_pred. %.4f, mae. %.4f" % \
(repeat_id, fid, diversity_real, diversity, R_precision_real[0], R_precision_real[1], R_precision_real[2],
R_precision[0], R_precision[1], R_precision[2], matching_score_real, matching_score_pred, l1_dist)
# logger.info(msg)
print(msg)
return fid, diversity, R_precision, matching_score_pred, l1_dist
@torch.no_grad()
def evaluation_mask_transformer_test(val_loader, vq_model, trans, repeat_id, eval_wrapper,
time_steps, cond_scale, temperature, topkr, gsample=True, force_mask=False, cal_mm=True):
trans.eval()
vq_model.eval()
motion_annotation_list = []
motion_pred_list = []
motion_multimodality = []
R_precision_real = 0
R_precision = 0
matching_score_real = 0
matching_score_pred = 0
multimodality = 0
nb_sample = 0
if cal_mm:
num_mm_batch = 3
else:
num_mm_batch = 0
for i, batch in enumerate(val_loader):
# print(i)
word_embeddings, pos_one_hots, clip_text, sent_len, pose, m_length, token = batch
m_length = m_length.cuda()
bs, seq = pose.shape[:2]
# num_joints = 21 if pose.shape[-1] == 251 else 22
# for i in range(mm_batch)
if i < num_mm_batch:
# (b, seqlen, c)
motion_multimodality_batch = []
for _ in range(30):
mids = trans.generate(clip_text, m_length // 4, time_steps, cond_scale,
temperature=temperature, topk_filter_thres=topkr,
gsample=gsample, force_mask=force_mask)
# motion_codes = motion_codes.permute(0, 2, 1)
mids.unsqueeze_(-1)
pred_motions = vq_model.forward_decoder(mids)
et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_motions.clone(),
m_length)
# em_pred = em_pred.unsqueeze(1) #(bs, 1, d)
motion_multimodality_batch.append(em_pred.unsqueeze(1))
motion_multimodality_batch = torch.cat(motion_multimodality_batch, dim=1) #(bs, 30, d)
motion_multimodality.append(motion_multimodality_batch)
else:
mids = trans.generate(clip_text, m_length // 4, time_steps, cond_scale,
temperature=temperature, topk_filter_thres=topkr,
force_mask=force_mask)
# motion_codes = motion_codes.permute(0, 2, 1)
mids.unsqueeze_(-1)
pred_motions = vq_model.forward_decoder(mids)
et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len,
pred_motions.clone(),
m_length)
pose = pose.cuda().float()
et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pose, m_length)
motion_annotation_list.append(em)
motion_pred_list.append(em_pred)
temp_R = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True)
temp_match = euclidean_distance_matrix(et.cpu().numpy(), em.cpu().numpy()).trace()
R_precision_real += temp_R
matching_score_real += temp_match
# print(et_pred.shape, em_pred.shape)
temp_R = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True)
temp_match = euclidean_distance_matrix(et_pred.cpu().numpy(), em_pred.cpu().numpy()).trace()
R_precision += temp_R
matching_score_pred += temp_match
nb_sample += bs
motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()
if not force_mask and cal_mm:
motion_multimodality = torch.cat(motion_multimodality, dim=0).cpu().numpy()
multimodality = calculate_multimodality(motion_multimodality, 10)
gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np)
mu, cov = calculate_activation_statistics(motion_pred_np)
diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100)
diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100)
R_precision_real = R_precision_real / nb_sample
R_precision = R_precision / nb_sample
matching_score_real = matching_score_real / nb_sample
matching_score_pred = matching_score_pred / nb_sample
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
msg = f"--> \t Eva. Repeat {repeat_id} :, FID. {fid:.4f}, " \
f"Diversity Real. {diversity_real:.4f}, Diversity. {diversity:.4f}, " \
f"R_precision_real. {R_precision_real}, R_precision. {R_precision}, " \
f"matching_score_real. {matching_score_real:.4f}, matching_score_pred. {matching_score_pred:.4f}," \
f"multimodality. {multimodality:.4f}"
print(msg)
return fid, diversity, R_precision, matching_score_pred, multimodality
@torch.no_grad()
def evaluation_mask_transformer_test_plus_res(val_loader, vq_model, res_model, trans, repeat_id, eval_wrapper,
time_steps, cond_scale, temperature, topkr, gsample=True, force_mask=False,
cal_mm=True, res_cond_scale=5):
trans.eval()
vq_model.eval()
res_model.eval()
motion_annotation_list = []
motion_pred_list = []
motion_multimodality = []
R_precision_real = 0
R_precision = 0
matching_score_real = 0
matching_score_pred = 0
multimodality = 0
nb_sample = 0
if force_mask or (not cal_mm):
num_mm_batch = 0
else:
num_mm_batch = 3
for i, batch in enumerate(val_loader):
word_embeddings, pos_one_hots, clip_text, sent_len, pose, m_length, token = batch
m_length = m_length.cuda()
bs, seq = pose.shape[:2]
# num_joints = 21 if pose.shape[-1] == 251 else 22
# for i in range(mm_batch)
if i < num_mm_batch:
# (b, seqlen, c)
motion_multimodality_batch = []
for _ in range(30):
mids = trans.generate(clip_text, m_length // 4, time_steps, cond_scale,
temperature=temperature, topk_filter_thres=topkr,
gsample=gsample, force_mask=force_mask)
# motion_codes = motion_codes.permute(0, 2, 1)
# mids.unsqueeze_(-1)
pred_ids = res_model.generate(mids, clip_text, m_length // 4, temperature=1, cond_scale=res_cond_scale)
# pred_codes = trans(code_indices[..., 0], clip_text, m_length//4, force_mask=force_mask)
# pred_ids = torch.where(pred_ids==-1, 0, pred_ids)
pred_motions = vq_model.forward_decoder(pred_ids)
# pred_motions = vq_model.decoder(codes)
# pred_motions = vq_model.forward_decoder(mids)
et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_motions.clone(),
m_length)
# em_pred = em_pred.unsqueeze(1) #(bs, 1, d)
motion_multimodality_batch.append(em_pred.unsqueeze(1))
motion_multimodality_batch = torch.cat(motion_multimodality_batch, dim=1) #(bs, 30, d)
motion_multimodality.append(motion_multimodality_batch)
else:
mids = trans.generate(clip_text, m_length // 4, time_steps, cond_scale,
temperature=temperature, topk_filter_thres=topkr,
force_mask=force_mask)
# motion_codes = motion_codes.permute(0, 2, 1)
# mids.unsqueeze_(-1)
pred_ids = res_model.generate(mids, clip_text, m_length // 4, temperature=1, cond_scale=res_cond_scale)
# pred_codes = trans(code_indices[..., 0], clip_text, m_length//4, force_mask=force_mask)
# pred_ids = torch.where(pred_ids == -1, 0, pred_ids)
pred_motions = vq_model.forward_decoder(pred_ids)
# pred_motions = vq_model.forward_decoder(mids)
et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len,
pred_motions.clone(),
m_length)
pose = pose.cuda().float()
et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pose, m_length)
motion_annotation_list.append(em)
motion_pred_list.append(em_pred)
temp_R = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True)
temp_match = euclidean_distance_matrix(et.cpu().numpy(), em.cpu().numpy()).trace()
R_precision_real += temp_R
matching_score_real += temp_match
# print(et_pred.shape, em_pred.shape)
temp_R = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True)
temp_match = euclidean_distance_matrix(et_pred.cpu().numpy(), em_pred.cpu().numpy()).trace()
R_precision += temp_R
matching_score_pred += temp_match
nb_sample += bs
motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()
if not force_mask and cal_mm:
motion_multimodality = torch.cat(motion_multimodality, dim=0).cpu().numpy()
multimodality = calculate_multimodality(motion_multimodality, 10)
gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np)
mu, cov = calculate_activation_statistics(motion_pred_np)
diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100)
diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100)
R_precision_real = R_precision_real / nb_sample
R_precision = R_precision / nb_sample
matching_score_real = matching_score_real / nb_sample
matching_score_pred = matching_score_pred / nb_sample
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
msg = f"--> \t Eva. Repeat {repeat_id} :, FID. {fid:.4f}, " \
f"Diversity Real. {diversity_real:.4f}, Diversity. {diversity:.4f}, " \
f"R_precision_real. {R_precision_real}, R_precision. {R_precision}, " \
f"matching_score_real. {matching_score_real:.4f}, matching_score_pred. {matching_score_pred:.4f}," \
f"multimodality. {multimodality:.4f}"
print(msg)
return fid, diversity, R_precision, matching_score_pred, multimodality