Spaces:
Sleeping
Sleeping
from typing import Any, Union | |
import ignite.distributed as idist | |
import torch | |
from ignite.engine import DeterministicEngine, Engine, Events | |
from torch.cuda.amp import autocast | |
from torch.nn import Module | |
from torch.optim import Optimizer | |
from torch.utils.data import DistributedSampler, Sampler | |
def setup_trainer( | |
config: Any, | |
model: Module, | |
optimizer: Optimizer, | |
loss_fn: Module, | |
device: Union[str, torch.device], | |
train_sampler: Sampler, | |
) -> Union[Engine, DeterministicEngine]: | |
def train_function(engine: Union[Engine, DeterministicEngine], batch: Any): | |
if config.overfit: | |
# No batch norm | |
model.eval() | |
else: | |
model.train() | |
samples = batch[0].to(device, non_blocking=True) | |
targets = batch[1].to(device, non_blocking=True) | |
attack_targets = batch[2].to(device, non_blocking=True) | |
sample_ids = batch[3].to(device, non_blocking=True) | |
with autocast(config.use_amp): | |
outputs = model(samples, attack_targets) | |
loss = loss_fn(outputs, attack_targets, targets) | |
loss.backward() | |
optimizer.step() | |
optimizer.zero_grad() | |
train_loss = loss.item() | |
engine.state.metrics = { | |
"epoch": engine.state.epoch, | |
"train_loss": train_loss, | |
} | |
return {"train_loss": train_loss} | |
trainer = Engine(train_function) | |
# set epoch for distributed sa5mpler | |
def set_epoch(): | |
if idist.get_world_size() > 1 and isinstance(train_sampler, DistributedSampler): | |
train_sampler.set_epoch(trainer.state.epoch - 1) | |
return trainer | |
def setup_evaluator( | |
config: Any, | |
model: Module, | |
device: Union[str, torch.device], | |
) -> Engine: | |
def eval_function(engine: Engine, batch: Any): | |
model.eval() | |
samples, gt_labels, attack_targets, sample_ids = batch | |
samples = samples.to(device, non_blocking=True) | |
gt_labels = gt_labels.to(device, non_blocking=True) | |
attack_targets = attack_targets.to(device, non_blocking=True) | |
sample_ids = sample_ids.to(device, non_blocking=True) | |
with autocast(config.use_amp): | |
outputs, perturbations = model(samples, attack_targets, gt_labels) | |
return outputs, attack_targets, { | |
"gt_targets": gt_labels, | |
"perturbations": perturbations | |
} | |
return Engine(eval_function) | |