Spaces:
Sleeping
Sleeping
""" | |
USEMetric class: | |
------------------------------------------------------- | |
Class for calculating USE similarity on AttackResults | |
""" | |
from textattack.attack_results import FailedAttackResult, SkippedAttackResult | |
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder | |
from textattack.metrics import Metric | |
class USEMetric(Metric): | |
def __init__(self, **kwargs): | |
self.use_obj = UniversalSentenceEncoder() | |
self.use_obj.model = UniversalSentenceEncoder() | |
self.original_candidates = [] | |
self.successful_candidates = [] | |
self.all_metrics = {} | |
def calculate(self, results): | |
"""Calculates average USE similarity on all successfull attacks. | |
Args: | |
results (``AttackResult`` objects): | |
Attack results for each instance in dataset | |
Example:: | |
>> import textattack | |
>> import transformers | |
>> model = transformers.AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") | |
>> tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") | |
>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) | |
>> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper) | |
>> dataset = textattack.datasets.HuggingFaceDataset("glue", "sst2", split="train") | |
>> attack_args = textattack.AttackArgs( | |
num_examples=1, | |
log_to_csv="log.csv", | |
checkpoint_interval=5, | |
checkpoint_dir="checkpoints", | |
disable_stdout=True | |
) | |
>> attacker = textattack.Attacker(attack, dataset, attack_args) | |
>> results = attacker.attack_dataset() | |
>> usem = textattack.metrics.quality_metrics.USEMetric().calculate(results) | |
""" | |
self.results = results | |
for i, result in enumerate(self.results): | |
if isinstance(result, FailedAttackResult): | |
continue | |
elif isinstance(result, SkippedAttackResult): | |
continue | |
else: | |
self.original_candidates.append(result.original_result.attacked_text) | |
self.successful_candidates.append(result.perturbed_result.attacked_text) | |
use_scores = [] | |
for c in range(len(self.original_candidates)): | |
use_scores.append( | |
self.use_obj._sim_score( | |
self.original_candidates[c], self.successful_candidates[c] | |
).item() | |
) | |
self.all_metrics["avg_attack_use_score"] = round( | |
sum(use_scores) / len(use_scores), 2 | |
) | |
return self.all_metrics | |