Spaces:
Sleeping
Sleeping
File size: 5,097 Bytes
4a1df2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
"""
Word Embedding Distance
--------------------------
"""
from textattack.constraints import Constraint
from textattack.shared import AbstractWordEmbedding, WordEmbedding
from textattack.shared.validators import transformation_consists_of_word_swaps
class WordEmbeddingDistance(Constraint):
"""A constraint on word substitutions which places a maximum distance
between the embedding of the word being deleted and the word being
inserted.
Args:
embedding (obj): Wrapper for word embedding.
include_unknown_words (bool): Whether or not the constraint is fulfilled if the embedding of x or x_adv is unknown.
min_cos_sim (:obj:`float`, optional): The minimum cosine similarity between word embeddings.
max_mse_dist (:obj:`float`, optional): The maximum euclidean distance between word embeddings.
cased (bool): Whether embedding supports uppercase & lowercase (defaults to False, or just lowercase).
compare_against_original (bool): If `True`, compare new `x_adv` against the original `x`. Otherwise, compare it against the previous `x_adv`.
"""
def __init__(
self,
embedding=None,
include_unknown_words=True,
min_cos_sim=None,
max_mse_dist=None,
cased=False,
compare_against_original=True,
):
super().__init__(compare_against_original)
if embedding is None:
embedding = WordEmbedding.counterfitted_GLOVE_embedding()
self.include_unknown_words = include_unknown_words
self.cased = cased
if bool(min_cos_sim) == bool(max_mse_dist):
raise ValueError("You must choose either `min_cos_sim` or `max_mse_dist`.")
self.min_cos_sim = min_cos_sim
self.max_mse_dist = max_mse_dist
if not isinstance(embedding, AbstractWordEmbedding):
raise ValueError(
"`embedding` object must be of type `textattack.shared.AbstractWordEmbedding`."
)
self.embedding = embedding
def get_cos_sim(self, a, b):
"""Returns the cosine similarity of words with IDs a and b."""
return self.embedding.get_cos_sim(a, b)
def get_mse_dist(self, a, b):
"""Returns the MSE distance of words with IDs a and b."""
return self.embedding.get_mse_dist(a, b)
def _check_constraint(self, transformed_text, reference_text):
"""Returns true if (``transformed_text`` and ``reference_text``) are
closer than ``self.min_cos_sim`` or ``self.max_mse_dist``."""
try:
indices = transformed_text.attack_attrs["newly_modified_indices"]
except KeyError:
raise KeyError(
"Cannot apply part-of-speech constraint without `newly_modified_indices`"
)
# FIXME The index i is sometimes larger than the number of tokens - 1
if any(
i >= len(reference_text.words) or i >= len(transformed_text.words)
for i in indices
):
return False
for i in indices:
ref_word = reference_text.words[i]
transformed_word = transformed_text.words[i]
if not self.cased:
# If embedding vocabulary is all lowercase, lowercase words.
ref_word = ref_word.lower()
transformed_word = transformed_word.lower()
try:
ref_id = self.embedding.word2index(ref_word)
transformed_id = self.embedding.word2index(transformed_word)
except KeyError:
# This error is thrown if x or x_adv has no corresponding ID.
if self.include_unknown_words:
continue
return False
# Check cosine distance.
if self.min_cos_sim:
cos_sim = self.get_cos_sim(ref_id, transformed_id)
if cos_sim < self.min_cos_sim:
return False
# Check MSE distance.
if self.max_mse_dist:
mse_dist = self.get_mse_dist(ref_id, transformed_id)
if mse_dist > self.max_mse_dist:
return False
return True
def check_compatibility(self, transformation):
"""WordEmbeddingDistance requires a word being both deleted and
inserted at the same index in order to compare their embeddings,
therefore it's restricted to word swaps."""
return transformation_consists_of_word_swaps(transformation)
def extra_repr_keys(self):
"""Set the extra representation of the constraint using these keys.
To print customized extra information, you should reimplement
this method in your own constraint. Both single-line and multi-
line strings are acceptable.
"""
if self.min_cos_sim is None:
metric = "max_mse_dist"
else:
metric = "min_cos_sim"
return [
"embedding",
metric,
"cased",
"include_unknown_words",
] + super().extra_repr_keys()
|