Spaces:
Sleeping
Sleeping
''' | |
HEART Gradio Example App | |
To run: | |
- clone the repository | |
- execute: gradio examples/gradio_app.py or python examples/gradio_app.py | |
- navigate to local URL e.g. http://127.0.0.1:7860 | |
''' | |
import torch | |
import numpy as np | |
import pandas as pd | |
# from carbon_theme import Carbon | |
import gradio as gr | |
import os | |
css = """ | |
.input-image { margin: auto !important } | |
.small-font span{ | |
font-size: 0.6em; | |
} | |
.df-padding { | |
padding-left: 50px !important; | |
padding-right: 50px !important; | |
} | |
""" | |
def basic_cifar10_model(): | |
''' | |
Load an example CIFAR10 model | |
''' | |
from heart.estimators.classification.pytorch import JaticPyTorchClassifier | |
labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] | |
path = './' | |
class Model(torch.nn.Module): | |
""" | |
Create model for pytorch. | |
Here the model does not use maxpooling. Needed for certification tests. | |
""" | |
def __init__(self): | |
super(Model, self).__init__() | |
self.conv = torch.nn.Conv2d( | |
in_channels=3, out_channels=16, kernel_size=(4, 4), dilation=(1, 1), padding=(0, 0), stride=(3, 3) | |
) | |
self.fullyconnected = torch.nn.Linear(in_features=1600, out_features=10) | |
self.relu = torch.nn.ReLU() | |
w_conv2d = np.load( | |
os.path.join( | |
os.path.dirname(path), | |
"utils/resources/models", | |
"W_CONV2D_NO_MPOOL_CIFAR10.npy", | |
) | |
) | |
b_conv2d = np.load( | |
os.path.join( | |
os.path.dirname(path), | |
"utils/resources/models", | |
"B_CONV2D_NO_MPOOL_CIFAR10.npy", | |
) | |
) | |
w_dense = np.load( | |
os.path.join( | |
os.path.dirname(path), | |
"utils/resources/models", | |
"W_DENSE_NO_MPOOL_CIFAR10.npy", | |
) | |
) | |
b_dense = np.load( | |
os.path.join( | |
os.path.dirname(path), | |
"utils/resources/models", | |
"B_DENSE_NO_MPOOL_CIFAR10.npy", | |
) | |
) | |
self.conv.weight = torch.nn.Parameter(torch.Tensor(w_conv2d)) | |
self.conv.bias = torch.nn.Parameter(torch.Tensor(b_conv2d)) | |
self.fullyconnected.weight = torch.nn.Parameter(torch.Tensor(w_dense)) | |
self.fullyconnected.bias = torch.nn.Parameter(torch.Tensor(b_dense)) | |
# pylint: disable=W0221 | |
# disable pylint because of API requirements for function | |
def forward(self, x): | |
""" | |
Forward function to evaluate the model | |
:param x: Input to the model | |
:return: Prediction of the model | |
""" | |
x = self.conv(x) | |
x = self.relu(x) | |
x = x.reshape(-1, 1600) | |
x = self.fullyconnected(x) | |
return x | |
# Define the network | |
model = Model() | |
# Define a loss function and optimizer | |
loss_fn = torch.nn.CrossEntropyLoss(reduction="sum") | |
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) | |
# Get classifier | |
jptc = JaticPyTorchClassifier( | |
model=model, loss=loss_fn, optimizer=optimizer, input_shape=(3, 32, 32), nb_classes=10, clip_values=(0, 1), labels=labels | |
) | |
return jptc | |
def clf_evasion_evaluate(*args): | |
''' | |
Run a classification task evaluation | |
''' | |
attack = args[0] | |
model_type = args[1] | |
model_path = args[2] | |
model_channels = args[3] | |
model_height = args[4] | |
model_width = args[5] | |
model_clip = args[6] | |
dataset_type = args[-4] | |
dataset_path = args[-3] | |
dataset_split = args[-2] | |
image = args[-1] | |
if dataset_type == "Example XView": | |
from maite import load_dataset | |
import torchvision | |
jatic_dataset = load_dataset( | |
provider="huggingface", | |
dataset_name="CDAO/xview-subset-classification", | |
task="image-classification", | |
split="test", | |
) | |
IMAGE_H, IMAGE_W = 224, 224 | |
transform = torchvision.transforms.Compose( | |
[ | |
torchvision.transforms.Resize((IMAGE_H, IMAGE_W)), | |
torchvision.transforms.ToTensor(), | |
] | |
) | |
jatic_dataset.set_transform(lambda x: {"image": transform(x["image"]), "label": x["label"]}) | |
image = {'image': [i['image'].numpy() for i in jatic_dataset], | |
'label': [i['label'] for i in jatic_dataset]} | |
elif dataset_type=="huggingface": | |
from maite import load_dataset | |
jatic_dataset = load_dataset( | |
provider=dataset_type, | |
dataset_name=dataset_path, | |
task="image-classification", | |
split=dataset_split, | |
drop_labels=False | |
) | |
image = {'image': [i['image'] for i in jatic_dataset], | |
'label': [i['label'] for i in jatic_dataset]} | |
elif dataset_type=="torchvision": | |
from maite import load_dataset | |
jatic_dataset = load_dataset( | |
provider=dataset_type, | |
dataset_name=dataset_path, | |
task="image-classification", | |
split=dataset_split, | |
root='./data/', | |
download=True | |
) | |
image = {'image': [i['image'] for i in jatic_dataset], | |
'label': [i['label'] for i in jatic_dataset]} | |
elif dataset_type=="Example CIFAR10": | |
from maite import load_dataset | |
jatic_dataset = load_dataset( | |
provider="torchvision", | |
dataset_name="CIFAR10", | |
task="image-classification", | |
split=dataset_split, | |
root='./data/', | |
download=True | |
) | |
image = {'image': [i['image'] for i in jatic_dataset][:100], | |
'label': [i['label'] for i in jatic_dataset][:100]} | |
if model_type == "Example CIFAR10": | |
jptc = basic_cifar10_model() | |
elif model_type == "Example XView": | |
import torchvision | |
from heart.estimators.classification.pytorch import JaticPyTorchClassifier | |
classes = { | |
0:'Building', | |
1:'Construction Site', | |
2:'Engineering Vehicle', | |
3:'Fishing Vessel', | |
4:'Oil Tanker', | |
5:'Vehicle Lot' | |
} | |
model = torchvision.models.resnet18(False) | |
num_ftrs = model.fc.in_features | |
model.fc = torch.nn.Linear(num_ftrs, len(classes.keys())) | |
model.load_state_dict(torch.load('./utils/resources/models/xview_model.pt')) | |
_ = model.eval() | |
jptc = JaticPyTorchClassifier( | |
model=model, loss = torch.nn.CrossEntropyLoss(), input_shape=(3, 224, 224), | |
nb_classes=len(classes), clip_values=(0, 1), labels=list(classes.values()) | |
) | |
elif model_type == "torchvision": | |
from maite.interop.torchvision import TorchVisionClassifier | |
from heart.estimators.classification.pytorch import JaticPyTorchClassifier | |
clf = TorchVisionClassifier.from_pretrained(model_path) | |
loss_fn = torch.nn.CrossEntropyLoss(reduction="sum") | |
jptc = JaticPyTorchClassifier( | |
model=clf._model, loss=loss_fn, input_shape=(model_channels, model_height, model_width), | |
nb_classes=len(clf._labels), clip_values=(0, model_clip), labels=clf._labels | |
) | |
elif model_type == "huggingface": | |
from maite.interop.huggingface import HuggingFaceImageClassifier | |
from heart.estimators.classification.pytorch import JaticPyTorchClassifier | |
clf = HuggingFaceImageClassifier.from_pretrained(model_path) | |
loss_fn = torch.nn.CrossEntropyLoss(reduction="sum") | |
jptc = JaticPyTorchClassifier( | |
model=clf._model, loss=loss_fn, input_shape=(model_channels, model_height, model_width), | |
nb_classes=len(clf._labels), clip_values=(0, model_clip), labels=clf._labels | |
) | |
if attack=="PGD": | |
from art.attacks.evasion.projected_gradient_descent.projected_gradient_descent_pytorch import ProjectedGradientDescentPyTorch | |
from heart.attacks.attack import JaticAttack | |
from heart.metrics import AccuracyPerturbationMetric | |
from torch.nn.functional import softmax | |
from maite.protocols import HasDataImage, is_typed_dict, ArrayLike | |
pgd_attack = ProjectedGradientDescentPyTorch(estimator=jptc, max_iter=args[7], eps=args[8], | |
eps_step=args[9], targeted=args[10]!="") | |
attack = JaticAttack(pgd_attack) | |
preds = jptc(image) | |
preds = softmax(torch.from_numpy(preds.logits), dim=1) | |
labels = {} | |
for i, label in enumerate(jptc.get_labels()): | |
labels[label] = preds[0][i] | |
if args[10]!="": | |
if is_typed_dict(image, HasDataImage): | |
data = {'image': image['image'], 'label': [args[10]]*len(image['image'])} | |
else: | |
data = {'image': image, 'label': [args[10]]} | |
else: | |
data = image | |
x_adv = attack.run_attack(data=data) | |
adv_preds = jptc(x_adv.adversarial_examples) | |
adv_preds = softmax(torch.from_numpy(adv_preds.logits), dim=1) | |
adv_labels = {} | |
for i, label in enumerate(jptc.get_labels()): | |
adv_labels[label] = adv_preds[0][i] | |
metric = AccuracyPerturbationMetric() | |
metric.update(jptc, jptc.device, image, x_adv.adversarial_examples) | |
clean_accuracy, robust_accuracy, perturbation_added = metric.compute() | |
metrics = pd.DataFrame([[clean_accuracy, robust_accuracy, perturbation_added]], | |
columns=['clean accuracy', 'robust accuracy', 'perturbation']) | |
adv_imgs = [img.transpose(1,2,0) for img in x_adv.adversarial_examples] | |
if is_typed_dict(image, HasDataImage): | |
image = image['image'] | |
if not isinstance(image, list): | |
image = [image] | |
# in case where multiple images, use argmax to get the predicted label and add as caption | |
if dataset_type!="local": | |
temp = [] | |
for i, img in enumerate(image): | |
if isinstance(img, ArrayLike): | |
temp.append((img.transpose(1,2,0), str(jptc.get_labels()[np.argmax(preds[i])]) )) | |
else: | |
temp.append((img, str(jptc.get_labels()[np.argmax(preds[i])]) )) | |
image = temp | |
temp = [] | |
for i, img in enumerate(adv_imgs): | |
temp.append((img, str(jptc.get_labels()[np.argmax(adv_preds[i])]) )) | |
adv_imgs = temp | |
return [image, labels, adv_imgs, adv_labels, clean_accuracy, robust_accuracy, perturbation_added] | |
elif attack=="Adversarial Patch": | |
from art.attacks.evasion.adversarial_patch.adversarial_patch_pytorch import AdversarialPatchPyTorch | |
from heart.attacks.attack import JaticAttack | |
from heart.metrics import AccuracyPerturbationMetric | |
from torch.nn.functional import softmax | |
from maite.protocols import HasDataImage, is_typed_dict, ArrayLike | |
batch_size = 16 | |
scale_min = 0.3 | |
scale_max = 1.0 | |
rotation_max = 0 | |
learning_rate = 5000. | |
max_iter = 2000 | |
patch_shape = (3, 14, 14) | |
patch_location = (18,18) | |
patch_attack = AdversarialPatchPyTorch(estimator=jptc, rotation_max=rotation_max, patch_location=(args[8], args[9]), | |
scale_min=scale_min, scale_max=scale_max, patch_type='square', | |
learning_rate=learning_rate, max_iter=args[7], batch_size=batch_size, | |
patch_shape=(3, args[10], args[11]), verbose=False, targeted=args[12]!="") | |
attack = JaticAttack(patch_attack) | |
preds = jptc(image) | |
preds = softmax(torch.from_numpy(preds.logits), dim=1) | |
labels = {} | |
for i, label in enumerate(jptc.get_labels()): | |
labels[label] = preds[0][i] | |
if args[12]!="": | |
if is_typed_dict(image, HasDataImage): | |
data = {'image': image['image'], 'label': [args[12]]*len(image['image'])} | |
else: | |
data = {'image': image, 'label': [args[12]]} | |
else: | |
data = image | |
attack_output = attack.run_attack(data=data) | |
adv_preds = jptc(attack_output.adversarial_examples) | |
adv_preds = softmax(torch.from_numpy(adv_preds.logits), dim=1) | |
adv_labels = {} | |
for i, label in enumerate(jptc.get_labels()): | |
adv_labels[label] = adv_preds[0][i] | |
metric = AccuracyPerturbationMetric() | |
metric.update(jptc, jptc.device, image, attack_output.adversarial_examples) | |
clean_accuracy, robust_accuracy, perturbation_added = metric.compute() | |
metrics = pd.DataFrame([[clean_accuracy, robust_accuracy, perturbation_added]], | |
columns=['clean accuracy', 'robust accuracy', 'perturbation']) | |
adv_imgs = [img.transpose(1,2,0) for img in attack_output.adversarial_examples] | |
if is_typed_dict(image, HasDataImage): | |
image = image['image'] | |
if not isinstance(image, list): | |
image = [image] | |
# in case where multiple images, use argmax to get the predicted label and add as caption | |
if dataset_type!="local": | |
temp = [] | |
for i, img in enumerate(image): | |
if isinstance(img, ArrayLike): | |
temp.append((img.transpose(1,2,0), str(jptc.get_labels()[np.argmax(preds[i])]) )) | |
else: | |
temp.append((img, str(jptc.get_labels()[np.argmax(preds[i])]) )) | |
image = temp | |
temp = [] | |
for i, img in enumerate(adv_imgs): | |
temp.append((img, str(jptc.get_labels()[np.argmax(adv_preds[i])]) )) | |
adv_imgs = temp | |
patch, patch_mask = attack_output.adversarial_patch | |
patch_image = ((patch) * patch_mask).transpose(1,2,0) | |
return [image, labels, adv_imgs, adv_labels, clean_accuracy, robust_accuracy, patch_image] | |
def show_model_params(model_type): | |
''' | |
Show model parameters based on selected model type | |
''' | |
if model_type!="Example CIFAR10" and model_type!="Example XView": | |
return gr.Column(visible=True) | |
return gr.Column(visible=False) | |
def show_dataset_params(dataset_type): | |
''' | |
Show dataset parameters based on dataset type | |
''' | |
if dataset_type=="Example CIFAR10" or dataset_type=="Example XView": | |
return [gr.Column(visible=False), gr.Row(visible=False), gr.Row(visible=False)] | |
elif dataset_type=="local": | |
return [gr.Column(visible=True), gr.Row(visible=True), gr.Row(visible=False)] | |
return [gr.Column(visible=True), gr.Row(visible=False), gr.Row(visible=True)] | |
def pgd_show_label_output(dataset_type): | |
''' | |
Show PGD output component based on dataset type | |
''' | |
if dataset_type=="local": | |
return [gr.Label(visible=True), gr.Label(visible=True), gr.Number(visible=False), gr.Number(visible=False), gr.Number(visible=True)] | |
return [gr.Label(visible=False), gr.Label(visible=False), gr.Number(visible=True), gr.Number(visible=True), gr.Number(visible=True)] | |
def pgd_update_epsilon(clip_values): | |
''' | |
Update max value of PGD epsilon slider based on model clip values | |
''' | |
if clip_values == 255: | |
return gr.Slider(minimum=0.0001, maximum=255, label="Epslion", value=55) | |
return gr.Slider(minimum=0.0001, maximum=1, label="Epslion", value=0.05) | |
def patch_show_label_output(dataset_type): | |
''' | |
Show adversarial patch output components based on dataset type | |
''' | |
if dataset_type=="local": | |
return [gr.Label(visible=True), gr.Label(visible=True), gr.Number(visible=False), gr.Number(visible=False), gr.Number(visible=True)] | |
return [gr.Label(visible=False), gr.Label(visible=False), gr.Number(visible=True), gr.Number(visible=True), gr.Number(visible=True)] | |
def show_target_label_dataframe(dataset_type): | |
if dataset_type == "Example CIFAR10": | |
return gr.Dataframe(visible=True), gr.Dataframe(visible=False) | |
elif dataset_type == "Example XView": | |
return gr.Dataframe(visible=False), gr.Dataframe(visible=True) | |
return gr.Dataframe(visible=False), gr.Dataframe(visible=False) | |
# e.g. To use a local alternative theme: carbon_theme = Carbon() | |
with gr.Blocks(css=css, theme='xiaobaiyuan/theme_brief') as demo: | |
gr.Markdown("<h1>HEART Adversarial Robustness Gradio Example</h1>") | |
with gr.Tab("Info"): | |
gr.Markdown('This is step 1. Using the tabs, select a task for evaluation.') | |
with gr.Tab("Classification", elem_classes="task-tab"): | |
gr.Markdown("Classifying images with a set of categories.") | |
# Model and Dataset Selection | |
with gr.Row(): | |
# Model and Dataset type e.g. Torchvision, HuggingFace, local etc. | |
with gr.Column(): | |
model_type = gr.Radio(label="Model type", choices=["Example CIFAR10", "Example XView", "torchvision"], | |
value="Example CIFAR10") | |
dataset_type = gr.Radio(label="Dataset", choices=["Example CIFAR10", "Example XView", "local", "torchvision", "huggingface"], | |
value="Example CIFAR10") | |
# Model parameters e.g. RESNET, VIT, input dimensions, clipping values etc. | |
with gr.Column(visible=False) as model_params: | |
model_path = gr.Textbox(placeholder="URL", label="Model path") | |
with gr.Row(): | |
with gr.Column(): | |
model_channels = gr.Textbox(placeholder="Integer, 3 for RGB images", label="Input Channels", value=3) | |
with gr.Column(): | |
model_width = gr.Textbox(placeholder="Integer", label="Input Width", value=640) | |
with gr.Row(): | |
with gr.Column(): | |
model_height = gr.Textbox(placeholder="Integer", label="Input Height", value=480) | |
with gr.Column(): | |
model_clip = gr.Radio(choices=[1, 255], label="Pixel clip", value=1) | |
# Dataset parameters e.g. Torchvision, HuggingFace, local etc. | |
with gr.Column(visible=False) as dataset_params: | |
with gr.Row() as local_image: | |
image = gr.Image(sources=['upload'], type="pil", height=150, width=150, elem_classes="input-image") | |
with gr.Row() as hosted_image: | |
dataset_path = gr.Textbox(placeholder="URL", label="Dataset path") | |
dataset_split = gr.Textbox(placeholder="test", label="Dataset split") | |
model_type.change(show_model_params, model_type, model_params) | |
dataset_type.change(show_dataset_params, dataset_type, [dataset_params, local_image, hosted_image]) | |
# Attack Selection | |
with gr.Row(): | |
with gr.Tab("Info"): | |
gr.Markdown("This is step 2. Select the type of attack for evaluation.") | |
with gr.Tab("White Box"): | |
gr.Markdown("White box attacks assume the attacker has __full access__ to the model.") | |
with gr.Tab("Info"): | |
gr.Markdown("This is step 3. Select the type of white-box attack to evaluate.") | |
with gr.Tab("Evasion"): | |
gr.Markdown("Evasion attacks are deployed to cause a model to incorrectly classify or detect items/objects in an image.") | |
with gr.Tab("Info"): | |
gr.Markdown("This is step 4. Select the type of Evasion attack to evaluate.") | |
with gr.Tab("Projected Gradient Descent"): | |
gr.Markdown("This attack uses PGD to identify adversarial examples.") | |
with gr.Row(): | |
with gr.Column(): | |
attack = gr.Textbox(visible=True, value="PGD", label="Attack", interactive=False) | |
max_iter = gr.Slider(minimum=1, maximum=5000, label="Max iterations", value=1000) | |
eps = gr.Slider(minimum=0.0001, maximum=1, label="Epslion", value=0.05) | |
eps_steps = gr.Slider(minimum=0.001, maximum=1000, label="Epsilon steps", value=0.1) | |
targeted = gr.Textbox(placeholder="Target label (integer)", label="Target") | |
with gr.Accordion("Target mapping", open=False): | |
cifar_labels = gr.Dataframe(pd.DataFrame(['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'], | |
columns=['label']).rename_axis('target').reset_index(), | |
visible=True, elem_classes=["small-font", "df-padding"], | |
type="pandas",interactive=False) | |
xview_labels = gr.Dataframe(pd.DataFrame(['Building', 'Construction Site', 'Engineering Vehicle', 'Fishing Vessel', 'Oil Tanker', | |
'Vehicle Lot'], | |
columns=['label']).rename_axis('target').reset_index(), | |
visible=False, elem_classes=["small-font", "df-padding"], | |
type="pandas",interactive=False) | |
eval_btn_pgd = gr.Button("Evaluate") | |
model_clip.change(pgd_update_epsilon, model_clip, eps) | |
dataset_type.change(show_target_label_dataframe, dataset_type, [cifar_labels, xview_labels]) | |
# Evaluation Output. Visualisations of success/failures of running evaluation attacks. | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
original_gallery = gr.Gallery(label="Original", preview=True, height=600) | |
benign_output = gr.Label(num_top_classes=3, visible=False) | |
clean_accuracy = gr.Number(label="Clean Accuracy", precision=2) | |
with gr.Column(): | |
adversarial_gallery = gr.Gallery(label="Adversarial", preview=True, height=600) | |
adversarial_output = gr.Label(num_top_classes=3, visible=False) | |
robust_accuracy = gr.Number(label="Robust Accuracy", precision=2) | |
perturbation_added = gr.Number(label="Perturbation Added", precision=2) | |
dataset_type.change(pgd_show_label_output, dataset_type, [benign_output, adversarial_output, | |
clean_accuracy, robust_accuracy, perturbation_added]) | |
eval_btn_pgd.click(clf_evasion_evaluate, inputs=[attack, model_type, model_path, model_channels, model_height, model_width, | |
model_clip, max_iter, eps, eps_steps, targeted, | |
dataset_type, dataset_path, dataset_split, image], | |
outputs=[original_gallery, benign_output, adversarial_gallery, adversarial_output, clean_accuracy, | |
robust_accuracy, perturbation_added], api_name='patch') | |
with gr.Row(): | |
clear_btn = gr.ClearButton([image, targeted, original_gallery, benign_output, clean_accuracy, | |
adversarial_gallery, adversarial_output, robust_accuracy, perturbation_added]) | |
with gr.Tab("Adversarial Patch"): | |
gr.Markdown("This attack crafts an adversarial patch that facilitates evasion.") | |
with gr.Row(): | |
with gr.Column(): | |
attack = gr.Textbox(visible=True, value="Adversarial Patch", label="Attack", interactive=False) | |
max_iter = gr.Slider(minimum=1, maximum=5000, label="Max iterations", value=100) | |
x_location = gr.Slider(minimum=1, maximum=640, label="Location (x)", value=18) | |
y_location = gr.Slider(minimum=1, maximum=480, label="Location (y)", value=18) | |
patch_height = gr.Slider(minimum=1, maximum=640, label="Patch height", value=18) | |
patch_width = gr.Slider(minimum=1, maximum=480, label="Patch width", value=18) | |
targeted = gr.Textbox(placeholder="Target label (integer)", label="Target") | |
with gr.Accordion("Target mapping", open=False): | |
cifar_labels = gr.Dataframe(pd.DataFrame(['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'], | |
columns=['label']).rename_axis('target').reset_index(), | |
visible=True, elem_classes=["small-font", "df-padding"], | |
type="pandas",interactive=False) | |
xview_labels = gr.Dataframe(pd.DataFrame(['Building', 'Construction Site', 'Engineering Vehicle', 'Fishing Vessel', 'Oil Tanker', | |
'Vehicle Lot'], | |
columns=['label']).rename_axis('target').reset_index(), | |
visible=False, elem_classes=["small-font", "df-padding"], | |
type="pandas",interactive=False) | |
eval_btn_patch = gr.Button("Evaluate") | |
model_clip.change() | |
dataset_type.change(show_target_label_dataframe, dataset_type, [cifar_labels, xview_labels]) | |
# Evaluation Output. Visualisations of success/failures of running evaluation attacks. | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
original_gallery = gr.Gallery(label="Original", preview=True, height=600) | |
benign_output = gr.Label(num_top_classes=3, visible=False) | |
clean_accuracy = gr.Number(label="Clean Accuracy", precision=2) | |
with gr.Column(): | |
adversarial_gallery = gr.Gallery(label="Adversarial", preview=True, height=600) | |
adversarial_output = gr.Label(num_top_classes=3, visible=False) | |
robust_accuracy = gr.Number(label="Robust Accuracy", precision=2) | |
patch_image = gr.Image(label="Adversarial Patch") | |
dataset_type.change(patch_show_label_output, dataset_type, [benign_output, adversarial_output, | |
clean_accuracy, robust_accuracy, patch_image]) | |
eval_btn_patch.click(clf_evasion_evaluate, inputs=[attack, model_type, model_path, model_channels, model_height, model_width, | |
model_clip, max_iter, x_location, y_location, patch_height, patch_width, targeted, | |
dataset_type, dataset_path, dataset_split, image], | |
outputs=[original_gallery, benign_output, adversarial_gallery, adversarial_output, clean_accuracy, | |
robust_accuracy, patch_image]) | |
with gr.Row(): | |
clear_btn = gr.ClearButton([image, targeted, original_gallery, benign_output, clean_accuracy, | |
adversarial_gallery, adversarial_output, robust_accuracy, patch_image]) | |
with gr.Tab("Poisoning"): | |
gr.Markdown("Coming soon.") | |
with gr.Tab("Black Box"): | |
gr.Markdown("Black box attacks assume the attacker __does not__ have full access to the model but can query it for predictions.") | |
with gr.Tab("Info"): | |
gr.Markdown("This is step 3. Select the type of black-box attack to evaluate.") | |
with gr.Tab("Evasion"): | |
gr.Markdown("Evasion attacks are deployed to cause a model to incorrectly classify or detect items/objects in an image.") | |
with gr.Tab("Info"): | |
gr.Markdown("This is step 4. Select the type of Evasion attack to evaluate.") | |
with gr.Tab("HopSkipJump"): | |
gr.Markdown("Coming soon.") | |
with gr.Tab("Square Attack"): | |
gr.Markdown("Coming soon.") | |
with gr.Tab("AutoAttack"): | |
gr.Markdown("Coming soon.") | |
with gr.Tab("Object Detection"): | |
gr.Markdown("Extracting objects from images and identifying their category.") | |
gr.Markdown("Coming soon.") | |
if __name__ == "__main__": | |
import os, sys, subprocess | |
# Huggingface does not support LFS via external https, disable smudge | |
os.putenv('GIT_LFS_SKIP_SMUDGE', '1') | |
HEART_USER=os.environ['HEART_USER'] | |
HEART_TOKEN=os.environ['HEART_TOKEN'] | |
HEART_INSTALL=f"git+https://{HEART_USER}:{HEART_TOKEN}@gitlab.jatic.net/jatic/ibm/hardened-extension-adversarial-robustness-toolbox.git@HEART-Gradio" | |
subprocess.run([sys.executable, '-m', 'pip', 'install', HEART_INSTALL]) | |
# during development, set debug=True | |
demo.launch() | |