thomaspaniagua commited on
Commit
71f183c
0 Parent(s):

QuadAttack release

Browse files
Files changed (35) hide show
  1. .gitignore +143 -0
  2. README.md +11 -0
  3. base_config.yaml +36 -0
  4. debug_test.py +185 -0
  5. modelguidedattacks/cls_models/__init__.py +4 -0
  6. modelguidedattacks/cls_models/accuracy.py +115 -0
  7. modelguidedattacks/cls_models/registry.py +163 -0
  8. modelguidedattacks/cls_models/setup.py +0 -0
  9. modelguidedattacks/data/__init__.py +4 -0
  10. modelguidedattacks/data/classification_wrapper.py +22 -0
  11. modelguidedattacks/data/imagenet_metadata.py +1000 -0
  12. modelguidedattacks/data/registry.py +108 -0
  13. modelguidedattacks/data/setup.py +170 -0
  14. modelguidedattacks/guides/instance_guide.py +109 -0
  15. modelguidedattacks/guides/unguided.py +314 -0
  16. modelguidedattacks/losses/__init__.py +4 -0
  17. modelguidedattacks/losses/_qp_solver_patch.py +170 -0
  18. modelguidedattacks/losses/adversarial_distillation/ad_loss.py +38 -0
  19. modelguidedattacks/losses/adversarial_distillation/adversarial_distribution.py +72 -0
  20. modelguidedattacks/losses/adversarial_distillation/glove.py +58 -0
  21. modelguidedattacks/losses/adversarial_distillation/glove_simi.py +126 -0
  22. modelguidedattacks/losses/boilerplate.py +59 -0
  23. modelguidedattacks/losses/cvx_proj.py +108 -0
  24. modelguidedattacks/losses/cw_extension.py +46 -0
  25. modelguidedattacks/losses/energy.py +80 -0
  26. modelguidedattacks/metrics/topk_accuracy.py +15 -0
  27. modelguidedattacks/models.py +32 -0
  28. modelguidedattacks/results.py +261 -0
  29. modelguidedattacks/run.py +140 -0
  30. modelguidedattacks/trainers.py +83 -0
  31. modelguidedattacks/utils.py +133 -0
  32. print_results.py +25 -0
  33. print_table.py +163 -0
  34. result_stats.py +90 -0
  35. setup.py +40 -0
