""" AttackResult Class ====================== """ from abc import ABC from langdetect import detect from textattack.goal_function_results import GoalFunctionResult from textattack.shared import utils class AttackResult(ABC): """Result of an Attack run on a single (output, text_input) pair. Args: original_result (:class:`~textattack.goal_function_results.GoalFunctionResult`): Result of the goal function applied to the original text perturbed_result (:class:`~textattack.goal_function_results.GoalFunctionResult`): Result of the goal function applied to the perturbed text. May or may not have been successful. """ def __init__(self, original_result, perturbed_result): if original_result is None: raise ValueError("Attack original result cannot be None") elif not isinstance(original_result, GoalFunctionResult): raise TypeError(f"Invalid original goal function result: {original_result}") if perturbed_result is None: raise ValueError("Attack perturbed result cannot be None") elif not isinstance(perturbed_result, GoalFunctionResult): raise TypeError( f"Invalid perturbed goal function result: {perturbed_result}" ) self.original_result = original_result self.perturbed_result = perturbed_result self.num_queries = perturbed_result.num_queries # We don't want the AttackedText attributes sticking around clogging up # space on our devices. Delete them here, if they're still present, # because we won't need them anymore anyway. self.original_result.attacked_text.free_memory() self.perturbed_result.attacked_text.free_memory() def original_text(self, color_method=None): """Returns the text portion of `self.original_result`. Helper method. """ return self.original_result.attacked_text.printable_text( key_color=("bold", "underline"), key_color_method=color_method ) def perturbed_text(self, color_method=None): """Returns the text portion of `self.perturbed_result`. Helper method. """ return self.perturbed_result.attacked_text.printable_text( key_color=("bold", "underline"), key_color_method=color_method ) def str_lines(self, color_method=None): """A list of the lines to be printed for this result's string representation.""" lines = [self.goal_function_result_str(color_method=color_method)] lines.extend(self.diff_color(color_method)) return lines def __str__(self, color_method=None): return "\n\n".join(self.str_lines(color_method=color_method)) def goal_function_result_str(self, color_method=None): """Returns a string illustrating the results of the goal function.""" orig_colored = self.original_result.get_colored_output(color_method) pert_colored = self.perturbed_result.get_colored_output(color_method) return orig_colored + " --> " + pert_colored def diff_color(self, color_method=None): """Highlights the difference between two texts using color. Has to account for deletions and insertions from original text to perturbed. Relies on the index map stored in ``self.original_result.attacked_text.attack_attrs["original_index_map"]``. """ t1 = self.original_result.attacked_text t2 = self.perturbed_result.attacked_text if detect(t1.text) == "zh-cn" or detect(t1.text) == "ko": return t1.printable_text(), t2.printable_text() if color_method is None: return t1.printable_text(), t2.printable_text() color_1 = self.original_result.get_text_color_input() color_2 = self.perturbed_result.get_text_color_perturbed() # iterate through and count equal/unequal words words_1_idxs = [] t2_equal_idxs = set() original_index_map = t2.attack_attrs["original_index_map"] for t1_idx, t2_idx in enumerate(original_index_map): if t2_idx == -1: # add words in t1 that are not in t2 words_1_idxs.append(t1_idx) else: w1 = t1.words[t1_idx] w2 = t2.words[t2_idx] if w1 == w2: t2_equal_idxs.add(t2_idx) else: words_1_idxs.append(t1_idx) # words to color in t2 are all the words that didn't have an equal, # mapped word in t1 words_2_idxs = list(sorted(set(range(t2.num_words)) - t2_equal_idxs)) # make lists of colored words words_1 = [t1.words[i] for i in words_1_idxs] words_1 = [utils.color_text(w, color_1, color_method) for w in words_1] words_2 = [t2.words[i] for i in words_2_idxs] words_2 = [utils.color_text(w, color_2, color_method) for w in words_2] t1 = self.original_result.attacked_text.replace_words_at_indices( words_1_idxs, words_1 ) t2 = self.perturbed_result.attacked_text.replace_words_at_indices( words_2_idxs, words_2 ) key_color = ("bold", "underline") return ( t1.printable_text(key_color=key_color, key_color_method=color_method), t2.printable_text(key_color=key_color, key_color_method=color_method), )