PFEemp2024's picture
solving GPU error for previous version
4a1df2e
"""
USEMetric class:
-------------------------------------------------------
Class for calculating USE similarity on AttackResults
"""
from textattack.attack_results import FailedAttackResult, SkippedAttackResult
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder
from textattack.metrics import Metric
class USEMetric(Metric):
def __init__(self, **kwargs):
self.use_obj = UniversalSentenceEncoder()
self.use_obj.model = UniversalSentenceEncoder()
self.original_candidates = []
self.successful_candidates = []
self.all_metrics = {}
def calculate(self, results):
"""Calculates average USE similarity on all successfull attacks.
Args:
results (``AttackResult`` objects):
Attack results for each instance in dataset
Example::
>> import textattack
>> import transformers
>> model = transformers.AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
>> tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
>> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper)
>> dataset = textattack.datasets.HuggingFaceDataset("glue", "sst2", split="train")
>> attack_args = textattack.AttackArgs(
num_examples=1,
log_to_csv="log.csv",
checkpoint_interval=5,
checkpoint_dir="checkpoints",
disable_stdout=True
)
>> attacker = textattack.Attacker(attack, dataset, attack_args)
>> results = attacker.attack_dataset()
>> usem = textattack.metrics.quality_metrics.USEMetric().calculate(results)
"""
self.results = results
for i, result in enumerate(self.results):
if isinstance(result, FailedAttackResult):
continue
elif isinstance(result, SkippedAttackResult):
continue
else:
self.original_candidates.append(result.original_result.attacked_text)
self.successful_candidates.append(result.perturbed_result.attacked_text)
use_scores = []
for c in range(len(self.original_candidates)):
use_scores.append(
self.use_obj._sim_score(
self.original_candidates[c], self.successful_candidates[c]
).item()
)
self.all_metrics["avg_attack_use_score"] = round(
sum(use_scores) / len(use_scores), 2
)
return self.all_metrics