thomaspaniagua
QuadAttack release
71f183c
raw
history blame
No virus
4.9 kB
import sys
from pprint import pformat
from typing import Any
import os
import torch
import ignite.distributed as idist
import yaml
from ignite.engine import Events
from ignite.metrics import Accuracy, Loss
from ignite.utils import manual_seed
from torch import nn, optim
from modelguidedattacks.data.setup import setup_data
from modelguidedattacks.losses.boilerplate import BoilerplateLoss
from modelguidedattacks.losses.energy import Energy, EnergyLoss
from modelguidedattacks.metrics.topk_accuracy import TopKAccuracy
from modelguidedattacks.models import setup_model
from modelguidedattacks.trainers import setup_evaluator, setup_trainer
from modelguidedattacks.utils import setup_parser, setup_output_dir
from modelguidedattacks.utils import setup_logging, log_metrics, Engine
def run(local_rank: int, config: Any):
print ("Running ", local_rank)
# make a certain seed
rank = idist.get_rank()
manual_seed(config.seed + rank)
# create output folder
config.output_dir = setup_output_dir(config, rank)
# setup engines logger with python logging
# print training configurations
logger = setup_logging(config)
logger.info("Configuration: \n%s", pformat(vars(config)))
(config.output_dir / "config-lock.yaml").write_text(yaml.dump(config))
# donwload datasets and create dataloaders
dataloader_train, dataloader_eval = setup_data(config, rank)
# model, optimizer, loss function, device
device = idist.device()
model = idist.auto_model(setup_model(config, idist.device()))
loss_fn = BoilerplateLoss().to(device=device)
l2_energy_loss = Energy(p=2).to(device)
l1_energy_loss = Energy(p=1).to(device)
l_inf_energy_loss = Energy(p=torch.inf).to(device)
evaluator = setup_evaluator(config, model, device)
evaluator.logger = logger
# attach metrics to evaluator
accuracy = TopKAccuracy(device=device)
metrics = {
"ASR": accuracy,
"L2 Energy": EnergyLoss(l2_energy_loss, device=device),
"L1 Energy": EnergyLoss(l1_energy_loss, device=device),
"L_inf Energy": EnergyLoss(l_inf_energy_loss, device=device),
"L2 Energy Min": EnergyLoss(l2_energy_loss, reduction="min", device=device),
"L1 Energy Min": EnergyLoss(l1_energy_loss, reduction="min", device=device),
"L_inf Energy Min": EnergyLoss(l_inf_energy_loss, reduction="min", device=device),
"L2 Energy Max": EnergyLoss(l2_energy_loss, reduction="max", device=device),
"L1 Energy Max": EnergyLoss(l1_energy_loss, reduction="max", device=device),
"L_inf Energy Max": EnergyLoss(l_inf_energy_loss, reduction="max", device=device)
}
for name, metric in metrics.items():
metric.attach(evaluator, name)
if config.guide_model in ["unguided", "instance_guided"]:
first_batch_passed = False
early_stopped = False
def compute_metrics(engine: Engine, tag: str):
nonlocal first_batch_passed
nonlocal early_stopped
for name, metric in metrics.items():
metric.completed(engine, name)
if not first_batch_passed:
if engine.state.metrics["ASR"] < 1e-3:
print ("Early stop, assuming no success throughout")
early_stopped = True
engine.terminate()
else:
first_batch_passed = True
evaluator.add_event_handler(
Events.ITERATION_COMPLETED(every=config.log_every_iters),
compute_metrics,
tag="eval",
)
evaluator.add_event_handler(
Events.ITERATION_COMPLETED(every=config.log_every_iters),
log_metrics,
tag="eval",
)
evaluator.run(dataloader_eval, epoch_length=config.eval_epoch_length)
log_metrics(evaluator, "eval")
if len(config.out_dir) > 0:
# Store results in out_dir
os.makedirs(config.out_dir, exist_ok=True)
metrics_dict = evaluator.state.metrics
metrics_dict["config"] = config
metrics_dict["early_stopped"] = early_stopped
metrics_file_path = os.path.join(config.out_dir, "results.save")
torch.save(metrics_dict, metrics_file_path)
# No need to train with an unguided model
return
assert False, "This code path is for the future"
# main entrypoint
def launch(config=None):
if config is None:
config_path = sys.argv[1]
config = setup_parser(config_path).parse_args(sys.argv[2:])
backend = config.backend
nproc_per_node = config.nproc_per_node
if nproc_per_node == 0 or backend is None:
backend = None
nproc_per_node = None
with idist.Parallel(backend, nproc_per_node) as p:
p.run(run, config=config)
if __name__ == "__main__":
launch()