File size: 4,062 Bytes
71f183c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import logging
import os

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from modelguidedattacks.data import get_dataset
from . import get_model

from .registry import ClsModel
from typing import Optional, List

DATASET_METADATA_DIR = "./dataset_metadata"

def correct_subset_cache_path(dataset_name: str, model_name: str, train: bool):
    filename_train_val = "train" if train else "val"
    subset_cache_filename = f"{dataset_name}_{model_name}_{filename_train_val}.p"
    subset_cache_path = os.path.join(DATASET_METADATA_DIR, subset_cache_filename)

    return subset_cache_path

@torch.no_grad()
def get_correct_subset(model: Optional[ClsModel]=None, dataset_name: Optional[str]=None, 
                       model_name: Optional[str]=None, train=True, batch_size=256, 
                       force_cache=False, device="cuda"):
    """
    model: Model to evaluate
    dataset_name: Name of dataset (not needed if model is provided)
    model_name: Name of model (not needed if model is provided)
    train: Use training dataset
    batch_size: Batch size to use while evaluating
    force_cache: Only read from cache and fail if not available

    Returns indices in dataset of correctly classified items
    """

    if model is not None:
        assert dataset_name is None
        assert model_name is None

    if dataset_name is not None or model_name is not None:
        assert dataset_name is not None
        assert model_name is not None
        assert model is None

    if dataset_name is None:
        dataset_name = model.dataset_name

    if model_name is None:
        model_name = model.model_name

    filename_train_val = "train" if train else "val"
    subset_cache_filename = f"{dataset_name}_{model_name}_{filename_train_val}.p"
    subset_cache_path = os.path.join(DATASET_METADATA_DIR, subset_cache_filename)

    os.makedirs(DATASET_METADATA_DIR, exist_ok=True)

    if os.path.exists(subset_cache_path):
        correct_subset = torch.load(subset_cache_path)
        return correct_subset

    if force_cache:
        raise Exception("Cache not found and requested for cached correct subset.")

    logging.info(f"No cache found. Computing correct subset for {dataset_name}-{model_name} Train: {train}")

    device = device if model is None else model.device    

    if model is None:
        model = get_model(dataset_name, model_name, device)
    
    model.eval()

    train_dataset, val_dataset = get_dataset(dataset_name)

    dataset = train_dataset
    
    if not train:
        dataset = val_dataset

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    correct_indices = []

    for batch_i, (batch_imgs, batch_gt_class) in tqdm(enumerate(dataloader), total=len(dataloader)):
        if torch.device(model.device).type.startswith("cuda"):
            torch.cuda.synchronize(model.device)

        data_start_index = batch_i * batch_size
        predictions = model(batch_imgs.to(model.device)) # [B, C]
        prediction_class_idx = predictions.argmax(dim=-1) # [B] (long)
        prediction_correct = prediction_class_idx == batch_gt_class.to(model.device)
        batch_correct_idxs = data_start_index + prediction_correct.nonzero()[:, 0]
        batch_correct_idxs = batch_correct_idxs.tolist()

        correct_indices.extend(batch_correct_idxs)

    correct_subset = set(correct_indices)
    torch.save(correct_subset, subset_cache_path)

    return set(correct_indices)

def get_correct_subset_for_models(model_names: List[str], dataset_name, device, train):
    correct_intersection = None
    for model_name in model_names:
        model_correct_subset = get_correct_subset(model_name=model_name, dataset_name=dataset_name,
                                                   device=device, train=train)
        
        if correct_intersection is None:
            correct_intersection = model_correct_subset
        else:
            correct_intersection = model_correct_subset.intersection(correct_intersection)

    return list(correct_intersection)