import torch
import types
import timm
import requests
import random
import yaml
import gradio as gr
from PIL import Image
from timm import create_model
from torchvision import transforms
from timm.data import resolve_data_config
from modelguidedattacks.guides.unguided import Unguided
from timm.data.transforms_factory import create_transform
from modelguidedattacks.cls_models.registry import TimmPretrainModelWrapper
# Download human-readable labels for ImageNet.
IMAGENET_LABELS_URL = "https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt"
LABELS = requests.get(IMAGENET_LABELS_URL).text.strip().split("\n")
SORTED_LABELS = sorted(LABELS.copy(), key=lambda s: s.lower())
def get_timm_model(name):
"""Retrieves model from timm library by name with weights loaded.
model = create_model(name,pretrained="true")
transform = create_transform(**resolve_data_config({}, model=model))
model = model.eval()
return model, transform
def create_attacker(model, transform, iterations):
""" Instantiates an QuadAttack Model.
# config_dict = {"cvx_proj_margin" : 0.2,
# "opt_warmup_its": 5}
with open("base_config.yaml") as f:
config_dict = yaml.safe_load(f)
config = types.SimpleNamespace(**config_dict)
attacker = Unguided(TimmPretrainModelWrapper(model, transform,"", "", ""), config, iterations=iterations,
lr=0.002, topk_loss_coef_upper=10)
return attacker
def predict_topk_accuracies(img, k, iters, model_name, desired_labels, button=None, progress=gr.Progress(track_tqdm=True)):
""" Predict the top K results using base model and attacker model.
label_inds = list(range(0,1000)) #label indices
# convert user desired labels to desired inds
desired_inds = [LABELS.index(name) for name in desired_labels]
# remove selected before randomly sampling the rest
for ind in desired_inds:
# fill up user selections to top k results
desired_inds = desired_inds + random.sample(label_inds,k-len(desired_inds))
tensorized_desired_inds = torch.tensor(desired_inds).unsqueeze(0) #[B,K]
model, transform = get_timm_model(model_name)
# Define a transformation to convert PIL image to a tensor
normalization = transforms.Compose([
transform.transforms[-1] # Converts to a PyTorch tensor
preprocess = transforms.Compose(
transform.transforms[:-1] # Converts to a PyTorch tensor
attacker = create_attacker(model, normalization, iters)
img = img.convert('RGB')
orig_img = img.copy()
orig_img = preprocess(orig_img)
orig_img = orig_img.unsqueeze(0)
img = transform(img).unsqueeze(0)
with torch.no_grad():
outputs = model(img)
attack_outputs, attack_img = attacker(orig_img, tensorized_desired_inds, None)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
attacker_probs = torch.nn.functional.softmax(attack_outputs[0], dim=0)
values, indices = torch.topk(probabilities, k)
attack_vals, attack_inds = torch.topk(attacker_probs, k)
attack_img_out = orig_img + attack_img #B C H W
# Convert the PyTorch tensor to a NumPy array
attack_img_out = attack_img_out.squeeze(0) # C H W
attack_img_out = attack_img_out.permute(1, 2, 0).numpy() # H W C
orig_img = orig_img.squeeze(0)
orig_img = orig_img.permute(1, 2, 0).numpy()
attack_img = attack_img.squeeze(0)
attack_img = attack_img.permute(1, 2, 0).numpy()
# Convert the NumPy array to a PIL image
attack_img_out = Image.fromarray((attack_img_out * 255).astype('uint8'))
orig_img = Image.fromarray((orig_img * 255).astype('uint8'))
attack_img = Image.fromarray((attack_img * 255).astype('uint8'))
return (orig_img, attack_img_out, attack_img,{LABELS[i]: v.item() for i, v in zip(indices, values)}, {LABELS[i]: v.item() for i, v in zip(attack_inds, attack_vals)})
def random_fill_classes(desired_labels, k):
label_inds = list(range(0,1000)) #label indices
# convert user desired labels to desired inds
if len(desired_labels) > k:
desired_labels = desired_labels[:k]
desired_inds = [LABELS.index(name) for name in desired_labels]
# remove selected before randomly sampling the rest
for ind in desired_inds:
# fill up user selections to top k results
desired_inds = desired_inds + random.sample(label_inds,k-len(desired_inds))
return [LABELS[ind] for ind in desired_inds]
input_img = gr.Image(type='pil')
top_k_slider = gr.Slider(2, 20, value=10, step=1, label="Top K predictions", info="Choose between 2 and 20")
iteration_slider = gr.Slider(30, 1000, value=60, step=1, label="QuadAttack Iterations", info="Choose how many iterations to optimize using QuadAttack! (Usually <= 60 is enough)")
model_choice_list = gr.Dropdown(
timm.list_models(), value="vit_base_patch16_224", label="timm model name", info="Currently only supporting timm models! See code for models used in paper."
desired_labels = gr.Dropdown(
SORTED_LABELS, max_choices=20,filterable=True, multiselect=True, label="Desired Labels for QuadAttack", info="Select classes you wish to output from an attack. \
Classes will be ranked in order listed and randomly filled up to \
K if < K options are selected."
button = gr.Button("Randomly fill Top-K attack classes.")
desc = r'<div align="center">Authors: Thomas Paniagua, Ryan Grainger, Tianfu Wu <p><a href="https://arxiv.org/abs/2312.11510">Paper</a><br><a href="https://github.com/thomaspaniagua/quadattack">Code</a></p> </div>'
with gr.Interface(predict_topk_accuracies,
gr.Image(type='pil', label="Input Image"),
gr.Image(type='pil', label="Perturbed Image"),
gr.Image(type='pil', label="Added Noise"),
gr.Label(label="Original Top K"),
gr.Label(label="QuadAttack Top K"),
# gr.Image(type='pil', label="Perturbed Image")
description= desc,
thumbnail= "quadattack_pipeline.pdf",
examples = [["image_examples/RV.jpeg", 5, 30, "vit_base_patch16_224", None, None
["image_examples/mower.jpeg", 15, 100,"wide_resnet101_2", None , None
],
# ],
["image_examples/fish.jpeg", 10, 100, "pvt_v2_b0", None, None
).queue() as app:
#turn off clear button as it erases globals
for block in app.blocks:
if isinstance(app.blocks[block],gr.Button):
if app.blocks[block].value == "Clear":
button.click(random_fill_classes, inputs=[desired_labels,top_k_slider], outputs=desired_labels)
if __name__ == "__main__":