File size: 4,897 Bytes
71f183c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import sys
from pprint import pformat
from typing import Any
import os
import torch

import ignite.distributed as idist
import yaml
from ignite.engine import Events
from ignite.metrics import Accuracy, Loss
from ignite.utils import manual_seed
from torch import nn, optim

from modelguidedattacks.data.setup import setup_data
from modelguidedattacks.losses.boilerplate import BoilerplateLoss
from modelguidedattacks.losses.energy import Energy, EnergyLoss
from modelguidedattacks.metrics.topk_accuracy import TopKAccuracy
from modelguidedattacks.models import setup_model
from modelguidedattacks.trainers import setup_evaluator, setup_trainer
from modelguidedattacks.utils import setup_parser, setup_output_dir
from modelguidedattacks.utils import setup_logging, log_metrics, Engine

def run(local_rank: int, config: Any):

    print ("Running ", local_rank)
    # make a certain seed
    rank = idist.get_rank()
    manual_seed(config.seed + rank)

    # create output folder
    config.output_dir = setup_output_dir(config, rank)

    # setup engines logger with python logging
    # print training configurations
    logger = setup_logging(config)
    logger.info("Configuration: \n%s", pformat(vars(config)))
    (config.output_dir / "config-lock.yaml").write_text(yaml.dump(config))

    # donwload datasets and create dataloaders
    dataloader_train, dataloader_eval = setup_data(config, rank)

    # model, optimizer, loss function, device
    device = idist.device()
    model = idist.auto_model(setup_model(config, idist.device()))
    loss_fn = BoilerplateLoss().to(device=device)
    l2_energy_loss = Energy(p=2).to(device)
    l1_energy_loss = Energy(p=1).to(device)
    l_inf_energy_loss = Energy(p=torch.inf).to(device)

    evaluator = setup_evaluator(config, model, device)
    evaluator.logger = logger

    # attach metrics to evaluator
    accuracy = TopKAccuracy(device=device)
    metrics = {
        "ASR": accuracy,
        "L2 Energy": EnergyLoss(l2_energy_loss, device=device),
        "L1 Energy": EnergyLoss(l1_energy_loss, device=device),
        "L_inf Energy": EnergyLoss(l_inf_energy_loss, device=device),

        "L2 Energy Min": EnergyLoss(l2_energy_loss, reduction="min", device=device),
        "L1 Energy Min": EnergyLoss(l1_energy_loss, reduction="min", device=device),
        "L_inf Energy Min": EnergyLoss(l_inf_energy_loss, reduction="min", device=device),

        "L2 Energy Max": EnergyLoss(l2_energy_loss, reduction="max", device=device),
        "L1 Energy Max": EnergyLoss(l1_energy_loss, reduction="max", device=device),
        "L_inf Energy Max": EnergyLoss(l_inf_energy_loss, reduction="max", device=device)
    }
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    if config.guide_model in ["unguided", "instance_guided"]:

        first_batch_passed = False
        early_stopped = False

        def compute_metrics(engine: Engine, tag: str):
            nonlocal first_batch_passed
            nonlocal early_stopped

            for name, metric in metrics.items():
                metric.completed(engine, name)

            if not first_batch_passed:
                if engine.state.metrics["ASR"] < 1e-3:
                    print ("Early stop, assuming no success throughout")
                    early_stopped = True
                    engine.terminate()
                else:
                    first_batch_passed = True

        evaluator.add_event_handler(
            Events.ITERATION_COMPLETED(every=config.log_every_iters),
            compute_metrics,
            tag="eval",
        )

        evaluator.add_event_handler(
            Events.ITERATION_COMPLETED(every=config.log_every_iters),
            log_metrics,
            tag="eval",
        )

        evaluator.run(dataloader_eval, epoch_length=config.eval_epoch_length)
        log_metrics(evaluator, "eval")

        if len(config.out_dir) > 0:
            # Store results in out_dir
            os.makedirs(config.out_dir, exist_ok=True)
            metrics_dict = evaluator.state.metrics
            metrics_dict["config"] = config
            metrics_dict["early_stopped"] = early_stopped

            metrics_file_path = os.path.join(config.out_dir, "results.save")
            torch.save(metrics_dict, metrics_file_path)

        # No need to train with an unguided model
        return

    assert False, "This code path is for the future"

# main entrypoint
def launch(config=None):
    if config is None:
        config_path = sys.argv[1]
        config = setup_parser(config_path).parse_args(sys.argv[2:])
        
    backend = config.backend
    nproc_per_node = config.nproc_per_node

    if nproc_per_node == 0 or backend is None:
        backend = None
        nproc_per_node = None

    with idist.Parallel(backend, nproc_per_node) as p:
        p.run(run, config=config)


if __name__ == "__main__":
    launch()