File size: 5,476 Bytes
4a1df2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
AttackResult Class
======================

"""

from abc import ABC

from langdetect import detect

from textattack.goal_function_results import GoalFunctionResult
from textattack.shared import utils


class AttackResult(ABC):
    """Result of an Attack run on a single (output, text_input) pair.

    Args:
        original_result (:class:`~textattack.goal_function_results.GoalFunctionResult`):
            Result of the goal function applied to the original text
        perturbed_result (:class:`~textattack.goal_function_results.GoalFunctionResult`):
            Result of the goal function applied to the perturbed text. May or may not have been successful.
    """

    def __init__(self, original_result, perturbed_result):
        if original_result is None:
            raise ValueError("Attack original result cannot be None")
        elif not isinstance(original_result, GoalFunctionResult):
            raise TypeError(f"Invalid original goal function result: {original_result}")
        if perturbed_result is None:
            raise ValueError("Attack perturbed result cannot be None")
        elif not isinstance(perturbed_result, GoalFunctionResult):
            raise TypeError(
                f"Invalid perturbed goal function result: {perturbed_result}"
            )

        self.original_result = original_result
        self.perturbed_result = perturbed_result
        self.num_queries = perturbed_result.num_queries

        # We don't want the AttackedText attributes sticking around clogging up
        # space on our devices. Delete them here, if they're still present,
        # because we won't need them anymore anyway.
        self.original_result.attacked_text.free_memory()
        self.perturbed_result.attacked_text.free_memory()

    def original_text(self, color_method=None):
        """Returns the text portion of `self.original_result`.

        Helper method.
        """
        return self.original_result.attacked_text.printable_text(
            key_color=("bold", "underline"), key_color_method=color_method
        )

    def perturbed_text(self, color_method=None):
        """Returns the text portion of `self.perturbed_result`.

        Helper method.
        """
        return self.perturbed_result.attacked_text.printable_text(
            key_color=("bold", "underline"), key_color_method=color_method
        )

    def str_lines(self, color_method=None):
        """A list of the lines to be printed for this result's string
        representation."""
        lines = [self.goal_function_result_str(color_method=color_method)]
        lines.extend(self.diff_color(color_method))
        return lines

    def __str__(self, color_method=None):
        return "\n\n".join(self.str_lines(color_method=color_method))

    def goal_function_result_str(self, color_method=None):
        """Returns a string illustrating the results of the goal function."""
        orig_colored = self.original_result.get_colored_output(color_method)
        pert_colored = self.perturbed_result.get_colored_output(color_method)
        return orig_colored + " --> " + pert_colored

    def diff_color(self, color_method=None):
        """Highlights the difference between two texts using color.

        Has to account for deletions and insertions from original text to
        perturbed. Relies on the index map stored in
        ``self.original_result.attacked_text.attack_attrs["original_index_map"]``.
        """
        t1 = self.original_result.attacked_text
        t2 = self.perturbed_result.attacked_text

        if detect(t1.text) == "zh-cn" or detect(t1.text) == "ko":
            return t1.printable_text(), t2.printable_text()

        if color_method is None:
            return t1.printable_text(), t2.printable_text()

        color_1 = self.original_result.get_text_color_input()
        color_2 = self.perturbed_result.get_text_color_perturbed()

        # iterate through and count equal/unequal words
        words_1_idxs = []
        t2_equal_idxs = set()
        original_index_map = t2.attack_attrs["original_index_map"]
        for t1_idx, t2_idx in enumerate(original_index_map):
            if t2_idx == -1:
                # add words in t1 that are not in t2
                words_1_idxs.append(t1_idx)
            else:
                w1 = t1.words[t1_idx]
                w2 = t2.words[t2_idx]
                if w1 == w2:
                    t2_equal_idxs.add(t2_idx)
                else:
                    words_1_idxs.append(t1_idx)

        # words to color in t2 are all the words that didn't have an equal,
        # mapped word in t1
        words_2_idxs = list(sorted(set(range(t2.num_words)) - t2_equal_idxs))

        # make lists of colored words
        words_1 = [t1.words[i] for i in words_1_idxs]
        words_1 = [utils.color_text(w, color_1, color_method) for w in words_1]
        words_2 = [t2.words[i] for i in words_2_idxs]
        words_2 = [utils.color_text(w, color_2, color_method) for w in words_2]

        t1 = self.original_result.attacked_text.replace_words_at_indices(
            words_1_idxs, words_1
        )
        t2 = self.perturbed_result.attacked_text.replace_words_at_indices(
            words_2_idxs, words_2
        )

        key_color = ("bold", "underline")
        return (
            t1.printable_text(key_color=key_color, key_color_method=color_method),
            t2.printable_text(key_color=key_color, key_color_method=color_method),
        )