rgrainger commited on
Commit
3303c2f
1 Parent(s): 3d67e33

adding demo

Browse files
app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import types
3
+ import timm
4
+ import requests
5
+ import random
6
+ import yaml
7
+ import gradio as gr
8
+ from PIL import Image
9
+ from timm import create_model
10
+ from torchvision import transforms
11
+ from timm.data import resolve_data_config
12
+ from modelguidedattacks.guides.unguided import Unguided
13
+ from timm.data.transforms_factory import create_transform
14
+ from modelguidedattacks.cls_models.registry import TimmPretrainModelWrapper
15
+
16
+
17
+ # Download human-readable labels for ImageNet.
18
+ IMAGENET_LABELS_URL = "https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt"
19
+ LABELS = requests.get(IMAGENET_LABELS_URL).text.strip().split("\n")
20
+ SORTED_LABELS = sorted(LABELS.copy(), key=lambda s: s.lower())
21
+
22
+ def get_timm_model(name):
23
+ """Retrieves model from timm library by name with weights loaded.
24
+ """
25
+ model = create_model(name,pretrained="true")
26
+ transform = create_transform(**resolve_data_config({}, model=model))
27
+ model = model.eval()
28
+ return model, transform
29
+
30
+ def create_attacker(model, transform, iterations):
31
+ """ Instantiates an QuadAttack Model.
32
+ """
33
+ # config_dict = {"cvx_proj_margin" : 0.2,
34
+ # "opt_warmup_its": 5}
35
+ with open("base_config.yaml") as f:
36
+ config_dict = yaml.safe_load(f)
37
+
38
+ config = types.SimpleNamespace(**config_dict)
39
+
40
+ attacker = Unguided(TimmPretrainModelWrapper(model, transform,"", "", ""), config, iterations=iterations,
41
+ lr=0.002, topk_loss_coef_upper=10)
42
+
43
+ return attacker
44
+
45
+ def predict_topk_accuracies(img, k, iters, model_name, desired_labels, button=None, progress=gr.Progress(track_tqdm=True)):
46
+ """ Predict the top K results using base model and attacker model.
47
+ """
48
+ label_inds = list(range(0,1000)) #label indices
49
+ # convert user desired labels to desired inds
50
+ desired_inds = [LABELS.index(name) for name in desired_labels]
51
+ # remove selected before randomly sampling the rest
52
+ for ind in desired_inds:
53
+ label_inds.remove(ind)
54
+
55
+ # fill up user selections to top k results
56
+ desired_inds = desired_inds + random.sample(label_inds,k-len(desired_inds))
57
+ tensorized_desired_inds = torch.tensor(desired_inds).unsqueeze(0) #[B,K]
58
+
59
+ model, transform = get_timm_model(model_name)
60
+
61
+ # Define a transformation to convert PIL image to a tensor
62
+ normalization = transforms.Compose([
63
+ transform.transforms[-1] # Converts to a PyTorch tensor
64
+ ])
65
+ preprocess = transforms.Compose(
66
+ transform.transforms[:-1] # Converts to a PyTorch tensor
67
+ )
68
+
69
+ attacker = create_attacker(model, normalization, iters)
70
+
71
+
72
+ img = img.convert('RGB')
73
+ orig_img = img.copy()
74
+ orig_img = preprocess(orig_img)
75
+ orig_img = orig_img.unsqueeze(0)
76
+ img = transform(img).unsqueeze(0)
77
+
78
+ with torch.no_grad():
79
+ outputs = model(img)
80
+ attack_outputs, attack_img = attacker(orig_img, tensorized_desired_inds, None)
81
+
82
+ probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
83
+ attacker_probs = torch.nn.functional.softmax(attack_outputs[0], dim=0)
84
+
85
+ values, indices = torch.topk(probabilities, k)
86
+
87
+ attack_vals, attack_inds = torch.topk(attacker_probs, k)
88
+
89
+ attack_img_out = orig_img + attack_img #B C H W
90
+ # Convert the PyTorch tensor to a NumPy array
91
+ attack_img_out = attack_img_out.squeeze(0) # C H W
92
+ attack_img_out = attack_img_out.permute(1, 2, 0).numpy() # H W C
93
+
94
+ orig_img = orig_img.squeeze(0)
95
+ orig_img = orig_img.permute(1, 2, 0).numpy()
96
+
97
+ attack_img = attack_img.squeeze(0)
98
+ attack_img = attack_img.permute(1, 2, 0).numpy()
99
+
100
+
101
+ # Convert the NumPy array to a PIL image
102
+ attack_img_out = Image.fromarray((attack_img_out * 255).astype('uint8'))
103
+ orig_img = Image.fromarray((orig_img * 255).astype('uint8'))
104
+ attack_img = Image.fromarray((attack_img * 255).astype('uint8'))
105
+
106
+
107
+ return (orig_img, attack_img_out, attack_img,{LABELS[i]: v.item() for i, v in zip(indices, values)}, {LABELS[i]: v.item() for i, v in zip(attack_inds, attack_vals)})
108
+
109
+ def random_fill_classes(desired_labels, k):
110
+
111
+ label_inds = list(range(0,1000)) #label indices
112
+ # convert user desired labels to desired inds
113
+ if len(desired_labels) > k:
114
+ desired_labels = desired_labels[:k]
115
+ desired_inds = [LABELS.index(name) for name in desired_labels]
116
+ # remove selected before randomly sampling the rest
117
+ for ind in desired_inds:
118
+ label_inds.remove(ind)
119
+
120
+ # fill up user selections to top k results
121
+ desired_inds = desired_inds + random.sample(label_inds,k-len(desired_inds))
122
+
123
+ return [LABELS[ind] for ind in desired_inds]
124
+
125
+
126
+ input_img = gr.Image(type='pil')
127
+ top_k_slider = gr.Slider(2, 20, value=10, step=1, label="Top K predictions", info="Choose between 2 and 20")
128
+ iteration_slider = gr.Slider(30, 1000, value=60, step=1, label="QuadAttack Iterations", info="Choose how many iterations to optimize using QuadAttack! (Usually <= 60 is enough)")
129
+ model_choice_list = gr.Dropdown(
130
+ timm.list_models(), value="vit_base_patch16_224", label="timm model name", info="Currently only supporting timm models! See code for models used in paper."
131
+ )
132
+ desired_labels = gr.Dropdown(
133
+ SORTED_LABELS, max_choices=20,filterable=True, multiselect=True, label="Desired Labels for QuadAttack", info="Select classes you wish to output from an attack. \
134
+ Classes will be ranked in order listed and randomly filled up to \
135
+ K if < K options are selected."
136
+ )
137
+ button = gr.Button("Randomly fill Top-K attack classes.")
138
+
139
+ desc = r'<div align="center">Authors: Thomas Paniagua, Ryan Grainger, Tianfu Wu <p><a href="https://arxiv.org/abs/2312.11510">Paper</a><br><a href="https://github.com/thomaspaniagua/quadattack">Code</a></p> </div>'
140
+ with gr.Interface(predict_topk_accuracies,
141
+ inputs=[input_img,
142
+ top_k_slider,
143
+ iteration_slider,
144
+ model_choice_list,
145
+ desired_labels,
146
+ button],
147
+ outputs=[
148
+ gr.Image(type='pil', label="Input Image"),
149
+ gr.Image(type='pil', label="Perturbed Image"),
150
+ gr.Image(type='pil', label="Added Noise"),
151
+ gr.Label(label="Original Top K"),
152
+ gr.Label(label="QuadAttack Top K"),
153
+ # gr.Image(type='pil', label="Perturbed Image")
154
+ ],
155
+ title='QuadAttack!',
156
+ description= desc,
157
+ cache_examples=False,
158
+ allow_flagging="never",
159
+ thumbnail= "quadattack_pipeline.pdf",
160
+ examples = [["image_examples/RV.jpeg", 5, 30, "vit_base_patch16_224", None, None
161
+ # ["lemon", "plastic_bag", "hay", "tripod", "bell_cote, bell_cot"]
162
+ ],
163
+ # ["image_examples/biker.jpeg", 10, 60, "swinv2_cr_base_224", None, None
164
+
165
+ # ["hog, pig, grunter, squealer, Sus_scrofa",
166
+ # "lesser_panda, red_panda, panda, bear_cat, cat_bear, Ailurus_fulgens",
167
+ # "caldron, cauldron", "dowitcher", "water_tower", "quill, quill_pen",
168
+ # "balance_beam, beam", "unicycle, monocycle", "pencil_sharpener",
169
+ # "puffer, pufferfish, blowfish, globefish"
170
+ # ]
171
+ # ],
172
+ ["image_examples/mower.jpeg", 15, 100,"wide_resnet101_2", None , None
173
+
174
+ # ["washbasin, handbasin, washbowl, lavabo, wash-hand_basin",
175
+ # "cucumber, cuke", "bolete", "oboe, hautboy, hautboi", "crane",
176
+ # "wolf_spider, hunting_spider", "Norfolk_terrier", "nail", "sidewinder, horned_rattlesnake, Crotalus_cerastes",
177
+ # "cannon", "beaker", "Shetland_sheepdog, Shetland_sheep_dog, Shetland",
178
+ # "monitor", "restaurant, eating_house, eating_place, eatery", "electric_fan, blower"
179
+ # ]
180
+ ],
181
+ # ["image_examples/dog.jpeg", 20, 150, "xcit_small_12_p8_224", None, None
182
+
183
+ # ["church, church_building", "axolotl, mud_puppy, Ambystoma_mexicanum",
184
+ # "Scotch_terrier, Scottish_terrier, Scottie", "black-footed_ferret, ferret, Mustela_nigripes",
185
+ # "lab_coat, laboratory_coat", "gyromitra", "grasshopper, hopper", "snail", "tabby, tabby_cat",
186
+ # "bell_cote, bell_cot", "Indian_cobra, Naja_naja", "robin, American_robin, Turdus_migratorius",
187
+ # "tiger_cat", "book_jacket, dust_cover, dust_jacket, dust_wrapper", "loudspeaker, speaker, speaker_unit, loudspeaker_system, speaker_system",
188
+ # "washbasin, handbasin, washbowl, lavabo, wash-hand_basin", "electric_guitar", "armadillo", "ski_mask",
189
+ # "convertible"
190
+ # ]
191
+
192
+ # ],
193
+ ["image_examples/fish.jpeg", 10, 100, "pvt_v2_b0", None, None
194
+
195
+ # ["ground_beetle, carabid_beetle", "sunscreen, sunblock, sun_blocker", "brass, memorial_tablet, plaque", "Irish_terrier", "head_cabbage", "bathtub, bathing_tub, bath, tub",
196
+ # "centipede", "squirrel_monkey, Saimiri_sciureus", "Chihuahua", "hourglass"
197
+ # ]
198
+ ]
199
+ ]
200
+
201
+ ).queue() as app:
202
+ #turn off clear button as it erases globals
203
+ for block in app.blocks:
204
+ if isinstance(app.blocks[block],gr.Button):
205
+ if app.blocks[block].value == "Clear":
206
+ app.blocks[block].visible=False
207
+ button.click(random_fill_classes, inputs=[desired_labels,top_k_slider], outputs=desired_labels)
208
+
209
+
210
+ if __name__ == "__main__":
211
+ app.launch(server_port=9000)
image_examples/RV.jpeg ADDED
image_examples/__init__.py ADDED
File without changes
image_examples/biker.jpeg ADDED
image_examples/dog.jpeg ADDED
image_examples/fish.jpeg ADDED
image_examples/mower.jpeg ADDED
modelguidedattacks/cls_models/registry.py CHANGED
@@ -31,6 +31,57 @@ class ClsModel(nn.Module):
31
 
