File size: 2,490 Bytes
71f183c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from typing import Any, Union

import ignite.distributed as idist
import torch
from ignite.engine import DeterministicEngine, Engine, Events
from torch.cuda.amp import autocast
from torch.nn import Module
from torch.optim import Optimizer
from torch.utils.data import DistributedSampler, Sampler


def setup_trainer(
    config: Any,
    model: Module,
    optimizer: Optimizer,
    loss_fn: Module,
    device: Union[str, torch.device],
    train_sampler: Sampler,
) -> Union[Engine, DeterministicEngine]:
    def train_function(engine: Union[Engine, DeterministicEngine], batch: Any):
        if config.overfit:
            # No batch norm
            model.eval()
        else:
            model.train()

        samples = batch[0].to(device, non_blocking=True)
        targets = batch[1].to(device, non_blocking=True)
        attack_targets = batch[2].to(device, non_blocking=True)
        sample_ids = batch[3].to(device, non_blocking=True)

        with autocast(config.use_amp):
            outputs = model(samples, attack_targets)
            loss = loss_fn(outputs, attack_targets, targets)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        train_loss = loss.item()
        engine.state.metrics = {
            "epoch": engine.state.epoch,
            "train_loss": train_loss,
        }
        return {"train_loss": train_loss}


    trainer = Engine(train_function)

    # set epoch for distributed sa5mpler
    @trainer.on(Events.EPOCH_STARTED)
    def set_epoch():
        if idist.get_world_size() > 1 and isinstance(train_sampler, DistributedSampler):
            train_sampler.set_epoch(trainer.state.epoch - 1)

    return trainer


def setup_evaluator(
    config: Any,
    model: Module,
    device: Union[str, torch.device],
) -> Engine:
    @torch.no_grad()
    def eval_function(engine: Engine, batch: Any):
        model.eval()

        samples, gt_labels, attack_targets, sample_ids = batch

        samples = samples.to(device, non_blocking=True)
        gt_labels = gt_labels.to(device, non_blocking=True)
        attack_targets = attack_targets.to(device, non_blocking=True)
        sample_ids = sample_ids.to(device, non_blocking=True)

        with autocast(config.use_amp):
            outputs, perturbations = model(samples, attack_targets, gt_labels)

        return outputs, attack_targets, {
            "gt_targets": gt_labels,
            "perturbations": perturbations
        }

    return Engine(eval_function)