Spaces:
Sleeping
Sleeping
File size: 5,476 Bytes
4a1df2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
"""
AttackResult Class
======================
"""
from abc import ABC
from langdetect import detect
from textattack.goal_function_results import GoalFunctionResult
from textattack.shared import utils
class AttackResult(ABC):
"""Result of an Attack run on a single (output, text_input) pair.
Args:
original_result (:class:`~textattack.goal_function_results.GoalFunctionResult`):
Result of the goal function applied to the original text
perturbed_result (:class:`~textattack.goal_function_results.GoalFunctionResult`):
Result of the goal function applied to the perturbed text. May or may not have been successful.
"""
def __init__(self, original_result, perturbed_result):
if original_result is None:
raise ValueError("Attack original result cannot be None")
elif not isinstance(original_result, GoalFunctionResult):
raise TypeError(f"Invalid original goal function result: {original_result}")
if perturbed_result is None:
raise ValueError("Attack perturbed result cannot be None")
elif not isinstance(perturbed_result, GoalFunctionResult):
raise TypeError(
f"Invalid perturbed goal function result: {perturbed_result}"
)
self.original_result = original_result
self.perturbed_result = perturbed_result
self.num_queries = perturbed_result.num_queries
# We don't want the AttackedText attributes sticking around clogging up
# space on our devices. Delete them here, if they're still present,
# because we won't need them anymore anyway.
self.original_result.attacked_text.free_memory()
self.perturbed_result.attacked_text.free_memory()
def original_text(self, color_method=None):
"""Returns the text portion of `self.original_result`.
Helper method.
"""
return self.original_result.attacked_text.printable_text(
key_color=("bold", "underline"), key_color_method=color_method
)
def perturbed_text(self, color_method=None):
"""Returns the text portion of `self.perturbed_result`.
Helper method.
"""
return self.perturbed_result.attacked_text.printable_text(
key_color=("bold", "underline"), key_color_method=color_method
)
def str_lines(self, color_method=None):
"""A list of the lines to be printed for this result's string
representation."""
lines = [self.goal_function_result_str(color_method=color_method)]
lines.extend(self.diff_color(color_method))
return lines
def __str__(self, color_method=None):
return "\n\n".join(self.str_lines(color_method=color_method))
def goal_function_result_str(self, color_method=None):
"""Returns a string illustrating the results of the goal function."""
orig_colored = self.original_result.get_colored_output(color_method)
pert_colored = self.perturbed_result.get_colored_output(color_method)
return orig_colored + " --> " + pert_colored
def diff_color(self, color_method=None):
"""Highlights the difference between two texts using color.
Has to account for deletions and insertions from original text to
perturbed. Relies on the index map stored in
``self.original_result.attacked_text.attack_attrs["original_index_map"]``.
"""
t1 = self.original_result.attacked_text
t2 = self.perturbed_result.attacked_text
if detect(t1.text) == "zh-cn" or detect(t1.text) == "ko":
return t1.printable_text(), t2.printable_text()
if color_method is None:
return t1.printable_text(), t2.printable_text()
color_1 = self.original_result.get_text_color_input()
color_2 = self.perturbed_result.get_text_color_perturbed()
# iterate through and count equal/unequal words
words_1_idxs = []
t2_equal_idxs = set()
original_index_map = t2.attack_attrs["original_index_map"]
for t1_idx, t2_idx in enumerate(original_index_map):
if t2_idx == -1:
# add words in t1 that are not in t2
words_1_idxs.append(t1_idx)
else:
w1 = t1.words[t1_idx]
w2 = t2.words[t2_idx]
if w1 == w2:
t2_equal_idxs.add(t2_idx)
else:
words_1_idxs.append(t1_idx)
# words to color in t2 are all the words that didn't have an equal,
# mapped word in t1
words_2_idxs = list(sorted(set(range(t2.num_words)) - t2_equal_idxs))
# make lists of colored words
words_1 = [t1.words[i] for i in words_1_idxs]
words_1 = [utils.color_text(w, color_1, color_method) for w in words_1]
words_2 = [t2.words[i] for i in words_2_idxs]
words_2 = [utils.color_text(w, color_2, color_method) for w in words_2]
t1 = self.original_result.attacked_text.replace_words_at_indices(
words_1_idxs, words_1
)
t2 = self.perturbed_result.attacked_text.replace_words_at_indices(
words_2_idxs, words_2
)
key_color = ("bold", "underline")
return (
t1.printable_text(key_color=key_color, key_color_method=color_method),
t2.printable_text(key_color=key_color, key_color_method=color_method),
)
|