File size: 2,893 Bytes
03f6091
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# -*- coding: utf-8 -*-
r"""
Text Encoder 
==============
    Base class difining a common interface between the different tokenizers used
    in Polos.
"""
from typing import List, Dict, Tuple, Iterator

import torch

from torchnlp.encoders import Encoder
from torchnlp.encoders.text import stack_and_pad_tensors
from torchnlp.encoders.text.text_encoder import TextEncoder


class TextEncoderBase(TextEncoder):
    """
    Base class for the specific tokenizers of each model.
    """

    def __init__(self) -> None:
        self.enforce_reversible = False

    @property
    def unk_index(self) -> int:
        """ Returns the index used for the unknown token. """
        return self._unk_index

    @property
    def bos_index(self) -> int:
        """ Returns the index used for the begin-of-sentence token. """
        return self._bos_index

    @property
    def eos_index(self) -> int:
        """ Returns the index used for the end-of-sentence token. """
        return self._eos_index

    @property
    def padding_index(self) -> int:
        """ Returns the index used for padding. """
        return self._pad_index

    @property
    def mask_index(self) -> int:
        """ Returns the index used for masking. """
        return self._mask_index

    @property
    def vocab(self) -> Dict[str, int]:
        """
        Returns:
            dictionary with tokens -> index
        """
        return self.stoi

    @property
    def vocab_size(self) -> int:
        """
        Returns:
            int: Number of tokens in the dictionary.
        """
        return len(self.itos)

    def tokenize(self, sequence: str) -> List[str]:
        """
        Function that tokenizes a string.
        - To be extended by subclasses.
        """
        raise NotImplementedError

    def encode(self, sequence: str) -> torch.Tensor:
        """Encodes a 'sequence'.
        :param sequence: String 'sequence' to encode.

        Returns:
            - torch.Tensor: Encoding of the 'sequence'.
        """
        sequence = super().encode(sequence)
        tokens = self.tokenize(sequence)
        vector = [self.stoi.get(token, self.unk_index) for token in tokens]
        return torch.tensor(vector)

    def batch_encode(
        self, iterator: Iterator[str], dim: int = 0, **kwargs
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        :param iterator (iterator): Batch of text to encode.
        :param dim (int, optional): Dimension along which to concatenate tensors.
        :param **kwargs: Keyword arguments passed to 'encode'.

        Returns
            torch.Tensor, torch.Tensor: Encoded and padded batch of sequences; Original lengths of
                sequences.
        """
        return stack_and_pad_tensors(
            Encoder.batch_encode(self, iterator, **kwargs),
            padding_index=self.padding_index,
            dim=dim,
        )