Spaces:
Sleeping
Sleeping
""" | |
Goal Function for Attempts to minimize the BLEU score | |
------------------------------------------------------- | |
""" | |
import functools | |
import nltk | |
import textattack | |
from .text_to_text_goal_function import TextToTextGoalFunction | |
class MinimizeBleu(TextToTextGoalFunction): | |
"""Attempts to minimize the BLEU score between the current output | |
translation and the reference translation. | |
BLEU score was defined in (BLEU: a Method for Automatic Evaluation of Machine Translation). | |
`ArxivURL`_ | |
.. _ArxivURL: https://www.aclweb.org/anthology/P02-1040.pdf | |
This goal function is defined in (It’s Morphin’ Time! Combating Linguistic Discrimination with Inflectional Perturbations). | |
`ArxivURL2`_ | |
.. _ArxivURL2: https://www.aclweb.org/anthology/2020.acl-main.263 | |
""" | |
EPS = 1e-10 | |
def __init__(self, *args, target_bleu=0.0, **kwargs): | |
self.target_bleu = target_bleu | |
super().__init__(*args, **kwargs) | |
def clear_cache(self): | |
if self.use_cache: | |
self._call_model_cache.clear() | |
get_bleu.cache_clear() | |
def _is_goal_complete(self, model_output, _): | |
bleu_score = 1.0 - self._get_score(model_output, _) | |
return bleu_score <= (self.target_bleu + MinimizeBleu.EPS) | |
def _get_score(self, model_output, _): | |
model_output_at = textattack.shared.AttackedText(model_output) | |
ground_truth_at = textattack.shared.AttackedText(self.ground_truth_output) | |
bleu_score = get_bleu(model_output_at, ground_truth_at) | |
return 1.0 - bleu_score | |
def extra_repr_keys(self): | |
if self.maximizable: | |
return ["maximizable"] | |
else: | |
return ["maximizable", "target_bleu"] | |
def get_bleu(a, b): | |
ref = a.words | |
hyp = b.words | |
bleu_score = nltk.translate.bleu_score.sentence_bleu([ref], hyp) | |
return bleu_score | |