Spaces:
Sleeping
Sleeping
File size: 4,897 Bytes
71f183c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import sys
from pprint import pformat
from typing import Any
import os
import torch
import ignite.distributed as idist
import yaml
from ignite.engine import Events
from ignite.metrics import Accuracy, Loss
from ignite.utils import manual_seed
from torch import nn, optim
from modelguidedattacks.data.setup import setup_data
from modelguidedattacks.losses.boilerplate import BoilerplateLoss
from modelguidedattacks.losses.energy import Energy, EnergyLoss
from modelguidedattacks.metrics.topk_accuracy import TopKAccuracy
from modelguidedattacks.models import setup_model
from modelguidedattacks.trainers import setup_evaluator, setup_trainer
from modelguidedattacks.utils import setup_parser, setup_output_dir
from modelguidedattacks.utils import setup_logging, log_metrics, Engine
def run(local_rank: int, config: Any):
print ("Running ", local_rank)
# make a certain seed
rank = idist.get_rank()
manual_seed(config.seed + rank)
# create output folder
config.output_dir = setup_output_dir(config, rank)
# setup engines logger with python logging
# print training configurations
logger = setup_logging(config)
logger.info("Configuration: \n%s", pformat(vars(config)))
(config.output_dir / "config-lock.yaml").write_text(yaml.dump(config))
# donwload datasets and create dataloaders
dataloader_train, dataloader_eval = setup_data(config, rank)
# model, optimizer, loss function, device
device = idist.device()
model = idist.auto_model(setup_model(config, idist.device()))
loss_fn = BoilerplateLoss().to(device=device)
l2_energy_loss = Energy(p=2).to(device)
l1_energy_loss = Energy(p=1).to(device)
l_inf_energy_loss = Energy(p=torch.inf).to(device)
evaluator = setup_evaluator(config, model, device)
evaluator.logger = logger
# attach metrics to evaluator
accuracy = TopKAccuracy(device=device)
metrics = {
"ASR": accuracy,
"L2 Energy": EnergyLoss(l2_energy_loss, device=device),
"L1 Energy": EnergyLoss(l1_energy_loss, device=device),
"L_inf Energy": EnergyLoss(l_inf_energy_loss, device=device),
"L2 Energy Min": EnergyLoss(l2_energy_loss, reduction="min", device=device),
"L1 Energy Min": EnergyLoss(l1_energy_loss, reduction="min", device=device),
"L_inf Energy Min": EnergyLoss(l_inf_energy_loss, reduction="min", device=device),
"L2 Energy Max": EnergyLoss(l2_energy_loss, reduction="max", device=device),
"L1 Energy Max": EnergyLoss(l1_energy_loss, reduction="max", device=device),
"L_inf Energy Max": EnergyLoss(l_inf_energy_loss, reduction="max", device=device)
}
for name, metric in metrics.items():
metric.attach(evaluator, name)
if config.guide_model in ["unguided", "instance_guided"]:
first_batch_passed = False
early_stopped = False
def compute_metrics(engine: Engine, tag: str):
nonlocal first_batch_passed
nonlocal early_stopped
for name, metric in metrics.items():
metric.completed(engine, name)
if not first_batch_passed:
if engine.state.metrics["ASR"] < 1e-3:
print ("Early stop, assuming no success throughout")
early_stopped = True
engine.terminate()
else:
first_batch_passed = True
evaluator.add_event_handler(
Events.ITERATION_COMPLETED(every=config.log_every_iters),
compute_metrics,
tag="eval",
)
evaluator.add_event_handler(
Events.ITERATION_COMPLETED(every=config.log_every_iters),
log_metrics,
tag="eval",
)
evaluator.run(dataloader_eval, epoch_length=config.eval_epoch_length)
log_metrics(evaluator, "eval")
if len(config.out_dir) > 0:
# Store results in out_dir
os.makedirs(config.out_dir, exist_ok=True)
metrics_dict = evaluator.state.metrics
metrics_dict["config"] = config
metrics_dict["early_stopped"] = early_stopped
metrics_file_path = os.path.join(config.out_dir, "results.save")
torch.save(metrics_dict, metrics_file_path)
# No need to train with an unguided model
return
assert False, "This code path is for the future"
# main entrypoint
def launch(config=None):
if config is None:
config_path = sys.argv[1]
config = setup_parser(config_path).parse_args(sys.argv[2:])
backend = config.backend
nproc_per_node = config.nproc_per_node
if nproc_per_node == 0 or backend is None:
backend = None
nproc_per_node = None
with idist.Parallel(backend, nproc_per_node) as p:
p.run(run, config=config)
if __name__ == "__main__":
launch()
|