Spaces:
Sleeping
Sleeping
""" | |
Genetic Algorithm Word Swap | |
==================================== | |
""" | |
from abc import ABC, abstractmethod | |
import numpy as np | |
import torch | |
from textattack.goal_function_results import GoalFunctionResultStatus | |
from textattack.search_methods import PopulationBasedSearch, PopulationMember | |
from textattack.shared.validators import transformation_consists_of_word_swaps | |
class GeneticAlgorithm(PopulationBasedSearch, ABC): | |
"""Base class for attacking a model with word substiutitions using a | |
genetic algorithm. | |
Args: | |
pop_size (int): The population size. Defaults to 20. | |
max_iters (int): The maximum number of iterations to use. Defaults to 50. | |
temp (float): Temperature for softmax function used to normalize probability dist when sampling parents. | |
Higher temperature increases the sensitivity to lower probability candidates. | |
give_up_if_no_improvement (bool): If True, stop the search early if no candidate that improves the score is found. | |
post_crossover_check (bool): If True, check if child produced from crossover step passes the constraints. | |
max_crossover_retries (int): Maximum number of crossover retries if resulting child fails to pass the constraints. | |
Applied only when `post_crossover_check` is set to `True`. | |
Setting it to 0 means we immediately take one of the parents at random as the child upon failure. | |
""" | |
def __init__( | |
self, | |
pop_size=60, | |
max_iters=20, | |
temp=0.3, | |
give_up_if_no_improvement=False, | |
post_crossover_check=True, | |
max_crossover_retries=20, | |
): | |
self.max_iters = max_iters | |
self.pop_size = pop_size | |
self.temp = temp | |
self.give_up_if_no_improvement = give_up_if_no_improvement | |
self.post_crossover_check = post_crossover_check | |
self.max_crossover_retries = max_crossover_retries | |
# internal flag to indicate if search should end immediately | |
self._search_over = False | |
def _modify_population_member(self, pop_member, new_text, new_result, word_idx): | |
"""Modify `pop_member` by returning a new copy with `new_text`, | |
`new_result`, and, `attributes` altered appropriately for given | |
`word_idx`""" | |
raise NotImplementedError() | |
def _get_word_select_prob_weights(self, pop_member): | |
"""Get the attribute of `pop_member` that is used for determining | |
probability of each word being selected for perturbation.""" | |
raise NotImplementedError | |
def _perturb(self, pop_member, original_result, index=None): | |
"""Perturb `pop_member` and return it. Replaces a word at a random | |
(unless `index` is specified) in `pop_member`. | |
Args: | |
pop_member (PopulationMember): The population member being perturbed. | |
original_result (GoalFunctionResult): Result of original sample being attacked | |
index (int): Index of word to perturb. | |
Returns: | |
Perturbed `PopulationMember` | |
""" | |
num_words = pop_member.attacked_text.num_words | |
# `word_select_prob_weights` is a list of values used for sampling one word to transform | |
word_select_prob_weights = np.copy( | |
self._get_word_select_prob_weights(pop_member) | |
) | |
non_zero_indices = np.count_nonzero(word_select_prob_weights) | |
if non_zero_indices == 0: | |
return pop_member | |
iterations = 0 | |
while iterations < non_zero_indices: | |
if index: | |
idx = index | |
else: | |
w_select_probs = word_select_prob_weights / np.sum( | |
word_select_prob_weights | |
) | |
idx = np.random.choice(num_words, 1, p=w_select_probs)[0] | |
transformed_texts = self.get_transformations( | |
pop_member.attacked_text, | |
original_text=original_result.attacked_text, | |
indices_to_modify=[idx], | |
) | |
if not len(transformed_texts): | |
iterations += 1 | |
continue | |
new_results, self._search_over = self.get_goal_results(transformed_texts) | |
diff_scores = ( | |
torch.Tensor([r.score for r in new_results]) - pop_member.result.score | |
) | |
if len(diff_scores) and diff_scores.max() > 0: | |
idx_with_max_score = diff_scores.argmax() | |
pop_member = self._modify_population_member( | |
pop_member, | |
transformed_texts[idx_with_max_score], | |
new_results[idx_with_max_score], | |
idx, | |
) | |
return pop_member | |
word_select_prob_weights[idx] = 0 | |
iterations += 1 | |
if self._search_over: | |
break | |
return pop_member | |
def _crossover_operation(self, pop_member1, pop_member2): | |
"""Actual operation that takes `pop_member1` text and `pop_member2` | |
text and mixes the two to generate crossover between `pop_member1` and | |
`pop_member2`. | |
Args: | |
pop_member1 (PopulationMember): The first population member. | |
pop_member2 (PopulationMember): The second population member. | |
Returns: | |
Tuple of `AttackedText` and a dictionary of attributes. | |
""" | |
raise NotImplementedError() | |
def _post_crossover_check( | |
self, new_text, parent_text1, parent_text2, original_text | |
): | |
"""Check if `new_text` that has been produced by performing crossover | |
between `parent_text1` and `parent_text2` aligns with the constraints. | |
Args: | |
new_text (AttackedText): Text produced by crossover operation | |
parent_text1 (AttackedText): Parent text of `new_text` | |
parent_text2 (AttackedText): Second parent text of `new_text` | |
original_text (AttackedText): Original text | |
Returns: | |
`True` if `new_text` meets the constraints. If otherwise, return `False`. | |
""" | |
if "last_transformation" in new_text.attack_attrs: | |
previous_text = ( | |
parent_text1 | |
if "last_transformation" in parent_text1.attack_attrs | |
else parent_text2 | |
) | |
passed_constraints = self._check_constraints( | |
new_text, previous_text, original_text=original_text | |
) | |
return passed_constraints | |
else: | |
# `new_text` has not been actually transformed, so return True | |
return True | |
def _crossover(self, pop_member1, pop_member2, original_text): | |
"""Generates a crossover between pop_member1 and pop_member2. | |
If the child fails to satisfy the constraints, we re-try crossover for a fix number of times, | |
before taking one of the parents at random as the resulting child. | |
Args: | |
pop_member1 (PopulationMember): The first population member. | |
pop_member2 (PopulationMember): The second population member. | |
original_text (AttackedText): Original text | |
Returns: | |
A population member containing the crossover. | |
""" | |
x1_text = pop_member1.attacked_text | |
x2_text = pop_member2.attacked_text | |
num_tries = 0 | |
passed_constraints = False | |
while num_tries < self.max_crossover_retries + 1: | |
new_text, attributes = self._crossover_operation(pop_member1, pop_member2) | |
replaced_indices = new_text.attack_attrs["newly_modified_indices"] | |
new_text.attack_attrs["modified_indices"] = ( | |
x1_text.attack_attrs["modified_indices"] - replaced_indices | |
) | (x2_text.attack_attrs["modified_indices"] & replaced_indices) | |
if "last_transformation" in x1_text.attack_attrs: | |
new_text.attack_attrs["last_transformation"] = x1_text.attack_attrs[ | |
"last_transformation" | |
] | |
elif "last_transformation" in x2_text.attack_attrs: | |
new_text.attack_attrs["last_transformation"] = x2_text.attack_attrs[ | |
"last_transformation" | |
] | |
if self.post_crossover_check: | |
passed_constraints = self._post_crossover_check( | |
new_text, x1_text, x2_text, original_text | |
) | |
if not self.post_crossover_check or passed_constraints: | |
break | |
num_tries += 1 | |
if self.post_crossover_check and not passed_constraints: | |
# If we cannot find a child that passes the constraints, | |
# we just randomly pick one of the parents to be the child for the next iteration. | |
pop_mem = pop_member1 if np.random.uniform() < 0.5 else pop_member2 | |
return pop_mem | |
else: | |
new_results, self._search_over = self.get_goal_results([new_text]) | |
return PopulationMember( | |
new_text, result=new_results[0], attributes=attributes | |
) | |
def _initialize_population(self, initial_result, pop_size): | |
""" | |
Initialize a population of size `pop_size` with `initial_result` | |
Args: | |
initial_result (GoalFunctionResult): Original text | |
pop_size (int): size of population | |
Returns: | |
population as `list[PopulationMember]` | |
""" | |
raise NotImplementedError() | |
def perform_search(self, initial_result): | |
self._search_over = False | |
population = self._initialize_population(initial_result, self.pop_size) | |
pop_size = len(population) | |
current_score = initial_result.score | |
for i in range(self.max_iters): | |
population = sorted(population, key=lambda x: x.result.score, reverse=True) | |
if ( | |
self._search_over | |
or population[0].result.goal_status | |
== GoalFunctionResultStatus.SUCCEEDED | |
): | |
break | |
if population[0].result.score > current_score: | |
current_score = population[0].result.score | |
elif self.give_up_if_no_improvement: | |
break | |
pop_scores = torch.Tensor([pm.result.score for pm in population]) | |
logits = ((-pop_scores) / self.temp).exp() | |
select_probs = (logits / logits.sum()).cpu().numpy() | |
parent1_idx = np.random.choice(pop_size, size=pop_size - 1, p=select_probs) | |
parent2_idx = np.random.choice(pop_size, size=pop_size - 1, p=select_probs) | |
children = [] | |
for idx in range(pop_size - 1): | |
child = self._crossover( | |
population[parent1_idx[idx]], | |
population[parent2_idx[idx]], | |
initial_result.attacked_text, | |
) | |
if self._search_over: | |
break | |
child = self._perturb(child, initial_result) | |
children.append(child) | |
# We need two `search_over` checks b/c value might change both in | |
# `crossover` method and `perturb` method. | |
if self._search_over: | |
break | |
population = [population[0]] + children | |
return population[0].result | |
def check_transformation_compatibility(self, transformation): | |
"""The genetic algorithm is specifically designed for word | |
substitutions.""" | |
return transformation_consists_of_word_swaps(transformation) | |
def is_black_box(self): | |
return True | |
def extra_repr_keys(self): | |
return [ | |
"pop_size", | |
"max_iters", | |
"temp", | |
"give_up_if_no_improvement", | |
"post_crossover_check", | |
"max_crossover_retries", | |
] | |