Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from .. import losses | |
import ignite.distributed as idist | |
import torch_optimizer | |
from tqdm import tqdm | |
import matplotlib.pyplot as plt | |
from torch.nn import functional as F | |
import os | |
import shutil | |
from modelguidedattacks.cls_models.registry import MMPretrainVisualTransformerWrapper | |
from modelguidedattacks.data.imagenet_metadata import imgnet_idx_to_name | |
class Unguided(nn.Module): | |
def __init__(self, model: nn.Module, config, optimizer=torch.optim.AdamW, seed=0, iterations=1000, | |
loss_fn=losses.CVXProjLoss, lr=1e-3, | |
binary_search_steps=1, topk_loss_coef_upper=10., | |
topk_loss_coef_lower=0.) -> None: | |
super().__init__() | |
self.guided = False | |
self.model = model | |
self.seed = seed | |
self.iterations = iterations | |
self.loss = loss_fn() | |
self.optimizer = optimizer | |
self.lr = lr | |
self.binary_search_steps = binary_search_steps | |
self.topk_loss_coef_upper = topk_loss_coef_upper | |
self.topk_loss_coef_lower = topk_loss_coef_lower | |
self.config = config | |
def surject_perturbation(self, x, max_norm=5.): | |
x_shape = x.shape | |
x = x.flatten(1) | |
x_norm = x.norm(dim=-1) | |
x_unit = x / x_norm[:, None] | |
x_norm_outside = x_norm > max_norm | |
x_norm_outside = x_norm_outside.expand_as(x) | |
x = torch.where(x_norm_outside, x_unit*max_norm, x) | |
return x.view(x_shape) | |
def attack(self, x, attack_targets, gt_labels, topk_coefs): | |
""" | |
For a given set of topk coefficients, this function computes | |
best energy attack in the given number of iterations and configuration | |
x: [B, C, H, W] [0-1 for colors] | |
attack_targets: [B, K] (long) | |
gt_labels: [B] (long) | |
topk_coefs: [B] (floats) | |
""" | |
topk_coefs = topk_coefs.clone() | |
K = attack_targets.shape[-1] | |
x_perturbation = nn.Parameter(torch.randn(x.shape, | |
device=x.device)*2e-3) | |
optimizer = self.optimizer([x_perturbation], lr=self.lr) | |
precomputed_state = self.loss.precompute(attack_targets, gt_labels, self.config) | |
with torch.no_grad(): | |
prediction_logits_0, prediction_feats_0 \ | |
= self.model(x, return_features=True) | |
best_perturbations = torch.zeros_like(x) # [B, 3, H, W] | |
has_successful_attack = torch.zeros(x.shape[0], dtype=torch.long, device=x.device) # [B] | |
best_energy = torch.full((x.shape[0],), float('inf'), device=x.device) # [B] | |
pbar = tqdm(range(self.iterations)) | |
for i in pbar: | |
if i == self.config.opt_warmup_its: | |
# Reset optimizer state | |
optimizer = self.optimizer([x_perturbation], lr=self.lr) | |
x_perturbed = x + x_perturbation#self.surject_perturbation(x_perturbation) | |
prediction_logits, prediction_feats = self.model(x_perturbed, return_features=True) | |
pred_classes = prediction_logits.argsort(dim=-1, descending=True) # [B, C] | |
attack_successful = (pred_classes[:, :K] == attack_targets).all(dim=-1) # [B] | |
attack_energy = x_perturbation.flatten(1).norm(dim=-1) # [B] | |
attack_improved = attack_successful & (attack_energy <= best_energy) | |
best_perturbations[attack_improved] = x_perturbation[attack_improved] | |
has_successful_attack[attack_improved] = True | |
best_energy[attack_improved] = attack_energy[attack_improved] | |
loss = self.loss(logits_pred=prediction_logits, | |
feats_pred=prediction_feats, | |
feats_pred_0=prediction_feats_0, | |
attack_targets=attack_targets, | |
model=self.model, **precomputed_state) | |
loss = loss * topk_coefs | |
loss = loss.sum() | |
pbar.set_description(f"Loss: {loss.item():.3f}") | |
loss = loss + x_perturbation.flatten(1).square().sum() | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
# If we were successfull let's start taking the norm down | |
topk_coefs[attack_improved] *= 0.75 | |
# Project perturbation to be within image limits | |
with torch.no_grad(): | |
x_perturbed = x + x_perturbation | |
x_perturbed = x_perturbed.clamp_(min=0., max=1.) | |
x_perturbation.data = x_perturbed - x | |
x_perturbed_best = x + best_perturbations | |
prediction_logits, prediction_feats = self.model(x_perturbed_best, return_features=True) | |
if self.config.dump_plots: | |
if os.path.isdir(self.config.plot_out): | |
shutil.rmtree(self.config.plot_out) | |
if has_successful_attack.any(): | |
def dump_random_map(): | |
os.makedirs(self.config.plot_out, exist_ok=True) | |
# selected_idx = best_energy.argmin() | |
successful_idxs = has_successful_attack.nonzero()[:, 0] | |
if self.config.plot_idx == "find": | |
selected_idx = successful_idxs[torch.randperm(len(successful_idxs))[0]] | |
# selected_idx = best_energy.argmin() | |
else: | |
selected_idx = int(self.config.plot_idx) | |
print ("Selected idx", selected_idx) | |
top_classes = prediction_logits_0[selected_idx].argsort(dim=-1, descending=True) | |
attack_targets_selected = attack_targets[selected_idx] | |
def imgnet_names(idxs): | |
return [imgnet_idx_to_name[int(idx)].split(",")[0] for idx in idxs] | |
top_class_names = imgnet_names(top_classes)[:K] | |
attack_targets_selected_names = imgnet_names(attack_targets_selected) | |
def plot_attn_map(attn_map): | |
attn_map = attn_map[0].mean(dim=0)[1:] # [196] get class tokens | |
attn_map = attn_map.view(14, 14) | |
attn_map = F.interpolate( | |
attn_map[None, None], | |
x.shape[-2:], | |
mode="bilinear" | |
).view(x.shape[-2:]) | |
plt.imshow(attn_map.detach().cpu(), alpha=0.5) | |
plt.figure() | |
plt.imshow(x[selected_idx].permute(1,2,0).flip(dims=(-1,)).detach().cpu()) | |
plt.axis("off") | |
plt.savefig(f"{self.config.plot_out}/clean_image.png", bbox_inches="tight", pad_inches=0) | |
plt.figure() | |
plt.imshow(x_perturbed_best[selected_idx].permute(1,2,0).flip(dims=(-1,)).detach().cpu()) | |
plt.axis("off") | |
plt.savefig(f"{self.config.plot_out}/perturbed_image.png", bbox_inches="tight", pad_inches=0) | |
plt.figure() | |
plt.imshow(best_perturbations[selected_idx].mean(dim=0).abs().detach().cpu(), cmap="hot") | |
plt.colorbar() | |
plt.savefig(f"{self.config.plot_out}/perturbation.png", bbox_inches="tight") | |
if isinstance(self.model, MMPretrainVisualTransformerWrapper): | |
attn_maps_clean = self.model.get_attention_maps(x)[-1][selected_idx] | |
attn_maps_attacked = self.model.get_attention_maps(x_perturbed_best)[-1][selected_idx] | |
plt.figure() | |
plt.imshow(x[selected_idx].permute(1,2,0).flip(dims=(-1,)).detach().cpu()) | |
plot_attn_map(attn_maps_clean) | |
plt.axis("off") | |
plt.savefig(f"{self.config.plot_out}/clean_map.png", bbox_inches="tight", pad_inches=0) | |
plt.figure() | |
plt.imshow(x[selected_idx].permute(1,2,0).flip(dims=(-1,)).detach().cpu()) | |
plot_attn_map(attn_maps_attacked) | |
plt.axis("off") | |
plt.savefig(f"{self.config.plot_out}/attacked_map.png", bbox_inches="tight", pad_inches=0) | |
with open(f'{self.config.plot_out}/clean_classes_names.txt', 'w') as f: | |
f.write(", ".join(top_class_names)) | |
with open(f'{self.config.plot_out}/attack_targets_names.txt', 'w') as f: | |
f.write(", ".join(attack_targets_selected_names)) | |
with open(f'{self.config.plot_out}/clean_classes_names.txt', 'w') as f: | |
f.write(", ".join(top_class_names)) | |
with open(f'{self.config.plot_out}/selected_idx.txt', 'w') as f: | |
if isinstance(selected_idx, torch.Tensor): | |
selected_idx = selected_idx.item() | |
f.write(str(selected_idx)) | |
with open(f'{self.config.plot_out}/energy.txt', 'w') as f: | |
f.write(str(best_energy[selected_idx].item())) | |
C = prediction_logits_0.shape[-1] | |
class_idxs = torch.arange(C) + 1 | |
clean_probs = prediction_logits_0[selected_idx].detach().cpu().softmax(dim=-1) | |
attacked_probs = prediction_logits[selected_idx].detach().cpu().softmax(dim=-1) | |
def label_classes(bars): | |
adjusted_heights = {} | |
for i, cls_idx in enumerate(attack_targets_selected.tolist()): | |
bar = bars[cls_idx] | |
height = bar.get_height() | |
ann_x = bar.get_x() + bar.get_width() | |
rotation = 90 | |
font_size = 10 | |
max_neighboring_height = -1 | |
for other_cls_idx in attack_targets_selected.tolist(): | |
if abs(cls_idx - other_cls_idx) <= 40 and cls_idx != other_cls_idx: | |
if other_cls_idx in adjusted_heights and adjusted_heights[other_cls_idx] > max_neighboring_height: | |
max_neighboring_height = adjusted_heights[other_cls_idx] | |
if max_neighboring_height > 0: | |
height = max_neighboring_height + 0.05 | |
adjusted_heights[cls_idx] = height | |
plt.text(ann_x, height, f"[{i}]", rotation=rotation, | |
ha='center', va='bottom', fontsize=font_size, color='red')#.get_bbox_patch().get_height() | |
plt.figure() | |
bars_clean = plt.bar(class_idxs, clean_probs, width=4) | |
plt.ylim(0,1) | |
label_classes(bars_clean) | |
plt.savefig(f"{self.config.plot_out}/clean_probs.png", bbox_inches="tight", pad_inches=0) | |
plt.figure() | |
bars_attacked = plt.bar(class_idxs, attacked_probs, width=4) | |
plt.ylim(0,1) | |
label_classes(bars_attacked) | |
plt.savefig(f"{self.config.plot_out}/attacked_probs.png", bbox_inches="tight", pad_inches=0) | |
print ("Idx", selected_idx) | |
print (best_energy[selected_idx]) | |
print ("Finished plotting") | |
dump_random_map() | |
import sys | |
sys.exit(1) | |
print ("Dumped attention map") | |
return prediction_logits, best_perturbations, best_energy | |
def forward(self, x, attack_targets, gt_labels): | |
""" | |
This function is in charge of performing a binary search through | |
topk loss coefficients and running attacks on each. | |
""" | |
B = x.shape[0] | |
device = x.device | |
topk_coefs_lower = torch.full((B,), fill_value=self.topk_loss_coef_lower, | |
device=device, dtype=torch.float) | |
topk_coefs_upper = torch.full((B,), fill_value=self.topk_loss_coef_upper, | |
device=device, dtype=torch.float) | |
best_perturbations = torch.zeros_like(x) # [B, 3, H, W] | |
best_energy = torch.full((B,), float('inf'), device=device) # [B] | |
best_prediction_logits = None | |
for search_step_i in range(self.binary_search_steps): | |
if x.device.index is None or x.device.index == 0: | |
print ("Running binary search step", search_step_i + 1) | |
current_topk_coefs = (topk_coefs_lower + topk_coefs_upper) / 2 | |
current_logits, current_perturbations, current_energy = \ | |
self.attack(x, attack_targets, gt_labels, current_topk_coefs) | |
current_attack_suceeded = ~torch.isinf(current_energy) | |
update_mask = current_energy < best_energy | |
best_perturbations[update_mask] = current_perturbations[update_mask] | |
best_energy[update_mask] = current_energy[update_mask] | |
if best_prediction_logits is None: | |
best_prediction_logits = current_logits.clone() | |
else: | |
best_prediction_logits[update_mask] = current_logits[update_mask] | |
# If we fail to attack, we must increase our topk coef | |
topk_coefs_lower[~current_attack_suceeded] = current_topk_coefs[~current_attack_suceeded] | |
# If we succeed, we must lower to seek a more frugal attack | |
topk_coefs_upper[current_attack_suceeded] = current_topk_coefs[current_attack_suceeded] | |
idist.barrier() | |
return best_prediction_logits, best_perturbations |