32
  raise NotImplementedError("Forward not implemented for base class")
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  class MMPretrainModelWrapper(ClsModel):
35
  """
36
  Calls data preprocessing for model before entering forward
 
31
 
32
  raise NotImplementedError("Forward not implemented for base class")
33
 
34
+ class TimmPretrainModelWrapper(ClsModel):
35
+ """
36
+ Calls data preprocessing for model before entering forward
37
+ """
38
+ def __init__(self, model: nn.Module, transform, dataset_name: str, model_name: str, device: str) -> None:
39
+ super().__init__(dataset_name, model_name, device)
40
+ self.model = model
41
+ self.transform = transform
42
+
43
+ @property
44
+ def final_linear_layer(self):
45
+ try:
46
+ testing_head = self.model.head
47
+ head = True
48
+ except:
49
+ head = False
50
+
51
+ if head:
52
+ if isinstance(self.model.head, torch.nn.Linear):
53
+ return self.model.head
54
+ else:
55
+ return self.model.head.fc
56
+ else:
57
+ return self.model.fc
58
+
59
+ def head_features(self):
60
+ return self.final_linear_layer.in_features
61
+
62
+ def num_classes(self):
63
+ return self.final_linear_layer.out_features
64
+
65
+ def head(self, feats):
66
+ return self.model.head((feats,))
67
+
68
+ def head_matrices(self):
69
+ return self.final_linear_layer.weight, self.final_linear_layer.bias
70
+
71
+ def forward(self, x, return_features=False):
72
+ x = self.transform(x)
73
+ if return_features:
74
+ feats = self.model.forward_features(x)
75
+ logits = self.model.forward_head(feats, pre_logits=True)
76
+ try:
77
+ preds = self.model.fc(logits) # convnet,
78
+ except:
79
+ preds = self.model.head(logits) # vit
80
+
81
+ return preds, logits
82
+ else:
83
+ return self.model(x)
84
+
85
  class MMPretrainModelWrapper(ClsModel):
86
  """
87
  Calls data preprocessing for model before entering forward
modelguidedattacks/guides/unguided.py CHANGED
@@ -63,6 +63,10 @@ class Unguided(nn.Module):
63
 
64
  x_perturbation = nn.Parameter(torch.randn(x.shape,
65
  device=x.device)*2e-3)
 
 
 
 
66
 
67
  with torch.no_grad():
68
  prediction_logits_0, prediction_feats_0 \
 
63
 
64
  x_perturbation = nn.Parameter(torch.randn(x.shape,
65
  device=x.device)*2e-3)
66
+
67
+ optimizer = self.optimizer([x_perturbation], lr=self.lr)
68
+
69
+ precomputed_state = self.loss.precompute(attack_targets, gt_labels, self.config)
70
 
71
  with torch.no_grad():
72
  prediction_logits_0, prediction_feats_0 \
quadattack_pipeline.pdf ADDED
Binary file (111 kB). View file
 
testing.md ADDED
@@ -0,0 +1 @@
 
 
1
+ <\center> #QuadAttack