thomaspaniagua
QuadAttack release
71f183c
raw
history blame
2.49 kB
from typing import Any, Union
import ignite.distributed as idist
import torch
from ignite.engine import DeterministicEngine, Engine, Events
from torch.cuda.amp import autocast
from torch.nn import Module
from torch.optim import Optimizer
from torch.utils.data import DistributedSampler, Sampler
def setup_trainer(
config: Any,
model: Module,
optimizer: Optimizer,
loss_fn: Module,
device: Union[str, torch.device],
train_sampler: Sampler,
) -> Union[Engine, DeterministicEngine]:
def train_function(engine: Union[Engine, DeterministicEngine], batch: Any):
if config.overfit:
# No batch norm
model.eval()
else:
model.train()
samples = batch[0].to(device, non_blocking=True)
targets = batch[1].to(device, non_blocking=True)
attack_targets = batch[2].to(device, non_blocking=True)
sample_ids = batch[3].to(device, non_blocking=True)
with autocast(config.use_amp):
outputs = model(samples, attack_targets)
loss = loss_fn(outputs, attack_targets, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
train_loss = loss.item()
engine.state.metrics = {
"epoch": engine.state.epoch,
"train_loss": train_loss,
}
return {"train_loss": train_loss}
trainer = Engine(train_function)
# set epoch for distributed sa5mpler
@trainer.on(Events.EPOCH_STARTED)
def set_epoch():
if idist.get_world_size() > 1 and isinstance(train_sampler, DistributedSampler):
train_sampler.set_epoch(trainer.state.epoch - 1)
return trainer
def setup_evaluator(
config: Any,
model: Module,
device: Union[str, torch.device],
) -> Engine:
@torch.no_grad()
def eval_function(engine: Engine, batch: Any):
model.eval()
samples, gt_labels, attack_targets, sample_ids = batch
samples = samples.to(device, non_blocking=True)
gt_labels = gt_labels.to(device, non_blocking=True)
attack_targets = attack_targets.to(device, non_blocking=True)
sample_ids = sample_ids.to(device, non_blocking=True)
with autocast(config.use_amp):
outputs, perturbations = model(samples, attack_targets, gt_labels)
return outputs, attack_targets, {
"gt_targets": gt_labels,
"perturbations": perturbations
}
return Engine(eval_function)