HEART-Gradio / app.py
lockwooda's picture
Initial manual application publish
history blame
32.2 kB
HEART Gradio Example App
To run:
- clone the repository
- execute: gradio examples/gradio_app.py or python examples/gradio_app.py
- navigate to local URL e.g.
import torch
import numpy as np
import pandas as pd
# from carbon_theme import Carbon
import gradio as gr
import os
css = """
.input-image { margin: auto !important }
.small-font span{
font-size: 0.6em;
.df-padding {
padding-left: 50px !important;
padding-right: 50px !important;
def basic_cifar10_model():
Load an example CIFAR10 model
from heart.estimators.classification.pytorch import JaticPyTorchClassifier
labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
path = './'
class Model(torch.nn.Module):
Create model for pytorch.
Here the model does not use maxpooling. Needed for certification tests.
def __init__(self):
super(Model, self).__init__()
self.conv = torch.nn.Conv2d(
in_channels=3, out_channels=16, kernel_size=(4, 4), dilation=(1, 1), padding=(0, 0), stride=(3, 3)
self.fullyconnected = torch.nn.Linear(in_features=1600, out_features=10)
self.relu = torch.nn.ReLU()
w_conv2d = np.load(
b_conv2d = np.load(
w_dense = np.load(
b_dense = np.load(
self.conv.weight = torch.nn.Parameter(torch.Tensor(w_conv2d))
self.conv.bias = torch.nn.Parameter(torch.Tensor(b_conv2d))
self.fullyconnected.weight = torch.nn.Parameter(torch.Tensor(w_dense))
self.fullyconnected.bias = torch.nn.Parameter(torch.Tensor(b_dense))
# pylint: disable=W0221
# disable pylint because of API requirements for function
def forward(self, x):
Forward function to evaluate the model
:param x: Input to the model
:return: Prediction of the model
x = self.conv(x)
x = self.relu(x)
x = x.reshape(-1, 1600)
x = self.fullyconnected(x)
return x
# Define the network
model = Model()
# Define a loss function and optimizer
loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# Get classifier
jptc = JaticPyTorchClassifier(
model=model, loss=loss_fn, optimizer=optimizer, input_shape=(3, 32, 32), nb_classes=10, clip_values=(0, 1), labels=labels
return jptc
def clf_evasion_evaluate(*args):
Run a classification task evaluation
attack = args[0]
model_type = args[1]
model_path = args[2]
model_channels = args[3]
model_height = args[4]
model_width = args[5]
model_clip = args[6]
dataset_type = args[-4]
dataset_path = args[-3]
dataset_split = args[-2]
image = args[-1]
if dataset_type == "Example XView":
from maite import load_dataset
import torchvision
jatic_dataset = load_dataset(
IMAGE_H, IMAGE_W = 224, 224
transform = torchvision.transforms.Compose(
torchvision.transforms.Resize((IMAGE_H, IMAGE_W)),
jatic_dataset.set_transform(lambda x: {"image": transform(x["image"]), "label": x["label"]})
image = {'image': [i['image'].numpy() for i in jatic_dataset],
'label': [i['label'] for i in jatic_dataset]}
elif dataset_type=="huggingface":
from maite import load_dataset
jatic_dataset = load_dataset(
image = {'image': [i['image'] for i in jatic_dataset],
'label': [i['label'] for i in jatic_dataset]}
elif dataset_type=="torchvision":
from maite import load_dataset
jatic_dataset = load_dataset(
image = {'image': [i['image'] for i in jatic_dataset],
'label': [i['label'] for i in jatic_dataset]}
elif dataset_type=="Example CIFAR10":
from maite import load_dataset
jatic_dataset = load_dataset(
image = {'image': [i['image'] for i in jatic_dataset][:100],
'label': [i['label'] for i in jatic_dataset][:100]}
if model_type == "Example CIFAR10":
jptc = basic_cifar10_model()
elif model_type == "Example XView":
import torchvision
from heart.estimators.classification.pytorch import JaticPyTorchClassifier
classes = {
1:'Construction Site',
2:'Engineering Vehicle',
3:'Fishing Vessel',
4:'Oil Tanker',
5:'Vehicle Lot'
model = torchvision.models.resnet18(False)
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, len(classes.keys()))
_ = model.eval()
jptc = JaticPyTorchClassifier(
model=model, loss = torch.nn.CrossEntropyLoss(), input_shape=(3, 224, 224),
nb_classes=len(classes), clip_values=(0, 1), labels=list(classes.values())
elif model_type == "torchvision":
from maite.interop.torchvision import TorchVisionClassifier
from heart.estimators.classification.pytorch import JaticPyTorchClassifier
clf = TorchVisionClassifier.from_pretrained(model_path)
loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
jptc = JaticPyTorchClassifier(
model=clf._model, loss=loss_fn, input_shape=(model_channels, model_height, model_width),
nb_classes=len(clf._labels), clip_values=(0, model_clip), labels=clf._labels
elif model_type == "huggingface":
from maite.interop.huggingface import HuggingFaceImageClassifier
from heart.estimators.classification.pytorch import JaticPyTorchClassifier
clf = HuggingFaceImageClassifier.from_pretrained(model_path)
loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
jptc = JaticPyTorchClassifier(
model=clf._model, loss=loss_fn, input_shape=(model_channels, model_height, model_width),
nb_classes=len(clf._labels), clip_values=(0, model_clip), labels=clf._labels
if attack=="PGD":
from art.attacks.evasion.projected_gradient_descent.projected_gradient_descent_pytorch import ProjectedGradientDescentPyTorch
from heart.attacks.attack import JaticAttack
from heart.metrics import AccuracyPerturbationMetric
from torch.nn.functional import softmax
from maite.protocols import HasDataImage, is_typed_dict, ArrayLike
pgd_attack = ProjectedGradientDescentPyTorch(estimator=jptc, max_iter=args[7], eps=args[8],
eps_step=args[9], targeted=args[10]!="")
attack = JaticAttack(pgd_attack)
preds = jptc(image)
preds = softmax(torch.from_numpy(preds.logits), dim=1)
labels = {}
for i, label in enumerate(jptc.get_labels()):
labels[label] = preds[0][i]
if args[10]!="":
if is_typed_dict(image, HasDataImage):
data = {'image': image['image'], 'label': [args[10]]*len(image['image'])}
data = {'image': image, 'label': [args[10]]}
data = image
x_adv = attack.run_attack(data=data)
adv_preds = jptc(x_adv.adversarial_examples)
adv_preds = softmax(torch.from_numpy(adv_preds.logits), dim=1)
adv_labels = {}
for i, label in enumerate(jptc.get_labels()):
adv_labels[label] = adv_preds[0][i]
metric = AccuracyPerturbationMetric()
metric.update(jptc, jptc.device, image, x_adv.adversarial_examples)
clean_accuracy, robust_accuracy, perturbation_added = metric.compute()
metrics = pd.DataFrame([[clean_accuracy, robust_accuracy, perturbation_added]],
columns=['clean accuracy', 'robust accuracy', 'perturbation'])
adv_imgs = [img.transpose(1,2,0) for img in x_adv.adversarial_examples]
if is_typed_dict(image, HasDataImage):
image = image['image']
if not isinstance(image, list):
image = [image]
# in case where multiple images, use argmax to get the predicted label and add as caption
if dataset_type!="local":
temp = []
for i, img in enumerate(image):
if isinstance(img, ArrayLike):
temp.append((img.transpose(1,2,0), str(jptc.get_labels()[np.argmax(preds[i])]) ))
temp.append((img, str(jptc.get_labels()[np.argmax(preds[i])]) ))
image = temp
temp = []
for i, img in enumerate(adv_imgs):
temp.append((img, str(jptc.get_labels()[np.argmax(adv_preds[i])]) ))
adv_imgs = temp
return [image, labels, adv_imgs, adv_labels, clean_accuracy, robust_accuracy, perturbation_added]
elif attack=="Adversarial Patch":
from art.attacks.evasion.adversarial_patch.adversarial_patch_pytorch import AdversarialPatchPyTorch
from heart.attacks.attack import JaticAttack
from heart.metrics import AccuracyPerturbationMetric
from torch.nn.functional import softmax
from maite.protocols import HasDataImage, is_typed_dict, ArrayLike
batch_size = 16
scale_min = 0.3
scale_max = 1.0
rotation_max = 0
learning_rate = 5000.
max_iter = 2000
patch_shape = (3, 14, 14)
patch_location = (18,18)
patch_attack = AdversarialPatchPyTorch(estimator=jptc, rotation_max=rotation_max, patch_location=(args[8], args[9]),
scale_min=scale_min, scale_max=scale_max, patch_type='square',
learning_rate=learning_rate, max_iter=args[7], batch_size=batch_size,
patch_shape=(3, args[10], args[11]), verbose=False, targeted=args[12]!="")
attack = JaticAttack(patch_attack)
preds = jptc(image)
preds = softmax(torch.from_numpy(preds.logits), dim=1)
labels = {}
for i, label in enumerate(jptc.get_labels()):
labels[label] = preds[0][i]
if args[12]!="":
if is_typed_dict(image, HasDataImage):
data = {'image': image['image'], 'label': [args[12]]*len(image['image'])}
data = {'image': image, 'label': [args[12]]}
data = image
attack_output = attack.run_attack(data=data)
adv_preds = jptc(attack_output.adversarial_examples)
adv_preds = softmax(torch.from_numpy(adv_preds.logits), dim=1)
adv_labels = {}
for i, label in enumerate(jptc.get_labels()):
adv_labels[label] = adv_preds[0][i]
metric = AccuracyPerturbationMetric()
metric.update(jptc, jptc.device, image, attack_output.adversarial_examples)
clean_accuracy, robust_accuracy, perturbation_added = metric.compute()
metrics = pd.DataFrame([[clean_accuracy, robust_accuracy, perturbation_added]],
columns=['clean accuracy', 'robust accuracy', 'perturbation'])
adv_imgs = [img.transpose(1,2,0) for img in attack_output.adversarial_examples]
if is_typed_dict(image, HasDataImage):
image = image['image']
if not isinstance(image, list):
image = [image]
# in case where multiple images, use argmax to get the predicted label and add as caption
if dataset_type!="local":
temp = []
for i, img in enumerate(image):
if isinstance(img, ArrayLike):
temp.append((img.transpose(1,2,0), str(jptc.get_labels()[np.argmax(preds[i])]) ))
temp.append((img, str(jptc.get_labels()[np.argmax(preds[i])]) ))
image = temp
temp = []
for i, img in enumerate(adv_imgs):
temp.append((img, str(jptc.get_labels()[np.argmax(adv_preds[i])]) ))
adv_imgs = temp
patch, patch_mask = attack_output.adversarial_patch
patch_image = ((patch) * patch_mask).transpose(1,2,0)
return [image, labels, adv_imgs, adv_labels, clean_accuracy, robust_accuracy, patch_image]
def show_model_params(model_type):
Show model parameters based on selected model type
if model_type!="Example CIFAR10" and model_type!="Example XView":
return gr.Column(visible=True)
return gr.Column(visible=False)
def show_dataset_params(dataset_type):
Show dataset parameters based on dataset type
if dataset_type=="Example CIFAR10" or dataset_type=="Example XView":
return [gr.Column(visible=False), gr.Row(visible=False), gr.Row(visible=False)]
elif dataset_type=="local":
return [gr.Column(visible=True), gr.Row(visible=True), gr.Row(visible=False)]
return [gr.Column(visible=True), gr.Row(visible=False), gr.Row(visible=True)]
def pgd_show_label_output(dataset_type):
Show PGD output component based on dataset type
if dataset_type=="local":
return [gr.Label(visible=True), gr.Label(visible=True), gr.Number(visible=False), gr.Number(visible=False), gr.Number(visible=True)]
return [gr.Label(visible=False), gr.Label(visible=False), gr.Number(visible=True), gr.Number(visible=True), gr.Number(visible=True)]
def pgd_update_epsilon(clip_values):
Update max value of PGD epsilon slider based on model clip values
if clip_values == 255:
return gr.Slider(minimum=0.0001, maximum=255, label="Epslion", value=55)
return gr.Slider(minimum=0.0001, maximum=1, label="Epslion", value=0.05)
def patch_show_label_output(dataset_type):
Show adversarial patch output components based on dataset type
if dataset_type=="local":
return [gr.Label(visible=True), gr.Label(visible=True), gr.Number(visible=False), gr.Number(visible=False), gr.Number(visible=True)]
return [gr.Label(visible=False), gr.Label(visible=False), gr.Number(visible=True), gr.Number(visible=True), gr.Number(visible=True)]
def show_target_label_dataframe(dataset_type):
if dataset_type == "Example CIFAR10":
return gr.Dataframe(visible=True), gr.Dataframe(visible=False)
elif dataset_type == "Example XView":
return gr.Dataframe(visible=False), gr.Dataframe(visible=True)
return gr.Dataframe(visible=False), gr.Dataframe(visible=False)
# e.g. To use a local alternative theme: carbon_theme = Carbon()
with gr.Blocks(css=css, theme='xiaobaiyuan/theme_brief') as demo:
gr.Markdown("<h1>HEART Adversarial Robustness Gradio Example</h1>")
with gr.Tab("Info"):
gr.Markdown('This is step 1. Using the tabs, select a task for evaluation.')
with gr.Tab("Classification", elem_classes="task-tab"):
gr.Markdown("Classifying images with a set of categories.")
# Model and Dataset Selection
with gr.Row():
# Model and Dataset type e.g. Torchvision, HuggingFace, local etc.
with gr.Column():
model_type = gr.Radio(label="Model type", choices=["Example CIFAR10", "Example XView", "torchvision"],
value="Example CIFAR10")
dataset_type = gr.Radio(label="Dataset", choices=["Example CIFAR10", "Example XView", "local", "torchvision", "huggingface"],
value="Example CIFAR10")
# Model parameters e.g. RESNET, VIT, input dimensions, clipping values etc.
with gr.Column(visible=False) as model_params:
model_path = gr.Textbox(placeholder="URL", label="Model path")
with gr.Row():
with gr.Column():
model_channels = gr.Textbox(placeholder="Integer, 3 for RGB images", label="Input Channels", value=3)
with gr.Column():
model_width = gr.Textbox(placeholder="Integer", label="Input Width", value=640)
with gr.Row():
with gr.Column():
model_height = gr.Textbox(placeholder="Integer", label="Input Height", value=480)
with gr.Column():
model_clip = gr.Radio(choices=[1, 255], label="Pixel clip", value=1)
# Dataset parameters e.g. Torchvision, HuggingFace, local etc.
with gr.Column(visible=False) as dataset_params:
with gr.Row() as local_image:
image = gr.Image(sources=['upload'], type="pil", height=150, width=150, elem_classes="input-image")
with gr.Row() as hosted_image:
dataset_path = gr.Textbox(placeholder="URL", label="Dataset path")
dataset_split = gr.Textbox(placeholder="test", label="Dataset split")
model_type.change(show_model_params, model_type, model_params)
dataset_type.change(show_dataset_params, dataset_type, [dataset_params, local_image, hosted_image])
# Attack Selection
with gr.Row():
with gr.Tab("Info"):
gr.Markdown("This is step 2. Select the type of attack for evaluation.")
with gr.Tab("White Box"):
gr.Markdown("White box attacks assume the attacker has __full access__ to the model.")
with gr.Tab("Info"):
gr.Markdown("This is step 3. Select the type of white-box attack to evaluate.")
with gr.Tab("Evasion"):
gr.Markdown("Evasion attacks are deployed to cause a model to incorrectly classify or detect items/objects in an image.")
with gr.Tab("Info"):
gr.Markdown("This is step 4. Select the type of Evasion attack to evaluate.")
with gr.Tab("Projected Gradient Descent"):
gr.Markdown("This attack uses PGD to identify adversarial examples.")
with gr.Row():
with gr.Column():
attack = gr.Textbox(visible=True, value="PGD", label="Attack", interactive=False)
max_iter = gr.Slider(minimum=1, maximum=5000, label="Max iterations", value=1000)
eps = gr.Slider(minimum=0.0001, maximum=1, label="Epslion", value=0.05)
eps_steps = gr.Slider(minimum=0.001, maximum=1000, label="Epsilon steps", value=0.1)
targeted = gr.Textbox(placeholder="Target label (integer)", label="Target")
with gr.Accordion("Target mapping", open=False):
cifar_labels = gr.Dataframe(pd.DataFrame(['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
visible=True, elem_classes=["small-font", "df-padding"],
xview_labels = gr.Dataframe(pd.DataFrame(['Building', 'Construction Site', 'Engineering Vehicle', 'Fishing Vessel', 'Oil Tanker',
'Vehicle Lot'],
visible=False, elem_classes=["small-font", "df-padding"],
eval_btn_pgd = gr.Button("Evaluate")
model_clip.change(pgd_update_epsilon, model_clip, eps)
dataset_type.change(show_target_label_dataframe, dataset_type, [cifar_labels, xview_labels])
# Evaluation Output. Visualisations of success/failures of running evaluation attacks.
with gr.Column():
with gr.Row():
with gr.Column():
original_gallery = gr.Gallery(label="Original", preview=True, height=600)
benign_output = gr.Label(num_top_classes=3, visible=False)
clean_accuracy = gr.Number(label="Clean Accuracy", precision=2)
with gr.Column():
adversarial_gallery = gr.Gallery(label="Adversarial", preview=True, height=600)
adversarial_output = gr.Label(num_top_classes=3, visible=False)
robust_accuracy = gr.Number(label="Robust Accuracy", precision=2)
perturbation_added = gr.Number(label="Perturbation Added", precision=2)
dataset_type.change(pgd_show_label_output, dataset_type, [benign_output, adversarial_output,
clean_accuracy, robust_accuracy, perturbation_added])
eval_btn_pgd.click(clf_evasion_evaluate, inputs=[attack, model_type, model_path, model_channels, model_height, model_width,
model_clip, max_iter, eps, eps_steps, targeted,
dataset_type, dataset_path, dataset_split, image],
outputs=[original_gallery, benign_output, adversarial_gallery, adversarial_output, clean_accuracy,
robust_accuracy, perturbation_added], api_name='patch')
with gr.Row():
clear_btn = gr.ClearButton([image, targeted, original_gallery, benign_output, clean_accuracy,
adversarial_gallery, adversarial_output, robust_accuracy, perturbation_added])
with gr.Tab("Adversarial Patch"):
gr.Markdown("This attack crafts an adversarial patch that facilitates evasion.")
with gr.Row():
with gr.Column():
attack = gr.Textbox(visible=True, value="Adversarial Patch", label="Attack", interactive=False)
max_iter = gr.Slider(minimum=1, maximum=5000, label="Max iterations", value=100)
x_location = gr.Slider(minimum=1, maximum=640, label="Location (x)", value=18)
y_location = gr.Slider(minimum=1, maximum=480, label="Location (y)", value=18)
patch_height = gr.Slider(minimum=1, maximum=640, label="Patch height", value=18)
patch_width = gr.Slider(minimum=1, maximum=480, label="Patch width", value=18)
targeted = gr.Textbox(placeholder="Target label (integer)", label="Target")
with gr.Accordion("Target mapping", open=False):
cifar_labels = gr.Dataframe(pd.DataFrame(['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
visible=True, elem_classes=["small-font", "df-padding"],
xview_labels = gr.Dataframe(pd.DataFrame(['Building', 'Construction Site', 'Engineering Vehicle', 'Fishing Vessel', 'Oil Tanker',
'Vehicle Lot'],
visible=False, elem_classes=["small-font", "df-padding"],
eval_btn_patch = gr.Button("Evaluate")
dataset_type.change(show_target_label_dataframe, dataset_type, [cifar_labels, xview_labels])
# Evaluation Output. Visualisations of success/failures of running evaluation attacks.
with gr.Column():
with gr.Row():
with gr.Column():
original_gallery = gr.Gallery(label="Original", preview=True, height=600)
benign_output = gr.Label(num_top_classes=3, visible=False)
clean_accuracy = gr.Number(label="Clean Accuracy", precision=2)
with gr.Column():
adversarial_gallery = gr.Gallery(label="Adversarial", preview=True, height=600)
adversarial_output = gr.Label(num_top_classes=3, visible=False)
robust_accuracy = gr.Number(label="Robust Accuracy", precision=2)
patch_image = gr.Image(label="Adversarial Patch")
dataset_type.change(patch_show_label_output, dataset_type, [benign_output, adversarial_output,
clean_accuracy, robust_accuracy, patch_image])
eval_btn_patch.click(clf_evasion_evaluate, inputs=[attack, model_type, model_path, model_channels, model_height, model_width,
model_clip, max_iter, x_location, y_location, patch_height, patch_width, targeted,
dataset_type, dataset_path, dataset_split, image],
outputs=[original_gallery, benign_output, adversarial_gallery, adversarial_output, clean_accuracy,
robust_accuracy, patch_image])
with gr.Row():
clear_btn = gr.ClearButton([image, targeted, original_gallery, benign_output, clean_accuracy,
adversarial_gallery, adversarial_output, robust_accuracy, patch_image])
with gr.Tab("Poisoning"):
gr.Markdown("Coming soon.")
with gr.Tab("Black Box"):
gr.Markdown("Black box attacks assume the attacker __does not__ have full access to the model but can query it for predictions.")
with gr.Tab("Info"):
gr.Markdown("This is step 3. Select the type of black-box attack to evaluate.")
with gr.Tab("Evasion"):
gr.Markdown("Evasion attacks are deployed to cause a model to incorrectly classify or detect items/objects in an image.")
with gr.Tab("Info"):
gr.Markdown("This is step 4. Select the type of Evasion attack to evaluate.")
with gr.Tab("HopSkipJump"):
gr.Markdown("Coming soon.")
with gr.Tab("Square Attack"):
gr.Markdown("Coming soon.")
with gr.Tab("AutoAttack"):
gr.Markdown("Coming soon.")
with gr.Tab("Object Detection"):
gr.Markdown("Extracting objects from images and identifying their category.")
gr.Markdown("Coming soon.")
if __name__ == "__main__":
import os, sys, subprocess
# Huggingface does not support LFS via external https, disable smudge
os.putenv('GIT_LFS_SKIP_SMUDGE', '1')
subprocess.run([sys.executable, '-m', 'pip', 'install', HEART_INSTALL])
# during development, set debug=True