File size: 985 Bytes
71f183c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from torchvision import models

from modelguidedattacks.guides.instance_guide import InstanceGuide
from modelguidedattacks.guides.unguided import Unguided
from modelguidedattacks import losses

from .cls_models.registry import get_model

guide_model_registry = {
    "instance_guided": InstanceGuide,
    "unguided": Unguided
}

loss_registry = {
    "cvxproj": losses.CVXProjLoss,
    "cwk": losses.CWExtensionLoss,
    "ad": losses.AdversarialDistillationLoss
}

def setup_model(config, device):
    model = get_model(config.dataset, config.model, device)

    kwargs = {}

    if config.guide_model == "unguided":
        kwargs["iterations"] = config.unguided_iterations
        kwargs["lr"] = config.unguided_lr
        kwargs["loss_fn"] = loss_registry[config.loss]
        kwargs["binary_search_steps"] = config.binary_search_steps
        kwargs["topk_loss_coef_upper"] = config.topk_loss_coef_upper

    return guide_model_registry[config.guide_model](model, config, **kwargs)