from typing import List import os import torch from torch import Tensor from torchmetrics import Metric from .utils import * from bert_score import score as score_bert import spacy from mGPT.config import instantiate_from_config class M2TMetrics(Metric): def __init__(self, cfg, w_vectorizer, dataname='humanml3d', top_k=3, bleu_k=4, R_size=32, max_text_len=40, diversity_times=300, dist_sync_on_step=True, unit_length=4, **kwargs): super().__init__(dist_sync_on_step=dist_sync_on_step) self.cfg = cfg self.dataname = dataname self.w_vectorizer = w_vectorizer self.name = "matching, fid, and diversity scores" # self.text = True if cfg.TRAIN.STAGE in ["diffusion","t2m_gpt"] else False self.max_text_len = max_text_len self.top_k = top_k self.bleu_k = bleu_k self.R_size = R_size self.diversity_times = diversity_times self.unit_length = unit_length self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("count_seq", default=torch.tensor(0), dist_reduce_fx="sum") self.metrics = [] # Matching scores self.add_state("Matching_score", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("gt_Matching_score", default=torch.tensor(0.0), dist_reduce_fx="sum") self.Matching_metrics = ["Matching_score", "gt_Matching_score"] for k in range(1, top_k + 1): self.add_state( f"R_precision_top_{str(k)}", default=torch.tensor(0.0), dist_reduce_fx="sum", ) self.Matching_metrics.append(f"R_precision_top_{str(k)}") for k in range(1, top_k + 1): self.add_state( f"gt_R_precision_top_{str(k)}", default=torch.tensor(0.0), dist_reduce_fx="sum", ) self.Matching_metrics.append(f"gt_R_precision_top_{str(k)}") self.metrics.extend(self.Matching_metrics) # NLG for k in range(1, top_k + 1): self.add_state( f"Bleu_{str(k)}", default=torch.tensor(0.0), dist_reduce_fx="sum", ) self.metrics.append(f"Bleu_{str(k)}") self.add_state("ROUGE_L", default=torch.tensor(0.0), dist_reduce_fx="sum") self.metrics.append("ROUGE_L") self.add_state("CIDEr", default=torch.tensor(0.0), dist_reduce_fx="sum") self.metrics.append("CIDEr") # Chached batches self.pred_texts = [] self.gt_texts = [] self.add_state("predtext_embeddings", default=[]) self.add_state("gttext_embeddings", default=[]) self.add_state("gtmotion_embeddings", default=[]) # T2M Evaluator self._get_t2m_evaluator(cfg) self.nlp = spacy.load('en_core_web_sm') if self.cfg.model.params.task == 'm2t': from nlgmetricverse import NLGMetricverse, load_metric metrics = [ load_metric("bleu", resulting_name="bleu_1", compute_kwargs={"max_order": 1}), load_metric("bleu", resulting_name="bleu_4", compute_kwargs={"max_order": 4}), load_metric("rouge"), load_metric("cider"), ] self.nlg_evaluator = NLGMetricverse(metrics) def _get_t2m_evaluator(self, cfg): """ load T2M text encoder and motion encoder for evaluating """ # init module self.t2m_textencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_textencoder) self.t2m_moveencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_moveencoder) self.t2m_motionencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_motionencoder) # load pretrianed if self.dataname == "kit": dataname = "kit" else: dataname = "t2m" t2m_checkpoint = torch.load(os.path.join( cfg.METRIC.TM2T.t2m_path, dataname, "text_mot_match/model/finest.tar"), map_location='cpu') self.t2m_textencoder.load_state_dict(t2m_checkpoint["text_encoder"]) self.t2m_moveencoder.load_state_dict( t2m_checkpoint["movement_encoder"]) self.t2m_motionencoder.load_state_dict( t2m_checkpoint["motion_encoder"]) # freeze params self.t2m_textencoder.eval() self.t2m_moveencoder.eval() self.t2m_motionencoder.eval() for p in self.t2m_textencoder.parameters(): p.requires_grad = False for p in self.t2m_moveencoder.parameters(): p.requires_grad = False for p in self.t2m_motionencoder.parameters(): p.requires_grad = False def _process_text(self, sentence): sentence = sentence.replace('-', '') doc = self.nlp(sentence) word_list = [] pos_list = [] for token in doc: word = token.text if not word.isalpha(): continue if (token.pos_ == 'NOUN' or token.pos_ == 'VERB') and (word != 'left'): word_list.append(token.lemma_) else: word_list.append(word) pos_list.append(token.pos_) return word_list, pos_list def _get_text_embeddings(self, texts): word_embs = [] pos_ohot = [] text_lengths = [] for i, sentence in enumerate(texts): word_list, pos_list = self._process_text(sentence.strip()) t_tokens = [ '%s/%s' % (word_list[i], pos_list[i]) for i in range(len(word_list)) ] if len(t_tokens) < self.max_text_len: # pad with "unk" tokens = ['sos/OTHER'] + t_tokens + ['eos/OTHER'] sent_len = len(tokens) tokens = tokens + ['unk/OTHER' ] * (self.max_text_len + 2 - sent_len) else: # crop tokens = t_tokens[:self.max_text_len] tokens = ['sos/OTHER'] + tokens + ['eos/OTHER'] sent_len = len(tokens) pos_one_hots = [] word_embeddings = [] for token in tokens: word_emb, pos_oh = self.w_vectorizer[token] pos_one_hots.append(torch.tensor(pos_oh).float()[None]) word_embeddings.append(torch.tensor(word_emb).float()[None]) text_lengths.append(sent_len) pos_ohot.append(torch.cat(pos_one_hots, dim=0)[None]) word_embs.append(torch.cat(word_embeddings, dim=0)[None]) word_embs = torch.cat(word_embs, dim=0).to(self.Matching_score) pos_ohot = torch.cat(pos_ohot, dim=0).to(self.Matching_score) text_lengths = torch.tensor(text_lengths).to(self.Matching_score) align_idx = np.argsort(text_lengths.data.tolist())[::-1].copy() # get text embeddings text_embeddings = self.t2m_textencoder(word_embs[align_idx], pos_ohot[align_idx], text_lengths[align_idx]) original_text_embeddings = text_embeddings.clone() for idx, sort in enumerate(align_idx): original_text_embeddings[sort] = text_embeddings[idx] return original_text_embeddings @torch.no_grad() def compute(self, sanity_flag): count = self.count.item() count_seq = self.count_seq.item() # Init metrics dict metrics = {metric: getattr(self, metric) for metric in self.metrics} # Jump in sanity check stage if sanity_flag: return metrics # Cat cached batches and shuffle shuffle_idx = torch.randperm(count_seq) all_motions = torch.cat(self.gtmotion_embeddings, axis=0).cpu()[shuffle_idx, :] all_gttexts = torch.cat(self.gttext_embeddings, axis=0).cpu()[shuffle_idx, :] all_predtexts = torch.cat(self.predtext_embeddings, axis=0).cpu()[shuffle_idx, :] print("Computing metrics...") # Compute r-precision assert count_seq >= self.R_size top_k_mat = torch.zeros((self.top_k, )) for i in range(count_seq // self.R_size): # [bs=32, 1*256] group_texts = all_predtexts[i * self.R_size:(i + 1) * self.R_size] # [bs=32, 1*256] group_motions = all_motions[i * self.R_size:(i + 1) * self.R_size] # [bs=32, 32] dist_mat = euclidean_distance_matrix(group_texts, group_motions).nan_to_num() # print(dist_mat[:5]) self.Matching_score += dist_mat.trace() argsmax = torch.argsort(dist_mat, dim=1) top_k_mat += calculate_top_k(argsmax, top_k=self.top_k).sum(axis=0) R_count = count_seq // self.R_size * self.R_size metrics["Matching_score"] = self.Matching_score / R_count for k in range(self.top_k): metrics[f"R_precision_top_{str(k+1)}"] = top_k_mat[k] / R_count # Compute r-precision with gt assert count_seq >= self.R_size top_k_mat = torch.zeros((self.top_k, )) for i in range(count_seq // self.R_size): # [bs=32, 1*256] group_texts = all_gttexts[i * self.R_size:(i + 1) * self.R_size] # [bs=32, 1*256] group_motions = all_motions[i * self.R_size:(i + 1) * self.R_size] # [bs=32, 32] dist_mat = euclidean_distance_matrix(group_texts, group_motions).nan_to_num() # match score self.gt_Matching_score += dist_mat.trace() argsmax = torch.argsort(dist_mat, dim=1) top_k_mat += calculate_top_k(argsmax, top_k=self.top_k).sum(axis=0) metrics["gt_Matching_score"] = self.gt_Matching_score / R_count for k in range(self.top_k): metrics[f"gt_R_precision_top_{str(k+1)}"] = top_k_mat[k] / R_count # NLP metrics scores = self.nlg_evaluator(predictions=self.pred_texts, references=self.gt_texts) for k in range(1, self.bleu_k + 1): metrics[f"Bleu_{str(k)}"] = torch.tensor(scores[f'bleu_{str(k)}'], device=self.device) metrics["ROUGE_L"] = torch.tensor(scores["rouge"]["rougeL"], device=self.device) metrics["CIDEr"] = torch.tensor(scores["cider"]['score'],device=self.device) # Bert metrics P, R, F1 = score_bert(self.pred_texts, self.gt_texts, lang='en', rescale_with_baseline=True, idf=True, device=self.device, verbose=False) metrics["Bert_F1"] = F1.mean() # Reset self.reset() self.gt_texts = [] self.pred_texts = [] return {**metrics} @torch.no_grad() def update(self, feats_ref: Tensor, pred_texts: List[str], gt_texts: List[str], lengths: List[int], word_embs: Tensor = None, pos_ohot: Tensor = None, text_lengths: Tensor = None): self.count += sum(lengths) self.count_seq += len(lengths) # motion encoder m_lens = torch.tensor(lengths, device=feats_ref.device) align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() feats_ref = feats_ref[align_idx] m_lens = m_lens[align_idx] m_lens = torch.div(m_lens, self.cfg.DATASET.HUMANML3D.UNIT_LEN, rounding_mode="floor") ref_mov = self.t2m_moveencoder(feats_ref[..., :-4]).detach() m_lens = m_lens // self.unit_length ref_emb = self.t2m_motionencoder(ref_mov, m_lens) gtmotion_embeddings = torch.flatten(ref_emb, start_dim=1).detach() self.gtmotion_embeddings.append(gtmotion_embeddings) # text encoder gttext_emb = self.t2m_textencoder(word_embs, pos_ohot, text_lengths)[align_idx] gttext_embeddings = torch.flatten(gttext_emb, start_dim=1).detach() predtext_emb = self._get_text_embeddings(pred_texts)[align_idx] predtext_embeddings = torch.flatten(predtext_emb, start_dim=1).detach() self.gttext_embeddings.append(gttext_embeddings) self.predtext_embeddings.append(predtext_embeddings) self.pred_texts.extend(pred_texts) self.gt_texts.extend(gt_texts)