Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from ignite.metrics import Loss | |
from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce | |
from typing import Callable, cast, Dict, Sequence, Tuple, Union | |
def get_correct_mask(y_pred, y_attack): | |
k = y_attack.shape[-1] | |
y_pred_indices = y_pred.argsort(dim=-1, descending=True) # [N, C] | |
correct = (y_pred_indices[:, :k] == y_attack).all(dim=-1) | |
return correct | |
class EnergyLoss(Loss): | |
def __init__(self, loss_fn, reduction="mean", device = ...): | |
super().__init__(loss_fn, device=device) | |
self.reduction = reduction | |
def reset(self) -> None: | |
self._sum = torch.tensor(0.0, device=self._device) | |
self._min = torch.tensor(torch.inf, device=self._device) | |
self._max = torch.tensor(0.0, device=self._device) | |
self._num_examples = 0 | |
def update(self, output: Sequence[Union[torch.Tensor, Dict]]) -> None: | |
if len(output) == 2: | |
y_pred, y = cast(Tuple[torch.Tensor, torch.Tensor], output) | |
kwargs: Dict = {} | |
else: | |
y_pred, y, kwargs = cast(Tuple[torch.Tensor, torch.Tensor, Dict], output) | |
sample_energies = self._loss_fn(y_pred, y, **kwargs).detach() | |
n = len(sample_energies) | |
if n > 0: | |
self._sum += sample_energies.sum() | |
self._min = torch.minimum(self._min, sample_energies.min()) | |
self._max = torch.maximum(self._max, sample_energies.max()) | |
self._num_examples += n | |
def compute(self) -> float: | |
if self.reduction == "mean": | |
if self._num_examples == 0: | |
return torch.inf | |
return self._sum.item() / self._num_examples | |
elif self.reduction == "max": | |
if self._num_examples == 0: | |
return torch.nan | |
return self._max.item() | |
elif self.reduction == "min": | |
if self._num_examples == 0: | |
return torch.inf | |
return self._min.item() | |
else: | |
assert False | |
class Energy(nn.Module): | |
def __init__(self, p="2") -> None: | |
super().__init__() | |
self.p = p | |
def forward(self, y_pred, y_attack, perturbations, **kwargs): | |
correct = get_correct_mask(y_pred, y_attack) | |
# Don't want to take into account perturbations of | |
# unsuccessful attacks | |
perturbations = perturbations[correct] | |
perturbations = perturbations.flatten(1) | |
return torch.linalg.vector_norm(perturbations, dim=-1, ord=self.p) |