PFEemp2024's picture
solving GPU error for previous version
4a1df2e
"""
Metrics on perturbed words
---------------------------------------------------------------------
"""
import numpy as np
from textattack.attack_results import FailedAttackResult, SkippedAttackResult
from textattack.metrics import Metric
class WordsPerturbed(Metric):
def __init__(self):
self.total_attacks = 0
self.all_num_words = None
self.perturbed_word_percentages = None
self.num_words_changed_until_success = 0
self.all_metrics = {}
def calculate(self, results):
"""Calculates all metrics related to perturbed words in an attack.
Args:
results (``AttackResult`` objects):
Attack results for each instance in dataset
"""
self.results = results
self.total_attacks = len(self.results)
self.all_num_words = np.zeros(len(self.results))
self.perturbed_word_percentages = np.zeros(len(self.results))
self.num_words_changed_until_success = np.zeros(2**16)
self.max_words_changed = 0
for i, result in enumerate(self.results):
self.all_num_words[i] = len(result.original_result.attacked_text.words)
if isinstance(result, FailedAttackResult) or isinstance(
result, SkippedAttackResult
):
continue
num_words_changed = len(
result.original_result.attacked_text.all_words_diff(
result.perturbed_result.attacked_text
)
)
self.num_words_changed_until_success[num_words_changed - 1] += 1
self.max_words_changed = max(
self.max_words_changed or num_words_changed, num_words_changed
)
if len(result.original_result.attacked_text.words) > 0:
perturbed_word_percentage = (
num_words_changed
* 100.0
/ len(result.original_result.attacked_text.words)
)
else:
perturbed_word_percentage = 0
self.perturbed_word_percentages[i] = perturbed_word_percentage
self.all_metrics["avg_word_perturbed"] = self.avg_number_word_perturbed_num()
self.all_metrics["avg_word_perturbed_perc"] = self.avg_perturbation_perc()
self.all_metrics["max_words_changed"] = self.max_words_changed
self.all_metrics[
"num_words_changed_until_success"
] = self.num_words_changed_until_success
return self.all_metrics
def avg_number_word_perturbed_num(self):
average_num_words = self.all_num_words.mean()
average_num_words = round(average_num_words, 2)
return average_num_words
def avg_perturbation_perc(self):
self.perturbed_word_percentages = self.perturbed_word_percentages[
self.perturbed_word_percentages > 0
]
average_perc_words_perturbed = self.perturbed_word_percentages.mean()
average_perc_words_perturbed = round(average_perc_words_perturbed, 2)
return average_perc_words_perturbed