File size: 5,097 Bytes
4a1df2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
Word Embedding Distance
--------------------------
"""

from textattack.constraints import Constraint
from textattack.shared import AbstractWordEmbedding, WordEmbedding
from textattack.shared.validators import transformation_consists_of_word_swaps


class WordEmbeddingDistance(Constraint):
    """A constraint on word substitutions which places a maximum distance
    between the embedding of the word being deleted and the word being
    inserted.

    Args:
        embedding (obj): Wrapper for word embedding.
        include_unknown_words (bool): Whether or not the constraint is fulfilled if the embedding of x or x_adv is unknown.
        min_cos_sim (:obj:`float`, optional): The minimum cosine similarity between word embeddings.
        max_mse_dist (:obj:`float`, optional): The maximum euclidean distance between word embeddings.
        cased (bool): Whether embedding supports uppercase & lowercase (defaults to False, or just lowercase).
        compare_against_original (bool):  If `True`, compare new `x_adv` against the original `x`. Otherwise, compare it against the previous `x_adv`.
    """

    def __init__(
        self,
        embedding=None,
        include_unknown_words=True,
        min_cos_sim=None,
        max_mse_dist=None,
        cased=False,
        compare_against_original=True,
    ):
        super().__init__(compare_against_original)
        if embedding is None:
            embedding = WordEmbedding.counterfitted_GLOVE_embedding()
        self.include_unknown_words = include_unknown_words
        self.cased = cased

        if bool(min_cos_sim) == bool(max_mse_dist):
            raise ValueError("You must choose either `min_cos_sim` or `max_mse_dist`.")
        self.min_cos_sim = min_cos_sim
        self.max_mse_dist = max_mse_dist

        if not isinstance(embedding, AbstractWordEmbedding):
            raise ValueError(
                "`embedding` object must be of type `textattack.shared.AbstractWordEmbedding`."
            )
        self.embedding = embedding

    def get_cos_sim(self, a, b):
        """Returns the cosine similarity of words with IDs a and b."""
        return self.embedding.get_cos_sim(a, b)

    def get_mse_dist(self, a, b):
        """Returns the MSE distance of words with IDs a and b."""
        return self.embedding.get_mse_dist(a, b)

    def _check_constraint(self, transformed_text, reference_text):
        """Returns true if (``transformed_text`` and ``reference_text``) are
        closer than ``self.min_cos_sim`` or ``self.max_mse_dist``."""
        try:
            indices = transformed_text.attack_attrs["newly_modified_indices"]
        except KeyError:
            raise KeyError(
                "Cannot apply part-of-speech constraint without `newly_modified_indices`"
            )

        # FIXME The index i is sometimes larger than the number of tokens - 1
        if any(
            i >= len(reference_text.words) or i >= len(transformed_text.words)
            for i in indices
        ):
            return False

        for i in indices:
            ref_word = reference_text.words[i]
            transformed_word = transformed_text.words[i]

            if not self.cased:
                # If embedding vocabulary is all lowercase, lowercase words.
                ref_word = ref_word.lower()
                transformed_word = transformed_word.lower()

            try:
                ref_id = self.embedding.word2index(ref_word)
                transformed_id = self.embedding.word2index(transformed_word)
            except KeyError:
                # This error is thrown if x or x_adv has no corresponding ID.
                if self.include_unknown_words:
                    continue
                return False

            # Check cosine distance.
            if self.min_cos_sim:
                cos_sim = self.get_cos_sim(ref_id, transformed_id)
                if cos_sim < self.min_cos_sim:
                    return False
            # Check MSE distance.
            if self.max_mse_dist:
                mse_dist = self.get_mse_dist(ref_id, transformed_id)
                if mse_dist > self.max_mse_dist:
                    return False

        return True

    def check_compatibility(self, transformation):
        """WordEmbeddingDistance requires a word being both deleted and
        inserted at the same index in order to compare their embeddings,
        therefore it's restricted to word swaps."""
        return transformation_consists_of_word_swaps(transformation)

    def extra_repr_keys(self):
        """Set the extra representation of the constraint using these keys.

        To print customized extra information, you should reimplement
        this method in your own constraint. Both single-line and multi-
        line strings are acceptable.
        """
        if self.min_cos_sim is None:
            metric = "max_mse_dist"
        else:
            metric = "min_cos_sim"
        return [
            "embedding",
            metric,
            "cased",
            "include_unknown_words",
        ] + super().extra_repr_keys()