PFEemp2024's picture
add necessary file
63775f2
raw
history blame contribute delete
No virus
1.93 kB
"""
Goal Function for Attempts to minimize the BLEU score
-------------------------------------------------------
"""
import functools
import nltk
import textattack
from .text_to_text_goal_function import TextToTextGoalFunction
class MinimizeBleu(TextToTextGoalFunction):
"""Attempts to minimize the BLEU score between the current output
translation and the reference translation.
BLEU score was defined in (BLEU: a Method for Automatic Evaluation of Machine Translation).
`ArxivURL`_
.. _ArxivURL: https://www.aclweb.org/anthology/P02-1040.pdf
This goal function is defined in (It’s Morphin’ Time! Combating Linguistic Discrimination with Inflectional Perturbations).
`ArxivURL2`_
.. _ArxivURL2: https://www.aclweb.org/anthology/2020.acl-main.263
"""
EPS = 1e-10
def __init__(self, *args, target_bleu=0.0, **kwargs):
self.target_bleu = target_bleu
super().__init__(*args, **kwargs)
def clear_cache(self):
if self.use_cache:
self._call_model_cache.clear()
get_bleu.cache_clear()
def _is_goal_complete(self, model_output, _):
bleu_score = 1.0 - self._get_score(model_output, _)
return bleu_score <= (self.target_bleu + MinimizeBleu.EPS)
def _get_score(self, model_output, _):
model_output_at = textattack.shared.AttackedText(model_output)
ground_truth_at = textattack.shared.AttackedText(self.ground_truth_output)
bleu_score = get_bleu(model_output_at, ground_truth_at)
return 1.0 - bleu_score
def extra_repr_keys(self):
if self.maximizable:
return ["maximizable"]
else:
return ["maximizable", "target_bleu"]
@functools.lru_cache(maxsize=2**12)
def get_bleu(a, b):
ref = a.words
hyp = b.words
bleu_score = nltk.translate.bleu_score.sentence_bleu([ref], hyp)
return bleu_score