Spaces:
Running
Running
""" | |
GoalFunctionResult class | |
==================================== | |
""" | |
from abc import ABC, abstractmethod | |
import torch | |
from textattack.shared import utils | |
class GoalFunctionResultStatus: | |
SUCCEEDED = 0 | |
SEARCHING = 1 # In process of searching for a success | |
MAXIMIZING = 2 | |
SKIPPED = 3 | |
class GoalFunctionResult(ABC): | |
"""Represents the result of a goal function evaluating a AttackedText | |
object. | |
Args: | |
attacked_text: The sequence that was evaluated. | |
output: The display-friendly output. | |
goal_status: The ``GoalFunctionResultStatus`` representing the status of the achievement of the goal. | |
score: A score representing how close the model is to achieving its goal. | |
num_queries: How many model queries have been used | |
ground_truth_output: The ground truth output | |
""" | |
def __init__( | |
self, | |
attacked_text, | |
raw_output, | |
output, | |
goal_status, | |
score, | |
num_queries, | |
ground_truth_output, | |
goal_function_result_type="", | |
): | |
self.attacked_text = attacked_text | |
self.raw_output = raw_output | |
self.output = output | |
self.score = score | |
self.goal_status = goal_status | |
self.num_queries = num_queries | |
self.ground_truth_output = ground_truth_output | |
self.goal_function_result_type = goal_function_result_type | |
if isinstance(self.raw_output, torch.Tensor): | |
self.raw_output = self.raw_output.numpy() | |
if isinstance(self.score, torch.Tensor): | |
self.score = self.score.item() | |
def __repr__(self): | |
main_str = "GoalFunctionResult( " | |
lines = [] | |
lines.append( | |
utils.add_indent( | |
f"(goal_function_result_type): {self.goal_function_result_type}", 2 | |
) | |
) | |
lines.append(utils.add_indent(f"(attacked_text): {self.attacked_text.text}", 2)) | |
lines.append( | |
utils.add_indent(f"(ground_truth_output): {self.ground_truth_output}", 2) | |
) | |
lines.append(utils.add_indent(f"(model_output): {self.output}", 2)) | |
lines.append(utils.add_indent(f"(score): {self.score}", 2)) | |
main_str += "\n " + "\n ".join(lines) + "\n" | |
main_str += ")" | |
return main_str | |
def get_text_color_input(self): | |
"""A string representing the color this result's changed portion should | |
be if it represents the original input.""" | |
raise NotImplementedError() | |
def get_text_color_perturbed(self): | |
"""A string representing the color this result's changed portion should | |
be if it represents the perturbed input.""" | |
raise NotImplementedError() | |
def get_colored_output(self, color_method=None): | |
"""Returns a string representation of this result's output, colored | |
according to `color_method`.""" | |
raise NotImplementedError() | |