import torch from os.path import join as pjoin import numpy as np from models.modules import MovementConvEncoder, TextEncoderBiGRUCo, MotionEncoderBiGRUCo from utils.word_vectorizer import POS_enumerator def build_models(opt): movement_enc = MovementConvEncoder(opt.dim_pose-4, opt.dim_movement_enc_hidden, opt.dim_movement_latent) text_enc = TextEncoderBiGRUCo(word_size=opt.dim_word, pos_size=opt.dim_pos_ohot, hidden_size=opt.dim_text_hidden, output_size=opt.dim_coemb_hidden, device=opt.device) motion_enc = MotionEncoderBiGRUCo(input_size=opt.dim_movement_latent, hidden_size=opt.dim_motion_hidden, output_size=opt.dim_coemb_hidden, device=opt.device) checkpoint = torch.load(pjoin(opt.checkpoints_dir, opt.dataset_name, 'text_mot_match', 'model', 'finest.tar'), map_location=opt.device) movement_enc.load_state_dict(checkpoint['movement_encoder']) text_enc.load_state_dict(checkpoint['text_encoder']) motion_enc.load_state_dict(checkpoint['motion_encoder']) print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch'])) return text_enc, motion_enc, movement_enc class EvaluatorModelWrapper(object): def __init__(self, opt): if opt.dataset_name == 't2m': opt.dim_pose = 263 elif opt.dataset_name == 'kit': opt.dim_pose = 251 else: raise KeyError('Dataset not Recognized!!!') opt.dim_word = 300 opt.max_motion_length = 196 opt.dim_pos_ohot = len(POS_enumerator) opt.dim_motion_hidden = 1024 opt.max_text_len = 20 opt.dim_text_hidden = 512 opt.dim_coemb_hidden = 512 # print(opt) self.text_encoder, self.motion_encoder, self.movement_encoder = build_models(opt) self.opt = opt self.device = opt.device self.text_encoder.to(opt.device) self.motion_encoder.to(opt.device) self.movement_encoder.to(opt.device) self.text_encoder.eval() self.motion_encoder.eval() self.movement_encoder.eval() # Please note that the results does not following the order of inputs def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens): with torch.no_grad(): word_embs = word_embs.detach().to(self.device).float() pos_ohot = pos_ohot.detach().to(self.device).float() motions = motions.detach().to(self.device).float() '''Movement Encoding''' movements = self.movement_encoder(motions[..., :-4]).detach() m_lens = m_lens // self.opt.unit_length motion_embedding = self.motion_encoder(movements, m_lens) '''Text Encoding''' text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens) return text_embedding, motion_embedding # Please note that the results does not following the order of inputs def get_motion_embeddings(self, motions, m_lens): with torch.no_grad(): motions = motions.detach().to(self.device).float() align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() motions = motions[align_idx] m_lens = m_lens[align_idx] '''Movement Encoding''' movements = self.movement_encoder(motions[..., :-4]).detach() m_lens = m_lens // self.opt.unit_length motion_embedding = self.motion_encoder(movements, m_lens) return motion_embedding