Spaces:
Running
Running
import os | |
import clip | |
import numpy as np | |
import torch | |
# from scipy import linalg | |
from utils.metrics import * | |
import torch.nn.functional as F | |
# import visualization.plot_3d_global as plot_3d | |
from utils.motion_process import recover_from_ric | |
# | |
# | |
# def tensorborad_add_video_xyz(writer, xyz, nb_iter, tag, nb_vis=4, title_batch=None, outname=None): | |
# xyz = xyz[:1] | |
# bs, seq = xyz.shape[:2] | |
# xyz = xyz.reshape(bs, seq, -1, 3) | |
# plot_xyz = plot_3d.draw_to_batch(xyz.cpu().numpy(), title_batch, outname) | |
# plot_xyz = np.transpose(plot_xyz, (0, 1, 4, 2, 3)) | |
# writer.add_video(tag, plot_xyz, nb_iter, fps=20) | |
def evaluation_vqvae(out_dir, val_loader, net, writer, ep, best_fid, best_div, best_top1, | |
best_top2, best_top3, best_matching, eval_wrapper, save=True, draw=True): | |
net.eval() | |
motion_annotation_list = [] | |
motion_pred_list = [] | |
R_precision_real = 0 | |
R_precision = 0 | |
nb_sample = 0 | |
matching_score_real = 0 | |
matching_score_pred = 0 | |
for batch in val_loader: | |
# print(len(batch)) | |
word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, token = batch | |
motion = motion.cuda() | |
et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, motion, m_length) | |
bs, seq = motion.shape[0], motion.shape[1] | |
# num_joints = 21 if motion.shape[-1] == 251 else 22 | |
# pred_pose_eval = torch.zeros((bs, seq, motion.shape[-1])).cuda() | |
pred_pose_eval, loss_commit, perplexity = net(motion) | |
et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_pose_eval, | |
m_length) | |
motion_pred_list.append(em_pred) | |
motion_annotation_list.append(em) | |
temp_R = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True) | |
temp_match = euclidean_distance_matrix(et.cpu().numpy(), em.cpu().numpy()).trace() | |
R_precision_real += temp_R | |
matching_score_real += temp_match | |
temp_R = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True) | |
temp_match = euclidean_distance_matrix(et_pred.cpu().numpy(), em_pred.cpu().numpy()).trace() | |
R_precision += temp_R | |
matching_score_pred += temp_match | |
nb_sample += bs | |
motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy() | |
motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy() | |
gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np) | |
mu, cov = calculate_activation_statistics(motion_pred_np) | |
diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100) | |
diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100) | |
R_precision_real = R_precision_real / nb_sample | |
R_precision = R_precision / nb_sample | |
matching_score_real = matching_score_real / nb_sample | |
matching_score_pred = matching_score_pred / nb_sample | |
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov) | |
msg = "--> \t Eva. Ep %d:, FID. %.4f, Diversity Real. %.4f, Diversity. %.4f, R_precision_real. (%.4f, %.4f, %.4f), R_precision. (%.4f, %.4f, %.4f), matching_score_real. %.4f, matching_score_pred. %.4f"%\ | |
(ep, fid, diversity_real, diversity, R_precision_real[0],R_precision_real[1], R_precision_real[2], | |
R_precision[0],R_precision[1], R_precision[2], matching_score_real, matching_score_pred ) | |
# logger.info(msg) | |
print(msg) | |
if draw: | |
writer.add_scalar('./Test/FID', fid, ep) | |
writer.add_scalar('./Test/Diversity', diversity, ep) | |
writer.add_scalar('./Test/top1', R_precision[0], ep) | |
writer.add_scalar('./Test/top2', R_precision[1], ep) | |
writer.add_scalar('./Test/top3', R_precision[2], ep) | |
writer.add_scalar('./Test/matching_score', matching_score_pred, ep) | |
if fid < best_fid: | |
msg = "--> --> \t FID Improved from %.5f to %.5f !!!" % (best_fid, fid) | |
if draw: print(msg) | |
best_fid = fid | |
if save: | |
torch.save({'vq_model': net.state_dict(), 'ep': ep}, os.path.join(out_dir, 'net_best_fid.tar')) | |
if abs(diversity_real - diversity) < abs(diversity_real - best_div): | |
msg = "--> --> \t Diversity Improved from %.5f to %.5f !!!"%(best_div, diversity) | |
if draw: print(msg) | |
best_div = diversity | |
# if save: | |
# torch.save({'net': net.state_dict()}, os.path.join(out_dir, 'net_best_div.pth')) | |
if R_precision[0] > best_top1: | |
msg = "--> --> \t Top1 Improved from %.5f to %.5f !!!" % (best_top1, R_precision[0]) | |
if draw: print(msg) | |
best_top1 = R_precision[0] | |
# if save: | |
# torch.save({'vq_model': net.state_dict(), 'ep':ep}, os.path.join(out_dir, 'net_best_top1.tar')) | |
if R_precision[1] > best_top2: | |
msg = "--> --> \t Top2 Improved from %.5f to %.5f!!!" % (best_top2, R_precision[1]) | |
if draw: print(msg) | |
best_top2 = R_precision[1] | |
if R_precision[2] > best_top3: | |
msg = "--> --> \t Top3 Improved from %.5f to %.5f !!!" % (best_top3, R_precision[2]) | |
if draw: print(msg) | |
best_top3 = R_precision[2] | |
if matching_score_pred < best_matching: | |
msg = f"--> --> \t matching_score Improved from %.5f to %.5f !!!" % (best_matching, matching_score_pred) | |
if draw: print(msg) | |
best_matching = matching_score_pred | |
if save: | |
torch.save({'vq_model': net.state_dict(), 'ep': ep}, os.path.join(out_dir, 'net_best_mm.tar')) | |
# if save: | |
# torch.save({'net': net.state_dict()}, os.path.join(out_dir, 'net_last.pth')) | |
net.train() | |
return best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer | |
def evaluation_vqvae_plus_mpjpe(val_loader, net, repeat_id, eval_wrapper, num_joint): | |
net.eval() | |
motion_annotation_list = [] | |
motion_pred_list = [] | |
R_precision_real = 0 | |
R_precision = 0 | |
nb_sample = 0 | |
matching_score_real = 0 | |
matching_score_pred = 0 | |
mpjpe = 0 | |
num_poses = 0 | |
for batch in val_loader: | |
# print(len(batch)) | |
word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, token = batch | |
motion = motion.cuda() | |
et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, motion, m_length) | |
bs, seq = motion.shape[0], motion.shape[1] | |
# num_joints = 21 if motion.shape[-1] == 251 else 22 | |
# pred_pose_eval = torch.zeros((bs, seq, motion.shape[-1])).cuda() | |
pred_pose_eval, loss_commit, perplexity = net(motion) | |
# all_indices,_ = net.encode(motion) | |
# pred_pose_eval = net.forward_decoder(all_indices[..., :1]) | |
et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_pose_eval, | |
m_length) | |
bgt = val_loader.dataset.inv_transform(motion.detach().cpu().numpy()) | |
bpred = val_loader.dataset.inv_transform(pred_pose_eval.detach().cpu().numpy()) | |
for i in range(bs): | |
gt = recover_from_ric(torch.from_numpy(bgt[i, :m_length[i]]).float(), num_joint) | |
pred = recover_from_ric(torch.from_numpy(bpred[i, :m_length[i]]).float(), num_joint) | |
mpjpe += torch.sum(calculate_mpjpe(gt, pred)) | |
# print(calculate_mpjpe(gt, pred).shape, gt.shape, pred.shape) | |
num_poses += gt.shape[0] | |
# print(mpjpe, num_poses) | |
# exit() | |
motion_pred_list.append(em_pred) | |
motion_annotation_list.append(em) | |
temp_R = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True) | |
temp_match = euclidean_distance_matrix(et.cpu().numpy(), em.cpu().numpy()).trace() | |
R_precision_real += temp_R | |
matching_score_real += temp_match | |
temp_R = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True) | |
temp_match = euclidean_distance_matrix(et_pred.cpu().numpy(), em_pred.cpu().numpy()).trace() | |
R_precision += temp_R | |
matching_score_pred += temp_match | |
nb_sample += bs | |
motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy() | |
motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy() | |
gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np) | |
mu, cov = calculate_activation_statistics(motion_pred_np) | |
diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100) | |
diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100) | |
R_precision_real = R_precision_real / nb_sample | |
R_precision = R_precision / nb_sample | |
matching_score_real = matching_score_real / nb_sample | |
matching_score_pred = matching_score_pred / nb_sample | |
mpjpe = mpjpe / num_poses | |
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov) | |
msg = "--> \t Eva. Re %d:, FID. %.4f, Diversity Real. %.4f, Diversity. %.4f, R_precision_real. (%.4f, %.4f, %.4f), R_precision. (%.4f, %.4f, %.4f), matching_real. %.4f, matching_pred. %.4f, MPJPE. %.4f" % \ | |
(repeat_id, fid, diversity_real, diversity, R_precision_real[0], R_precision_real[1], R_precision_real[2], | |
R_precision[0], R_precision[1], R_precision[2], matching_score_real, matching_score_pred, mpjpe) | |
# logger.info(msg) | |
print(msg) | |
return fid, diversity, R_precision, matching_score_pred, mpjpe | |
def evaluation_vqvae_plus_l1(val_loader, net, repeat_id, eval_wrapper, num_joint): | |
net.eval() | |
motion_annotation_list = [] | |
motion_pred_list = [] | |
R_precision_real = 0 | |
R_precision = 0 | |
nb_sample = 0 | |
matching_score_real = 0 | |
matching_score_pred = 0 | |
l1_dist = 0 | |
num_poses = 1 | |
for batch in val_loader: | |
# print(len(batch)) | |
word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, token = batch | |
motion = motion.cuda() | |
et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, motion, m_length) | |
bs, seq = motion.shape[0], motion.shape[1] | |
# num_joints = 21 if motion.shape[-1] == 251 else 22 | |
# pred_pose_eval = torch.zeros((bs, seq, motion.shape[-1])).cuda() | |
pred_pose_eval, loss_commit, perplexity = net(motion) | |
# all_indices,_ = net.encode(motion) | |
# pred_pose_eval = net.forward_decoder(all_indices[..., :1]) | |
et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_pose_eval, | |
m_length) | |
bgt = val_loader.dataset.inv_transform(motion.detach().cpu().numpy()) | |
bpred = val_loader.dataset.inv_transform(pred_pose_eval.detach().cpu().numpy()) | |
for i in range(bs): | |
gt = recover_from_ric(torch.from_numpy(bgt[i, :m_length[i]]).float(), num_joint) | |
pred = recover_from_ric(torch.from_numpy(bpred[i, :m_length[i]]).float(), num_joint) | |
# gt = motion[i, :m_length[i]] | |
# pred = pred_pose_eval[i, :m_length[i]] | |
num_pose = gt.shape[0] | |
l1_dist += F.l1_loss(gt, pred) * num_pose | |
num_poses += num_pose | |
motion_pred_list.append(em_pred) | |
motion_annotation_list.append(em) | |
temp_R = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True) | |
temp_match = euclidean_distance_matrix(et.cpu().numpy(), em.cpu().numpy()).trace() | |
R_precision_real += temp_R | |
matching_score_real += temp_match | |
temp_R = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True) | |
temp_match = euclidean_distance_matrix(et_pred.cpu().numpy(), em_pred.cpu().numpy()).trace() | |
R_precision += temp_R | |
matching_score_pred += temp_match | |
nb_sample += bs | |
motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy() | |
motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy() | |
gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np) | |
mu, cov = calculate_activation_statistics(motion_pred_np) | |
diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100) | |
diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100) | |
R_precision_real = R_precision_real / nb_sample | |
R_precision = R_precision / nb_sample | |
matching_score_real = matching_score_real / nb_sample | |
matching_score_pred = matching_score_pred / nb_sample | |
l1_dist = l1_dist / num_poses | |
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov) | |
msg = "--> \t Eva. Re %d:, FID. %.4f, Diversity Real. %.4f, Diversity. %.4f, R_precision_real. (%.4f, %.4f, %.4f), R_precision. (%.4f, %.4f, %.4f), matching_real. %.4f, matching_pred. %.4f, mae. %.4f"%\ | |
(repeat_id, fid, diversity_real, diversity, R_precision_real[0],R_precision_real[1], R_precision_real[2], | |
R_precision[0],R_precision[1], R_precision[2], matching_score_real, matching_score_pred, l1_dist) | |
# logger.info(msg) | |
print(msg) | |
return fid, diversity, R_precision, matching_score_pred, l1_dist | |
def evaluation_res_plus_l1(val_loader, vq_model, res_model, repeat_id, eval_wrapper, num_joint, do_vq_res=True): | |
vq_model.eval() | |
res_model.eval() | |
motion_annotation_list = [] | |
motion_pred_list = [] | |
R_precision_real = 0 | |
R_precision = 0 | |
nb_sample = 0 | |
matching_score_real = 0 | |
matching_score_pred = 0 | |
l1_dist = 0 | |
num_poses = 1 | |
for batch in val_loader: | |
# print(len(batch)) | |
word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, token = batch | |
motion = motion.cuda() | |
et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, motion, m_length) | |
bs, seq = motion.shape[0], motion.shape[1] | |
# num_joints = 21 if motion.shape[-1] == 251 else 22 | |
# pred_pose_eval = torch.zeros((bs, seq, motion.shape[-1])).cuda() | |
if do_vq_res: | |
code_ids, all_codes = vq_model.encode(motion) | |
if len(code_ids.shape) == 3: | |
pred_vq_codes = res_model(code_ids[..., 0]) | |
else: | |
pred_vq_codes = res_model(code_ids) | |
# pred_vq_codes = pred_vq_codes - pred_vq_res + all_codes[1:].sum(0) | |
pred_pose_eval = vq_model.decoder(pred_vq_codes) | |
else: | |
rec_motions, _, _ = vq_model(motion) | |
pred_pose_eval = res_model(rec_motions) # all_indices,_ = net.encode(motion) | |
# pred_pose_eval = net.forward_decoder(all_indices[..., :1]) | |
et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_pose_eval, | |
m_length) | |
bgt = val_loader.dataset.inv_transform(motion.detach().cpu().numpy()) | |
bpred = val_loader.dataset.inv_transform(pred_pose_eval.detach().cpu().numpy()) | |
for i in range(bs): | |
gt = recover_from_ric(torch.from_numpy(bgt[i, :m_length[i]]).float(), num_joint) | |
pred = recover_from_ric(torch.from_numpy(bpred[i, :m_length[i]]).float(), num_joint) | |
# gt = motion[i, :m_length[i]] | |
# pred = pred_pose_eval[i, :m_length[i]] | |
num_pose = gt.shape[0] | |
l1_dist += F.l1_loss(gt, pred) * num_pose | |
num_poses += num_pose | |
motion_pred_list.append(em_pred) | |
motion_annotation_list.append(em) | |
temp_R = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True) | |
temp_match = euclidean_distance_matrix(et.cpu().numpy(), em.cpu().numpy()).trace() | |
R_precision_real += temp_R | |
matching_score_real += temp_match | |
temp_R = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True) | |
temp_match = euclidean_distance_matrix(et_pred.cpu().numpy(), em_pred.cpu().numpy()).trace() | |
R_precision += temp_R | |
matching_score_pred += temp_match | |
nb_sample += bs | |
motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy() | |
motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy() | |
gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np) | |
mu, cov = calculate_activation_statistics(motion_pred_np) | |
diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100) | |
diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100) | |
R_precision_real = R_precision_real / nb_sample | |
R_precision = R_precision / nb_sample | |
matching_score_real = matching_score_real / nb_sample | |
matching_score_pred = matching_score_pred / nb_sample | |
l1_dist = l1_dist / num_poses | |
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov) | |
msg = "--> \t Eva. Re %d:, FID. %.4f, Diversity Real. %.4f, Diversity. %.4f, R_precision_real. (%.4f, %.4f, %.4f), R_precision. (%.4f, %.4f, %.4f), matching_real. %.4f, matching_pred. %.4f, mae. %.4f"%\ | |
(repeat_id, fid, diversity_real, diversity, R_precision_real[0],R_precision_real[1], R_precision_real[2], | |
R_precision[0],R_precision[1], R_precision[2], matching_score_real, matching_score_pred, l1_dist) | |
# logger.info(msg) | |
print(msg) | |
return fid, diversity, R_precision, matching_score_pred, l1_dist | |
def evaluation_mask_transformer(out_dir, val_loader, trans, vq_model, writer, ep, best_fid, best_div, | |
best_top1, best_top2, best_top3, best_matching, eval_wrapper, plot_func, | |
save_ckpt=False, save_anim=False): | |
def save(file_name, ep): | |
t2m_trans_state_dict = trans.state_dict() | |
clip_weights = [e for e in t2m_trans_state_dict.keys() if e.startswith('clip_model.')] | |
for e in clip_weights: | |
del t2m_trans_state_dict[e] | |
state = { | |
't2m_transformer': t2m_trans_state_dict, | |
# 'opt_t2m_transformer': self.opt_t2m_transformer.state_dict(), | |
# 'scheduler':self.scheduler.state_dict(), | |
'ep': ep, | |
} | |
torch.save(state, file_name) | |
trans.eval() | |
vq_model.eval() | |
motion_annotation_list = [] | |
motion_pred_list = [] | |
R_precision_real = 0 | |
R_precision = 0 | |
matching_score_real = 0 | |
matching_score_pred = 0 | |
time_steps = 18 | |
if "kit" in out_dir: | |
cond_scale = 2 | |
else: | |
cond_scale = 4 | |
# print(num_quantizer) | |
# assert num_quantizer >= len(time_steps) and num_quantizer >= len(cond_scales) | |
nb_sample = 0 | |
# for i in range(1): | |
for batch in val_loader: | |
word_embeddings, pos_one_hots, clip_text, sent_len, pose, m_length, token = batch | |
m_length = m_length.cuda() | |
bs, seq = pose.shape[:2] | |
# num_joints = 21 if pose.shape[-1] == 251 else 22 | |
# (b, seqlen) | |
mids = trans.generate(clip_text, m_length//4, time_steps, cond_scale, temperature=1) | |
# motion_codes = motion_codes.permute(0, 2, 1) | |
mids.unsqueeze_(-1) | |
pred_motions = vq_model.forward_decoder(mids) | |
et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_motions.clone(), | |
m_length) | |
pose = pose.cuda().float() | |
et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pose, m_length) | |
motion_annotation_list.append(em) | |
motion_pred_list.append(em_pred) | |
temp_R = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True) | |
temp_match = euclidean_distance_matrix(et.cpu().numpy(), em.cpu().numpy()).trace() | |
R_precision_real += temp_R | |
matching_score_real += temp_match | |
temp_R = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True) | |
temp_match = euclidean_distance_matrix(et_pred.cpu().numpy(), em_pred.cpu().numpy()).trace() | |
R_precision += temp_R | |
matching_score_pred += temp_match | |
nb_sample += bs | |
motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy() | |
motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy() | |
gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np) | |
mu, cov = calculate_activation_statistics(motion_pred_np) | |
diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100) | |
diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100) | |
R_precision_real = R_precision_real / nb_sample | |
R_precision = R_precision / nb_sample | |
matching_score_real = matching_score_real / nb_sample | |
matching_score_pred = matching_score_pred / nb_sample | |
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov) | |
msg = f"--> \t Eva. Ep {ep} :, FID. {fid:.4f}, Diversity Real. {diversity_real:.4f}, Diversity. {diversity:.4f}, R_precision_real. {R_precision_real}, R_precision. {R_precision}, matching_score_real. {matching_score_real}, matching_score_pred. {matching_score_pred}" | |
print(msg) | |
# if draw: | |
writer.add_scalar('./Test/FID', fid, ep) | |
writer.add_scalar('./Test/Diversity', diversity, ep) | |
writer.add_scalar('./Test/top1', R_precision[0], ep) | |
writer.add_scalar('./Test/top2', R_precision[1], ep) | |
writer.add_scalar('./Test/top3', R_precision[2], ep) | |
writer.add_scalar('./Test/matching_score', matching_score_pred, ep) | |
if fid < best_fid: | |
msg = f"--> --> \t FID Improved from {best_fid:.5f} to {fid:.5f} !!!" | |
print(msg) | |
best_fid, best_ep = fid, ep | |
if save_ckpt: | |
save(os.path.join(out_dir, 'model', 'net_best_fid.tar'), ep) | |
if matching_score_pred < best_matching: | |
msg = f"--> --> \t matching_score Improved from {best_matching:.5f} to {matching_score_pred:.5f} !!!" | |
print(msg) | |
best_matching = matching_score_pred | |
if abs(diversity_real - diversity) < abs(diversity_real - best_div): | |
msg = f"--> --> \t Diversity Improved from {best_div:.5f} to {diversity:.5f} !!!" | |
print(msg) | |
best_div = diversity | |
if R_precision[0] > best_top1: | |
msg = f"--> --> \t Top1 Improved from {best_top1:.4f} to {R_precision[0]:.4f} !!!" | |
print(msg) | |
best_top1 = R_precision[0] | |
if R_precision[1] > best_top2: | |
msg = f"--> --> \t Top2 Improved from {best_top2:.4f} to {R_precision[1]:.4f} !!!" | |
print(msg) | |
best_top2 = R_precision[1] | |
if R_precision[2] > best_top3: | |
msg = f"--> --> \t Top3 Improved from {best_top3:.4f} to {R_precision[2]:.4f} !!!" | |
print(msg) | |
best_top3 = R_precision[2] | |
if save_anim: | |
rand_idx = torch.randint(bs, (3,)) | |
data = pred_motions[rand_idx].detach().cpu().numpy() | |
captions = [clip_text[k] for k in rand_idx] | |
lengths = m_length[rand_idx].cpu().numpy() | |
save_dir = os.path.join(out_dir, 'animation', 'E%04d' % ep) | |
os.makedirs(save_dir, exist_ok=True) | |
# print(lengths) | |
plot_func(data, save_dir, captions, lengths) | |
return best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer | |
def evaluation_res_transformer(out_dir, val_loader, trans, vq_model, writer, ep, best_fid, best_div, | |
best_top1, best_top2, best_top3, best_matching, eval_wrapper, plot_func, | |
save_ckpt=False, save_anim=False, cond_scale=2, temperature=1): | |
def save(file_name, ep): | |
res_trans_state_dict = trans.state_dict() | |
clip_weights = [e for e in res_trans_state_dict.keys() if e.startswith('clip_model.')] | |
for e in clip_weights: | |
del res_trans_state_dict[e] | |
state = { | |
'res_transformer': res_trans_state_dict, | |
# 'opt_t2m_transformer': self.opt_t2m_transformer.state_dict(), | |
# 'scheduler':self.scheduler.state_dict(), | |
'ep': ep, | |
} | |
torch.save(state, file_name) | |
trans.eval() | |
vq_model.eval() | |
motion_annotation_list = [] | |
motion_pred_list = [] | |
R_precision_real = 0 | |
R_precision = 0 | |
matching_score_real = 0 | |
matching_score_pred = 0 | |
# print(num_quantizer) | |
# assert num_quantizer >= len(time_steps) and num_quantizer >= len(cond_scales) | |
nb_sample = 0 | |
# for i in range(1): | |
for batch in val_loader: | |
word_embeddings, pos_one_hots, clip_text, sent_len, pose, m_length, token = batch | |
m_length = m_length.cuda().long() | |
pose = pose.cuda().float() | |
bs, seq = pose.shape[:2] | |
# num_joints = 21 if pose.shape[-1] == 251 else 22 | |
code_indices, all_codes = vq_model.encode(pose) | |
# (b, seqlen) | |
if ep == 0: | |
pred_ids = code_indices[..., 0:1] | |
else: | |
pred_ids = trans.generate(code_indices[..., 0], clip_text, m_length//4, | |
temperature=temperature, cond_scale=cond_scale) | |
# pred_codes = trans(code_indices[..., 0], clip_text, m_length//4, force_mask=force_mask) | |
pred_motions = vq_model.forward_decoder(pred_ids) | |
et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_motions.clone(), | |
m_length) | |
pose = pose.cuda().float() | |
et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pose, m_length) | |
motion_annotation_list.append(em) | |
motion_pred_list.append(em_pred) | |
temp_R = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True) | |
temp_match = euclidean_distance_matrix(et.cpu().numpy(), em.cpu().numpy()).trace() | |
R_precision_real += temp_R | |
matching_score_real += temp_match | |
temp_R = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True) | |
temp_match = euclidean_distance_matrix(et_pred.cpu().numpy(), em_pred.cpu().numpy()).trace() | |
R_precision += temp_R | |
matching_score_pred += temp_match | |
nb_sample += bs | |
motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy() | |
motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy() | |
gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np) | |
mu, cov = calculate_activation_statistics(motion_pred_np) | |
diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100) | |
diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100) | |
R_precision_real = R_precision_real / nb_sample | |
R_precision = R_precision / nb_sample | |
matching_score_real = matching_score_real / nb_sample | |
matching_score_pred = matching_score_pred / nb_sample | |
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov) | |
msg = f"--> \t Eva. Ep {ep} :, FID. {fid:.4f}, Diversity Real. {diversity_real:.4f}, Diversity. {diversity:.4f}, R_precision_real. {R_precision_real}, R_precision. {R_precision}, matching_score_real. {matching_score_real}, matching_score_pred. {matching_score_pred}" | |
print(msg) | |
# if draw: | |
writer.add_scalar('./Test/FID', fid, ep) | |
writer.add_scalar('./Test/Diversity', diversity, ep) | |
writer.add_scalar('./Test/top1', R_precision[0], ep) | |
writer.add_scalar('./Test/top2', R_precision[1], ep) | |
writer.add_scalar('./Test/top3', R_precision[2], ep) | |
writer.add_scalar('./Test/matching_score', matching_score_pred, ep) | |
if fid < best_fid: | |
msg = f"--> --> \t FID Improved from {best_fid:.5f} to {fid:.5f} !!!" | |
print(msg) | |
best_fid, best_ep = fid, ep | |
if save_ckpt: | |
save(os.path.join(out_dir, 'model', 'net_best_fid.tar'), ep) | |
if matching_score_pred < best_matching: | |
msg = f"--> --> \t matching_score Improved from {best_matching:.5f} to {matching_score_pred:.5f} !!!" | |
print(msg) | |
best_matching = matching_score_pred | |
if abs(diversity_real - diversity) < abs(diversity_real - best_div): | |
msg = f"--> --> \t Diversity Improved from {best_div:.5f} to {diversity:.5f} !!!" | |
print(msg) | |
best_div = diversity | |
if R_precision[0] > best_top1: | |
msg = f"--> --> \t Top1 Improved from {best_top1:.4f} to {R_precision[0]:.4f} !!!" | |
print(msg) | |
best_top1 = R_precision[0] | |
if R_precision[1] > best_top2: | |
msg = f"--> --> \t Top2 Improved from {best_top2:.4f} to {R_precision[1]:.4f} !!!" | |
print(msg) | |
best_top2 = R_precision[1] | |
if R_precision[2] > best_top3: | |
msg = f"--> --> \t Top3 Improved from {best_top3:.4f} to {R_precision[2]:.4f} !!!" | |
print(msg) | |
best_top3 = R_precision[2] | |
if save_anim: | |
rand_idx = torch.randint(bs, (3,)) | |
data = pred_motions[rand_idx].detach().cpu().numpy() | |
captions = [clip_text[k] for k in rand_idx] | |
lengths = m_length[rand_idx].cpu().numpy() | |
save_dir = os.path.join(out_dir, 'animation', 'E%04d' % ep) | |
os.makedirs(save_dir, exist_ok=True) | |
# print(lengths) | |
plot_func(data, save_dir, captions, lengths) | |
return best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer | |
def evaluation_res_transformer_plus_l1(val_loader, vq_model, trans, repeat_id, eval_wrapper, num_joint, | |
cond_scale=2, temperature=1, topkr=0.9, cal_l1=True): | |
trans.eval() | |
vq_model.eval() | |
motion_annotation_list = [] | |
motion_pred_list = [] | |
R_precision_real = 0 | |
R_precision = 0 | |
matching_score_real = 0 | |
matching_score_pred = 0 | |
# print(num_quantizer) | |
# assert num_quantizer >= len(time_steps) and num_quantizer >= len(cond_scales) | |
nb_sample = 0 | |
l1_dist = 0 | |
num_poses = 1 | |
# for i in range(1): | |
for batch in val_loader: | |
word_embeddings, pos_one_hots, clip_text, sent_len, pose, m_length, token = batch | |
m_length = m_length.cuda().long() | |
pose = pose.cuda().float() | |
bs, seq = pose.shape[:2] | |
# num_joints = 21 if pose.shape[-1] == 251 else 22 | |
code_indices, all_codes = vq_model.encode(pose) | |
# print(code_indices[0:2, :, 1]) | |
pred_ids = trans.generate(code_indices[..., 0], clip_text, m_length//4, topk_filter_thres=topkr, | |
temperature=temperature, cond_scale=cond_scale) | |
# pred_codes = trans(code_indices[..., 0], clip_text, m_length//4, force_mask=force_mask) | |
pred_motions = vq_model.forward_decoder(pred_ids) | |
if cal_l1: | |
bgt = val_loader.dataset.inv_transform(pose.detach().cpu().numpy()) | |
bpred = val_loader.dataset.inv_transform(pred_motions.detach().cpu().numpy()) | |
for i in range(bs): | |
gt = recover_from_ric(torch.from_numpy(bgt[i, :m_length[i]]).float(), num_joint) | |
pred = recover_from_ric(torch.from_numpy(bpred[i, :m_length[i]]).float(), num_joint) | |
# gt = motion[i, :m_length[i]] | |
# pred = pred_pose_eval[i, :m_length[i]] | |
num_pose = gt.shape[0] | |
l1_dist += F.l1_loss(gt, pred) * num_pose | |
num_poses += num_pose | |
et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_motions.clone(), | |
m_length) | |
pose = pose.cuda().float() | |
et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pose, m_length) | |
motion_annotation_list.append(em) | |
motion_pred_list.append(em_pred) | |
temp_R = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True) | |
temp_match = euclidean_distance_matrix(et.cpu().numpy(), em.cpu().numpy()).trace() | |
R_precision_real += temp_R | |
matching_score_real += temp_match | |
temp_R = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True) | |
temp_match = euclidean_distance_matrix(et_pred.cpu().numpy(), em_pred.cpu().numpy()).trace() | |
R_precision += temp_R | |
matching_score_pred += temp_match | |
nb_sample += bs | |
motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy() | |
motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy() | |
gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np) | |
mu, cov = calculate_activation_statistics(motion_pred_np) | |
diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100) | |
diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100) | |
R_precision_real = R_precision_real / nb_sample | |
R_precision = R_precision / nb_sample | |
matching_score_real = matching_score_real / nb_sample | |
matching_score_pred = matching_score_pred / nb_sample | |
l1_dist = l1_dist / num_poses | |
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov) | |
msg = "--> \t Eva. Re %d:, FID. %.4f, Diversity Real. %.4f, Diversity. %.4f, R_precision_real. (%.4f, %.4f, %.4f), R_precision. (%.4f, %.4f, %.4f), matching_real. %.4f, matching_pred. %.4f, mae. %.4f" % \ | |
(repeat_id, fid, diversity_real, diversity, R_precision_real[0], R_precision_real[1], R_precision_real[2], | |
R_precision[0], R_precision[1], R_precision[2], matching_score_real, matching_score_pred, l1_dist) | |
# logger.info(msg) | |
print(msg) | |
return fid, diversity, R_precision, matching_score_pred, l1_dist | |
def evaluation_mask_transformer_test(val_loader, vq_model, trans, repeat_id, eval_wrapper, | |
time_steps, cond_scale, temperature, topkr, gsample=True, force_mask=False, cal_mm=True): | |
trans.eval() | |
vq_model.eval() | |
motion_annotation_list = [] | |
motion_pred_list = [] | |
motion_multimodality = [] | |
R_precision_real = 0 | |
R_precision = 0 | |
matching_score_real = 0 | |
matching_score_pred = 0 | |
multimodality = 0 | |
nb_sample = 0 | |
if cal_mm: | |
num_mm_batch = 3 | |
else: | |
num_mm_batch = 0 | |
for i, batch in enumerate(val_loader): | |
# print(i) | |
word_embeddings, pos_one_hots, clip_text, sent_len, pose, m_length, token = batch | |
m_length = m_length.cuda() | |
bs, seq = pose.shape[:2] | |
# num_joints = 21 if pose.shape[-1] == 251 else 22 | |
# for i in range(mm_batch) | |
if i < num_mm_batch: | |
# (b, seqlen, c) | |
motion_multimodality_batch = [] | |
for _ in range(30): | |
mids = trans.generate(clip_text, m_length // 4, time_steps, cond_scale, | |
temperature=temperature, topk_filter_thres=topkr, | |
gsample=gsample, force_mask=force_mask) | |
# motion_codes = motion_codes.permute(0, 2, 1) | |
mids.unsqueeze_(-1) | |
pred_motions = vq_model.forward_decoder(mids) | |
et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_motions.clone(), | |
m_length) | |
# em_pred = em_pred.unsqueeze(1) #(bs, 1, d) | |
motion_multimodality_batch.append(em_pred.unsqueeze(1)) | |
motion_multimodality_batch = torch.cat(motion_multimodality_batch, dim=1) #(bs, 30, d) | |
motion_multimodality.append(motion_multimodality_batch) | |
else: | |
mids = trans.generate(clip_text, m_length // 4, time_steps, cond_scale, | |
temperature=temperature, topk_filter_thres=topkr, | |
force_mask=force_mask) | |
# motion_codes = motion_codes.permute(0, 2, 1) | |
mids.unsqueeze_(-1) | |
pred_motions = vq_model.forward_decoder(mids) | |
et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, | |
pred_motions.clone(), | |
m_length) | |
pose = pose.cuda().float() | |
et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pose, m_length) | |
motion_annotation_list.append(em) | |
motion_pred_list.append(em_pred) | |
temp_R = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True) | |
temp_match = euclidean_distance_matrix(et.cpu().numpy(), em.cpu().numpy()).trace() | |
R_precision_real += temp_R | |
matching_score_real += temp_match | |
# print(et_pred.shape, em_pred.shape) | |
temp_R = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True) | |
temp_match = euclidean_distance_matrix(et_pred.cpu().numpy(), em_pred.cpu().numpy()).trace() | |
R_precision += temp_R | |
matching_score_pred += temp_match | |
nb_sample += bs | |
motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy() | |
motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy() | |
if not force_mask and cal_mm: | |
motion_multimodality = torch.cat(motion_multimodality, dim=0).cpu().numpy() | |
multimodality = calculate_multimodality(motion_multimodality, 10) | |
gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np) | |
mu, cov = calculate_activation_statistics(motion_pred_np) | |
diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100) | |
diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100) | |
R_precision_real = R_precision_real / nb_sample | |
R_precision = R_precision / nb_sample | |
matching_score_real = matching_score_real / nb_sample | |
matching_score_pred = matching_score_pred / nb_sample | |
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov) | |
msg = f"--> \t Eva. Repeat {repeat_id} :, FID. {fid:.4f}, " \ | |
f"Diversity Real. {diversity_real:.4f}, Diversity. {diversity:.4f}, " \ | |
f"R_precision_real. {R_precision_real}, R_precision. {R_precision}, " \ | |
f"matching_score_real. {matching_score_real:.4f}, matching_score_pred. {matching_score_pred:.4f}," \ | |
f"multimodality. {multimodality:.4f}" | |
print(msg) | |
return fid, diversity, R_precision, matching_score_pred, multimodality | |
def evaluation_mask_transformer_test_plus_res(val_loader, vq_model, res_model, trans, repeat_id, eval_wrapper, | |
time_steps, cond_scale, temperature, topkr, gsample=True, force_mask=False, | |
cal_mm=True, res_cond_scale=5): | |
trans.eval() | |
vq_model.eval() | |
res_model.eval() | |
motion_annotation_list = [] | |
motion_pred_list = [] | |
motion_multimodality = [] | |
R_precision_real = 0 | |
R_precision = 0 | |
matching_score_real = 0 | |
matching_score_pred = 0 | |
multimodality = 0 | |
nb_sample = 0 | |
if force_mask or (not cal_mm): | |
num_mm_batch = 0 | |
else: | |
num_mm_batch = 3 | |
for i, batch in enumerate(val_loader): | |
word_embeddings, pos_one_hots, clip_text, sent_len, pose, m_length, token = batch | |
m_length = m_length.cuda() | |
bs, seq = pose.shape[:2] | |
# num_joints = 21 if pose.shape[-1] == 251 else 22 | |
# for i in range(mm_batch) | |
if i < num_mm_batch: | |
# (b, seqlen, c) | |
motion_multimodality_batch = [] | |
for _ in range(30): | |
mids = trans.generate(clip_text, m_length // 4, time_steps, cond_scale, | |
temperature=temperature, topk_filter_thres=topkr, | |
gsample=gsample, force_mask=force_mask) | |
# motion_codes = motion_codes.permute(0, 2, 1) | |
# mids.unsqueeze_(-1) | |
pred_ids = res_model.generate(mids, clip_text, m_length // 4, temperature=1, cond_scale=res_cond_scale) | |
# pred_codes = trans(code_indices[..., 0], clip_text, m_length//4, force_mask=force_mask) | |
# pred_ids = torch.where(pred_ids==-1, 0, pred_ids) | |
pred_motions = vq_model.forward_decoder(pred_ids) | |
# pred_motions = vq_model.decoder(codes) | |
# pred_motions = vq_model.forward_decoder(mids) | |
et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_motions.clone(), | |
m_length) | |
# em_pred = em_pred.unsqueeze(1) #(bs, 1, d) | |
motion_multimodality_batch.append(em_pred.unsqueeze(1)) | |
motion_multimodality_batch = torch.cat(motion_multimodality_batch, dim=1) #(bs, 30, d) | |
motion_multimodality.append(motion_multimodality_batch) | |
else: | |
mids = trans.generate(clip_text, m_length // 4, time_steps, cond_scale, | |
temperature=temperature, topk_filter_thres=topkr, | |
force_mask=force_mask) | |
# motion_codes = motion_codes.permute(0, 2, 1) | |
# mids.unsqueeze_(-1) | |
pred_ids = res_model.generate(mids, clip_text, m_length // 4, temperature=1, cond_scale=res_cond_scale) | |
# pred_codes = trans(code_indices[..., 0], clip_text, m_length//4, force_mask=force_mask) | |
# pred_ids = torch.where(pred_ids == -1, 0, pred_ids) | |
pred_motions = vq_model.forward_decoder(pred_ids) | |
# pred_motions = vq_model.forward_decoder(mids) | |
et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, | |
pred_motions.clone(), | |
m_length) | |
pose = pose.cuda().float() | |
et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pose, m_length) | |
motion_annotation_list.append(em) | |
motion_pred_list.append(em_pred) | |
temp_R = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True) | |
temp_match = euclidean_distance_matrix(et.cpu().numpy(), em.cpu().numpy()).trace() | |
R_precision_real += temp_R | |
matching_score_real += temp_match | |
# print(et_pred.shape, em_pred.shape) | |
temp_R = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True) | |
temp_match = euclidean_distance_matrix(et_pred.cpu().numpy(), em_pred.cpu().numpy()).trace() | |
R_precision += temp_R | |
matching_score_pred += temp_match | |
nb_sample += bs | |
motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy() | |
motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy() | |
if not force_mask and cal_mm: | |
motion_multimodality = torch.cat(motion_multimodality, dim=0).cpu().numpy() | |
multimodality = calculate_multimodality(motion_multimodality, 10) | |
gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np) | |
mu, cov = calculate_activation_statistics(motion_pred_np) | |
diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100) | |
diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100) | |
R_precision_real = R_precision_real / nb_sample | |
R_precision = R_precision / nb_sample | |
matching_score_real = matching_score_real / nb_sample | |
matching_score_pred = matching_score_pred / nb_sample | |
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov) | |
msg = f"--> \t Eva. Repeat {repeat_id} :, FID. {fid:.4f}, " \ | |
f"Diversity Real. {diversity_real:.4f}, Diversity. {diversity:.4f}, " \ | |
f"R_precision_real. {R_precision_real}, R_precision. {R_precision}, " \ | |
f"matching_score_real. {matching_score_real:.4f}, matching_score_pred. {matching_score_pred:.4f}," \ | |
f"multimodality. {multimodality:.4f}" | |
print(msg) | |
return fid, diversity, R_precision, matching_score_pred, multimodality |