File size: 4,236 Bytes
03f6091
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# -*- coding: utf-8 -*-
r"""
Translation Ranking Base Model
==============================
    Abstract base class used to build new ranking systems inside Polos.
    This task consists of ranking "good" translations above "worse" ones.
"""
from argparse import Namespace
from typing import List

import pandas as pd
import torch
import torch.nn as nn

from polos.models.model_base import ModelBase
from polos.models.utils import average_pooling, max_pooling
from polos.modules.scalar_mix import ScalarMixWithDropout
from polos.metrics import WMTKendall


class RankingBase(ModelBase):
    """
    Ranking Model base class used to fine-tune pretrained models such as XLM-R
    to produce better sentence embeddings by optmizing Triplet Margin Loss.

    :param hparams: Namespace containing the hyperparameters.
    """

    def __init__(
        self,
        hparams: Namespace,
    ) -> None:
        super().__init__(hparams)

    def read_csv(self, path: str) -> List[dict]:
        """Reads a comma separated value file.

        :param path: path to a csv file.

        :return: List of records as dictionaries
        """
        df = pd.read_csv(path)
        df = df[["src", "ref", "pos", "neg"]]
        df["src"] = df["src"].astype(str)
        df["ref"] = df["ref"].astype(str)
        df["pos"] = df["pos"].astype(str)
        df["neg"] = df["neg"].astype(str)
        return df.to_dict("records")

    def _build_loss(self):
        """ Initializes the loss function/s. """
        self.loss = nn.TripletMarginLoss(margin=1.0, p=2)

    def _build_model(self) -> ModelBase:
        """
        Initializes the ranking model architecture.
        """
        super()._build_model()
        self.metrics = WMTKendall()
        if self.hparams.encoder_model != "LASER":
            self.layer = (
                int(self.hparams.layer)
                if self.hparams.layer != "mix"
                else self.hparams.layer
            )

            self.scalar_mix = (
                ScalarMixWithDropout(
                    mixture_size=self.encoder.num_layers,
                    dropout=self.hparams.scalar_mix_dropout,
                    do_layer_norm=True,
                )
                if self.layer == "mix" and self.hparams.pool != "default"
                else None
            )

    def get_sentence_embedding(
        self, tokens: torch.Tensor, lengths: torch.Tensor
    ) -> torch.Tensor:
        """Auxiliar function that extracts sentence embeddings for
            a single sentence.

        :param tokens: sequences [batch_size x seq_len]
        :param lengths: lengths [batch_size]

        :return: torch.Tensor [batch_size x hidden_size]
        """
        # When using just one GPU this should not change behavior
        # but when splitting batches across GPU the tokens have padding
        # from the entire original batch
        if self.trainer and self.trainer.use_dp and self.trainer.num_gpus > 1:
            tokens = tokens[:, : lengths.max()]

        encoder_out = self.encoder(tokens, lengths)
        # for LASER we dont care about the word embeddings
        if self.hparams.encoder_model == "LASER":
            pass
        elif self.scalar_mix:
            embeddings = self.scalar_mix(encoder_out["all_layers"], encoder_out["mask"])
        elif self.layer >= 0 and self.layer < self.encoder.num_layers:
            embeddings = encoder_out["all_layers"][self.layer]
        else:
            raise Exception("Invalid model layer {}.".format(self.layer))

        if self.hparams.pool == "default" or self.hparams.encoder_model == "LASER":
            sentemb = encoder_out["sentemb"]

        elif self.hparams.pool == "max":
            sentemb = max_pooling(
                tokens, embeddings, self.encoder.tokenizer.padding_index
            )

        elif self.hparams.pool == "avg":
            sentemb = average_pooling(
                tokens,
                embeddings,
                encoder_out["mask"],
                self.encoder.tokenizer.padding_index,
            )

        elif self.hparams.pool == "cls":
            sentemb = embeddings[:, 0, :]

        else:
            raise Exception("Invalid pooling technique.")

        return sentemb