# -*- coding: utf-8 -*- r""" Polos Ranker Model ====================== The goal of this model is to rank good translations closer to the reference and source text and bad translations further by a small margin. https://pytorch.org/docs/stable/nn.html#tripletmarginloss """ from argparse import Namespace from typing import Dict, List, Tuple, Union import torch import torch.nn.functional as F from tqdm import tqdm from polos.models.ranking.ranking_base import RankingBase from polos.models.utils import move_to_cuda from torchnlp.utils import collate_tensors class PolosRanker(RankingBase): # extends ptl.LightningModule """ Polos Ranker class that uses a pretrained encoder to extract features from the sequences and then passes those features through a Triplet Margin Loss. :param hparams: Namespace containing the hyperparameters. """ def __init__(self, hparams: Namespace) -> None: super().__init__(hparams) def compute_metrics(self, outputs: List[Dict[str, torch.Tensor]]) -> dict: """ Computes WMT19 shared task kendall tau like metric. """ distance_pos, distance_neg = [], [] for minibatch in outputs: minibatch = minibatch["val_prediction"] src_embedding = minibatch["src_sentemb"] ref_embedding = minibatch["ref_sentemb"] pos_embedding = minibatch["pos_sentemb"] neg_embedding = minibatch["neg_sentemb"] distance_src_pos = F.pairwise_distance(pos_embedding, src_embedding) distance_ref_pos = F.pairwise_distance(pos_embedding, ref_embedding) harmonic_distance_pos = (2 * distance_src_pos * distance_ref_pos) / ( distance_src_pos + distance_ref_pos ) distance_pos.append(harmonic_distance_pos) distance_src_neg = F.pairwise_distance(neg_embedding, src_embedding) distance_ref_neg = F.pairwise_distance(neg_embedding, ref_embedding) harmonic_distance_neg = (2 * distance_src_neg * distance_ref_neg) / ( distance_src_neg + distance_ref_neg ) distance_neg.append(harmonic_distance_neg) return { "kendall": self.metrics.compute( torch.cat(distance_pos), torch.cat(distance_neg) ) } def compute_loss(self, model_out: Dict[str, torch.Tensor], *args) -> torch.Tensor: """ # forwardの結果がmodel_outに入っているのでlossを計算 Computes Triplet Margin Loss for both the reference and the source. :param model_out: model specific output with src_anchor, ref_anchor, pos and neg sentence embeddings. """ # 参考 # "src_sentemb": self.get_sentence_embedding(src_tokens, src_lengths), # "ref_sentemb": self.get_sentence_embedding(ref_tokens, ref_lengths), # "pos_sentemb": self.get_sentence_embedding(pos_tokens, pos_lengths), # "neg_sentemb": self.get_sentence_embedding(neg_tokens, neg_lengths), ref_anchor = model_out["ref_sentemb"] src_anchor = model_out["src_sentemb"] positive = model_out["pos_sentemb"] negative = model_out["neg_sentemb"] return self.loss(src_anchor, positive, negative) + self.loss( ref_anchor, positive, negative ) def predict( self, samples: Dict[str, str], cuda: bool = False, show_progress: bool = False, batch_size: int = -1, ) -> (Dict[str, Union[str, float]], List[float]): """Function that runs a model prediction, :param samples: List of dictionaries with 'mt' and 'ref' keys. :param cuda: Flag that runs inference using 1 single GPU. :param show_progress: Flag to show progress during inference of multiple examples. :para batch_size: Batch size used during inference. By default uses the same batch size used during training. :return: Dictionary with model outputs """ if self.training: self.eval() if cuda and torch.cuda.is_available(): self.to("cuda") batch_size = self.hparams.batch_size if batch_size < 1 else batch_size with torch.no_grad(): batches = [ samples[i : i + batch_size] for i in range(0, len(samples), batch_size) ] model_inputs = [] if show_progress: pbar = tqdm( total=len(batches), desc="Preparing batches....", dynamic_ncols=True ) for batch in batches: model_inputs.append(self.prepare_sample(batch, inference=True)) if show_progress: pbar.update(1) if show_progress: pbar.close() if show_progress: pbar = tqdm( total=len(batches), desc="Scoring hypothesis...", dynamic_ncols=True ) distance_weighted, distance_src, distance_ref = [], [], [] for k, model_input in enumerate(model_inputs): src_input, mt_input, ref_input, alt_input = model_input if cuda and torch.cuda.is_available(): src_embeddings = self.get_sentence_embedding( **move_to_cuda(src_input) ) mt_embeddings = self.get_sentence_embedding( **move_to_cuda(mt_input) ) ref_embeddings = self.get_sentence_embedding( **move_to_cuda(ref_input) ) ref_distances = F.pairwise_distance( mt_embeddings, ref_embeddings ).cpu() src_distances = F.pairwise_distance( mt_embeddings, src_embeddings ).cpu() # When 2 references are given the distance to the reference is the Min between # both references. if alt_input is not None: alt_embeddings = self.get_sentence_embedding( **move_to_cuda(alt_input) ) alt_distances = F.pairwise_distance( mt_embeddings, alt_embeddings ).cpu() ref_distances = torch.stack([ref_distances, alt_distances]) ref_distances = ref_distances.min(dim=0).values else: src_embeddings = self.get_sentence_embedding(**src_input) mt_embeddings = self.get_sentence_embedding(**mt_input) ref_embeddings = self.get_sentence_embedding(**ref_input) ref_distances = F.pairwise_distance(mt_embeddings, ref_embeddings) src_distances = F.pairwise_distance(mt_embeddings, src_embeddings) # Harmonic mean between the distances: distances = (2 * ref_distances * src_distances) / ( ref_distances + src_distances ) src_distances = ref_distances.numpy().tolist() ref_distances = ref_distances.numpy().tolist() distances = distances.numpy().tolist() for i in range(len(distances)): distance_weighted.append(1 / (1 + distances[i])) distance_src.append(1 / (1 + src_distances[i])) distance_ref.append(1 / (1 + ref_distances[i])) if show_progress: pbar.update(1) if show_progress: pbar.close() assert len(distance_weighted) == len(samples) scores = [] for i in range(len(samples)): scores.append(distance_weighted[i]) samples[i]["predicted_score"] = scores[-1] samples[i]["reference_distance"] = distance_ref[i] samples[i]["source_distance"] = distance_src[i] return samples, scores def prepare_sample( self, sample: List[Dict[str, Union[str, float]]], inference: bool = False ) -> Union[Tuple[Dict[str, torch.Tensor], None], List[Dict[str, torch.Tensor]]]: """ Function that prepares a sample to input the model. :param sample: list of dictionaries. :param inference: If set to to False, then the model expects a MT and reference instead of anchor, pos, and neg segments. :return: Tuple with a dictionary containing the model inputs and None OR List with source, MT and reference tokenized and vectorized. """ sample = collate_tensors(sample) if inference: src_inputs = self.encoder.prepare_sample(sample["src"]) mt_inputs = self.encoder.prepare_sample(sample["mt"]) ref_inputs = self.encoder.prepare_sample(sample["ref"]) alt_inputs = ( self.encoder.prepare_sample(sample["alt"]) if "alt" in sample else None ) return src_inputs, mt_inputs, ref_inputs, alt_inputs ref_inputs = self.encoder.prepare_sample(sample["ref"]) src_inputs = self.encoder.prepare_sample(sample["src"]) pos_inputs = self.encoder.prepare_sample(sample["pos"]) neg_inputs = self.encoder.prepare_sample(sample["neg"]) ref_inputs = {"ref_" + k: v for k, v in ref_inputs.items()} src_inputs = {"src_" + k: v for k, v in src_inputs.items()} pos_inputs = {"pos_" + k: v for k, v in pos_inputs.items()} neg_inputs = {"neg_" + k: v for k, v in neg_inputs.items()} return {**ref_inputs, **src_inputs, **pos_inputs, **neg_inputs}, torch.empty(0) def forward( self, src_tokens: torch.tensor, ref_tokens: torch.tensor, pos_tokens: torch.tensor, neg_tokens: torch.tensor, src_lengths: torch.tensor, ref_lengths: torch.tensor, pos_lengths: torch.tensor, neg_lengths: torch.tensor, **kwargs ) -> Dict[str, torch.Tensor]: """ Function that encodes the anchor, positive samples and negative samples and returns embeddings for the triplet. :param src_tokens: anchor sequences [batch_size x anchor_seq_len] :param ref_tokens: anchor sequences [batch_size x anchor_seq_len] :param pos_tokens: positive sequences [batch_size x pos_seq_len] :param neg_tokens: negative sequences [batch_size x neg_seq_len] :param src_lengths: anchor lengths [batch_size] :param ref_lengths: anchor lengths [batch_size] :param pos_lengths: positive lengths [batch_size] :param neg_lengths: negative lengths [batch_size] :return: Dictionary with model outputs to be passed to the loss function. """ return { "src_sentemb": self.get_sentence_embedding(src_tokens, src_lengths), "ref_sentemb": self.get_sentence_embedding(ref_tokens, ref_lengths), "pos_sentemb": self.get_sentence_embedding(pos_tokens, pos_lengths), "neg_sentemb": self.get_sentence_embedding(neg_tokens, neg_lengths), }