File size: 1,925 Bytes
63775f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""
Goal Function for Attempts to minimize the BLEU score
-------------------------------------------------------


"""

import functools

import nltk

import textattack

from .text_to_text_goal_function import TextToTextGoalFunction


class MinimizeBleu(TextToTextGoalFunction):
    """Attempts to minimize the BLEU score between the current output
    translation and the reference translation.

        BLEU score was defined in (BLEU: a Method for Automatic Evaluation of Machine Translation).

        `ArxivURL`_

    .. _ArxivURL: https://www.aclweb.org/anthology/P02-1040.pdf

        This goal function is defined in (It’s Morphin’ Time! Combating Linguistic Discrimination with Inflectional Perturbations).

        `ArxivURL2`_

    .. _ArxivURL2: https://www.aclweb.org/anthology/2020.acl-main.263
    """

    EPS = 1e-10

    def __init__(self, *args, target_bleu=0.0, **kwargs):
        self.target_bleu = target_bleu
        super().__init__(*args, **kwargs)

    def clear_cache(self):
        if self.use_cache:
            self._call_model_cache.clear()
        get_bleu.cache_clear()

    def _is_goal_complete(self, model_output, _):
        bleu_score = 1.0 - self._get_score(model_output, _)
        return bleu_score <= (self.target_bleu + MinimizeBleu.EPS)

    def _get_score(self, model_output, _):
        model_output_at = textattack.shared.AttackedText(model_output)
        ground_truth_at = textattack.shared.AttackedText(self.ground_truth_output)
        bleu_score = get_bleu(model_output_at, ground_truth_at)
        return 1.0 - bleu_score

    def extra_repr_keys(self):
        if self.maximizable:
            return ["maximizable"]
        else:
            return ["maximizable", "target_bleu"]


@functools.lru_cache(maxsize=2**12)
def get_bleu(a, b):
    ref = a.words
    hyp = b.words
    bleu_score = nltk.translate.bleu_score.sentence_bleu([ref], hyp)
    return bleu_score