.gitignore ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+
132
+ *.tar.gz
133
+ cifar-10-batches-py
134
+ logs/**/**
135
+ logs_old/**/**
136
+ config-lock.yaml
137
+
138
+ *.png
139
+ *.p
140
+ datasets/*
141
+ data/*
142
+ *.save
143
+ *.txt
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Getting Started
2
+
3
+ Install the dependencies with `pip`:
4
+ pip install -r requirements.txt
5
+
6
+ python setup.py develop
7
+
8
+ Adjust configuration in base_config.yaml
9
+
10
+ Run current configuration by
11
+ python modelguidedattacks/run.py base_config.yaml
base_config.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 10
2
+ nproc_per_node: 1
3
+ k: 5
4
+ data_path: ./
5
+ train_batch_size: 64
6
+ eval_batch_size: 64
7
+ num_workers: 4
8
+ max_epochs: 1000
9
+ train_epoch_length: null
10
+ eval_epoch_length: null
11
+ lr: 0.001
12
+ unguided_lr: 0.0022
13
+ use_amp: false
14
+ debug: false
15
+ model: resnet50
16
+ dataset: imagenet
17
+ output_dir: ./logs
18
+ log_every_iters: 1
19
+ unguided_iterations: 30
20
+ overfit: false
21
+ guide_model: "unguided"
22
+ loss: "cvxproj"
23
+ out_dir: ""
24
+ attack_sampling: "random"
25
+
26
+ cvx_proj_margin: 0.2
27
+ topk_loss_coef_upper: 20
28
+ opt_warmup_its: 5
29
+ binary_search_steps: 1
30
+
31
+ dump_plots: false
32
+ plot_idx: find # or specific batch idx
33
+ plot_out: "myplots/mymethod"
34
+
35
+ # List of models used to find correct subset
36
+ compare_models: ["resnet50", "deit_small", "vit_base", "densenet121"]
debug_test.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cvxpy as cp
2
+ from cvxpylayers.torch import CvxpyLayer
3
+
4
+ from torch.nn import functional as F
5
+ import torch
6
+ from modelguidedattacks import cls_models
7
+ import time
8
+
9
+ torch.manual_seed(0)
10
+ device = "cuda"
11
+ # model = cls_models.get_model("imagenet", "resnet18", device)
12
+
13
+ rand_feats = torch.randn(1, 512, device=device)
14
+ attack_targets = [4, 7, 5, 9, 2]
15
+
16
+ # # pred_logits = model.head(rand_feats)
17
+
18
+ # # head_W, head_bias = model.head_matrices()
19
+
20
+ (head_W, head_bias, pred_logits) = torch.load("debugsaveimagenet.save")
21
+
22
+ rand_feats, rand_logits, attack_targets = torch.load("attack_case.p", map_location=device)
23
+ reconstructed_logits = rand_feats@head_W.T + head_bias
24
+
25
+ num_feats = head_W.shape[1]
26
+ num_classes = head_W.shape[0]
27
+ x = cp.Variable(num_feats)
28
+
29
+ anchor_feats = cp.Parameter(x.shape)
30
+ A = cp.Parameter(head_W.shape)
31
+ b = cp.Parameter(head_bias.shape)
32
+
33
+ logits = A@x + b
34
+
35
+ MARGIN = 0.1
36
+
37
+ # constraints = []
38
+ # for i in range(len(attack_targets) - 1):
39
+ # constraints.append( logits[attack_targets[i]] - logits[attack_targets[i+1]] >= MARGIN)
40
+
41
+ # for i in range(num_classes):
42
+ # if i in attack_targets:
43
+ # continue
44
+
45
+ # constraints.append(logits[attack_targets[-1]] - logits[i] >= MARGIN )
46
+
47
+ # objective = cp.Minimize(0.5 * cp.pnorm(x - anchor_feats, p=2))
48
+ # problem = cp.Problem(objective, constraints)
49
+
50
+ # anchor_feats.value = rand_feats[0].cpu().numpy()
51
+ # A.value = head_W.detach().cpu().numpy()
52
+ # b.value = head_bias.detach().cpu().numpy()
53
+
54
+ # start_time = time.time()
55
+ # problem.solve()
56
+ # print ("Non vectorized sol", time.time() - start_time)
57
+
58
+ # logits_sol_torch = torch.from_numpy(logits.value)
59
+ # logits_check = logits_sol_torch.argsort(descending=True)
60
+
61
+ # feats_sol = torch.from_numpy(x.value[:, None]).float().to(rand_feats)
62
+ # sol_feat_norm = (feats_sol[:, 0].cpu() - rand_feats[0].cpu()).norm(dim=-1)
63
+ # sol_logits = head_W@feats_sol + head_bias[:, None]
64
+ # sol_sort = sol_logits.argsort(dim=0, descending=True)
65
+
66
+
67
+ # Constraint matrix
68
+ num_constraints = num_classes - 1
69
+ D = torch.zeros((num_classes), num_constraints)
70
+
71
+ non_attack_targets = list(set(range(num_classes)) - set(attack_targets))
72
+
73
+ for constraint_cursor in range(num_constraints):
74
+ if constraint_cursor < len(attack_targets) - 1:
75
+ D[attack_targets[constraint_cursor], constraint_cursor] = 1
76
+ D[attack_targets[constraint_cursor + 1], constraint_cursor] = -1
77
+ else:
78
+ non_attack_i = constraint_cursor - len(attack_targets) + 1
79
+ D[attack_targets[-1], constraint_cursor] = 1
80
+ D[non_attack_targets[non_attack_i], constraint_cursor] = -1
81
+
82
+ D = D.T
83
+ # vectorized_differences = D @ logits
84
+ # vectorized_constraint = vectorized_differences >= torch.full(vectorized_differences.shape, fill_value=MARGIN).numpy()
85
+
86
+ # Q = 2*torch.eye(x.shape[0]).numpy()
87
+ # P = -2*anchor_feats
88
+
89
+ # G = D@A
90
+ # H = MARGIN - D @ b
91
+
92
+ # G = -G
93
+ # H = -H
94
+
95
+ # vectorized_constraint = G@x <= H
96
+
97
+ # objective = cp.Minimize((1/2)*cp.quad_form(x, Q) + P.T@x)
98
+ # problem = cp.Problem(objective, [vectorized_constraint])
99
+
100
+ # anchor_feats.value = rand_feats[0].cpu().numpy()
101
+ # A.value = head_W.detach().cpu().numpy()
102
+ # b.value = head_bias.detach().cpu().numpy()
103
+
104
+ # start_time = time.time()
105
+ # problem.solve()
106
+ # print ("vectorized sol", time.time() - start_time)
107
+
108
+ # logits_sol_torch = torch.from_numpy(logits.value)
109
+ # logits_check = logits_sol_torch.argsort(descending=True)
110
+ # feats_sol = torch.from_numpy(x.value[:, None]).float().to(rand_feats)
111
+ # sol_feat_norm = (feats_sol[:, 0].cpu() - rand_feats[0].cpu()).norm(dim=-1)
112
+ # sol_logits = head_W@feats_sol + head_bias[:, None]
113
+ # sol_sort = sol_logits.argsort(dim=0, descending=True)
114
+
115
+ import qpth
116
+
117
+
118
+ B = 2
119
+ nz = num_feats
120
+ nineq = num_constraints
121
+ device = "cuda"
122
+
123
+ attack_targets = attack_targets.expand(B, -1)
124
+ K = attack_targets.shape[-1]
125
+
126
+ # Start with all classes should be less than smallest attack target
127
+ D = -torch.eye(num_classes, device=device)[None].repeat(B, 1, 1)
128
+ attack_targets_write = attack_targets[:, -1][:, None, None].expand(-1, D.shape[1], -1)
129
+ D.scatter_(dim=2, index=attack_targets_write, src=torch.ones(attack_targets_write.shape, device=device))
130
+
131
+ # Clear out the constraint row for each item in the attack targets
132
+ attack_targets_clear = attack_targets[:, :, None].expand(-1, -1, D.shape[-1])
133
+ D.scatter_(dim=1, index=attack_targets_clear, src=torch.zeros(attack_targets_clear.shape, device=device))
134
+
135
+ batch_inds = torch.arange(B, device=device)[:, None].expand(-1, K - 1)
136
+ attack_targets_pos = attack_targets[:, :-1] # [B, K-1]
137
+ attack_targets_neg = attack_targets[:, 1:] # [B, K-1]
138
+
139
+ attack_targets_neg_inds = torch.stack((
140
+ batch_inds,
141
+ attack_targets_neg,
142
+ attack_targets_neg
143
+ ), dim=0) # [3, B, K - 1]
144
+ attack_targets_neg_inds = attack_targets_neg_inds.view(3, -1)
145
+
146
+ D[attack_targets_neg_inds[0], attack_targets_neg_inds[1], attack_targets_neg_inds[2]] = -1
147
+
148
+ attack_targets_pos_inds = torch.stack((
149
+ batch_inds,
150
+ attack_targets_neg,
151
+ attack_targets_pos
152
+ ), dim=0) # [3, B, K - 1]
153
+
154
+ D[attack_targets_pos_inds[0], attack_targets_pos_inds[1], attack_targets_pos_inds[2]] = 1
155
+
156
+ A = head_W.detach().to(device)
157
+ b = head_bias.detach().to(device)
158
+ D = D.to(device)
159
+
160
+ #rand_feats: [B, num_features]
161
+ Q = 2*torch.eye(nz, device=device)[None].expand(B, -1, -1)
162
+ P = -2*rand_feats.to(device).expand(B, -1)
163
+
164
+ # G = torch.randn(B, nineq, nz, device=device)
165
+ G = -D@A
166
+
167
+ # h = torch.randn(B, nineq)
168
+ H = -(MARGIN - D @ b)
169
+
170
+ # Constraints are indexed by smaller logit
171
+ # First attack target isn't smaller than any logit, so its
172
+ # constraint index is redundant, but we keep it for easier parallelization
173
+ # Make this constraint all 0s
174
+ zero_inds = attack_targets[:, 0:1] # [B, 1]
175
+ H.scatter_(dim=1, index=zero_inds, src=torch.zeros(zero_inds.shape, device=device))
176
+
177
+ e = torch.empty(0, device=device)
178
+
179
+ Q_t, P_t, G_t, H_t = torch.load("qpinputs.p", map_location=device)
180
+
181
+ z_sol = qpth.qp.QPFunction(verbose=True, check_Q_spd=False)(Q, P, G, H, e, e).T
182
+
183
+ logits = A@z_sol + b[:, None]
184
+
185
+ x = 5
modelguidedattacks/cls_models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .registry import register_default_models, get_model
2
+ from .accuracy import get_correct_subset, get_correct_subset_for_models
3
+
4
+ register_default_models()
modelguidedattacks/cls_models/accuracy.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ import torch
5
+ from torch.utils.data import DataLoader
6
+ from tqdm import tqdm
7
+
8
+ from modelguidedattacks.data import get_dataset
9
+ from . import get_model
10
+
11
+ from .registry import ClsModel
12
+ from typing import Optional, List
13
+
14
+ DATASET_METADATA_DIR = "./dataset_metadata"
15
+
16
+ def correct_subset_cache_path(dataset_name: str, model_name: str, train: bool):
17
+ filename_train_val = "train" if train else "val"
18
+ subset_cache_filename = f"{dataset_name}_{model_name}_{filename_train_val}.p"
19
+ subset_cache_path = os.path.join(DATASET_METADATA_DIR, subset_cache_filename)
20
+
21
+ return subset_cache_path
22
+
23
+ @torch.no_grad()
24
+ def get_correct_subset(model: Optional[ClsModel]=None, dataset_name: Optional[str]=None,
25
+ model_name: Optional[str]=None, train=True, batch_size=256,
26
+ force_cache=False, device="cuda"):
27
+ """
28
+ model: Model to evaluate
29
+ dataset_name: Name of dataset (not needed if model is provided)
30
+ model_name: Name of model (not needed if model is provided)
31
+ train: Use training dataset
32
+ batch_size: Batch size to use while evaluating
33
+ force_cache: Only read from cache and fail if not available
34
+
35
+ Returns indices in dataset of correctly classified items
36
+ """
37
+
38
+ if model is not None:
39
+ assert dataset_name is None
40
+ assert model_name is None
41
+
42
+ if dataset_name is not None or model_name is not None:
43
+ assert dataset_name is not None
44
+ assert model_name is not None
45
+ assert model is None
46
+
47
+ if dataset_name is None:
48
+ dataset_name = model.dataset_name
49
+
50
+ if model_name is None:
51
+ model_name = model.model_name
52
+
53
+ filename_train_val = "train" if train else "val"
54
+ subset_cache_filename = f"{dataset_name}_{model_name}_{filename_train_val}.p"
55
+ subset_cache_path = os.path.join(DATASET_METADATA_DIR, subset_cache_filename)
56
+
57
+ os.makedirs(DATASET_METADATA_DIR, exist_ok=True)
58
+
59
+ if os.path.exists(subset_cache_path):
60
+ correct_subset = torch.load(subset_cache_path)
61
+ return correct_subset
62
+
63
+ if force_cache:
64
+ raise Exception("Cache not found and requested for cached correct subset.")
65
+
66
+ logging.info(f"No cache found. Computing correct subset for {dataset_name}-{model_name} Train: {train}")
67
+
68
+ device = device if model is None else model.device
69
+
70
+ if model is None:
71
+ model = get_model(dataset_name, model_name, device)
72
+
73
+ model.eval()
74
+
75
+ train_dataset, val_dataset = get_dataset(dataset_name)
76
+
77
+ dataset = train_dataset
78
+
79
+ if not train:
80
+ dataset = val_dataset
81
+
82
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
83
+
84
+ correct_indices = []
85
+
86
+ for batch_i, (batch_imgs, batch_gt_class) in tqdm(enumerate(dataloader), total=len(dataloader)):
87
+ if torch.device(model.device).type.startswith("cuda"):
88
+ torch.cuda.synchronize(model.device)
89
+
90
+ data_start_index = batch_i * batch_size
91
+ predictions = model(batch_imgs.to(model.device)) # [B, C]
92
+ prediction_class_idx = predictions.argmax(dim=-1) # [B] (long)
93
+ prediction_correct = prediction_class_idx == batch_gt_class.to(model.device)
94
+ batch_correct_idxs = data_start_index + prediction_correct.nonzero()[:, 0]
95
+ batch_correct_idxs = batch_correct_idxs.tolist()
96
+
97
+ correct_indices.extend(batch_correct_idxs)
98
+
99
+ correct_subset = set(correct_indices)
100
+ torch.save(correct_subset, subset_cache_path)
101
+
102
+ return set(correct_indices)
103
+
104
+ def get_correct_subset_for_models(model_names: List[str], dataset_name, device, train):
105
+ correct_intersection = None
106
+ for model_name in model_names:
107
+ model_correct_subset = get_correct_subset(model_name=model_name, dataset_name=dataset_name,
108
+ device=device, train=train)
109
+
110
+ if correct_intersection is None:
111
+ correct_intersection = model_correct_subset
112
+ else:
113
+ correct_intersection = model_correct_subset.intersection(correct_intersection)
114
+
115
+ return list(correct_intersection)
modelguidedattacks/cls_models/registry.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mmpretrain
2
+ import torch
3
+ from torch import nn
4
+ from collections.abc import Iterable
5
+ from mmpretrain.models.utils.attention import MultiheadAttention
6
+ # This holds model instantiation functions by (dataset_name, model_name) tuple keys
7
+ MODEL_REGISTRY = {}
8
+
9
+ class ClsModel(nn.Module):
10
+ dataset_name: str
11
+ model_name: str
12
+
13
+ def __init__(self, dataset_name: str, model_name: str, device: str) -> None:
14
+ super().__init__()
15
+ self.dataset_name = dataset_name
16
+ self.model_name = model_name
17
+ self.device = device
18
+
19
+ def head_features(self):
20
+ pass
21
+
22
+ def num_classes(self):
23
+ pass
24
+
25
+ def forward(self, x):
26
+ """
27
+ x: [B, 3 (RGB), H, W] image (float) [0,1]
28
+
29
+ returns: [B, C] class logits
30
+ """
31
+
32
+ raise NotImplementedError("Forward not implemented for base class")
33
+
34
+ class MMPretrainModelWrapper(ClsModel):
35
+ """
36
+ Calls data preprocessing for model before entering forward
37
+ """
38
+ def __init__(self, model: nn.Module, dataset_name: str, model_name: str, device: str) -> None:
39
+ super().__init__(dataset_name, model_name, device)
40
+ self.model = model
41
+
42
+ @property
43
+ def final_linear_layer(self):
44
+ return self.model.head.fc
45
+
46
+ def head_features(self):
47
+ return self.final_linear_layer.in_features
48
+
49
+ def num_classes(self):
50
+ return self.final_linear_layer.out_features
51
+
52
+ def head(self, feats):
53
+ return self.model.head((feats,))
54
+
55
+ def head_matrices(self):
56
+ return self.final_linear_layer.weight, self.final_linear_layer.bias
57
+
58
+ def forward(self, x, return_features=False):
59
+ # Data preprocessor expects 0-255 range, but we don't want to cast to proper
60
+ # uint8 because we want to maintain differentiability
61
+ x = x * 255.
62
+ x = self.model.data_preprocessor({"inputs": x})["inputs"]
63
+
64
+ if return_features:
65
+ feats = self.model.extract_feat(x)
66
+
67
+ preds = self.model.head(feats)
68
+ if isinstance(feats, Iterable):
69
+ feats = feats[-1]
70
+
71
+ return preds, feats
72
+ else:
73
+ return self.model(x)
74
+
75
+ class MMPretrainVisualTransformerWrapper(MMPretrainModelWrapper):
76
+ def __init__(self, model, dataset_name: str, model_name: str, device: str) -> None:
77
+ super().__init__(model, dataset_name, model_name, device)
78
+
79
+ attn_layers = []
80
+
81
+ def find_mha(m: nn.Module):
82
+ if isinstance(m, MultiheadAttention):
83
+ attn_layers.append(m)
84
+
85
+ model.apply(find_mha)
86
+
87
+ self.attn_layers = attn_layers
88
+
89
+ @property
90
+ def final_linear_layer(self):
91
+ return self.model.head.layers.head
92
+
93
+ def get_attention_maps(self, x):
94
+ clean_forwards = []
95
+
96
+ attention_maps = []
97
+
98
+ for attn_layer in self.attn_layers:
99
+ clean_forward = attn_layer.forward
100
+ clean_forwards.append(clean_forward)
101
+
102
+ def scaled_dot_prod_attn(query,
103
+ key,
104
+ value,
105
+ attn_mask=None,
106
+ dropout_p=0.,
107
+ scale=None,
108
+ is_causal=False):
109
+ scale = scale or query.size(-1)**0.5
110
+ if is_causal and attn_mask is not None:
111
+ attn_mask = torch.ones(
112
+ query.size(-2), key.size(-2), dtype=torch.bool).tril(diagonal=0)
113
+ if attn_mask is not None and attn_mask.dtype == torch.bool:
114
+ attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf'))
115
+
116
+ attn_weight = query @ key.transpose(-2, -1) / scale
117
+ if attn_mask is not None:
118
+ attn_weight += attn_mask
119
+ attn_weight = torch.softmax(attn_weight, dim=-1)
120
+
121
+ attention_maps.append(attn_weight)
122
+
123
+ attn_weight = torch.dropout(attn_weight, dropout_p, True)
124
+ return attn_weight @ value
125
+
126
+ attn_layer.scaled_dot_product_attention = scaled_dot_prod_attn
127
+
128
+ ret_val = super().forward(x, False)
129
+
130
+ for attn_layer, clean_forward in zip(self.attn_layers, clean_forwards):
131
+ attn_layer.forward = clean_forward
132
+
133
+ return attention_maps
134
+
135
+ def register_mmcls_model(config_name, dataset_name, model_name,
136
+ wrapper_class=MMPretrainModelWrapper):
137
+ def instantiate_model(device):
138
+ model = mmpretrain.get_model(config_name, pretrained=True, device=device)
139
+ wrapper = wrapper_class(model, dataset_name, model_name, device)
140
+ return wrapper
141
+
142
+ MODEL_REGISTRY[(dataset_name, model_name)] = instantiate_model
143
+
144
+ def register_default_models():
145
+ register_mmcls_model("resnet18_8xb16_cifar10", "cifar10", "resnet18")
146
+ register_mmcls_model("resnet34_8xb16_cifar10", "cifar10", "resnet34")
147
+ register_mmcls_model("resnet18_8xb32_in1k", "imagenet", "resnet18")
148
+ register_mmcls_model("resnet50_8xb16_cifar100", "cifar100", "resnet50")
149
+ register_mmcls_model("resnet50_8xb32_in1k", "imagenet", "resnet50")
150
+ register_mmcls_model("densenet121_3rdparty_in1k", "imagenet", "densenet121")
151
+
152
+ register_mmcls_model("deit-small_4xb256_in1k", "imagenet", "deit_small",
153
+ wrapper_class=MMPretrainVisualTransformerWrapper)
154
+
155
+ register_mmcls_model("vit-base-p16_32xb128-mae_in1k", "imagenet", "vit_base",
156
+ wrapper_class=MMPretrainVisualTransformerWrapper)
157
+
158
+ def get_model(dataset_name, model_name, device):
159
+ """
160
+ Returns instance of model pretrained with specified dataset
161
+ """
162
+
163
+ return MODEL_REGISTRY[(dataset_name, model_name)](device).eval()
modelguidedattacks/cls_models/setup.py ADDED
File without changes
modelguidedattacks/data/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .registry import register_default_datasets
2
+ from .registry import get_dataset
3
+
4
+ register_default_datasets()
modelguidedattacks/data/classification_wrapper.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils.data as data
3
+
4
+ class TopKClassificationWrapper(data.Dataset):
5
+ def __init__(self, dataset: data.Dataset, attack_labels, seed=0, k=1) -> None:
6
+ super().__init__()
7
+ self.generator = torch.Generator("cpu")
8
+ self.generator.manual_seed(seed)
9
+
10
+ # Pregenerate attack labels
11
+ num_classes = len(dataset.classes)
12
+
13
+ self.src_dataset = dataset
14
+ self.attack_labels = attack_labels
15
+
16
+ def __getitem__(self, index):
17
+ image, label = self.src_dataset[index]
18
+
19
+ return image, label, self.attack_labels[index], index
20
+
21
+ def __len__(self):
22
+ return len(self.src_dataset)
modelguidedattacks/data/imagenet_metadata.py ADDED
@@ -0,0 +1,1000 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ imgnet_idx_to_name = {0: 'tench, Tinca tinca',
2
+ 1: 'goldfish, Carassius auratus',
3
+ 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
4
+ 3: 'tiger shark, Galeocerdo cuvieri',
5
+ 4: 'hammerhead, hammerhead shark',
6
+ 5: 'electric ray, crampfish, numbfish, torpedo',
7
+ 6: 'stingray',
8
+ 7: 'cock',
9
+ 8: 'hen',
10
+ 9: 'ostrich, Struthio camelus',
11
+ 10: 'brambling, Fringilla montifringilla',
12
+ 11: 'goldfinch, Carduelis carduelis',
13
+ 12: 'house finch, linnet, Carpodacus mexicanus',
14
+ 13: 'junco, snowbird',
15
+ 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
16
+ 15: 'robin, American robin, Turdus migratorius',
17
+ 16: 'bulbul',
18
+ 17: 'jay',
19
+ 18: 'magpie',
20
+ 19: 'chickadee',
21
+ 20: 'water ouzel, dipper',
22
+ 21: 'kite',
23
+ 22: 'bald eagle, American eagle, Haliaeetus leucocephalus',
24
+ 23: 'vulture',
25
+ 24: 'great grey owl, great gray owl, Strix nebulosa',
26
+ 25: 'European fire salamander, Salamandra salamandra',
27
+ 26: 'common newt, Triturus vulgaris',
28
+ 27: 'eft',
29
+ 28: 'spotted salamander, Ambystoma maculatum',
30
+ 29: 'axolotl, mud puppy, Ambystoma mexicanum',
31
+ 30: 'bullfrog, Rana catesbeiana',
32
+ 31: 'tree frog, tree-frog',
33
+ 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',
34
+ 33: 'loggerhead, loggerhead turtle, Caretta caretta',
35
+ 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea',
36
+ 35: 'mud turtle',
37
+ 36: 'terrapin',
38
+ 37: 'box turtle, box tortoise',
39
+ 38: 'banded gecko',
40
+ 39: 'common iguana, iguana, Iguana iguana',
41
+ 40: 'American chameleon, anole, Anolis carolinensis',
42
+ 41: 'whiptail, whiptail lizard',
43
+ 42: 'agama',
44
+ 43: 'frilled lizard, Chlamydosaurus kingi',
45
+ 44: 'alligator lizard',
46
+ 45: 'Gila monster, Heloderma suspectum',
47
+ 46: 'green lizard, Lacerta viridis',
48
+ 47: 'African chameleon, Chamaeleo chamaeleon',
49
+ 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis',
50
+ 49: 'African crocodile, Nile crocodile, Crocodylus niloticus',
51
+ 50: 'American alligator, Alligator mississipiensis',
52
+ 51: 'triceratops',
53
+ 52: 'thunder snake, worm snake, Carphophis amoenus',
54
+ 53: 'ringneck snake, ring-necked snake, ring snake',
55
+ 54: 'hognose snake, puff adder, sand viper',
56
+ 55: 'green snake, grass snake',
57
+ 56: 'king snake, kingsnake',
58
+ 57: 'garter snake, grass snake',
59
+ 58: 'water snake',
60
+ 59: 'vine snake',
61
+ 60: 'night snake, Hypsiglena torquata',
62
+ 61: 'boa constrictor, Constrictor constrictor',
63
+ 62: 'rock python, rock snake, Python sebae',
64
+ 63: 'Indian cobra, Naja naja',
65
+ 64: 'green mamba',
66
+ 65: 'sea snake',
67
+ 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus',
68
+ 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus',
69
+ 68: 'sidewinder, horned rattlesnake, Crotalus cerastes',
70
+ 69: 'trilobite',
71
+ 70: 'harvestman, daddy longlegs, Phalangium opilio',
72
+ 71: 'scorpion',
73
+ 72: 'black and gold garden spider, Argiope aurantia',
74
+ 73: 'barn spider, Araneus cavaticus',
75
+ 74: 'garden spider, Aranea diademata',
76
+ 75: 'black widow, Latrodectus mactans',
77
+ 76: 'tarantula',
78
+ 77: 'wolf spider, hunting spider',
79
+ 78: 'tick',
80
+ 79: 'centipede',
81
+ 80: 'black grouse',
82
+ 81: 'ptarmigan',
83
+ 82: 'ruffed grouse, partridge, Bonasa umbellus',
84
+ 83: 'prairie chicken, prairie grouse, prairie fowl',
85
+ 84: 'peacock',
86
+ 85: 'quail',
87
+ 86: 'partridge',
88
+ 87: 'African grey, African gray, Psittacus erithacus',
89
+ 88: 'macaw',
90
+ 89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
91
+ 90: 'lorikeet',
92
+ 91: 'coucal',
93
+ 92: 'bee eater',
94
+ 93: 'hornbill',
95
+ 94: 'hummingbird',
96
+ 95: 'jacamar',
97
+ 96: 'toucan',
98
+ 97: 'drake',
99
+ 98: 'red-breasted merganser, Mergus serrator',
100
+ 99: 'goose',
101
+ 100: 'black swan, Cygnus atratus',
102
+ 101: 'tusker',
103
+ 102: 'echidna, spiny anteater, anteater',
104
+ 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus',
105
+ 104: 'wallaby, brush kangaroo',
106
+ 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus',
107
+ 106: 'wombat',
108
+ 107: 'jellyfish',
109
+ 108: 'sea anemone, anemone',
110
+ 109: 'brain coral',
111
+ 110: 'flatworm, platyhelminth',
112
+ 111: 'nematode, nematode worm, roundworm',
113
+ 112: 'conch',
114
+ 113: 'snail',
115
+ 114: 'slug',
116
+ 115: 'sea slug, nudibranch',
117
+ 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore',
118
+ 117: 'chambered nautilus, pearly nautilus, nautilus',
119
+ 118: 'Dungeness crab, Cancer magister',
120
+ 119: 'rock crab, Cancer irroratus',
121
+ 120: 'fiddler crab',
122
+ 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica',
123
+ 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus',
124
+ 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish',
125
+ 124: 'crayfish, crawfish, crawdad, crawdaddy',
126
+ 125: 'hermit crab',
127
+ 126: 'isopod',
128
+ 127: 'white stork, Ciconia ciconia',
129
+ 128: 'black stork, Ciconia nigra',
130
+ 129: 'spoonbill',
131
+ 130: 'flamingo',
132
+ 131: 'little blue heron, Egretta caerulea',
133
+ 132: 'American egret, great white heron, Egretta albus',
134
+ 133: 'bittern',
135
+ 134: 'crane',
136
+ 135: 'limpkin, Aramus pictus',
137
+ 136: 'European gallinule, Porphyrio porphyrio',
138
+ 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana',
139
+ 138: 'bustard',
140
+ 139: 'ruddy turnstone, Arenaria interpres',
141
+ 140: 'red-backed sandpiper, dunlin, Erolia alpina',
142
+ 141: 'redshank, Tringa totanus',
143
+ 142: 'dowitcher',
144
+ 143: 'oystercatcher, oyster catcher',
145
+ 144: 'pelican',
146
+ 145: 'king penguin, Aptenodytes patagonica',
147
+ 146: 'albatross, mollymawk',
148
+ 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus',
149
+ 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca',
150
+ 149: 'dugong, Dugong dugon',
151
+ 150: 'sea lion',
152
+ 151: 'Chihuahua',
153
+ 152: 'Japanese spaniel',
154
+ 153: 'Maltese dog, Maltese terrier, Maltese',
155
+ 154: 'Pekinese, Pekingese, Peke',
156
+ 155: 'Shih-Tzu',
157
+ 156: 'Blenheim spaniel',
158
+ 157: 'papillon',
159
+ 158: 'toy terrier',
160
+ 159: 'Rhodesian ridgeback',
161
+ 160: 'Afghan hound, Afghan',
162
+ 161: 'basset, basset hound',
163
+ 162: 'beagle',
164
+ 163: 'bloodhound, sleuthhound',
165
+ 164: 'bluetick',
166
+ 165: 'black-and-tan coonhound',
167
+ 166: 'Walker hound, Walker foxhound',
168
+ 167: 'English foxhound',
169
+ 168: 'redbone',
170
+ 169: 'borzoi, Russian wolfhound',
171
+ 170: 'Irish wolfhound',
172
+ 171: 'Italian greyhound',
173
+ 172: 'whippet',
174
+ 173: 'Ibizan hound, Ibizan Podenco',
175
+ 174: 'Norwegian elkhound, elkhound',
176
+ 175: 'otterhound, otter hound',
177
+ 176: 'Saluki, gazelle hound',
178
+ 177: 'Scottish deerhound, deerhound',
179
+ 178: 'Weimaraner',
180
+ 179: 'Staffordshire bullterrier, Staffordshire bull terrier',
181
+ 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier',
182
+ 181: 'Bedlington terrier',
183
+ 182: 'Border terrier',
184
+ 183: 'Kerry blue terrier',
185
+ 184: 'Irish terrier',
186
+ 185: 'Norfolk terrier',
187
+ 186: 'Norwich terrier',
188
+ 187: 'Yorkshire terrier',
189
+ 188: 'wire-haired fox terrier',
190
+ 189: 'Lakeland terrier',
191
+ 190: 'Sealyham terrier, Sealyham',
192
+ 191: 'Airedale, Airedale terrier',
193
+ 192: 'cairn, cairn terrier',
194
+ 193: 'Australian terrier',
195
+ 194: 'Dandie Dinmont, Dandie Dinmont terrier',
196
+ 195: 'Boston bull, Boston terrier',
197
+ 196: 'miniature schnauzer',
198
+ 197: 'giant schnauzer',
199
+ 198: 'standard schnauzer',
200
+ 199: 'Scotch terrier, Scottish terrier, Scottie',
201
+ 200: 'Tibetan terrier, chrysanthemum dog',
202
+ 201: 'silky terrier, Sydney silky',
203
+ 202: 'soft-coated wheaten terrier',
204
+ 203: 'West Highland white terrier',
205
+ 204: 'Lhasa, Lhasa apso',
206
+ 205: 'flat-coated retriever',
207
+ 206: 'curly-coated retriever',
208
+ 207: 'golden retriever',
209
+ 208: 'Labrador retriever',
210
+ 209: 'Chesapeake Bay retriever',
211
+ 210: 'German short-haired pointer',
212
+ 211: 'vizsla, Hungarian pointer',
213
+ 212: 'English setter',
214
+ 213: 'Irish setter, red setter',
215
+ 214: 'Gordon setter',
216
+ 215: 'Brittany spaniel',
217
+ 216: 'clumber, clumber spaniel',
218
+ 217: 'English springer, English springer spaniel',
219
+ 218: 'Welsh springer spaniel',
220
+ 219: 'cocker spaniel, English cocker spaniel, cocker',
221
+ 220: 'Sussex spaniel',
222
+ 221: 'Irish water spaniel',
223
+ 222: 'kuvasz',
224
+ 223: 'schipperke',
225
+ 224: 'groenendael',
226
+ 225: 'malinois',
227
+ 226: 'briard',
228
+ 227: 'kelpie',
229
+ 228: 'komondor',
230
+ 229: 'Old English sheepdog, bobtail',
231
+ 230: 'Shetland sheepdog, Shetland sheep dog, Shetland',
232
+ 231: 'collie',
233
+ 232: 'Border collie',
234
+ 233: 'Bouvier des Flandres, Bouviers des Flandres',
235
+ 234: 'Rottweiler',
236
+ 235: 'German shepherd, German shepherd dog, German police dog, alsatian',
237
+ 236: 'Doberman, Doberman pinscher',
238
+ 237: 'miniature pinscher',
239
+ 238: 'Greater Swiss Mountain dog',
240
+ 239: 'Bernese mountain dog',
241
+ 240: 'Appenzeller',
242
+ 241: 'EntleBucher',
243
+ 242: 'boxer',
244
+ 243: 'bull mastiff',
245
+ 244: 'Tibetan mastiff',
246
+ 245: 'French bulldog',
247
+ 246: 'Great Dane',
248
+ 247: 'Saint Bernard, St Bernard',
249
+ 248: 'Eskimo dog, husky',
250
+ 249: 'malamute, malemute, Alaskan malamute',
251
+ 250: 'Siberian husky',
252
+ 251: 'dalmatian, coach dog, carriage dog',
253
+ 252: 'affenpinscher, monkey pinscher, monkey dog',
254
+ 253: 'basenji',
255
+ 254: 'pug, pug-dog',
256
+ 255: 'Leonberg',
257
+ 256: 'Newfoundland, Newfoundland dog',
258
+ 257: 'Great Pyrenees',
259
+ 258: 'Samoyed, Samoyede',
260
+ 259: 'Pomeranian',
261
+ 260: 'chow, chow chow',
262
+ 261: 'keeshond',
263
+ 262: 'Brabancon griffon',
264
+ 263: 'Pembroke, Pembroke Welsh corgi',
265
+ 264: 'Cardigan, Cardigan Welsh corgi',
266
+ 265: 'toy poodle',
267
+ 266: 'miniature poodle',
268
+ 267: 'standard poodle',
269
+ 268: 'Mexican hairless',
270
+ 269: 'timber wolf, grey wolf, gray wolf, Canis lupus',
271
+ 270: 'white wolf, Arctic wolf, Canis lupus tundrarum',
272
+ 271: 'red wolf, maned wolf, Canis rufus, Canis niger',
273
+ 272: 'coyote, prairie wolf, brush wolf, Canis latrans',
274
+ 273: 'dingo, warrigal, warragal, Canis dingo',
275
+ 274: 'dhole, Cuon alpinus',
276
+ 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus',
277
+ 276: 'hyena, hyaena',
278
+ 277: 'red fox, Vulpes vulpes',
279
+ 278: 'kit fox, Vulpes macrotis',
280
+ 279: 'Arctic fox, white fox, Alopex lagopus',
281
+ 280: 'grey fox, gray fox, Urocyon cinereoargenteus',
282
+ 281: 'tabby, tabby cat',
283
+ 282: 'tiger cat',
284
+ 283: 'Persian cat',
285
+ 284: 'Siamese cat, Siamese',
286
+ 285: 'Egyptian cat',
287
+ 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor',
288
+ 287: 'lynx, catamount',
289
+ 288: 'leopard, Panthera pardus',
290
+ 289: 'snow leopard, ounce, Panthera uncia',
291
+ 290: 'jaguar, panther, Panthera onca, Felis onca',
292
+ 291: 'lion, king of beasts, Panthera leo',
293
+ 292: 'tiger, Panthera tigris',
294
+ 293: 'cheetah, chetah, Acinonyx jubatus',
295
+ 294: 'brown bear, bruin, Ursus arctos',
296
+ 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus',
297
+ 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus',
298
+ 297: 'sloth bear, Melursus ursinus, Ursus ursinus',
299
+ 298: 'mongoose',
300
+ 299: 'meerkat, mierkat',
301
+ 300: 'tiger beetle',
302
+ 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
303
+ 302: 'ground beetle, carabid beetle',
304
+ 303: 'long-horned beetle, longicorn, longicorn beetle',
305
+ 304: 'leaf beetle, chrysomelid',
306
+ 305: 'dung beetle',
307
+ 306: 'rhinoceros beetle',
308
+ 307: 'weevil',
309
+ 308: 'fly',
310
+ 309: 'bee',
311
+ 310: 'ant, emmet, pismire',
312
+ 311: 'grasshopper, hopper',
313
+ 312: 'cricket',
314
+ 313: 'walking stick, walkingstick, stick insect',
315
+ 314: 'cockroach, roach',
316
+ 315: 'mantis, mantid',
317
+ 316: 'cicada, cicala',
318
+ 317: 'leafhopper',
319
+ 318: 'lacewing, lacewing fly',
320
+ 319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
321
+ 320: 'damselfly',
322
+ 321: 'admiral',
323
+ 322: 'ringlet, ringlet butterfly',
324
+ 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
325
+ 324: 'cabbage butterfly',
326
+ 325: 'sulphur butterfly, sulfur butterfly',
327
+ 326: 'lycaenid, lycaenid butterfly',
328
+ 327: 'starfish, sea star',
329
+ 328: 'sea urchin',
330
+ 329: 'sea cucumber, holothurian',
331
+ 330: 'wood rabbit, cottontail, cottontail rabbit',
332
+ 331: 'hare',
333
+ 332: 'Angora, Angora rabbit',
334
+ 333: 'hamster',
335
+ 334: 'porcupine, hedgehog',
336
+ 335: 'fox squirrel, eastern fox squirrel, Sciurus niger',
337
+ 336: 'marmot',
338
+ 337: 'beaver',
339
+ 338: 'guinea pig, Cavia cobaya',
340
+ 339: 'sorrel',
341
+ 340: 'zebra',
342
+ 341: 'hog, pig, grunter, squealer, Sus scrofa',
343
+ 342: 'wild boar, boar, Sus scrofa',
344
+ 343: 'warthog',
345
+ 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius',
346
+ 345: 'ox',
347
+ 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis',
348
+ 347: 'bison',
349
+ 348: 'ram, tup',
350
+ 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis',
351
+ 350: 'ibex, Capra ibex',
352
+ 351: 'hartebeest',
353
+ 352: 'impala, Aepyceros melampus',
354
+ 353: 'gazelle',
355
+ 354: 'Arabian camel, dromedary, Camelus dromedarius',
356
+ 355: 'llama',
357
+ 356: 'weasel',
358
+ 357: 'mink',
359
+ 358: 'polecat, fitch, foulmart, foumart, Mustela putorius',
360
+ 359: 'black-footed ferret, ferret, Mustela nigripes',
361
+ 360: 'otter',
362
+ 361: 'skunk, polecat, wood pussy',
363
+ 362: 'badger',
364
+ 363: 'armadillo',
365
+ 364: 'three-toed sloth, ai, Bradypus tridactylus',
366
+ 365: 'orangutan, orang, orangutang, Pongo pygmaeus',
367
+ 366: 'gorilla, Gorilla gorilla',
368
+ 367: 'chimpanzee, chimp, Pan troglodytes',
369
+ 368: 'gibbon, Hylobates lar',
370
+ 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus',
371
+ 370: 'guenon, guenon monkey',
372
+ 371: 'patas, hussar monkey, Erythrocebus patas',
373
+ 372: 'baboon',
374
+ 373: 'macaque',
375
+ 374: 'langur',
376
+ 375: 'colobus, colobus monkey',
377
+ 376: 'proboscis monkey, Nasalis larvatus',
378
+ 377: 'marmoset',
379
+ 378: 'capuchin, ringtail, Cebus capucinus',
380
+ 379: 'howler monkey, howler',
381
+ 380: 'titi, titi monkey',
382
+ 381: 'spider monkey, Ateles geoffroyi',
383
+ 382: 'squirrel monkey, Saimiri sciureus',
384
+ 383: 'Madagascar cat, ring-tailed lemur, Lemur catta',
385
+ 384: 'indri, indris, Indri indri, Indri brevicaudatus',
386
+ 385: 'Indian elephant, Elephas maximus',
387
+ 386: 'African elephant, Loxodonta africana',
388
+ 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens',
389
+ 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca',
390
+ 389: 'barracouta, snoek',
391
+ 390: 'eel',
392
+ 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch',
393
+ 392: 'rock beauty, Holocanthus tricolor',
394
+ 393: 'anemone fish',
395
+ 394: 'sturgeon',
396
+ 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus',
397
+ 396: 'lionfish',
398
+ 397: 'puffer, pufferfish, blowfish, globefish',
399
+ 398: 'abacus',
400
+ 399: 'abaya',
401
+ 400: "academic gown, academic robe, judge's robe",
402
+ 401: 'accordion, piano accordion, squeeze box',
403
+ 402: 'acoustic guitar',
404
+ 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier',
405
+ 404: 'airliner',
406
+ 405: 'airship, dirigible',
407
+ 406: 'altar',
408
+ 407: 'ambulance',
409
+ 408: 'amphibian, amphibious vehicle',
410
+ 409: 'analog clock',
411
+ 410: 'apiary, bee house',
412
+ 411: 'apron',
413
+ 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin',
414
+ 413: 'assault rifle, assault gun',
415
+ 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack',
416
+ 415: 'bakery, bakeshop, bakehouse',
417
+ 416: 'balance beam, beam',
418
+ 417: 'balloon',
419
+ 418: 'ballpoint, ballpoint pen, ballpen, Biro',
420
+ 419: 'Band Aid',
421
+ 420: 'banjo',
422
+ 421: 'bannister, banister, balustrade, balusters, handrail',
423
+ 422: 'barbell',
424
+ 423: 'barber chair',
425
+ 424: 'barbershop',
426
+ 425: 'barn',
427
+ 426: 'barometer',
428
+ 427: 'barrel, cask',
429
+ 428: 'barrow, garden cart, lawn cart, wheelbarrow',
430
+ 429: 'baseball',
431
+ 430: 'basketball',
432
+ 431: 'bassinet',
433
+ 432: 'bassoon',
434
+ 433: 'bathing cap, swimming cap',
435
+ 434: 'bath towel',
436
+ 435: 'bathtub, bathing tub, bath, tub',
437
+ 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon',
438
+ 437: 'beacon, lighthouse, beacon light, pharos',
439
+ 438: 'beaker',
440
+ 439: 'bearskin, busby, shako',
441
+ 440: 'beer bottle',
442
+ 441: 'beer glass',
443
+ 442: 'bell cote, bell cot',
444
+ 443: 'bib',
445
+ 444: 'bicycle-built-for-two, tandem bicycle, tandem',
446
+ 445: 'bikini, two-piece',
447
+ 446: 'binder, ring-binder',
448
+ 447: 'binoculars, field glasses, opera glasses',
449
+ 448: 'birdhouse',
450
+ 449: 'boathouse',
451
+ 450: 'bobsled, bobsleigh, bob',
452
+ 451: 'bolo tie, bolo, bola tie, bola',
453
+ 452: 'bonnet, poke bonnet',
454
+ 453: 'bookcase',
455
+ 454: 'bookshop, bookstore, bookstall',
456
+ 455: 'bottlecap',
457
+ 456: 'bow',
458
+ 457: 'bow tie, bow-tie, bowtie',
459
+ 458: 'brass, memorial tablet, plaque',
460
+ 459: 'brassiere, bra, bandeau',
461
+ 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty',
462
+ 461: 'breastplate, aegis, egis',
463
+ 462: 'broom',
464
+ 463: 'bucket, pail',
465
+ 464: 'buckle',
466
+ 465: 'bulletproof vest',
467
+ 466: 'bullet train, bullet',
468
+ 467: 'butcher shop, meat market',
469
+ 468: 'cab, hack, taxi, taxicab',
470
+ 469: 'caldron, cauldron',
471
+ 470: 'candle, taper, wax light',
472
+ 471: 'cannon',
473
+ 472: 'canoe',
474
+ 473: 'can opener, tin opener',
475
+ 474: 'cardigan',
476
+ 475: 'car mirror',
477
+ 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig',
478
+ 477: "carpenter's kit, tool kit",
479
+ 478: 'carton',
480
+ 479: 'car wheel',
481
+ 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM',
482
+ 481: 'cassette',
483
+ 482: 'cassette player',
484
+ 483: 'castle',
485
+ 484: 'catamaran',
486
+ 485: 'CD player',
487
+ 486: 'cello, violoncello',
488
+ 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
489
+ 488: 'chain',
490
+ 489: 'chainlink fence',
491
+ 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour',
492
+ 491: 'chain saw, chainsaw',
493
+ 492: 'chest',
494
+ 493: 'chiffonier, commode',
495
+ 494: 'chime, bell, gong',
496
+ 495: 'china cabinet, china closet',
497
+ 496: 'Christmas stocking',
498
+ 497: 'church, church building',
499
+ 498: 'cinema, movie theater, movie theatre, movie house, picture palace',
500
+ 499: 'cleaver, meat cleaver, chopper',
501
+ 500: 'cliff dwelling',
502
+ 501: 'cloak',
503
+ 502: 'clog, geta, patten, sabot',
504
+ 503: 'cocktail shaker',
505
+ 504: 'coffee mug',
506
+ 505: 'coffeepot',
507
+ 506: 'coil, spiral, volute, whorl, helix',
508
+ 507: 'combination lock',
509
+ 508: 'computer keyboard, keypad',
510
+ 509: 'confectionery, confectionary, candy store',
511
+ 510: 'container ship, containership, container vessel',
512
+ 511: 'convertible',
513
+ 512: 'corkscrew, bottle screw',
514
+ 513: 'cornet, horn, trumpet, trump',
515
+ 514: 'cowboy boot',
516
+ 515: 'cowboy hat, ten-gallon hat',
517
+ 516: 'cradle',
518
+ 517: 'crane',
519
+ 518: 'crash helmet',
520
+ 519: 'crate',
521
+ 520: 'crib, cot',
522
+ 521: 'Crock Pot',
523
+ 522: 'croquet ball',
524
+ 523: 'crutch',
525
+ 524: 'cuirass',
526
+ 525: 'dam, dike, dyke',
527
+ 526: 'desk',
528
+ 527: 'desktop computer',
529
+ 528: 'dial telephone, dial phone',
530
+ 529: 'diaper, nappy, napkin',
531
+ 530: 'digital clock',
532
+ 531: 'digital watch',
533
+ 532: 'dining table, board',
534
+ 533: 'dishrag, dishcloth',
535
+ 534: 'dishwasher, dish washer, dishwashing machine',
536
+ 535: 'disk brake, disc brake',
537
+ 536: 'dock, dockage, docking facility',
538
+ 537: 'dogsled, dog sled, dog sleigh',
539
+ 538: 'dome',
540
+ 539: 'doormat, welcome mat',
541
+ 540: 'drilling platform, offshore rig',
542
+ 541: 'drum, membranophone, tympan',
543
+ 542: 'drumstick',
544
+ 543: 'dumbbell',
545
+ 544: 'Dutch oven',
546
+ 545: 'electric fan, blower',
547
+ 546: 'electric guitar',
548
+ 547: 'electric locomotive',
549
+ 548: 'entertainment center',
550
+ 549: 'envelope',
551
+ 550: 'espresso maker',
552
+ 551: 'face powder',
553
+ 552: 'feather boa, boa',
554
+ 553: 'file, file cabinet, filing cabinet',
555
+ 554: 'fireboat',
556
+ 555: 'fire engine, fire truck',
557
+ 556: 'fire screen, fireguard',
558
+ 557: 'flagpole, flagstaff',
559
+ 558: 'flute, transverse flute',
560
+ 559: 'folding chair',
561
+ 560: 'football helmet',
562
+ 561: 'forklift',
563
+ 562: 'fountain',
564
+ 563: 'fountain pen',
565
+ 564: 'four-poster',
566
+ 565: 'freight car',
567
+ 566: 'French horn, horn',
568
+ 567: 'frying pan, frypan, skillet',
569
+ 568: 'fur coat',
570
+ 569: 'garbage truck, dustcart',
571
+ 570: 'gasmask, respirator, gas helmet',
572
+ 571: 'gas pump, gasoline pump, petrol pump, island dispenser',
573
+ 572: 'goblet',
574
+ 573: 'go-kart',
575
+ 574: 'golf ball',
576
+ 575: 'golfcart, golf cart',
577
+ 576: 'gondola',
578
+ 577: 'gong, tam-tam',
579
+ 578: 'gown',
580
+ 579: 'grand piano, grand',
581
+ 580: 'greenhouse, nursery, glasshouse',
582
+ 581: 'grille, radiator grille',
583
+ 582: 'grocery store, grocery, food market, market',
584
+ 583: 'guillotine',
585
+ 584: 'hair slide',
586
+ 585: 'hair spray',
587
+ 586: 'half track',
588
+ 587: 'hammer',
589
+ 588: 'hamper',
590
+ 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier',
591
+ 590: 'hand-held computer, hand-held microcomputer',
592
+ 591: 'handkerchief, hankie, hanky, hankey',
593
+ 592: 'hard disc, hard disk, fixed disk',
594
+ 593: 'harmonica, mouth organ, harp, mouth harp',
595
+ 594: 'harp',
596
+ 595: 'harvester, reaper',
597
+ 596: 'hatchet',
598
+ 597: 'holster',
599
+ 598: 'home theater, home theatre',
600
+ 599: 'honeycomb',
601
+ 600: 'hook, claw',
602
+ 601: 'hoopskirt, crinoline',
603
+ 602: 'horizontal bar, high bar',
604
+ 603: 'horse cart, horse-cart',
605
+ 604: 'hourglass',
606
+ 605: 'iPod',
607
+ 606: 'iron, smoothing iron',
608
+ 607: "jack-o'-lantern",
609
+ 608: 'jean, blue jean, denim',
610
+ 609: 'jeep, landrover',
611
+ 610: 'jersey, T-shirt, tee shirt',
612
+ 611: 'jigsaw puzzle',
613
+ 612: 'jinrikisha, ricksha, rickshaw',
614
+ 613: 'joystick',
615
+ 614: 'kimono',
616
+ 615: 'knee pad',
617
+ 616: 'knot',
618
+ 617: 'lab coat, laboratory coat',
619
+ 618: 'ladle',
620
+ 619: 'lampshade, lamp shade',
621
+ 620: 'laptop, laptop computer',
622
+ 621: 'lawn mower, mower',
623
+ 622: 'lens cap, lens cover',
624
+ 623: 'letter opener, paper knife, paperknife',
625
+ 624: 'library',
626
+ 625: 'lifeboat',
627
+ 626: 'lighter, light, igniter, ignitor',
628
+ 627: 'limousine, limo',
629
+ 628: 'liner, ocean liner',
630
+ 629: 'lipstick, lip rouge',
631
+ 630: 'Loafer',
632
+ 631: 'lotion',
633
+ 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system',
634
+ 633: "loupe, jeweler's loupe",
635
+ 634: 'lumbermill, sawmill',
636
+ 635: 'magnetic compass',
637
+ 636: 'mailbag, postbag',
638
+ 637: 'mailbox, letter box',
639
+ 638: 'maillot',
640
+ 639: 'maillot, tank suit',
641
+ 640: 'manhole cover',
642
+ 641: 'maraca',
643
+ 642: 'marimba, xylophone',
644
+ 643: 'mask',
645
+ 644: 'matchstick',
646
+ 645: 'maypole',
647
+ 646: 'maze, labyrinth',
648
+ 647: 'measuring cup',
649
+ 648: 'medicine chest, medicine cabinet',
650
+ 649: 'megalith, megalithic structure',
651
+ 650: 'microphone, mike',
652
+ 651: 'microwave, microwave oven',
653
+ 652: 'military uniform',
654
+ 653: 'milk can',
655
+ 654: 'minibus',
656
+ 655: 'miniskirt, mini',
657
+ 656: 'minivan',
658
+ 657: 'missile',
659
+ 658: 'mitten',
660
+ 659: 'mixing bowl',
661
+ 660: 'mobile home, manufactured home',
662
+ 661: 'Model T',
663
+ 662: 'modem',
664
+ 663: 'monastery',
665
+ 664: 'monitor',
666
+ 665: 'moped',
667
+ 666: 'mortar',
668
+ 667: 'mortarboard',
669
+ 668: 'mosque',
670
+ 669: 'mosquito net',
671
+ 670: 'motor scooter, scooter',
672
+ 671: 'mountain bike, all-terrain bike, off-roader',
673
+ 672: 'mountain tent',
674
+ 673: 'mouse, computer mouse',
675
+ 674: 'mousetrap',
676
+ 675: 'moving van',
677
+ 676: 'muzzle',
678
+ 677: 'nail',
679
+ 678: 'neck brace',
680
+ 679: 'necklace',
681
+ 680: 'nipple',
682
+ 681: 'notebook, notebook computer',
683
+ 682: 'obelisk',
684
+ 683: 'oboe, hautboy, hautbois',
685
+ 684: 'ocarina, sweet potato',
686
+ 685: 'odometer, hodometer, mileometer, milometer',
687
+ 686: 'oil filter',
688
+ 687: 'organ, pipe organ',
689
+ 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO',
690
+ 689: 'overskirt',
691
+ 690: 'oxcart',
692
+ 691: 'oxygen mask',
693
+ 692: 'packet',
694
+ 693: 'paddle, boat paddle',
695
+ 694: 'paddlewheel, paddle wheel',
696
+ 695: 'padlock',
697
+ 696: 'paintbrush',
698
+ 697: "pajama, pyjama, pj's, jammies",
699
+ 698: 'palace',
700
+ 699: 'panpipe, pandean pipe, syrinx',
701
+ 700: 'paper towel',
702
+ 701: 'parachute, chute',
703
+ 702: 'parallel bars, bars',
704
+ 703: 'park bench',
705
+ 704: 'parking meter',
706
+ 705: 'passenger car, coach, carriage',
707
+ 706: 'patio, terrace',
708
+ 707: 'pay-phone, pay-station',
709
+ 708: 'pedestal, plinth, footstall',
710
+ 709: 'pencil box, pencil case',
711
+ 710: 'pencil sharpener',
712
+ 711: 'perfume, essence',
713
+ 712: 'Petri dish',
714
+ 713: 'photocopier',
715
+ 714: 'pick, plectrum, plectron',
716
+ 715: 'pickelhaube',
717
+ 716: 'picket fence, paling',
718
+ 717: 'pickup, pickup truck',
719
+ 718: 'pier',
720
+ 719: 'piggy bank, penny bank',
721
+ 720: 'pill bottle',
722
+ 721: 'pillow',
723
+ 722: 'ping-pong ball',
724
+ 723: 'pinwheel',
725
+ 724: 'pirate, pirate ship',
726
+ 725: 'pitcher, ewer',
727
+ 726: "plane, carpenter's plane, woodworking plane",
728
+ 727: 'planetarium',
729
+ 728: 'plastic bag',
730
+ 729: 'plate rack',
731
+ 730: 'plow, plough',
732
+ 731: "plunger, plumber's helper",
733
+ 732: 'Polaroid camera, Polaroid Land camera',
734
+ 733: 'pole',
735
+ 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria',
736
+ 735: 'poncho',
737
+ 736: 'pool table, billiard table, snooker table',
738
+ 737: 'pop bottle, soda bottle',
739
+ 738: 'pot, flowerpot',
740
+ 739: "potter's wheel",
741
+ 740: 'power drill',
742
+ 741: 'prayer rug, prayer mat',
743
+ 742: 'printer',
744
+ 743: 'prison, prison house',
745
+ 744: 'projectile, missile',
746
+ 745: 'projector',
747
+ 746: 'puck, hockey puck',
748
+ 747: 'punching bag, punch bag, punching ball, punchball',
749
+ 748: 'purse',
750
+ 749: 'quill, quill pen',
751
+ 750: 'quilt, comforter, comfort, puff',
752
+ 751: 'racer, race car, racing car',
753
+ 752: 'racket, racquet',
754
+ 753: 'radiator',
755
+ 754: 'radio, wireless',
756
+ 755: 'radio telescope, radio reflector',
757
+ 756: 'rain barrel',
758
+ 757: 'recreational vehicle, RV, R.V.',
759
+ 758: 'reel',
760
+ 759: 'reflex camera',
761
+ 760: 'refrigerator, icebox',
762
+ 761: 'remote control, remote',
763
+ 762: 'restaurant, eating house, eating place, eatery',
764
+ 763: 'revolver, six-gun, six-shooter',
765
+ 764: 'rifle',
766
+ 765: 'rocking chair, rocker',
767
+ 766: 'rotisserie',
768
+ 767: 'rubber eraser, rubber, pencil eraser',
769
+ 768: 'rugby ball',
770
+ 769: 'rule, ruler',
771
+ 770: 'running shoe',
772
+ 771: 'safe',
773
+ 772: 'safety pin',
774
+ 773: 'saltshaker, salt shaker',
775
+ 774: 'sandal',
776
+ 775: 'sarong',
777
+ 776: 'sax, saxophone',
778
+ 777: 'scabbard',
779
+ 778: 'scale, weighing machine',
780
+ 779: 'school bus',
781
+ 780: 'schooner',
782
+ 781: 'scoreboard',
783
+ 782: 'screen, CRT screen',
784
+ 783: 'screw',
785
+ 784: 'screwdriver',
786
+ 785: 'seat belt, seatbelt',
787
+ 786: 'sewing machine',
788
+ 787: 'shield, buckler',
789
+ 788: 'shoe shop, shoe-shop, shoe store',
790
+ 789: 'shoji',
791
+ 790: 'shopping basket',
792
+ 791: 'shopping cart',
793
+ 792: 'shovel',
794
+ 793: 'shower cap',
795
+ 794: 'shower curtain',
796
+ 795: 'ski',
797
+ 796: 'ski mask',
798
+ 797: 'sleeping bag',
799
+ 798: 'slide rule, slipstick',
800
+ 799: 'sliding door',
801
+ 800: 'slot, one-armed bandit',
802
+ 801: 'snorkel',
803
+ 802: 'snowmobile',
804
+ 803: 'snowplow, snowplough',
805
+ 804: 'soap dispenser',
806
+ 805: 'soccer ball',
807
+ 806: 'sock',
808
+ 807: 'solar dish, solar collector, solar furnace',
809
+ 808: 'sombrero',
810
+ 809: 'soup bowl',
811
+ 810: 'space bar',
812
+ 811: 'space heater',
813
+ 812: 'space shuttle',
814
+ 813: 'spatula',
815
+ 814: 'speedboat',
816
+ 815: "spider web, spider's web",
817
+ 816: 'spindle',
818
+ 817: 'sports car, sport car',
819
+ 818: 'spotlight, spot',
820
+ 819: 'stage',
821
+ 820: 'steam locomotive',
822
+ 821: 'steel arch bridge',
823
+ 822: 'steel drum',
824
+ 823: 'stethoscope',
825
+ 824: 'stole',
826
+ 825: 'stone wall',
827
+ 826: 'stopwatch, stop watch',
828
+ 827: 'stove',
829
+ 828: 'strainer',
830
+ 829: 'streetcar, tram, tramcar, trolley, trolley car',
831
+ 830: 'stretcher',
832
+ 831: 'studio couch, day bed',
833
+ 832: 'stupa, tope',
834
+ 833: 'submarine, pigboat, sub, U-boat',
835
+ 834: 'suit, suit of clothes',
836
+ 835: 'sundial',
837
+ 836: 'sunglass',
838
+ 837: 'sunglasses, dark glasses, shades',
839
+ 838: 'sunscreen, sunblock, sun blocker',
840
+ 839: 'suspension bridge',
841
+ 840: 'swab, swob, mop',
842
+ 841: 'sweatshirt',
843
+ 842: 'swimming trunks, bathing trunks',
844
+ 843: 'swing',
845
+ 844: 'switch, electric switch, electrical switch',
846
+ 845: 'syringe',
847
+ 846: 'table lamp',
848
+ 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle',
849
+ 848: 'tape player',
850
+ 849: 'teapot',
851
+ 850: 'teddy, teddy bear',
852
+ 851: 'television, television system',
853
+ 852: 'tennis ball',
854
+ 853: 'thatch, thatched roof',
855
+ 854: 'theater curtain, theatre curtain',
856
+ 855: 'thimble',
857
+ 856: 'thresher, thrasher, threshing machine',
858
+ 857: 'throne',
859
+ 858: 'tile roof',
860
+ 859: 'toaster',
861
+ 860: 'tobacco shop, tobacconist shop, tobacconist',
862
+ 861: 'toilet seat',
863
+ 862: 'torch',
864
+ 863: 'totem pole',
865
+ 864: 'tow truck, tow car, wrecker',
866
+ 865: 'toyshop',
867
+ 866: 'tractor',
868
+ 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi',
869
+ 868: 'tray',
870
+ 869: 'trench coat',
871
+ 870: 'tricycle, trike, velocipede',
872
+ 871: 'trimaran',
873
+ 872: 'tripod',
874
+ 873: 'triumphal arch',
875
+ 874: 'trolleybus, trolley coach, trackless trolley',
876
+ 875: 'trombone',
877
+ 876: 'tub, vat',
878
+ 877: 'turnstile',
879
+ 878: 'typewriter keyboard',
880
+ 879: 'umbrella',
881
+ 880: 'unicycle, monocycle',
882
+ 881: 'upright, upright piano',
883
+ 882: 'vacuum, vacuum cleaner',
884
+ 883: 'vase',
885
+ 884: 'vault',
886
+ 885: 'velvet',
887
+ 886: 'vending machine',
888
+ 887: 'vestment',
889
+ 888: 'viaduct',
890
+ 889: 'violin, fiddle',
891
+ 890: 'volleyball',
892
+ 891: 'waffle iron',
893
+ 892: 'wall clock',
894
+ 893: 'wallet, billfold, notecase, pocketbook',
895
+ 894: 'wardrobe, closet, press',
896
+ 895: 'warplane, military plane',
897
+ 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin',
898
+ 897: 'washer, automatic washer, washing machine',
899
+ 898: 'water bottle',
900
+ 899: 'water jug',
901
+ 900: 'water tower',
902
+ 901: 'whiskey jug',
903
+ 902: 'whistle',
904
+ 903: 'wig',
905
+ 904: 'window screen',
906
+ 905: 'window shade',
907
+ 906: 'Windsor tie',
908
+ 907: 'wine bottle',
909
+ 908: 'wing',
910
+ 909: 'wok',
911
+ 910: 'wooden spoon',
912
+ 911: 'wool, woolen, woollen',
913
+ 912: 'worm fence, snake fence, snake-rail fence, Virginia fence',
914
+ 913: 'wreck',
915
+ 914: 'yawl',
916
+ 915: 'yurt',
917
+ 916: 'web site, website, internet site, site',
918
+ 917: 'comic book',
919
+ 918: 'crossword puzzle, crossword',
920
+ 919: 'street sign',
921
+ 920: 'traffic light, traffic signal, stoplight',
922
+ 921: 'book jacket, dust cover, dust jacket, dust wrapper',
923
+ 922: 'menu',
924
+ 923: 'plate',
925
+ 924: 'guacamole',
926
+ 925: 'consomme',
927
+ 926: 'hot pot, hotpot',
928
+ 927: 'trifle',
929
+ 928: 'ice cream, icecream',
930
+ 929: 'ice lolly, lolly, lollipop, popsicle',
931
+ 930: 'French loaf',
932
+ 931: 'bagel, beigel',
933
+ 932: 'pretzel',
934
+ 933: 'cheeseburger',
935
+ 934: 'hotdog, hot dog, red hot',
936
+ 935: 'mashed potato',
937
+ 936: 'head cabbage',
938
+ 937: 'broccoli',
939
+ 938: 'cauliflower',
940
+ 939: 'zucchini, courgette',
941
+ 940: 'spaghetti squash',
942
+ 941: 'acorn squash',
943
+ 942: 'butternut squash',
944
+ 943: 'cucumber, cuke',
945
+ 944: 'artichoke, globe artichoke',
946
+ 945: 'bell pepper',
947
+ 946: 'cardoon',
948
+ 947: 'mushroom',
949
+ 948: 'Granny Smith',
950
+ 949: 'strawberry',
951
+ 950: 'orange',
952
+ 951: 'lemon',
953
+ 952: 'fig',
954
+ 953: 'pineapple, ananas',
955
+ 954: 'banana',
956
+ 955: 'jackfruit, jak, jack',
957
+ 956: 'custard apple',
958
+ 957: 'pomegranate',
959
+ 958: 'hay',
960
+ 959: 'carbonara',
961
+ 960: 'chocolate sauce, chocolate syrup',
962
+ 961: 'dough',
963
+ 962: 'meat loaf, meatloaf',
964
+ 963: 'pizza, pizza pie',
965
+ 964: 'potpie',
966
+ 965: 'burrito',
967
+ 966: 'red wine',
968
+ 967: 'espresso',
969
+ 968: 'cup',
970
+ 969: 'eggnog',
971
+ 970: 'alp',
972
+ 971: 'bubble',
973
+ 972: 'cliff, drop, drop-off',
974
+ 973: 'coral reef',
975
+ 974: 'geyser',
976
+ 975: 'lakeside, lakeshore',
977
+ 976: 'promontory, headland, head, foreland',
978
+ 977: 'sandbar, sand bar',
979
+ 978: 'seashore, coast, seacoast, sea-coast',
980
+ 979: 'valley, vale',
981
+ 980: 'volcano',
982
+ 981: 'ballplayer, baseball player',
983
+ 982: 'groom, bridegroom',
984
+ 983: 'scuba diver',
985
+ 984: 'rapeseed',
986
+ 985: 'daisy',
987
+ 986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
988
+ 987: 'corn',
989
+ 988: 'acorn',
990
+ 989: 'hip, rose hip, rosehip',
991
+ 990: 'buckeye, horse chestnut, conker',
992
+ 991: 'coral fungus',
993
+ 992: 'agaric',
994
+ 993: 'gyromitra',
995
+ 994: 'stinkhorn, carrion fungus',
996
+ 995: 'earthstar',
997
+ 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa',
998
+ 997: 'bolete',
999
+ 998: 'ear, spike, capitulum',
1000
+ 999: 'toilet tissue, toilet paper, bathroom tissue'}
modelguidedattacks/data/registry.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import torchvision.transforms as T
4
+ from mmpretrain import datasets as mmdatasets
5
+ from mmpretrain.registry import TRANSFORMS
6
+ from mmengine.dataset import Compose
7
+
8
+ from torch import nn
9
+ from torch.utils.data import Dataset as TorchDataset
10
+
11
+ # This holds dataset instantiation functions by (dataset_name) tuple keys
12
+ DATASET_REGISTRY = {}
13
+ DATASET_PATH = "./datasets"
14
+
15
+ class MMPretrainWrapper(TorchDataset):
16
+ def __init__(self, mmdataset) -> None:
17
+ super().__init__()
18
+ self.mmdataset = mmdataset
19
+
20
+ test_pipeline = [
21
+ dict(type='LoadImageFromFile'),
22
+ dict(type='ResizeEdge', scale=256, edge='short'),
23
+ dict(type='CenterCrop', crop_size=224),
24
+ dict(type='PackInputs'),
25
+ ]
26
+
27
+ self.pipeline = self.init_pipeline(test_pipeline)
28
+
29
+ def init_pipeline(self, pipeline_cfg):
30
+ pipeline = Compose(
31
+ [TRANSFORMS.build(t) for t in pipeline_cfg])
32
+ return pipeline
33
+
34
+ @property
35
+ def classes(self):
36
+ return self.mmdataset.CLASSES
37
+
38
+ def __len__(self):
39
+ return len(self.mmdataset)
40
+
41
+ def __getitem__(self, index):
42
+ sample = self.mmdataset[index]
43
+ sample = self.pipeline(sample)
44
+
45
+ # Our interface expects images in [0-1]
46
+ img = sample["inputs"].float() / 255
47
+
48
+ return img, sample["data_samples"].gt_label.item()
49
+
50
+
51
+ def register_torchvision_dataset(dataset_name, dataset_cls, dataset_kwargs_train={}, dataset_kwargs_val={}):
52
+ def instantiate_dataset():
53
+ train_data = dataset_cls(
54
+ root=DATASET_PATH,
55
+ train=True,
56
+ download=True,
57
+ transform=T.ToTensor()
58
+ )
59
+
60
+ val_data = dataset_cls(
61
+ root=DATASET_PATH,
62
+ train=False,
63
+ download=True,
64
+ transform=T.ToTensor()
65
+ )
66
+
67
+ return train_data, val_data
68
+
69
+ DATASET_REGISTRY[dataset_name] = instantiate_dataset
70
+
71
+ def register_mmpretrain_dataset(dataset_name, dataset_cls, dataset_kwargs_train={}, dataset_kwargs_val={}):
72
+ def instantiate_dataset():
73
+ train_data = dataset_cls(**dataset_kwargs_train)
74
+ val_data = dataset_cls(**dataset_kwargs_val)
75
+
76
+ train_data = MMPretrainWrapper(train_data)
77
+ val_data = MMPretrainWrapper(val_data)
78
+
79
+ return train_data, val_data
80
+
81
+ DATASET_REGISTRY[dataset_name] = instantiate_dataset
82
+
83
+ def register_default_datasets():
84
+ register_torchvision_dataset("cifar10", torchvision.datasets.CIFAR10)
85
+ register_torchvision_dataset("cifar100", torchvision.datasets.CIFAR100)
86
+ register_mmpretrain_dataset("imagenet", mmdatasets.ImageNet,
87
+ dataset_kwargs_train=dict(
88
+ data_root = "data/imagenet",
89
+ data_prefix = "val",
90
+ ann_file = "meta/val.txt"
91
+ ),
92
+ dataset_kwargs_val=dict(
93
+ data_root = "data/imagenet",
94
+ data_prefix = "val",
95
+ ann_file = "meta/val.txt"
96
+ ))
97
+
98
+ def get_dataset(dataset_name):
99
+ """
100
+ Returns an instance of a dataset
101
+
102
+ dataset_name: Name of desired dataset
103
+ """
104
+
105
+ if dataset_name not in DATASET_REGISTRY:
106
+ raise Exception("Requested dataset not in registry")
107
+
108
+ return DATASET_REGISTRY[dataset_name]()
modelguidedattacks/data/setup.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ import os
3
+ import torch
4
+ import ignite.distributed as idist
5
+ import torchvision
6
+ import torchvision.transforms as T
7
+ from torch.utils import data as torch_data
8
+
9
+ from .classification_wrapper import TopKClassificationWrapper
10
+ from torch.utils.data import Subset
11
+ from modelguidedattacks.data import get_dataset
12
+ from modelguidedattacks.cls_models.accuracy import get_correct_subset_for_models, DATASET_METADATA_DIR
13
+
14
+ from tqdm import tqdm
15
+
16
+ def get_gt_labels(dataset: TopKClassificationWrapper, train:bool, dataset_name:str):
17
+ training_str = "train" if train else "val"
18
+ save_name = os.path.join(DATASET_METADATA_DIR, f"{dataset_name}_labels_{training_str}.p")
19
+
20
+ if os.path.exists(save_name):
21
+ print ("Found labels cache")
22
+ return torch.load(save_name)
23
+
24
+ dataloader = torch_data.DataLoader(dataset, batch_size=128, shuffle=False, num_workers=4)
25
+
26
+ gt_labels = []
27
+
28
+ for batch in tqdm(dataloader):
29
+ gt_labels.extend(batch[1].tolist())
30
+
31
+ gt_labels = torch.tensor(gt_labels)
32
+
33
+ torch.save(gt_labels, save_name)
34
+
35
+ return gt_labels
36
+
37
+ def class_balanced_sampling(dataset, gt_labels: torch.Tensor,
38
+ correct_labels: list, total_samples=1000):
39
+ num_classes = len(dataset.classes)
40
+
41
+ correct_labels = torch.tensor(correct_labels)
42
+ correct_mask = torch.zeros((len(dataset), ), dtype=torch.bool)
43
+ correct_mask[correct_labels] = True
44
+
45
+ sampled_indices = 0
46
+
47
+ total_sampled_indices = 0
48
+ sampled_indices = [[] for i in range(num_classes)]
49
+
50
+ shuffled_inds = torch.randperm(len(dataset))
51
+
52
+ for sample_cnt, sample_i in enumerate(shuffled_inds):
53
+ if not correct_mask[sample_i]:
54
+ continue
55
+
56
+ sample_class = gt_labels[sample_i]
57
+ desired_samples_in_class = (total_sampled_indices // num_classes) + 1
58
+
59
+ if len(sampled_indices[sample_class]) < desired_samples_in_class:
60
+ sampled_indices[sample_class].append(sample_i.item())
61
+ total_sampled_indices += 1
62
+
63
+ if total_sampled_indices >= total_samples:
64
+ break
65
+
66
+ flattened_indices = []
67
+ for class_samples in sampled_indices:
68
+ flattened_indices.extend(class_samples)
69
+
70
+ return torch.tensor(flattened_indices)
71
+
72
+ def sample_attack_labels(dataset, gt_labels, k, sampler):
73
+ """
74
+ dataset: Dataset we're generating attack labels for
75
+ gt_labels: List of gt idx for each sample in a dataset
76
+ k: attack size
77
+ sampler: ["random"]
78
+ """
79
+
80
+ # Sample from uniform and argsort to simulate
81
+ # a batched randperm
82
+ attack_label_uniforms = torch.rand((len(gt_labels), len(dataset.classes)))
83
+
84
+ # We don't want to sample the gt class for any samples
85
+ batch_inds = torch.arange(len(gt_labels))
86
+ attack_label_uniforms[batch_inds, gt_labels] = -1.
87
+
88
+ attack_labels = attack_label_uniforms.argsort(dim=-1, descending=True)[:, :k]
89
+
90
+ return attack_labels
91
+
92
+ def setup_data(config: Any, rank):
93
+ """Download datasets and create dataloaders
94
+
95
+ Parameters
96
+ ----------
97
+ config: needs to contain `data_path`, `train_batch_size`, `eval_batch_size`, and `num_workers`
98
+ """
99
+
100
+ dataset_train, dataset_eval = get_dataset(config.dataset)
101
+
102
+ train_subset = None
103
+ val_subset = None
104
+
105
+ attack_labels_train = None
106
+ attack_labels_val = None
107
+
108
+ if rank == 0:
109
+ gt_labels_train = get_gt_labels(dataset_train, True, config.dataset)
110
+ gt_labels_val = get_gt_labels(dataset_eval, False, config.dataset)
111
+
112
+ attack_labels_train = sample_attack_labels(dataset_train, gt_labels_train, k=config.k,
113
+ sampler=config.attack_sampling)
114
+ attack_labels_val = sample_attack_labels(dataset_eval, gt_labels_val, k=config.k,
115
+ sampler=config.attack_sampling)
116
+
117
+ device = "cuda" if torch.cuda.is_available() else "cpu"
118
+ correct_train_set = get_correct_subset_for_models(config.compare_models,
119
+ config.dataset, device,
120
+ train=True)
121
+
122
+ correct_eval_set = get_correct_subset_for_models(config.compare_models,
123
+ config.dataset, device,
124
+ train=False)
125
+
126
+ # Balanced sampling
127
+ train_subset = class_balanced_sampling(dataset_train, gt_labels_train,
128
+ correct_train_set)
129
+
130
+ val_subset = class_balanced_sampling(dataset_eval, gt_labels_val,
131
+ correct_eval_set)
132
+
133
+ if config.overfit:
134
+ rand_inds = torch.randperm(len(val_subset))[:16]
135
+ train_subset = train_subset[rand_inds]
136
+ val_subset = val_subset[rand_inds]
137
+
138
+ train_subset = idist.broadcast(train_subset, safe_mode=True)
139
+ val_subset = idist.broadcast(val_subset, safe_mode=True)
140
+
141
+ attack_labels_train = idist.broadcast(attack_labels_train, safe_mode=True)
142
+ attack_labels_val = idist.broadcast(attack_labels_val, safe_mode=True)
143
+
144
+ dataset_train = TopKClassificationWrapper(dataset_train, k=config.k,
145
+ attack_labels=attack_labels_train)
146
+ dataset_eval = TopKClassificationWrapper(dataset_eval, k=config.k,
147
+ attack_labels=attack_labels_val)
148
+
149
+ dataset_train = Subset(dataset_train, train_subset)
150
+ dataset_eval = Subset(dataset_eval, val_subset)
151
+
152
+ # if config.overfit:
153
+ # dataset_train = Subset(dataset_train, range(2))
154
+ # dataset_eval = dataset_train
155
+ # else:
156
+ # dataset_eval = Subset(dataset_eval, torch.randperm(len(dataset_eval))[:1000].tolist() )
157
+
158
+ dataloader_train = idist.auto_dataloader(
159
+ dataset_train,
160
+ batch_size=config.train_batch_size,
161
+ shuffle=not config.overfit,
162
+ num_workers=config.num_workers,
163
+ )
164
+ dataloader_eval = idist.auto_dataloader(
165
+ dataset_eval,
166
+ batch_size=config.eval_batch_size,
167
+ shuffle=True,
168
+ num_workers=config.num_workers,
169
+ )
170
+ return dataloader_train, dataloader_eval
modelguidedattacks/guides/instance_guide.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torchvision.ops import MLP
4
+
5
+ from .. import losses
6
+
7
+ class InstanceGuide(nn.Module):
8
+ def __init__(self, model: nn.Module, optimizer=torch.optim.AdamW, loss_fn=losses.CWExtensionLoss) -> None:
9
+ super().__init__()
10
+
11
+ self.guided = True
12
+ self.model = model
13
+
14
+
15
+ for p in self.model.parameters():
16
+ p.requires_grad_(False)
17
+
18
+ self.loss = loss_fn()
19
+ self.optimizer = optimizer
20
+
21
+ self.epochs = 30
22
+ self.mlp_iterations = 5
23
+ self.perturbation_iterations = 5
24
+
25
+ def surject_perturbation(self, x):
26
+ return x
27
+
28
+ def forward(self, x, attack_targets):
29
+ """
30
+ x: [B, channels, H, W]
31
+ attack_targets: [B, K]
32
+ """
33
+
34
+ B = x.shape[0]
35
+ K = attack_targets.shape[-1]
36
+ C = self.model.num_classes()
37
+
38
+ with torch.no_grad():
39
+ pred_clean, feats = self.model(x, return_features=True)
40
+
41
+ # We are assuming the clean predictions are ground truth since we make that
42
+ # constraint on the dataset side
43
+ attack_ground_truth = pred_clean.argmax(dim=-1) # [B]
44
+
45
+ mlp = MLP(self.model.head_features(),
46
+ [self.model.head_features()]*3 + [self.model.head_features()],
47
+ activation_layer=nn.GELU, inplace=None).to(x.device)
48
+
49
+ x_perturbation = nn.Parameter(torch.randn(x.shape,
50
+ device=x.device)*1e-3)
51
+
52
+ perturbation_optimizer = self.optimizer([x_perturbation], lr=1e-1)
53
+
54
+ mlp_optimizer = self.optimizer(mlp.parameters(), lr=1e-3)
55
+
56
+ logits_target_best = pred_clean
57
+ feats_target_best = feats
58
+
59
+ with torch.enable_grad():
60
+ for i in range(self.epochs):
61
+ for _ in range(self.mlp_iterations):
62
+ torch.cuda.synchronize()
63
+
64
+ feature_offset = mlp(feats)
65
+ feats_target_pred = feature_offset + feats
66
+ logits_target_pred = self.model.head(feats_target_pred)
67
+ # logits_target_pred = pred_logits
68
+ pred_classes = logits_target_pred.argsort(dim=-1, descending=True) # [B, C]
69
+ attack_successful = (pred_classes[:, :K] == attack_targets).all(dim=-1) # [B]
70
+
71
+ with torch.no_grad():
72
+ logits_target_best = torch.where(
73
+ attack_successful[:, None].expand(-1, C),
74
+ logits_target_pred,
75
+ logits_target_best
76
+ )
77
+
78
+ feats_target_best = torch.where(
79
+ attack_successful[:, None].expand(-1, self.model.head_features()),
80
+ feats_target_pred,
81
+ feats_target_best
82
+ )
83
+
84
+ mlp_loss = self.loss(logits_pred=logits_target_pred,
85
+ prediction_feats=feats_target_pred,
86
+ attack_targets=attack_targets,
87
+ attack_ground_truth=attack_ground_truth,
88
+ model=self.model)
89
+ mlp_loss = mlp_loss.mean() + feature_offset.view(B, -1).norm(dim=-1, p=2)*1
90
+
91
+ mlp_optimizer.zero_grad()
92
+ mlp_loss.backward()
93
+ mlp_optimizer.step()
94
+
95
+ feats_target_best = feats_target_best.detach()
96
+
97
+ for _ in range(self.perturbation_iterations):
98
+ x_perturbed = x + self.surject_perturbation(x_perturbation)
99
+ prediction, perturbed_feats = self.model(x_perturbed, return_features=True)
100
+ pred_classes = prediction.argsort(dim=-1, descending=True) # [B, C]
101
+ attack_successful = (pred_classes[:, :K] == attack_targets).all(dim=-1) # [B]
102
+
103
+ perturbation_loss = (prediction - logits_target_best).view(B, -1).norm(dim=-1).mean()
104
+
105
+ perturbation_optimizer.zero_grad()
106
+ perturbation_loss.backward()
107
+ perturbation_optimizer.step()
108
+
109
+ return prediction
modelguidedattacks/guides/unguided.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from .. import losses
4
+ import ignite.distributed as idist
5
+ import torch_optimizer
6
+ from tqdm import tqdm
7
+ import matplotlib.pyplot as plt
8
+ from torch.nn import functional as F
9
+ import os
10
+
11
+ import shutil
12
+ from modelguidedattacks.cls_models.registry import MMPretrainVisualTransformerWrapper
13
+ from modelguidedattacks.data.imagenet_metadata import imgnet_idx_to_name
14
+
15
+ class Unguided(nn.Module):
16
+ def __init__(self, model: nn.Module, config, optimizer=torch.optim.AdamW, seed=0, iterations=1000,
17
+ loss_fn=losses.CVXProjLoss, lr=1e-3,
18
+ binary_search_steps=1, topk_loss_coef_upper=10.,
19
+ topk_loss_coef_lower=0.) -> None:
20
+ super().__init__()
21
+
22
+ self.guided = False
23
+ self.model = model
24
+ self.seed = seed
25
+ self.iterations = iterations
26
+ self.loss = loss_fn()
27
+ self.optimizer = optimizer
28
+ self.lr = lr
29
+
30
+ self.binary_search_steps = binary_search_steps
31
+ self.topk_loss_coef_upper = topk_loss_coef_upper
32
+ self.topk_loss_coef_lower = topk_loss_coef_lower
33
+ self.config = config
34
+
35
+ def surject_perturbation(self, x, max_norm=5.):
36
+ x_shape = x.shape
37
+
38
+ x = x.flatten(1)
39
+ x_norm = x.norm(dim=-1)
40
+ x_unit = x / x_norm[:, None]
41
+
42
+ x_norm_outside = x_norm > max_norm
43
+ x_norm_outside = x_norm_outside.expand_as(x)
44
+
45
+ x = torch.where(x_norm_outside, x_unit*max_norm, x)
46
+
47
+ return x.view(x_shape)
48
+
49
+ @torch.enable_grad()
50
+ def attack(self, x, attack_targets, gt_labels, topk_coefs):
51
+ """
52
+ For a given set of topk coefficients, this function computes
53
+ best energy attack in the given number of iterations and configuration
54
+
55
+ x: [B, C, H, W] [0-1 for colors]
56
+ attack_targets: [B, K] (long)
57
+ gt_labels: [B] (long)
58
+ topk_coefs: [B] (floats)
59
+ """
60
+
61
+ topk_coefs = topk_coefs.clone()
62
+ K = attack_targets.shape[-1]
63
+
64
+ x_perturbation = nn.Parameter(torch.randn(x.shape,
65
+ device=x.device)*2e-3)
66
+
67
+ with torch.no_grad():
68
+ prediction_logits_0, prediction_feats_0 \
69
+ = self.model(x, return_features=True)
70
+
71
+ best_perturbations = torch.zeros_like(x) # [B, 3, H, W]
72
+ has_successful_attack = torch.zeros(x.shape[0], dtype=torch.long, device=x.device) # [B]
73
+ best_energy = torch.full((x.shape[0],), float('inf'), device=x.device) # [B]
74
+
75
+ pbar = tqdm(range(self.iterations))
76
+
77
+ for i in pbar:
78
+
79
+ if i == self.config.opt_warmup_its:
80
+ # Reset optimizer state
81
+ optimizer = self.optimizer([x_perturbation], lr=self.lr)
82
+
83
+ x_perturbed = x + x_perturbation#self.surject_perturbation(x_perturbation)
84
+ prediction_logits, prediction_feats = self.model(x_perturbed, return_features=True)
85
+
86
+ pred_classes = prediction_logits.argsort(dim=-1, descending=True) # [B, C]
87
+ attack_successful = (pred_classes[:, :K] == attack_targets).all(dim=-1) # [B]
88
+ attack_energy = x_perturbation.flatten(1).norm(dim=-1) # [B]
89
+
90
+ attack_improved = attack_successful & (attack_energy <= best_energy)
91
+
92
+ best_perturbations[attack_improved] = x_perturbation[attack_improved]
93
+ has_successful_attack[attack_improved] = True
94
+ best_energy[attack_improved] = attack_energy[attack_improved]
95
+
96
+ loss = self.loss(logits_pred=prediction_logits,
97
+ feats_pred=prediction_feats,
98
+ feats_pred_0=prediction_feats_0,
99
+ attack_targets=attack_targets,
100
+ model=self.model, **precomputed_state)
101
+
102
+ loss = loss * topk_coefs
103
+
104
+ loss = loss.sum()
105
+
106
+ pbar.set_description(f"Loss: {loss.item():.3f}")
107
+
108
+ loss = loss + x_perturbation.flatten(1).square().sum()
109
+
110
+ optimizer.zero_grad()
111
+ loss.backward()
112
+ optimizer.step()
113
+
114
+ # If we were successfull let's start taking the norm down
115
+ topk_coefs[attack_improved] *= 0.75
116
+
117
+ # Project perturbation to be within image limits
118
+ with torch.no_grad():
119
+ x_perturbed = x + x_perturbation
120
+ x_perturbed = x_perturbed.clamp_(min=0., max=1.)
121
+
122
+ x_perturbation.data = x_perturbed - x
123
+
124
+ x_perturbed_best = x + best_perturbations
125
+ prediction_logits, prediction_feats = self.model(x_perturbed_best, return_features=True)
126
+
127
+ if self.config.dump_plots:
128
+ if os.path.isdir(self.config.plot_out):
129
+ shutil.rmtree(self.config.plot_out)
130
+
131
+ if has_successful_attack.any():
132
+ def dump_random_map():
133
+ os.makedirs(self.config.plot_out, exist_ok=True)
134
+
135
+ # selected_idx = best_energy.argmin()
136
+ successful_idxs = has_successful_attack.nonzero()[:, 0]
137
+
138
+ if self.config.plot_idx == "find":
139
+ selected_idx = successful_idxs[torch.randperm(len(successful_idxs))[0]]
140
+ # selected_idx = best_energy.argmin()
141
+ else:
142
+ selected_idx = int(self.config.plot_idx)
143
+
144
+ print ("Selected idx", selected_idx)
145
+
146
+ top_classes = prediction_logits_0[selected_idx].argsort(dim=-1, descending=True)
147
+ attack_targets_selected = attack_targets[selected_idx]
148
+
149
+ def imgnet_names(idxs):
150
+ return [imgnet_idx_to_name[int(idx)].split(",")[0] for idx in idxs]
151
+
152
+ top_class_names = imgnet_names(top_classes)[:K]
153
+ attack_targets_selected_names = imgnet_names(attack_targets_selected)
154
+
155
+ def plot_attn_map(attn_map):
156
+ attn_map = attn_map[0].mean(dim=0)[1:] # [196] get class tokens
157
+ attn_map = attn_map.view(14, 14)
158
+ attn_map = F.interpolate(
159
+ attn_map[None, None],
160
+ x.shape[-2:],
161
+ mode="bilinear"
162
+ ).view(x.shape[-2:])
163
+
164
+ plt.imshow(attn_map.detach().cpu(), alpha=0.5)
165
+
166
+ plt.figure()
167
+ plt.imshow(x[selected_idx].permute(1,2,0).flip(dims=(-1,)).detach().cpu())
168
+ plt.axis("off")
169
+ plt.savefig(f"{self.config.plot_out}/clean_image.png", bbox_inches="tight", pad_inches=0)
170
+
171
+ plt.figure()
172
+ plt.imshow(x_perturbed_best[selected_idx].permute(1,2,0).flip(dims=(-1,)).detach().cpu())
173
+ plt.axis("off")
174
+ plt.savefig(f"{self.config.plot_out}/perturbed_image.png", bbox_inches="tight", pad_inches=0)
175
+
176
+ plt.figure()
177
+ plt.imshow(best_perturbations[selected_idx].mean(dim=0).abs().detach().cpu(), cmap="hot")
178
+ plt.colorbar()
179
+ plt.savefig(f"{self.config.plot_out}/perturbation.png", bbox_inches="tight")
180
+
181
+ if isinstance(self.model, MMPretrainVisualTransformerWrapper):
182
+ attn_maps_clean = self.model.get_attention_maps(x)[-1][selected_idx]
183
+ attn_maps_attacked = self.model.get_attention_maps(x_perturbed_best)[-1][selected_idx]
184
+
185
+ plt.figure()
186
+ plt.imshow(x[selected_idx].permute(1,2,0).flip(dims=(-1,)).detach().cpu())
187
+ plot_attn_map(attn_maps_clean)
188
+ plt.axis("off")
189
+ plt.savefig(f"{self.config.plot_out}/clean_map.png", bbox_inches="tight", pad_inches=0)
190
+
191
+ plt.figure()
192
+ plt.imshow(x[selected_idx].permute(1,2,0).flip(dims=(-1,)).detach().cpu())
193
+ plot_attn_map(attn_maps_attacked)
194
+ plt.axis("off")
195
+ plt.savefig(f"{self.config.plot_out}/attacked_map.png", bbox_inches="tight", pad_inches=0)
196
+
197
+ with open(f'{self.config.plot_out}/clean_classes_names.txt', 'w') as f:
198
+ f.write(", ".join(top_class_names))
199
+
200
+ with open(f'{self.config.plot_out}/attack_targets_names.txt', 'w') as f:
201
+ f.write(", ".join(attack_targets_selected_names))
202
+
203
+ with open(f'{self.config.plot_out}/clean_classes_names.txt', 'w') as f:
204
+ f.write(", ".join(top_class_names))
205
+
206
+ with open(f'{self.config.plot_out}/selected_idx.txt', 'w') as f:
207
+ if isinstance(selected_idx, torch.Tensor):
208
+ selected_idx = selected_idx.item()
209
+
210
+ f.write(str(selected_idx))
211
+
212
+ with open(f'{self.config.plot_out}/energy.txt', 'w') as f:
213
+ f.write(str(best_energy[selected_idx].item()))
214
+
215
+ C = prediction_logits_0.shape[-1]
216
+ class_idxs = torch.arange(C) + 1
217
+ clean_probs = prediction_logits_0[selected_idx].detach().cpu().softmax(dim=-1)
218
+ attacked_probs = prediction_logits[selected_idx].detach().cpu().softmax(dim=-1)
219
+
220
+ def label_classes(bars):
221
+ adjusted_heights = {}
222
+ for i, cls_idx in enumerate(attack_targets_selected.tolist()):
223
+ bar = bars[cls_idx]
224
+ height = bar.get_height()
225
+ ann_x = bar.get_x() + bar.get_width()
226
+
227
+ rotation = 90
228
+ font_size = 10
229
+
230
+ max_neighboring_height = -1
231
+ for other_cls_idx in attack_targets_selected.tolist():
232
+ if abs(cls_idx - other_cls_idx) <= 40 and cls_idx != other_cls_idx:
233
+ if other_cls_idx in adjusted_heights and adjusted_heights[other_cls_idx] > max_neighboring_height:
234
+ max_neighboring_height = adjusted_heights[other_cls_idx]
235
+
236
+ if max_neighboring_height > 0:
237
+ height = max_neighboring_height + 0.05
238
+
239
+ adjusted_heights[cls_idx] = height
240
+
241
+ plt.text(ann_x, height, f"[{i}]", rotation=rotation,
242
+ ha='center', va='bottom', fontsize=font_size, color='red')#.get_bbox_patch().get_height()
243
+
244
+
245
+ plt.figure()
246
+ bars_clean = plt.bar(class_idxs, clean_probs, width=4)
247
+ plt.ylim(0,1)
248
+ label_classes(bars_clean)
249
+ plt.savefig(f"{self.config.plot_out}/clean_probs.png", bbox_inches="tight", pad_inches=0)
250
+
251
+ plt.figure()
252
+ bars_attacked = plt.bar(class_idxs, attacked_probs, width=4)
253
+ plt.ylim(0,1)
254
+ label_classes(bars_attacked)
255
+ plt.savefig(f"{self.config.plot_out}/attacked_probs.png", bbox_inches="tight", pad_inches=0)
256
+
257
+ print ("Idx", selected_idx)
258
+ print (best_energy[selected_idx])
259
+ print ("Finished plotting")
260
+
261
+ dump_random_map()
262
+ import sys
263
+ sys.exit(1)
264
+ print ("Dumped attention map")
265
+
266
+
267
+ return prediction_logits, best_perturbations, best_energy
268
+
269
+ def forward(self, x, attack_targets, gt_labels):
270
+ """
271
+ This function is in charge of performing a binary search through
272
+ topk loss coefficients and running attacks on each.
273
+ """
274
+ B = x.shape[0]
275
+ device = x.device
276
+ topk_coefs_lower = torch.full((B,), fill_value=self.topk_loss_coef_lower,
277
+ device=device, dtype=torch.float)
278
+
279
+ topk_coefs_upper = torch.full((B,), fill_value=self.topk_loss_coef_upper,
280
+ device=device, dtype=torch.float)
281
+
282
+ best_perturbations = torch.zeros_like(x) # [B, 3, H, W]
283
+ best_energy = torch.full((B,), float('inf'), device=device) # [B]
284
+ best_prediction_logits = None
285
+
286
+ for search_step_i in range(self.binary_search_steps):
287
+ if x.device.index is None or x.device.index == 0:
288
+ print ("Running binary search step", search_step_i + 1)
289
+
290
+ current_topk_coefs = (topk_coefs_lower + topk_coefs_upper) / 2
291
+ current_logits, current_perturbations, current_energy = \
292
+ self.attack(x, attack_targets, gt_labels, current_topk_coefs)
293
+
294
+ current_attack_suceeded = ~torch.isinf(current_energy)
295
+
296
+ update_mask = current_energy < best_energy
297
+
298
+ best_perturbations[update_mask] = current_perturbations[update_mask]
299
+ best_energy[update_mask] = current_energy[update_mask]
300
+
301
+ if best_prediction_logits is None:
302
+ best_prediction_logits = current_logits.clone()
303
+ else:
304
+ best_prediction_logits[update_mask] = current_logits[update_mask]
305
+
306
+ # If we fail to attack, we must increase our topk coef
307
+ topk_coefs_lower[~current_attack_suceeded] = current_topk_coefs[~current_attack_suceeded]
308
+
309
+ # If we succeed, we must lower to seek a more frugal attack
310
+ topk_coefs_upper[current_attack_suceeded] = current_topk_coefs[current_attack_suceeded]
311
+
312
+ idist.barrier()
313
+
314
+ return best_prediction_logits, best_perturbations
modelguidedattacks/losses/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .boilerplate import BoilerplateLoss
2
+ from .cw_extension import CWExtensionLoss
3
+ from .cvx_proj import CVXProjLoss
4
+ from .adversarial_distillation.ad_loss import AdversarialDistillationLoss
modelguidedattacks/losses/_qp_solver_patch.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import qpth
2
+ from qpth.solvers.pdipm import batch as pdipm_b
3
+ from qpth.solvers.pdipm.batch import *
4
+
5
+ def reduce_stats(z):
6
+ return z[~z.isnan()].median()
7
+
8
+ def forward(Q, p, G, h, A, b, Q_LU, S_LU, R, eps=1e-12, verbose=0, notImprovedLim=3,
9
+ maxIter=20, solver=KKTSolvers.LU_PARTIAL):
10
+ """
11
+ Q_LU, S_LU, R = pre_factor_kkt(Q, G, A)
12
+ """
13
+ nineq, nz, neq, nBatch = get_sizes(G, A)
14
+
15
+ # Find initial values
16
+ if solver == KKTSolvers.LU_FULL:
17
+ D = torch.eye(nineq).repeat(nBatch, 1, 1).type_as(Q)
18
+ x, s, z, y = factor_solve_kkt(
19
+ Q, D, G, A, p,
20
+ torch.zeros(nBatch, nineq).type_as(Q),
21
+ -h, -b if b is not None else None)
22
+ elif solver == KKTSolvers.LU_PARTIAL:
23
+ d = torch.ones(nBatch, nineq).type_as(Q)
24
+ factor_kkt(S_LU, R, d)
25
+ x, s, z, y = solve_kkt(
26
+ Q_LU, d, G, A, S_LU,
27
+ p, torch.zeros(nBatch, nineq).type_as(Q),
28
+ -h, -b if neq > 0 else None)
29
+ elif solver == KKTSolvers.IR_UNOPT:
30
+ D = torch.eye(nineq).repeat(nBatch, 1, 1).type_as(Q)
31
+ x, s, z, y = solve_kkt_ir(
32
+ Q, D, G, A, p,
33
+ torch.zeros(nBatch, nineq).type_as(Q),
34
+ -h, -b if b is not None else None)
35
+ else:
36
+ assert False
37
+
38
+ # Make all of the slack variables >= 1.
39
+ M = torch.min(s, 1)[0]
40
+ M = M.view(M.size(0), 1).repeat(1, nineq)
41
+ I = M < 0
42
+ s[I] -= M[I] - 1
43
+
44
+ # Make all of the inequality dual variables >= 1.
45
+ M = torch.min(z, 1)[0]
46
+ M = M.view(M.size(0), 1).repeat(1, nineq)
47
+ I = M < 0
48
+ z[I] -= M[I] - 1
49
+
50
+ best = {'resids': None, 'x': None, 'z': None, 's': None, 'y': None}
51
+ nNotImproved = 0
52
+
53
+ for i in range(maxIter):
54
+ # affine scaling direction
55
+ rx = (torch.bmm(y.unsqueeze(1), A).squeeze(1) if neq > 0 else 0.) + \
56
+ torch.bmm(z.unsqueeze(1), G).squeeze(1) + \
57
+ torch.bmm(x.unsqueeze(1), Q.transpose(1, 2)).squeeze(1) + \
58
+ p
59
+ rs = z
60
+ rz = torch.bmm(x.unsqueeze(1), G.transpose(1, 2)).squeeze(1) + s - h
61
+ ry = torch.bmm(x.unsqueeze(1), A.transpose(
62
+ 1, 2)).squeeze(1) - b if neq > 0 else 0.0
63
+ mu = torch.abs((s * z).sum(1).squeeze() / nineq)
64
+ z_resid = torch.norm(rz, 2, 1).squeeze()
65
+ y_resid = torch.norm(ry, 2, 1).squeeze() if neq > 0 else 0
66
+ pri_resid = y_resid + z_resid
67
+ dual_resid = torch.norm(rx, 2, 1).squeeze()
68
+ resids = pri_resid + dual_resid + nineq * mu
69
+
70
+ d = z / s
71
+ try:
72
+ factor_kkt(S_LU, R, d)
73
+ except:
74
+ return best['x'], best['y'], best['z'], best['s']
75
+
76
+ if verbose == 1:
77
+ print('iter: {}, pri_resid: {:.5e}, dual_resid: {:.5e}, mu: {:.5e}'.format(
78
+ i, reduce_stats(pri_resid), reduce_stats(dual_resid), reduce_stats(mu)))
79
+ if best['resids'] is None:
80
+ best['resids'] = resids
81
+ best['x'] = x.clone()
82
+ best['z'] = z.clone()
83
+ best['s'] = s.clone()
84
+ best['y'] = y.clone() if y is not None else None
85
+ nNotImproved = 0
86
+ else:
87
+ I = resids < best['resids']
88
+ if I.sum() > 0:
89
+ nNotImproved = 0
90
+ else:
91
+ nNotImproved += 1
92
+ I_nz = I.repeat(nz, 1).t()
93
+ I_nineq = I.repeat(nineq, 1).t()
94
+ best['resids'][I] = resids[I]
95
+ best['x'][I_nz] = x[I_nz]
96
+ best['z'][I_nineq] = z[I_nineq]
97
+ best['s'][I_nineq] = s[I_nineq]
98
+ if neq > 0:
99
+ I_neq = I.repeat(neq, 1).t()
100
+ best['y'][I_neq] = y[I_neq]
101
+ if nNotImproved == notImprovedLim or reduce_stats(pri_resid) < eps or mu.min() > 1e32:
102
+ if best['resids'].max() > 1. and verbose >= 0:
103
+ print(INACC_ERR)
104
+ return best['x'], best['y'], best['z'], best['s']
105
+
106
+ if solver == KKTSolvers.LU_FULL:
107
+ D = bdiag(d)
108
+ dx_aff, ds_aff, dz_aff, dy_aff = factor_solve_kkt(
109
+ Q, D, G, A, rx, rs, rz, ry)
110
+ elif solver == KKTSolvers.LU_PARTIAL:
111
+ dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt(
112
+ Q_LU, d, G, A, S_LU, rx, rs, rz, ry)
113
+ elif solver == KKTSolvers.IR_UNOPT:
114
+ D = bdiag(d)
115
+ dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt_ir(
116
+ Q, D, G, A, rx, rs, rz, ry)
117
+ else:
118
+ assert False
119
+
120
+ # compute centering directions
121
+ alpha = torch.min(torch.min(get_step(z, dz_aff),
122
+ get_step(s, ds_aff)),
123
+ torch.ones(nBatch).type_as(Q))
124
+ alpha_nineq = alpha.repeat(nineq, 1).t()
125
+ t1 = s + alpha_nineq * ds_aff
126
+ t2 = z + alpha_nineq * dz_aff
127
+ t3 = torch.sum(t1 * t2, 1).squeeze()
128
+ t4 = torch.sum(s * z, 1).squeeze()
129
+ sig = (t3 / t4)**3
130
+
131
+ rx = torch.zeros(nBatch, nz).type_as(Q)
132
+ rs = ((-mu * sig).repeat(nineq, 1).t() + ds_aff * dz_aff) / s
133
+ rz = torch.zeros(nBatch, nineq).type_as(Q)
134
+ ry = torch.zeros(nBatch, neq).type_as(Q) if neq > 0 else torch.Tensor()
135
+
136
+ if solver == KKTSolvers.LU_FULL:
137
+ D = bdiag(d)
138
+ dx_cor, ds_cor, dz_cor, dy_cor = factor_solve_kkt(
139
+ Q, D, G, A, rx, rs, rz, ry)
140
+ elif solver == KKTSolvers.LU_PARTIAL:
141
+ dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt(
142
+ Q_LU, d, G, A, S_LU, rx, rs, rz, ry)
143
+ elif solver == KKTSolvers.IR_UNOPT:
144
+ D = bdiag(d)
145
+ dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt_ir(
146
+ Q, D, G, A, rx, rs, rz, ry)
147
+ else:
148
+ assert False
149
+
150
+ dx = dx_aff + dx_cor
151
+ ds = ds_aff + ds_cor
152
+ dz = dz_aff + dz_cor
153
+ dy = dy_aff + dy_cor if neq > 0 else None
154
+ alpha = torch.min(0.999 * torch.min(get_step(z, dz),
155
+ get_step(s, ds)),
156
+ torch.ones(nBatch).type_as(Q))
157
+ alpha_nineq = alpha.repeat(nineq, 1).t()
158
+ alpha_neq = alpha.repeat(neq, 1).t() if neq > 0 else None
159
+ alpha_nz = alpha.repeat(nz, 1).t()
160
+
161
+ x += alpha_nz * dx
162
+ s += alpha_nineq * ds
163
+ z += alpha_nineq * dz
164
+ y = y + alpha_neq * dy if neq > 0 else None
165
+
166
+ if best['resids'].max() > 1. and verbose >= 0:
167
+ print(INACC_ERR)
168
+ return best['x'], best['y'], best['z'], best['s']
169
+
170
+ pdipm_b.forward = forward
modelguidedattacks/losses/adversarial_distillation/ad_loss.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from .adversarial_distribution import AD_Distribution
4
+
5
+ class AdversarialDistillationLoss(nn.Module):
6
+ def __init__(self, confidence=0, alpha=10, beta=0.3):
7
+ super().__init__()
8
+
9
+ self.alpha = alpha
10
+ self.beta = beta
11
+
12
+ self.distri_generator = AD_Distribution(simi_name='glove',
13
+ alpha=self.alpha, beta=self.beta)
14
+
15
+ self.kl = nn.KLDivLoss(reduction='none')
16
+ self.logsoftmax = nn.LogSoftmax(dim=-1)
17
+
18
+ def precompute(self, attack_targets, gt_labels, config):
19
+ device = attack_targets.device
20
+
21
+ target_distribution = self.distri_generator.generate_distribution(gt_labels.cpu(), attack_targets.cpu())
22
+ target_distribution = torch.from_numpy(target_distribution).float().to(device)
23
+
24
+ K = attack_targets.shape[-1]
25
+ target_distribution_topk = target_distribution.argsort(dim=-1, descending=True)[:, :K]
26
+
27
+ assert (target_distribution_topk == attack_targets).all()
28
+
29
+ return {
30
+ "ad_distribution": target_distribution
31
+ }
32
+
33
+ def forward(self, logits_pred, feats_pred, feats_pred_0, attack_targets, model, ad_distribution, **kwargs):
34
+ log_logits = self.logsoftmax(logits_pred)
35
+ loss_kl = self.kl(log_logits, ad_distribution)
36
+ loss_kl = torch.sum(loss_kl, dim = -1)
37
+
38
+ return loss_kl
modelguidedattacks/losses/adversarial_distillation/adversarial_distribution.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from .glove_simi import generate_glove, AD_DIRECTORY
4
+
5
+ class AD_Distribution():
6
+ def __init__(self, simi_name, alpha, beta):
7
+ print('using ........ ', simi_name, ' .... knowledge')
8
+ path_simi_name = os.path.join(AD_DIRECTORY, 'imagenet_cos_similarity')
9
+ file_simi_name = path_simi_name +'_'+ simi_name + '.npy'
10
+ if os.path.exists(file_simi_name):
11
+ self.cos_similarity = np.load(file_simi_name)
12
+ print(simi_name+" cos_similarity loaded")
13
+ else:
14
+ self.cos_similarity = self.generate_similarity(simi_name)
15
+
16
+ self.alpha = alpha
17
+ self.beta = beta
18
+
19
+ def generate_similarity(self,simi_name):
20
+ if simi_name == 'glove':
21
+ similarity = generate_glove()
22
+ else:
23
+ print(simi_name + 'not implemented yet')
24
+ return similarity
25
+
26
+ def generate_distribution(self, gt_label, target):
27
+ distribution=[]
28
+
29
+ for i in range(len(target)):
30
+ distri = self.single_distribution_build(i, target[i], gt_label[i])
31
+ distribution.append(distri)
32
+
33
+ distribution = np.array(distribution)
34
+ return distribution
35
+
36
+ def single_distribution_build(self,index, target_id, gt_id):
37
+ if target_id.shape == ():
38
+ target_id = np.array([target_id])
39
+
40
+ simil_logits = np.zeros(self.cos_similarity[target_id[0],:].shape)
41
+ for i in range(len(target_id)):
42
+ simil_logits += self.cos_similarity[target_id[i],:]
43
+
44
+ simil_logits = (simil_logits)/ len(target_id)
45
+ logit_value = self.alpha
46
+
47
+ for i in range(len(target_id)):
48
+ simil_logits[target_id[i]] = logit_value
49
+ logit_value = logit_value - self.beta
50
+
51
+ if not self.check_oreder_target_no_groundtruth(simil_logits, target_id):
52
+ print('fail to generate distribution for index: ', index)
53
+
54
+ logits = self.softmax(simil_logits)
55
+ return logits
56
+
57
+ def check_oreder_target_no_groundtruth(self, probs, target_id):
58
+ sort_labels = np.argsort(probs)
59
+ cnt = 0
60
+ for i in range(len(target_id)):
61
+ if target_id[-(i+1)] == sort_labels[-(len(target_id)-i)]:
62
+ cnt +=1
63
+ if (cnt == len(target_id)):
64
+ return True
65
+ else:
66
+ return False
67
+
68
+ def softmax(self,logits):
69
+
70
+ prob=np.exp(logits) / np.sum(np.exp(logits))
71
+ return prob
72
+
modelguidedattacks/losses/adversarial_distillation/glove.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+
4
+ class GloVe():
5
+
6
+ def __init__(self, file_path):
7
+ self.dimension = None
8
+ self.embedding = dict()
9
+ with open(file_path, 'r') as f:
10
+ for line in tqdm(f.readlines()):
11
+ strs = line.strip().split(' ')
12
+ word = strs[0]
13
+ vector = torch.FloatTensor(list(map(float, strs[1:])))
14
+ self.embedding[word] = vector
15
+ if self.dimension is None:
16
+ self.dimension = len(vector)
17
+
18
+ def _fix_word(self, word):
19
+ terms = word.replace('_', ' ').split(' ')
20
+ ret = self.zeros()
21
+ cnt = 0
22
+ for term in terms:
23
+ v = self.embedding.get(term)
24
+ if v is None:
25
+ subterms = term.split('-')
26
+ subterm_sum = self.zeros()
27
+ subterm_cnt = 0
28
+ for subterm in subterms:
29
+ subv = self.embedding.get(subterm)
30
+ if subv is not None:
31
+ subterm_sum += subv
32
+ subterm_cnt += 1
33
+ if subterm_cnt > 0:
34
+ v = subterm_sum / subterm_cnt
35
+ if v is not None:
36
+ ret += v
37
+ cnt += 1
38
+ return ret / cnt if cnt > 0 else None
39
+
40
+ def __getitem__(self, words):
41
+ if type(words) is str:
42
+ words = [words]
43
+ ret = self.zeros()
44
+ cnt = 0
45
+ for word in words:
46
+ v = self.embedding.get(word)
47
+ if v is None:
48
+ v = self._fix_word(word)
49
+ if v is not None:
50
+ ret += v
51
+ cnt += 1
52
+ if cnt > 0:
53
+ return ret / cnt
54
+ else:
55
+ return self.zeros()
56
+
57
+ def zeros(self):
58
+ return torch.zeros(self.dimension)
modelguidedattacks/losses/adversarial_distillation/glove_simi.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import requests
4
+ import zipfile
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch import optim
10
+ from torch.autograd import Variable
11
+ from torchvision import datasets, transforms
12
+ import torchvision
13
+ from PIL import Image
14
+ from .glove import GloVe
15
+ from tqdm import tqdm
16
+
17
+ AD_DIRECTORY = os.path.dirname(__file__)
18
+
19
+ def obtain_vector(inputs, glove):
20
+ vector_im = glove.embedding.get(inputs)
21
+ if vector_im is None:
22
+ vector_im = glove.embedding.get(inputs.lower())
23
+ if vector_im is None:
24
+ vector_im = glove.embedding.get(inputs.title())
25
+ if vector_im is None:
26
+ vector_im = glove.embedding.get(inputs.upper())
27
+ return vector_im
28
+
29
+ def generate_glove():
30
+ print("Generating glove similarity...")
31
+ #download glove file
32
+ os.makedirs("./knowledge", exist_ok=True)
33
+ glove_file = './knowledge/glove.840B.300d.txt'
34
+ if not os.path.exists(glove_file):
35
+ print("Downloading glove files...")
36
+ print("")
37
+ print("Gonna take a while")
38
+ print("")
39
+ url_path = "http://nlp.stanford.edu/data/glove.840B.300d.zip"
40
+ r = requests.get(url_path)
41
+ with open("./knowledge/glove.840B.300d.zip","wb") as f:
42
+ f.write(r.content)
43
+ filename = './knowledge/glove.840B.300d.zip'
44
+ fz = zipfile.ZipFile(filename, 'r')
45
+ for file in fz.namelist():
46
+ fz.extract(file, './knowledge/.')
47
+ if os.path.exists(filename):
48
+ os.remove(filename)
49
+
50
+ glove = GloVe('./knowledge/glove.840B.300d.txt')
51
+ filepath = os.path.join(AD_DIRECTORY, "label_name.txt")
52
+ vec_list = []
53
+ vec_list_np = []
54
+ cos_similarity = np.zeros((1000,1000))
55
+
56
+ index = 0
57
+
58
+ #the labels could be a word or a phrase with multi words
59
+ #we also tested on average of every words
60
+ #But we assume the last word should be more important, so in our final version
61
+ #we assign a higher weight to last word in a phrase and average the fornt words
62
+
63
+ #w2v for last word of multi-words
64
+ for line in tqdm(open(filepath)):
65
+ a = line.strip('\n')
66
+
67
+ b = a.split(',')
68
+ cnt = 0
69
+ vector = torch.zeros(300)
70
+ vec_front = torch.zeros(300)
71
+ vec_b_average = torch.zeros(300)
72
+ cnt_b = 0
73
+ for i in range(len(b)):
74
+ b[i] = b[i].lstrip()
75
+ c = b[i].split(' ')
76
+ if obtain_vector(c[-1], glove) is not None:
77
+ vec_b_average += obtain_vector(c[-1], glove)
78
+ cnt_b += 1
79
+ if cnt_b == 0:
80
+ print('index ', index,' generatint word_vector failure')
81
+ continue
82
+ vec_b_average = vec_b_average / cnt_b
83
+
84
+ for i in range(len(b)):
85
+ b[i] = b[i].lstrip()
86
+ c = b[i].split(' ')
87
+ cnt_f = 0
88
+ for j in range(len(c) - 1):
89
+ if obtain_vector(c[j], glove) is not None:
90
+ vec_front += obtain_vector(c[j], glove)
91
+ cnt_f += 1
92
+ if obtain_vector(c[-1], glove) is not None:
93
+ vec_back =obtain_vector(c[-1], glove)
94
+ else:
95
+ vec_back = vec_b_average
96
+ if cnt_f == 0:
97
+ vector += vec_back
98
+ else:
99
+ vector += (vec_front / cnt_f )* 0.1 + vec_back * 0.9
100
+ cnt += 1
101
+
102
+ vector = torch.div(vector,cnt)
103
+
104
+ vec_list_np.append(np.array(vector))
105
+ vec_list.append(vector)
106
+ index += 1
107
+
108
+
109
+ vec_list_np_stacked = np.stack(vec_list_np)
110
+ vec_list_torch = torch.from_numpy(vec_list_np_stacked)
111
+
112
+ cos_similarity = F.cosine_similarity(vec_list_torch[None, :], vec_list_torch[:, None], dim=-1)
113
+ cos_similarity = cos_similarity.numpy()
114
+
115
+ # np.save('./knowledge/golve_vec_list', np.array(vec_list_np))
116
+ # for i in range(len(vec_list)):
117
+ # for j in range(len(vec_list)):
118
+ # cos_similarity[i,j] = F.cosine_similarity(vec_list[i], vec_list[j],dim=0).type(torch.half)
119
+ # if i != j:
120
+ # cos_similarity[i,j] = cos_similarity[i,j]
121
+ # cos_similarity = np.array(cos_similarity)
122
+
123
+
124
+ np.save(os.path.join(AD_DIRECTORY, "imagenet_cos_similarity_glove"), cos_similarity)
125
+ print("Glove cos_similarity finished")
126
+ return cos_similarity
modelguidedattacks/losses/boilerplate.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ def generalized_mean(x, p, dim):
5
+ x_type = x.dtype
6
+ x = x.to(torch.double)
7
+ x = x**p
8
+ x = x.mean(dim=dim)
9
+ x = x**(1/p)
10
+ return x.to(x_type)
11
+
12
+ def surject_to_positive(x, c=5):
13
+ assert x.min() >= -1
14
+ assert x.max() <= 1
15
+
16
+ return c + c * x
17
+
18
+ def surject_from_positive(x, c=5):
19
+ return (x - c) / c
20
+
21
+ class BoilerplateLoss(nn.Module):
22
+ def __init__(self) -> None:
23
+ super().__init__()
24
+ self.p = 9
25
+
26
+ def forward(self, y_pred, y_attack, **kwargs):
27
+ y_pred = y_pred.softmax(dim=-1)
28
+
29
+ C = y_pred.shape[1]
30
+ K = y_attack.shape[1]
31
+ desired_mask = torch.zeros_like(y_pred, dtype=torch.bool)
32
+ desired_mask.scatter_(dim=1, index=y_attack,
33
+ src=torch.ones_like(y_attack, dtype=torch.bool))
34
+
35
+ y_not_in_attack = (~desired_mask).nonzero()[:, 1].view(-1, C - K)
36
+
37
+ y_pred_in_attack = torch.gather(y_pred, dim=1, index=y_attack)
38
+ y_pred_not_in_attack = torch.gather(y_pred, dim=1, index=y_not_in_attack)
39
+
40
+ y_pred_in_attack_min = y_pred_in_attack.min(dim=-1).values #generalized_mean(y_pred_in_attack, -self.p, dim=1)
41
+ y_pred_not_in_attack_max = y_pred_not_in_attack.max(dim=-1).values #generalized_mean(y_pred_not_in_attack, self.p, dim=1)
42
+
43
+ macro_loss = (y_pred_not_in_attack_max - y_pred_in_attack_min)
44
+ sorting_loss = y_pred_in_attack.diff(dim=-1)
45
+
46
+ # Surject sorting_loss to positive domain, since it goes [-1,1] we can just shift by 1
47
+ sorting_loss = surject_to_positive(sorting_loss)
48
+ sorting_loss = generalized_mean(sorting_loss, p=9, dim=-1)
49
+
50
+ # Surject back
51
+ sorting_loss = surject_from_positive(sorting_loss)
52
+
53
+ catted_loss = torch.stack([macro_loss, sorting_loss], dim=-1)
54
+ catted_loss_pos = surject_to_positive(catted_loss)
55
+
56
+ final_loss_pos = generalized_mean(catted_loss_pos, p=10, dim=-1)
57
+ final_loss = surject_from_positive(final_loss_pos)
58
+
59
+ return final_loss
modelguidedattacks/losses/cvx_proj.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ import cvxpy as cp
5
+ import qpth
6
+ from . import _qp_solver_patch
7
+
8
+ def solve_qp(Q, P, G, H):
9
+ B = Q.shape[0]
10
+
11
+ if B == 1:
12
+ # Batch size of 1 has weird instabilities
13
+ # I imagine there is a .squeeze() or something inside the QP solver's code
14
+ # that messes up broadcasting dimensions when batch dimension is 1 so let's
15
+ # artificially make 2 solutions when we need 1
16
+
17
+ Q = Q.expand(2, -1, -1)
18
+ P = P.expand(2, -1)
19
+ G = G.expand(2, -1, -1)
20
+ H = H.expand(2, -1)
21
+
22
+ e = torch.empty(0, device=Q.device)
23
+ z_sol = qpth.qp.QPFunction(verbose=-1, eps=1e-2, check_Q_spd=False)(Q, P, G, H, e, e)
24
+
25
+ if B == 1:
26
+ z_sol = z_sol[:1]
27
+
28
+ return z_sol
29
+
30
+ class CVXProjLoss(nn.Module):
31
+ def __init__(self, confidence=0):
32
+ super().__init__()
33
+ self.confidence = confidence
34
+
35
+ def precompute(self, attack_targets, gt_labels, config):
36
+ return {
37
+ "margin": config.cvx_proj_margin
38
+ }
39
+
40
+ def forward(self, logits_pred, feats_pred, feats_pred_0, attack_targets, model, margin, **kwargs):
41
+ device = logits_pred.device
42
+ head_W, head_bias = model.head_matrices()
43
+
44
+ num_feats = head_W.shape[1]
45
+ num_classes = head_W.shape[0]
46
+
47
+ K = attack_targets.shape[-1]
48
+ B = logits_pred.shape[0]
49
+
50
+ # Start with all classes should be less than smallest attack target
51
+ D = -torch.eye(num_classes, device=device)[None].repeat(B, 1, 1) # [B, C, C]
52
+ attack_targets_write = attack_targets[:, -1][:, None, None].expand(-1, D.shape[1], -1)
53
+ D.scatter_(dim=2, index=attack_targets_write, src=torch.ones(attack_targets_write.shape, device=device))
54
+
55
+ # Clear out the constraint row for each item in the attack targets
56
+ attack_targets_clear = attack_targets[:, :, None].expand(-1, -1, D.shape[-1])
57
+ D.scatter_(dim=1, index=attack_targets_clear, src=torch.zeros(attack_targets_clear.shape, device=device))
58
+
59
+ batch_inds = torch.arange(B, device=device)[:, None].expand(-1, K - 1)
60
+ attack_targets_pos = attack_targets[:, :-1] # [B, K-1]
61
+ attack_targets_neg = attack_targets[:, 1:] # [B, K-1]
62
+
63
+ attack_targets_neg_inds = torch.stack((
64
+ batch_inds,
65
+ attack_targets_neg,
66
+ attack_targets_neg
67
+ ), dim=0) # [3, B, K - 1]
68
+ attack_targets_neg_inds = attack_targets_neg_inds.view(3, -1)
69
+
70
+ D[attack_targets_neg_inds[0], attack_targets_neg_inds[1], attack_targets_neg_inds[2]] = -1
71
+
72
+ attack_targets_pos_inds = torch.stack((
73
+ batch_inds,
74
+ attack_targets_neg,
75
+ attack_targets_pos
76
+ ), dim=0) # [3, B, K - 1]
77
+
78
+ D[attack_targets_pos_inds[0], attack_targets_pos_inds[1], attack_targets_pos_inds[2]] = 1
79
+
80
+ A = head_W
81
+ b = head_bias
82
+
83
+ Q = 2*torch.eye(feats_pred.shape[1], device=device)[None].expand(B, -1, -1)
84
+
85
+ # We want the solution features to be as close as possible
86
+ # to the current features but also head on the direction of
87
+ # the smallest possible perturbation from the initial predicted
88
+ # features
89
+ anchor_feats = feats_pred
90
+
91
+ P = -2*anchor_feats.expand(B, -1)
92
+
93
+ G = -D@A
94
+ H = -(margin - D @ b)
95
+
96
+ # Constraints are indexed by smaller logit
97
+ # First attack target isn't smaller than any logit, so its
98
+ # constraint index is redundant, but we keep it for easier parallelization
99
+ # Make this constraint all 0s
100
+ zero_inds = attack_targets[:, 0:1] # [B, 1]
101
+ H.scatter_(dim=1, index=zero_inds, src=torch.zeros(zero_inds.shape, device=device))
102
+
103
+ z_sol = solve_qp(Q, P, G, H)
104
+
105
+ loss = (feats_pred - z_sol).square().sum(dim=-1)
106
+
107
+ # loss_check = self.forward_check(logits_pred, feats_pred, attack_targets, model, **kwargs)
108
+ return loss
modelguidedattacks/losses/cw_extension.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ CARLINI_COEFF_UPPER = 1e10
5
+
6
+ class CWExtensionLoss(nn.Module):
7
+ def __init__(self, confidence=0):
8
+ super().__init__()
9
+ self.confidence = confidence
10
+
11
+ def precompute(self, *args, **kwargs):
12
+ return {}
13
+
14
+ def forward(self, logits_pred, attack_targets, **kwargs):
15
+ #orign cw attack loss
16
+ if attack_targets.dim() == 1:
17
+ mask_logits = F.one_hot(attack_targets, logits_pred.shape[1]).float()
18
+
19
+ real = (mask_logits * logits_pred).sum(dim=1)
20
+ other = ((1.0 - mask_logits) * logits_pred - (mask_logits * 10000.0)
21
+ ).max(1)[0]
22
+ loss_cw = torch.clamp(other - real + self.confidence, min=0.)
23
+ return loss_cw
24
+
25
+ #extended cw loss for topk attack tasks
26
+ else:
27
+ mask_logits = torch.zeros([logits_pred.shape[0], logits_pred.shape[1]], device=logits_pred.device)
28
+ min_values = torch.ones(attack_targets.shape[0], dtype=torch.float, device=logits_pred.device) * 1e10
29
+ loss_cw_topk = 0
30
+
31
+ for i in range(attack_targets.shape[1]):
32
+ other = ((1.0 - mask_logits) * logits_pred - (mask_logits * 10000.0)
33
+ ).max(1)[0]
34
+
35
+
36
+ loss_cw_topk += torch.clamp(other - min_values + self.confidence, min=0.)
37
+ mask_logits[torch.arange(len(attack_targets)), attack_targets[:,i]] = 1
38
+ min_values = torch.min(logits_pred[torch.arange(len(attack_targets)), attack_targets[:,i]], min_values)
39
+
40
+ real = min_values
41
+ other = ((1.0 - mask_logits) * logits_pred - (mask_logits * 10000.0)
42
+ ).max(1)[0]
43
+ loss_cw_topk += torch.clamp(other - real + self.confidence, min=0.)
44
+ constant = attack_targets.shape[1]
45
+
46
+ return (loss_cw_topk / constant)
modelguidedattacks/losses/energy.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from ignite.metrics import Loss
4
+ from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce
5
+ from typing import Callable, cast, Dict, Sequence, Tuple, Union
6
+
7
+ def get_correct_mask(y_pred, y_attack):
8
+ k = y_attack.shape[-1]
9
+
10
+ y_pred_indices = y_pred.argsort(dim=-1, descending=True) # [N, C]
11
+
12
+ correct = (y_pred_indices[:, :k] == y_attack).all(dim=-1)
13
+ return correct
14
+
15
+ class EnergyLoss(Loss):
16
+ def __init__(self, loss_fn, reduction="mean", device = ...):
17
+ super().__init__(loss_fn, device=device)
18
+ self.reduction = reduction
19
+
20
+ @reinit__is_reduced
21
+ def reset(self) -> None:
22
+ self._sum = torch.tensor(0.0, device=self._device)
23
+ self._min = torch.tensor(torch.inf, device=self._device)
24
+ self._max = torch.tensor(0.0, device=self._device)
25
+ self._num_examples = 0
26
+
27
+ @reinit__is_reduced
28
+ def update(self, output: Sequence[Union[torch.Tensor, Dict]]) -> None:
29
+ if len(output) == 2:
30
+ y_pred, y = cast(Tuple[torch.Tensor, torch.Tensor], output)
31
+ kwargs: Dict = {}
32
+ else:
33
+ y_pred, y, kwargs = cast(Tuple[torch.Tensor, torch.Tensor, Dict], output)
34
+
35
+ sample_energies = self._loss_fn(y_pred, y, **kwargs).detach()
36
+
37
+ n = len(sample_energies)
38
+
39
+ if n > 0:
40
+ self._sum += sample_energies.sum()
41
+ self._min = torch.minimum(self._min, sample_energies.min())
42
+ self._max = torch.maximum(self._max, sample_energies.max())
43
+ self._num_examples += n
44
+
45
+ @sync_all_reduce("_sum", "_num_examples", "_min:MIN", "_max:MAX")
46
+ def compute(self) -> float:
47
+
48
+ if self.reduction == "mean":
49
+ if self._num_examples == 0:
50
+ return torch.inf
51
+
52
+ return self._sum.item() / self._num_examples
53
+ elif self.reduction == "max":
54
+ if self._num_examples == 0:
55
+ return torch.nan
56
+
57
+ return self._max.item()
58
+ elif self.reduction == "min":
59
+ if self._num_examples == 0:
60
+ return torch.inf
61
+
62
+ return self._min.item()
63
+ else:
64
+ assert False
65
+
66
+ class Energy(nn.Module):
67
+ def __init__(self, p="2") -> None:
68
+ super().__init__()
69
+ self.p = p
70
+
71
+ def forward(self, y_pred, y_attack, perturbations, **kwargs):
72
+ correct = get_correct_mask(y_pred, y_attack)
73
+
74
+ # Don't want to take into account perturbations of
75
+ # unsuccessful attacks
76
+
77
+ perturbations = perturbations[correct]
78
+ perturbations = perturbations.flatten(1)
79
+
80
+ return torch.linalg.vector_norm(perturbations, dim=-1, ord=self.p)
modelguidedattacks/metrics/topk_accuracy.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ignite.metrics import Accuracy, Loss
3
+ from typing import Sequence
4
+
5
+ class TopKAccuracy(Accuracy):
6
+ def update(self, output: Sequence[torch.Tensor], **kwargs) -> None:
7
+ y_pred, y_attack = output[0].detach(), output[1].detach()
8
+ k = y_attack.shape[-1]
9
+
10
+ y_pred_indices = y_pred.argsort(dim=-1, descending=True) # [N, C]
11
+
12
+ correct = (y_pred_indices[:, :k] == y_attack).all(dim=-1)
13
+
14
+ self._num_correct += torch.sum(correct).to(self._device)
15
+ self._num_examples += correct.shape[0]
modelguidedattacks/models.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import models
2
+
3
+ from modelguidedattacks.guides.instance_guide import InstanceGuide
4
+ from modelguidedattacks.guides.unguided import Unguided
5
+ from modelguidedattacks import losses
6
+
7
+ from .cls_models.registry import get_model
8
+
9
+ guide_model_registry = {
10
+ "instance_guided": InstanceGuide,
11
+ "unguided": Unguided
12
+ }
13
+
14
+ loss_registry = {
15
+ "cvxproj": losses.CVXProjLoss,
16
+ "cwk": losses.CWExtensionLoss,
17
+ "ad": losses.AdversarialDistillationLoss
18
+ }
19
+
20
+ def setup_model(config, device):
21
+ model = get_model(config.dataset, config.model, device)
22
+
23
+ kwargs = {}
24
+
25
+ if config.guide_model == "unguided":
26
+ kwargs["iterations"] = config.unguided_iterations
27
+ kwargs["lr"] = config.unguided_lr
28
+ kwargs["loss_fn"] = loss_registry[config.loss]
29
+ kwargs["binary_search_steps"] = config.binary_search_steps
30
+ kwargs["topk_loss_coef_upper"] = config.topk_loss_coef_upper
31
+
32
+ return guide_model_registry[config.guide_model](model, config, **kwargs)
modelguidedattacks/results.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from pathlib import Path
4
+ import numpy as np
5
+ import copy
6
+ from collections import OrderedDict
7
+
8
+ config_parameter_keys = ["loss", "unguided_lr", "model", "k", "binary_search_steps",
9
+ "unguided_iterations", "topk_loss_coef_upper", "seed",
10
+ "opt_warmup_its", "cvx_proj_margin",
11
+ "topk_loss_coef_upper", "binary_search_steps"]
12
+
13
+ def config_to_dict(config):
14
+ result_keys = config_parameter_keys
15
+
16
+ result_dict = {
17
+ key: getattr(config, key) for key in result_keys
18
+ }
19
+
20
+ if hasattr(config, "cvx_proj_margin"):
21
+ result_dict["cvx_proj_margin"] = config.cvx_proj_margin
22
+ else:
23
+ result_dict["cvx_proj_margin"] = 0.2
24
+
25
+ return result_dict
26
+
27
+ def load_all_results(load_min_max=False):
28
+ if not os.path.isdir("results_rebuttal"):
29
+ return []
30
+
31
+ all_result_files = Path('results_rebuttal').rglob('*.save')
32
+
33
+ results_list = []
34
+ for result_file in all_result_files:
35
+ result = torch.load(result_file)
36
+ config = result["config"]
37
+
38
+ result_dict = config_to_dict(config)
39
+
40
+ result_dict["ASR"] = result["ASR"]
41
+ result_dict["L1"] = result["L1 Energy"]
42
+ result_dict["L2"] = result["L2 Energy"]
43
+ result_dict["L_inf"] = result["L_inf Energy"]
44
+
45
+ if "L2 Energy Max" in result and load_min_max:
46
+ result_dict["L1 Max"] = result["L1 Energy Max"]
47
+ result_dict["L2 Max"] = result["L2 Energy Max"]
48
+ result_dict["L_inf Max"] = result["L_inf Energy Max"]
49
+
50
+ result_dict["L1 Min"] = result["L1 Energy Min"]
51
+ result_dict["L2 Min"] = result["L2 Energy Min"]
52
+ result_dict["L_inf Min"] = result["L_inf Energy Min"]
53
+
54
+ results_list.append(result_dict)
55
+
56
+ return results_list
57
+
58
+ def close(target, eps=1e-5):
59
+ return lambda x: np.allclose(x, target, atol=eps)
60
+
61
+ def eq(target):
62
+ if isinstance(target, float):
63
+ return close(target)
64
+ else:
65
+ return lambda x: x == target
66
+
67
+ def gte(target):
68
+ return lambda x: float(x) >= target
69
+
70
+ def lte(target):
71
+ return lambda x: float(x) <= target
72
+
73
+ def in_set(target):
74
+ return lambda x: x in target
75
+
76
+ def filter_from_config(config):
77
+ config_dict = config_to_dict(config)
78
+
79
+ filter = {
80
+ key: eq(val) for (key, val) in config_dict.items()
81
+ }
82
+
83
+ return filter
84
+
85
+ def filter_results(filter, results_list, only_with_minmax=False):
86
+ filtered_results = []
87
+ for result in results_list:
88
+ pass_filter = True
89
+ for key, val in result.items():
90
+ if key not in filter:
91
+ continue
92
+
93
+ if not filter[key](val):
94
+ pass_filter = False
95
+ break
96
+
97
+ if only_with_minmax and "L2 Max" not in result:
98
+ continue
99
+
100
+ if pass_filter:
101
+ filtered_results.append(result)
102
+
103
+ return filtered_results
104
+
105
+ def resolve_nonunique_filter(filter, results_list, include_failed=False):
106
+ filtered_results = filter_results(filter, results_list)
107
+
108
+ unique_parameters = []
109
+ # Find unique parameter sets for results
110
+ for result in filtered_results:
111
+ result_parameters = {param_key:result[param_key] for param_key in config_parameter_keys}
112
+
113
+ # Round to avoid floating pt imprecision from messing with set uniqueness checks
114
+ for key in result_parameters.keys():
115
+ if isinstance(result_parameters[key], float):
116
+ result_parameters[key] = round(result_parameters[key], 5)
117
+
118
+ del result_parameters["seed"]
119
+ unique_parameters.append(result_parameters)
120
+
121
+ # Only keep unique dicts
122
+ unique_parameters = [dict(y) for y in set(tuple(x.items()) for x in unique_parameters)]
123
+
124
+ best_metric = -np.Infinity
125
+ best_param_set = None
126
+ best_result_list = None
127
+ for param_set in unique_parameters:
128
+ # Perform another search
129
+ unique_filter = {
130
+ param_name: eq(param_value) for param_name, param_value in param_set.items()
131
+ }
132
+
133
+ filtered_results = filter_results(unique_filter, results_list)
134
+
135
+ assert len(filtered_results) == 5
136
+
137
+ asrs = [result["ASR"] for result in filtered_results]
138
+ l2_energies = [result["L2"] for result in filtered_results]
139
+
140
+ mean_asr = np.mean(np.array(asrs)[np.isfinite(asrs)])
141
+ mean_l2 = np.mean(np.array(l2_energies)[np.isfinite(l2_energies)])
142
+
143
+ # Arbitrary point in tradeoff curve
144
+ result_goodness = -mean_l2 + mean_asr * 100
145
+
146
+ if (mean_asr > 0 and mean_asr < 0.025) and not include_failed:
147
+ # Irrelevant result and associated energies
148
+ continue
149
+
150
+ if result_goodness > best_metric or (include_failed and best_param_set is None):
151
+ best_param_set = param_set
152
+ best_result_list = filtered_results
153
+ best_metric = result_goodness
154
+
155
+ return best_param_set, best_result_list
156
+
157
+ def get_combined_results(filtered_results):
158
+ combined_results = {}
159
+
160
+ for result in filtered_results:
161
+ for key in result:
162
+ if key not in combined_results:
163
+ combined_results[key] = []
164
+
165
+ combined_results[key].append(result[key])
166
+
167
+ unique_runs = len(np.unique(combined_results["seed"]))
168
+ # assert len(combined_results["seed"]) == unique_runs
169
+
170
+ for key, val in list(combined_results.items()):
171
+ if key in ["ASR", "L1", "L2", "L_inf"]:
172
+ val = np.array(val)
173
+ combined_results[f"{key}_mean"] = np.mean(val[np.isfinite(val)])
174
+ combined_results[f"{key}_median"] = np.median(val[np.isfinite(val)])
175
+
176
+ # Coupled results
177
+ best_asr_idx = np.argmax(combined_results["ASR"])
178
+ best_asr = combined_results["ASR"][best_asr_idx]
179
+ best_l1 = combined_results["L1"][best_asr_idx]
180
+ best_l2 = combined_results["L2"][best_asr_idx]
181
+ best_linf = combined_results["L_inf"][best_asr_idx]
182
+
183
+ combined_results["ASR_best"] = best_asr
184
+ combined_results["L1_best"] = best_l1
185
+ combined_results["L2_best"] = best_l2
186
+ combined_results["L_inf_best"] = best_linf
187
+
188
+ worst_asr_idx = np.argmin(combined_results["ASR"])
189
+ worst_asr = combined_results["ASR"][worst_asr_idx]
190
+ worst_l1 = combined_results["L1"][worst_asr_idx]
191
+ worst_l2 = combined_results["L2"][worst_asr_idx]
192
+ worst_linf = combined_results["L_inf"][worst_asr_idx]
193
+
194
+ combined_results["ASR_worst"] = worst_asr
195
+ combined_results["L1_worst"] = worst_l1
196
+ combined_results["L2_worst"] = worst_l2
197
+ combined_results["L_inf_worst"] = worst_linf
198
+
199
+ return combined_results
200
+
201
+ def build_full_results_dict(model_name="resnet50", verbose=False,
202
+ all_k=[20, 15, 10, 5, 1],
203
+ all_num_iter=[60, 30],
204
+ all_search_steps=[1, 9],
205
+ all_methods=["cwk", "ad", "cvxproj"]):
206
+
207
+ if verbose:
208
+ print ("-" * 100)
209
+ print ("Results for", model_name)
210
+
211
+ results_list = load_all_results()
212
+ results = OrderedDict()
213
+
214
+ for k in all_k:
215
+ results[k] = OrderedDict()
216
+
217
+ for num_binary_search_steps in all_search_steps:
218
+ results[k][num_binary_search_steps] = OrderedDict()
219
+
220
+ for num_iter in all_num_iter:
221
+ results[k][num_binary_search_steps][num_iter] = OrderedDict()
222
+
223
+ for method_name in all_methods:
224
+ filter = {
225
+ "loss": eq(method_name),
226
+ "model": eq(model_name),
227
+ "k": eq(k),
228
+ "unguided_iterations": eq(num_iter),
229
+ "binary_search_steps": eq(num_binary_search_steps)
230
+ }
231
+
232
+ best_param_set, filtered_results = resolve_nonunique_filter(filter, results_list)
233
+
234
+ if verbose and best_param_set is not None:
235
+ print (f"K={k} Lr={best_param_set['unguided_lr']} and loss_coef={best_param_set['topk_loss_coef_upper']} ")
236
+
237
+ if best_param_set is None:
238
+ continue
239
+
240
+ assert len(filtered_results) == 5
241
+
242
+ combined_results = get_combined_results(filtered_results)
243
+
244
+ for key in list(combined_results):
245
+ if "L1" not in key and "L2" not in key and "L_inf" not in key and "ASR" not in key:
246
+ del combined_results[key]
247
+
248
+ for key in list(combined_results):
249
+ if "mean" not in key and "worst" not in key and "best" not in key:
250
+ del combined_results[key]
251
+
252
+ results[k][num_binary_search_steps][num_iter][method_name] = combined_results
253
+
254
+ return results
255
+
256
+ if __name__ == "__main__":
257
+ build_full_results_dict(model_name="resnet50", verbose=True, all_search_steps=[1], all_methods=["cvxproj"])
258
+ build_full_results_dict(model_name="densenet121", verbose=True, all_search_steps=[1], all_methods=["cvxproj"])
259
+ build_full_results_dict(model_name="deit_small", verbose=True, all_search_steps=[1], all_methods=["cvxproj"])
260
+ build_full_results_dict(model_name="vit_base", verbose=True, all_search_steps=[1], all_methods=["cvxproj"])
261
+ x = 5
modelguidedattacks/run.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pprint import pformat
3
+ from typing import Any
4
+ import os
5
+ import torch
6
+
7
+ import ignite.distributed as idist
8
+ import yaml
9
+ from ignite.engine import Events
10
+ from ignite.metrics import Accuracy, Loss
11
+ from ignite.utils import manual_seed
12
+ from torch import nn, optim
13
+
14
+ from modelguidedattacks.data.setup import setup_data
15
+ from modelguidedattacks.losses.boilerplate import BoilerplateLoss
16
+ from modelguidedattacks.losses.energy import Energy, EnergyLoss
17
+ from modelguidedattacks.metrics.topk_accuracy import TopKAccuracy
18
+ from modelguidedattacks.models import setup_model
19
+ from modelguidedattacks.trainers import setup_evaluator, setup_trainer
20
+ from modelguidedattacks.utils import setup_parser, setup_output_dir
21
+ from modelguidedattacks.utils import setup_logging, log_metrics, Engine
22
+
23
+ def run(local_rank: int, config: Any):
24
+
25
+ print ("Running ", local_rank)
26
+ # make a certain seed
27
+ rank = idist.get_rank()
28
+ manual_seed(config.seed + rank)
29
+
30
+ # create output folder
31
+ config.output_dir = setup_output_dir(config, rank)
32
+
33
+ # setup engines logger with python logging
34
+ # print training configurations
35
+ logger = setup_logging(config)
36
+ logger.info("Configuration: \n%s", pformat(vars(config)))
37
+ (config.output_dir / "config-lock.yaml").write_text(yaml.dump(config))
38
+
39
+ # donwload datasets and create dataloaders
40
+ dataloader_train, dataloader_eval = setup_data(config, rank)
41
+
42
+ # model, optimizer, loss function, device
43
+ device = idist.device()
44
+ model = idist.auto_model(setup_model(config, idist.device()))
45
+ loss_fn = BoilerplateLoss().to(device=device)
46
+ l2_energy_loss = Energy(p=2).to(device)
47
+ l1_energy_loss = Energy(p=1).to(device)
48
+ l_inf_energy_loss = Energy(p=torch.inf).to(device)
49
+
50
+ evaluator = setup_evaluator(config, model, device)
51
+ evaluator.logger = logger
52
+
53
+ # attach metrics to evaluator
54
+ accuracy = TopKAccuracy(device=device)
55
+ metrics = {
56
+ "ASR": accuracy,
57
+ "L2 Energy": EnergyLoss(l2_energy_loss, device=device),
58
+ "L1 Energy": EnergyLoss(l1_energy_loss, device=device),
59
+ "L_inf Energy": EnergyLoss(l_inf_energy_loss, device=device),
60
+
61
+ "L2 Energy Min": EnergyLoss(l2_energy_loss, reduction="min", device=device),
62
+ "L1 Energy Min": EnergyLoss(l1_energy_loss, reduction="min", device=device),
63
+ "L_inf Energy Min": EnergyLoss(l_inf_energy_loss, reduction="min", device=device),
64
+
65
+ "L2 Energy Max": EnergyLoss(l2_energy_loss, reduction="max", device=device),
66
+ "L1 Energy Max": EnergyLoss(l1_energy_loss, reduction="max", device=device),
67
+ "L_inf Energy Max": EnergyLoss(l_inf_energy_loss, reduction="max", device=device)
68
+ }
69
+ for name, metric in metrics.items():
70
+ metric.attach(evaluator, name)
71
+
72
+ if config.guide_model in ["unguided", "instance_guided"]:
73
+
74
+ first_batch_passed = False
75
+ early_stopped = False
76
+
77
+ def compute_metrics(engine: Engine, tag: str):
78
+ nonlocal first_batch_passed
79
+ nonlocal early_stopped
80
+
81
+ for name, metric in metrics.items():
82
+ metric.completed(engine, name)
83
+
84
+ if not first_batch_passed:
85
+ if engine.state.metrics["ASR"] < 1e-3:
86
+ print ("Early stop, assuming no success throughout")
87
+ early_stopped = True
88
+ engine.terminate()
89
+ else:
90
+ first_batch_passed = True
91
+
92
+ evaluator.add_event_handler(
93
+ Events.ITERATION_COMPLETED(every=config.log_every_iters),
94
+ compute_metrics,
95
+ tag="eval",
96
+ )
97
+
98
+ evaluator.add_event_handler(
99
+ Events.ITERATION_COMPLETED(every=config.log_every_iters),
100
+ log_metrics,
101
+ tag="eval",
102
+ )
103
+
104
+ evaluator.run(dataloader_eval, epoch_length=config.eval_epoch_length)
105
+ log_metrics(evaluator, "eval")
106
+
107
+ if len(config.out_dir) > 0:
108
+ # Store results in out_dir
109
+ os.makedirs(config.out_dir, exist_ok=True)
110
+ metrics_dict = evaluator.state.metrics
111
+ metrics_dict["config"] = config
112
+ metrics_dict["early_stopped"] = early_stopped
113
+
114
+ metrics_file_path = os.path.join(config.out_dir, "results.save")
115
+ torch.save(metrics_dict, metrics_file_path)
116
+
117
+ # No need to train with an unguided model
118
+ return
119
+
120
+ assert False, "This code path is for the future"
121
+
122
+ # main entrypoint
123
+ def launch(config=None):
124
+ if config is None:
125
+ config_path = sys.argv[1]
126
+ config = setup_parser(config_path).parse_args(sys.argv[2:])
127
+
128
+ backend = config.backend
129
+ nproc_per_node = config.nproc_per_node
130
+
131
+ if nproc_per_node == 0 or backend is None:
132
+ backend = None
133
+ nproc_per_node = None
134
+
135
+ with idist.Parallel(backend, nproc_per_node) as p:
136
+ p.run(run, config=config)
137
+
138
+
139
+ if __name__ == "__main__":
140
+ launch()
modelguidedattacks/trainers.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Union
2
+
3
+ import ignite.distributed as idist
4
+ import torch
5
+ from ignite.engine import DeterministicEngine, Engine, Events
6
+ from torch.cuda.amp import autocast
7
+ from torch.nn import Module
8
+ from torch.optim import Optimizer
9
+ from torch.utils.data import DistributedSampler, Sampler
10
+
11
+
12
+ def setup_trainer(
13
+ config: Any,
14
+ model: Module,
15
+ optimizer: Optimizer,
16
+ loss_fn: Module,
17
+ device: Union[str, torch.device],
18
+ train_sampler: Sampler,
19
+ ) -> Union[Engine, DeterministicEngine]:
20
+ def train_function(engine: Union[Engine, DeterministicEngine], batch: Any):
21
+ if config.overfit:
22
+ # No batch norm
23
+ model.eval()
24
+ else:
25
+ model.train()
26
+
27
+ samples = batch[0].to(device, non_blocking=True)
28
+ targets = batch[1].to(device, non_blocking=True)
29
+ attack_targets = batch[2].to(device, non_blocking=True)
30
+ sample_ids = batch[3].to(device, non_blocking=True)
31
+
32
+ with autocast(config.use_amp):
33
+ outputs = model(samples, attack_targets)
34
+ loss = loss_fn(outputs, attack_targets, targets)
35
+
36
+ loss.backward()
37
+ optimizer.step()
38
+ optimizer.zero_grad()
39
+
40
+ train_loss = loss.item()
41
+ engine.state.metrics = {
42
+ "epoch": engine.state.epoch,
43
+ "train_loss": train_loss,
44
+ }
45
+ return {"train_loss": train_loss}
46
+
47
+
48
+ trainer = Engine(train_function)
49
+
50
+ # set epoch for distributed sa5mpler
51
+ @trainer.on(Events.EPOCH_STARTED)
52
+ def set_epoch():
53
+ if idist.get_world_size() > 1 and isinstance(train_sampler, DistributedSampler):
54
+ train_sampler.set_epoch(trainer.state.epoch - 1)
55
+
56
+ return trainer
57
+
58
+
59
+ def setup_evaluator(
60
+ config: Any,
61
+ model: Module,
62
+ device: Union[str, torch.device],
63
+ ) -> Engine:
64
+ @torch.no_grad()
65
+ def eval_function(engine: Engine, batch: Any):
66
+ model.eval()
67
+
68
+ samples, gt_labels, attack_targets, sample_ids = batch
69
+
70
+ samples = samples.to(device, non_blocking=True)
71
+ gt_labels = gt_labels.to(device, non_blocking=True)
72
+ attack_targets = attack_targets.to(device, non_blocking=True)
73
+ sample_ids = sample_ids.to(device, non_blocking=True)
74
+
75
+ with autocast(config.use_amp):
76
+ outputs, perturbations = model(samples, attack_targets, gt_labels)
77
+
78
+ return outputs, attack_targets, {
79
+ "gt_targets": gt_labels,
80
+ "perturbations": perturbations
81
+ }
82
+
83
+ return Engine(eval_function)
modelguidedattacks/utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from argparse import ArgumentParser
3
+ from datetime import datetime
4
+ from logging import Logger
5
+ from pathlib import Path
6
+ from typing import Any, Mapping, Optional, Union
7
+
8
+ import ignite.distributed as idist
9
+ import torch
10
+ import yaml
11
+ from ignite.contrib.engines import common
12
+ from ignite.engine import Engine
13
+ from ignite.engine.events import Events
14
+ from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine
15
+ from ignite.handlers.early_stopping import EarlyStopping
16
+ from ignite.handlers.terminate_on_nan import TerminateOnNan
17
+ from ignite.handlers.time_limit import TimeLimit
18
+ from ignite.utils import setup_logger
19
+
20
+
21
+ def setup_parser(config_path="base_config.yaml"):
22
+ with open(config_path, "r") as f:
23
+ config = yaml.safe_load(f.read())
24
+
25
+ parser = ArgumentParser()
26
+ parser.add_argument("--config", default=None, type=str)
27
+ parser.add_argument("--backend", default=None, type=str)
28
+ for k, v in config.items():
29
+ if isinstance(v, bool):
30
+ parser.add_argument(f"--{k}", action="store_true")
31
+ else:
32
+ parser.add_argument(f"--{k}", default=v, type=type(v))
33
+
34
+ return parser
35
+
36
+
37
+ def log_metrics(engine: Engine, tag: str) -> None:
38
+ """Log `engine.state.metrics` with given `engine` and `tag`.
39
+
40
+ Parameters
41
+ ----------
42
+ engine
43
+ instance of `Engine` which metrics to log.
44
+ tag
45
+ a string to add at the start of output.
46
+ """
47
+ metrics_format = "{0} [{1}/{2}]: {3}".format(
48
+ tag, engine.state.epoch, engine.state.iteration, engine.state.metrics
49
+ )
50
+
51
+ epoch_size = engine.state.epoch_length
52
+ local_iteration = engine.state.iteration - epoch_size * (engine.state.epoch - 1)
53
+ metrics_format = f"{tag} Epoch {engine.state.epoch} - [{local_iteration} / {epoch_size}] : {engine.state.metrics}"
54
+
55
+ engine.logger.info(metrics_format)
56
+
57
+
58
+ def resume_from(
59
+ to_load: Mapping,
60
+ checkpoint_fp: Union[str, Path],
61
+ logger: Logger,
62
+ strict: bool = True,
63
+ model_dir: Optional[str] = None,
64
+ ) -> None:
65
+ """Loads state dict from a checkpoint file to resume the training.
66
+
67
+ Parameters
68
+ ----------
69
+ to_load
70
+ a dictionary with objects, e.g. {“model”: model, “optimizer”: optimizer, ...}
71
+ checkpoint_fp
72
+ path to the checkpoint file
73
+ logger
74
+ to log info about resuming from a checkpoint
75
+ strict
76
+ whether to strictly enforce that the keys in `state_dict` match the keys
77
+ returned by this module’s `state_dict()` function. Default: True
78
+ model_dir
79
+ directory in which to save the object
80
+ """
81
+ if isinstance(checkpoint_fp, str) and checkpoint_fp.startswith("https://"):
82
+ checkpoint = torch.hub.load_state_dict_from_url(
83
+ checkpoint_fp,
84
+ model_dir=model_dir,
85
+ map_location="cpu",
86
+ check_hash=True,
87
+ )
88
+ else:
89
+ if isinstance(checkpoint_fp, str):
90
+ checkpoint_fp = Path(checkpoint_fp)
91
+
92
+ if not checkpoint_fp.exists():
93
+ raise FileNotFoundError(f"Given {str(checkpoint_fp)} does not exist.")
94
+ checkpoint = torch.load(checkpoint_fp, map_location="cpu")
95
+
96
+ Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint, strict=strict)
97
+ logger.info("Successfully resumed from a checkpoint: %s", checkpoint_fp)
98
+
99
+
100
+ def setup_output_dir(config: Any, rank: int) -> Path:
101
+ """Create output folder."""
102
+ if rank == 0:
103
+ now = datetime.now().strftime("%Y%m%d-%H%M%S")
104
+ name = f"{now}-backend-{config.backend}-lr-{config.lr}"
105
+ path = Path(config.output_dir, name)
106
+ path.mkdir(parents=True, exist_ok=True)
107
+ config.output_dir = path.as_posix()
108
+
109
+ return Path(idist.broadcast(config.output_dir, src=0))
110
+
111
+
112
+ def setup_logging(config: Any) -> Logger:
113
+ """Setup logger with `ignite.utils.setup_logger()`.
114
+
115
+ Parameters
116
+ ----------
117
+ config
118
+ config object. config has to contain `verbose` and `output_dir` attribute.
119
+
120
+ Returns
121
+ -------
122
+ logger
123
+ an instance of `Logger`
124
+ """
125
+ green = "\033[32m"
126
+ reset = "\033[0m"
127
+ logger = setup_logger(
128
+ name=f"{green}[ignite]{reset}",
129
+ level=logging.DEBUG if config.debug else logging.INFO,
130
+ format="%(name)s: %(message)s",
131
+ filepath=config.output_dir / "training-info.log",
132
+ )
133
+ return logger
print_results.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modelguidedattacks import results
2
+
3
+ results_list = results.load_all_results()
4
+
5
+ filter = {
6
+ "loss": results.in_set(["cwk"]),
7
+ "model": results.eq("vit_base"),
8
+ "k": results.eq(5),
9
+ "binary_search_steps": results.eq(1),
10
+ "unguided_iterations": results.eq(30),
11
+ # "topk_loss_coef_upper": results.eq(20),
12
+ # "unguided_lr": results.eq(0.002),
13
+ "cvx_proj_margin": results.eq(0.2),
14
+ # "seed": results.eq(10),
15
+ }
16
+
17
+ filtered_results = results.filter_results(filter, results_list)
18
+
19
+ print ("Found", len(filtered_results))
20
+
21
+ for result in filtered_results:
22
+ print ("-" * 30)
23
+ for key, val in result.items():
24
+ print (key, "=", val)
25
+ print ("-" * 30)
print_table.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pylatex import Document, Section, Subsection, Tabular, MultiColumn,\
2
+ MultiRow, NoEscape
3
+ from pylatex.math import Math
4
+ from collections import OrderedDict
5
+ import copy
6
+ import numpy as np
7
+
8
+ from modelguidedattacks.results import build_full_results_dict
9
+
10
+ def result_to_str(result, long=False):
11
+ if result is None or np.isinf(result) or np.isnan(result):
12
+ return "-"
13
+ elif long:
14
+ return f"{result:.4f}"
15
+ else:
16
+ return f"{result:.2f}"
17
+
18
+ """
19
+ results top level will be keyed by K
20
+ next level will be binary search steps
21
+ next level will be keyed by iterations
22
+ next level will be keyed by method
23
+ """
24
+
25
+ only_mean = False
26
+ model_name = "resnet50"
27
+ results = build_full_results_dict(model_name)
28
+
29
+ model_to_tex = {
30
+ "resnet50": "Resnet-50",
31
+ "densenet121": "Densenet121",
32
+ "deit_small": "DeiT-S",
33
+ "vit_base": "ViT$_{B}$"
34
+ }
35
+
36
+ # Preprocess all results and select bests
37
+ for top_k, bs_dict in results.items():
38
+ for num_bs, iter_dict in bs_dict.items():
39
+ for num_iter, method_dict in iter_dict.items():
40
+ metric_bests = {}
41
+ metrics_compared = {}
42
+
43
+ for method_name, method_results in method_dict.items():
44
+ for metric_name, metric_value in method_results.items():
45
+ reduction_func = max if "ASR" in metric_name else min
46
+
47
+ if metric_name not in metric_bests:
48
+ metric_bests[metric_name] = 0. if reduction_func is max else np.Infinity
49
+ metrics_compared[metric_name] = 0
50
+
51
+ if metric_value is not None:
52
+ metric_bests[metric_name] = reduction_func(metric_bests[metric_name], metric_value)
53
+ metrics_compared[metric_name] += 1
54
+
55
+ for method_name, method_results in method_dict.items():
56
+ for metric_name, metric_value in method_results.items():
57
+ method_results[metric_name] = result_to_str(metric_value, "inf" in metric_name or "ASR" in metric_name)
58
+
59
+ if metric_value is not None and np.allclose(metric_value, metric_bests[metric_name]) \
60
+ and metrics_compared[metric_name] > 1:
61
+ method_results[metric_name] = rf"\textbf{{ {method_results[metric_name]} }}"
62
+
63
+ method_tex = {
64
+ "cwk": r"CW^K",
65
+ "ad": r"AD",
66
+ "cvxproj": r"\textbf{QuadAttac$K$}"
67
+ }
68
+
69
+ doc = Document("multirow")
70
+
71
+ protocol_cols = 1
72
+ attack_method_cols = 1
73
+ best_cols = 4
74
+ mean_cols = 4
75
+ worst_cols = 4
76
+
77
+ if only_mean:
78
+ col_widths = [protocol_cols, attack_method_cols, mean_cols]
79
+ else:
80
+ col_widths = [protocol_cols, attack_method_cols, best_cols, mean_cols, worst_cols]
81
+
82
+ total_cols = sum(col_widths)
83
+ tabular_string = "|"
84
+
85
+ for w in col_widths:
86
+ tabular_string += "l" * w + "|"
87
+
88
+ table1 = Tabular(tabular_string)
89
+ table1.add_hline()
90
+ table1.add_row((MultiColumn(total_cols, align='|c|', data=NoEscape(model_to_tex[model_name])),))
91
+ table1.add_hline()
92
+
93
+ if only_mean:
94
+ table1.add_row((
95
+ MultiRow(2, data="Protocol"),
96
+ MultiRow(2, data="Attack Method"),
97
+ MultiColumn(mean_cols, align="|c|", data="Mean"),
98
+ ))
99
+ else:
100
+ table1.add_row((
101
+ MultiRow(2, data="Protocol"),
102
+ MultiRow(2, data="Attack Method"),
103
+ MultiColumn(best_cols, align="|c|", data="Best"),
104
+ MultiColumn(mean_cols, align="|c|", data="Mean"),
105
+ MultiColumn(worst_cols, align="|c|", data="Worst"),
106
+ ))
107
+
108
+ table1.add_hline(start=protocol_cols + attack_method_cols + 1)
109
+
110
+ num_result_colums = 1 if only_mean else 3
111
+ table1.add_row("",
112
+ "",
113
+ *(NoEscape(r"ASR$\uparrow$"),
114
+ NoEscape(r"$\ell_1 \downarrow$"),
115
+ NoEscape(r"$\ell_2 \downarrow$"),
116
+ NoEscape(r"$\ell_{\infty} \downarrow$"))*num_result_colums
117
+ )
118
+
119
+ table1.add_hline()
120
+
121
+ for top_k, bs_dict in results.items():
122
+
123
+ total_results = 0
124
+ # Count total results
125
+ for _, iter_dict in bs_dict.items():
126
+ for num_iter, method_dict in iter_dict.items():
127
+ total_results += len(method_dict)
128
+
129
+ top_k_latex_obj = MultiRow(total_results, data=f"Top-{top_k}")
130
+
131
+ shown_topk_obj = False
132
+
133
+ for bs_steps, iter_dict in bs_dict.items():
134
+ for num_iter, method_dict in iter_dict.items():
135
+ for method_name, method_results in method_dict.items():
136
+ first_obj = top_k_latex_obj if not shown_topk_obj else ""
137
+ shown_topk_obj = True
138
+
139
+ row_results = []
140
+
141
+ reduction_names = ["mean"] if only_mean else ["best", "mean", "worst"]
142
+ for reduction in reduction_names:
143
+ for metric in ["ASR", "L1", "L2", "L_inf"]:
144
+ result_key = f"{metric}_{reduction}"
145
+ long_result = "inf" in metric
146
+ row_results.append(
147
+ NoEscape(method_results[result_key])
148
+ )
149
+
150
+ table1.add_row(
151
+ first_obj,
152
+ NoEscape("$" + method_tex[method_name] +
153
+ f"_{{{bs_steps}x{num_iter}}}$"),
154
+ *row_results
155
+ )
156
+
157
+ table1.add_hline(start=protocol_cols + 1)
158
+
159
+ table1.add_hline()
160
+
161
+ # doc.append(table1)
162
+
163
+ print(table1.dumps())
result_stats.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import sys
3
+ from modelguidedattacks import results
4
+
5
+ results_list = results.load_all_results()
6
+
7
+ filter = {
8
+ "loss": results.in_set(["cvxproj"]),
9
+ "model": results.eq("resnet50"),
10
+ "k": results.eq(20),
11
+ "binary_search_steps": results.eq(1),
12
+ "unguided_iterations": results.eq(60),
13
+ # "topk_loss_coef_upper": results.eq(20),
14
+ # "unguided_lr": results.eq(0.002),
15
+ "cvx_proj_margin": results.eq(0.2),
16
+ "topk_loss_coef_upper": results.gte(12)
17
+ # "seed": results.eq(10),
18
+ }
19
+
20
+ filtered_results = results.filter_results(filter, results_list)
21
+
22
+ combined_results = {}
23
+
24
+ for result in filtered_results:
25
+ for key in result:
26
+ if key not in combined_results:
27
+ combined_results[key] = []
28
+
29
+ combined_results[key].append(result[key])
30
+
31
+ unique_runs = len(np.unique(combined_results["seed"]))
32
+ print ("Stats from", len(filtered_results))
33
+ # assert len(combined_results["seed"]) == unique_runs
34
+
35
+ for key, val in list(combined_results.items()):
36
+ if key in ["ASR", "L1", "L2", "L_inf"]:
37
+ val = np.array(val)
38
+ combined_results[f"{key}_mean"] = np.mean(val[np.isfinite(val)])
39
+ combined_results[f"{key}_median"] = np.median(val[np.isfinite(val)])
40
+
41
+ # Coupled results
42
+ best_asr_idx = np.argmax(combined_results["ASR"])
43
+ best_asr = combined_results["ASR"][best_asr_idx]
44
+ best_l1 = combined_results["L1"][best_asr_idx]
45
+ best_l2 = combined_results["L2"][best_asr_idx]
46
+ best_linf = combined_results["L_inf"][best_asr_idx]
47
+
48
+ combined_results["ASR_best"] = best_asr
49
+ combined_results["L1_best"] = best_l1
50
+ combined_results["L2_best"] = best_l2
51
+ combined_results["L_inf_best"] = best_linf
52
+
53
+ worst_asr_idx = np.argmin(combined_results["ASR"])
54
+ worst_asr = combined_results["ASR"][worst_asr_idx]
55
+ worst_l1 = combined_results["L1"][worst_asr_idx]
56
+ worst_l2 = combined_results["L2"][worst_asr_idx]
57
+ worst_linf = combined_results["L_inf"][worst_asr_idx]
58
+
59
+ combined_results["ASR_worst"] = worst_asr
60
+ combined_results["L1_worst"] = worst_l1
61
+ combined_results["L2_worst"] = worst_l2
62
+ combined_results["L_inf_worst"] = worst_linf
63
+
64
+ draw_keys = ["best", "mean", "worst"]
65
+ val_keys = ["ASR", "L1", "L2", "L_inf"]
66
+
67
+ print ("---------------")
68
+ for draw_key in draw_keys:
69
+ for val_key in val_keys:
70
+ key = val_key + "_" + draw_key
71
+ val = combined_results[key]
72
+
73
+ print (key, val)
74
+ print ("---------------")
75
+
76
+ for draw_key in draw_keys:
77
+ for val_key in val_keys:
78
+ key = val_key + "_" + draw_key
79
+ val = combined_results[key]
80
+
81
+ if np.isinf(val):
82
+ val = "N/A"
83
+
84
+ sep = "&"
85
+ if isinstance(val, str):
86
+ sys.stdout.write(f"{val} {sep} ")
87
+ elif "inf" in key:
88
+ sys.stdout.write(f"{val:.3f} {sep} ")
89
+ else:
90
+ sys.stdout.write(f"{val:.2f} {sep} ")
setup.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ from setuptools import find_packages, setup
4
+
5
+ def read(*paths, **kwargs):
6
+ """Read the contents of a text file safely.
7
+ >>> read("project_name", "VERSION")
8
+ '0.1.0'
9
+ >>> read("README.md")
10
+ ...
11
+ """
12
+
13
+ content = ""
14
+ with io.open(
15
+ os.path.join(os.path.dirname(__file__), *paths),
16
+ encoding=kwargs.get("encoding", "utf8"),
17
+ ) as open_file:
18
+ content = open_file.read().strip()
19
+ return content
20
+
21
+
22
+ def read_requirements(path):
23
+ return [
24
+ line.strip()
25
+ for line in read(path).split("\n")
26
+ if not line.startswith(('"', "#", "-", "git+"))
27
+ ]
28
+
29
+
30
+ setup(
31
+ name="modelguidedattacks",
32
+ version="0.1",
33
+ description="Adversarial attacks",
34
+ url="",
35
+ long_description=read("README.md"),
36
+ long_description_content_type="text/markdown",
37
+ author="NCSU",
38
+ packages=["modelguidedattacks"],
39
+ install_requires=read_requirements("requirements.txt"),
40
+ )