File size: 9,229 Bytes
71f183c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
import os
import torch
from pathlib import Path
import numpy as np
import copy
from collections import OrderedDict

config_parameter_keys = ["loss", "unguided_lr", "model", "k", "binary_search_steps",
                        "unguided_iterations", "topk_loss_coef_upper", "seed",
                        "opt_warmup_its", "cvx_proj_margin", 
                        "topk_loss_coef_upper", "binary_search_steps"]

def config_to_dict(config):
    result_keys = config_parameter_keys
        
    result_dict = {
        key: getattr(config, key) for key in result_keys
    }

    if hasattr(config, "cvx_proj_margin"):
        result_dict["cvx_proj_margin"] = config.cvx_proj_margin
    else:
        result_dict["cvx_proj_margin"] = 0.2

    return result_dict

def load_all_results(load_min_max=False):
    if not os.path.isdir("results_rebuttal"):
        return []
    
    all_result_files = Path('results_rebuttal').rglob('*.save')

    results_list = []
    for result_file in all_result_files:
        result = torch.load(result_file)
        config = result["config"]
        
        result_dict = config_to_dict(config)

        result_dict["ASR"] = result["ASR"]
        result_dict["L1"] = result["L1 Energy"]
        result_dict["L2"] = result["L2 Energy"]
        result_dict["L_inf"] = result["L_inf Energy"]

        if "L2 Energy Max" in result and load_min_max:
            result_dict["L1 Max"] = result["L1 Energy Max"]
            result_dict["L2 Max"] = result["L2 Energy Max"]
            result_dict["L_inf Max"] = result["L_inf Energy Max"]

            result_dict["L1 Min"] = result["L1 Energy Min"]
            result_dict["L2 Min"] = result["L2 Energy Min"]
            result_dict["L_inf Min"] = result["L_inf Energy Min"]

        results_list.append(result_dict)

    return results_list

def close(target, eps=1e-5):
    return lambda x: np.allclose(x, target, atol=eps)

def eq(target):
    if isinstance(target, float):
        return close(target)
    else:
        return lambda x: x == target

def gte(target):
    return lambda x: float(x) >= target

def lte(target):
    return lambda x: float(x) <= target

def in_set(target):
    return lambda x: x in target

def filter_from_config(config):
    config_dict = config_to_dict(config)

    filter = {
        key: eq(val) for (key, val) in config_dict.items()
    }

    return filter

def filter_results(filter, results_list, only_with_minmax=False):
    filtered_results = []
    for result in results_list:
        pass_filter = True
        for key, val in result.items():
            if key not in filter:
                continue

            if not filter[key](val):
                pass_filter = False
                break

        if only_with_minmax and "L2 Max" not in result:
            continue

        if pass_filter:
            filtered_results.append(result)

    return filtered_results

def resolve_nonunique_filter(filter, results_list, include_failed=False):
    filtered_results = filter_results(filter, results_list)

    unique_parameters = []
    # Find unique parameter sets for results
    for result in filtered_results:
        result_parameters = {param_key:result[param_key] for param_key in config_parameter_keys}

        # Round to avoid floating pt imprecision from messing with set uniqueness checks
        for key in result_parameters.keys():
            if isinstance(result_parameters[key], float):
                result_parameters[key] = round(result_parameters[key], 5)

        del result_parameters["seed"]
        unique_parameters.append(result_parameters)

    # Only keep unique dicts
    unique_parameters = [dict(y) for y in set(tuple(x.items()) for x in unique_parameters)]

    best_metric = -np.Infinity
    best_param_set = None
    best_result_list = None
    for param_set in unique_parameters:
        # Perform another search
        unique_filter = {
            param_name: eq(param_value) for param_name, param_value in param_set.items()
        }

        filtered_results = filter_results(unique_filter, results_list)

        assert len(filtered_results) == 5

        asrs = [result["ASR"] for result in filtered_results]
        l2_energies = [result["L2"] for result in filtered_results]

        mean_asr = np.mean(np.array(asrs)[np.isfinite(asrs)])
        mean_l2 = np.mean(np.array(l2_energies)[np.isfinite(l2_energies)])

        # Arbitrary point in tradeoff curve
        result_goodness = -mean_l2 + mean_asr * 100

        if (mean_asr > 0 and mean_asr < 0.025) and not include_failed:
            # Irrelevant result and associated energies
            continue

        if result_goodness > best_metric or (include_failed and best_param_set is None):
            best_param_set = param_set
            best_result_list = filtered_results
            best_metric = result_goodness

    return best_param_set, best_result_list

def get_combined_results(filtered_results):
    combined_results = {}

    for result in filtered_results:
        for key in result:
            if key not in combined_results:
                combined_results[key] = []

            combined_results[key].append(result[key])

    unique_runs = len(np.unique(combined_results["seed"]))
    # assert len(combined_results["seed"]) == unique_runs

    for key, val in list(combined_results.items()):
        if key in ["ASR", "L1", "L2", "L_inf"]:
            val = np.array(val)
            combined_results[f"{key}_mean"] = np.mean(val[np.isfinite(val)])
            combined_results[f"{key}_median"] = np.median(val[np.isfinite(val)])

    # Coupled results
    best_asr_idx = np.argmax(combined_results["ASR"])
    best_asr = combined_results["ASR"][best_asr_idx]
    best_l1 = combined_results["L1"][best_asr_idx]
    best_l2 = combined_results["L2"][best_asr_idx]
    best_linf = combined_results["L_inf"][best_asr_idx]

    combined_results["ASR_best"] = best_asr
    combined_results["L1_best"] = best_l1
    combined_results["L2_best"] = best_l2
    combined_results["L_inf_best"] = best_linf

    worst_asr_idx = np.argmin(combined_results["ASR"])
    worst_asr = combined_results["ASR"][worst_asr_idx]
    worst_l1 = combined_results["L1"][worst_asr_idx]
    worst_l2 = combined_results["L2"][worst_asr_idx]
    worst_linf = combined_results["L_inf"][worst_asr_idx]

    combined_results["ASR_worst"] = worst_asr
    combined_results["L1_worst"] = worst_l1
    combined_results["L2_worst"] = worst_l2
    combined_results["L_inf_worst"] = worst_linf

    return combined_results

def build_full_results_dict(model_name="resnet50", verbose=False,
                            all_k=[20, 15, 10, 5, 1], 
                            all_num_iter=[60, 30],
                            all_search_steps=[1, 9],
                            all_methods=["cwk", "ad", "cvxproj"]):

    if verbose:
        print ("-" * 100)
        print ("Results for", model_name)

    results_list = load_all_results()
    results = OrderedDict()

    for k in all_k:
        results[k] = OrderedDict()

        for num_binary_search_steps in all_search_steps:
            results[k][num_binary_search_steps] = OrderedDict()

            for num_iter in all_num_iter:
                results[k][num_binary_search_steps][num_iter] = OrderedDict()

                for method_name in all_methods:
                    filter = {
                        "loss": eq(method_name),
                        "model": eq(model_name),
                        "k": eq(k),
                        "unguided_iterations": eq(num_iter),
                        "binary_search_steps": eq(num_binary_search_steps)
                    }

                    best_param_set, filtered_results = resolve_nonunique_filter(filter, results_list)

                    if verbose and best_param_set is not None:
                        print (f"K={k} Lr={best_param_set['unguided_lr']} and loss_coef={best_param_set['topk_loss_coef_upper']} ")
                
                    if best_param_set is None:
                        continue

                    assert len(filtered_results) == 5
                    
                    combined_results = get_combined_results(filtered_results)

                    for key in list(combined_results):
                        if "L1" not in key and "L2" not in key and "L_inf" not in key and "ASR" not in key:
                            del combined_results[key]

                    for key in list(combined_results):
                        if "mean" not in key and "worst" not in key and "best" not in key:
                            del combined_results[key]

                    results[k][num_binary_search_steps][num_iter][method_name] = combined_results

    return results

if __name__ == "__main__":
    build_full_results_dict(model_name="resnet50", verbose=True, all_search_steps=[1], all_methods=["cvxproj"])
    build_full_results_dict(model_name="densenet121", verbose=True, all_search_steps=[1], all_methods=["cvxproj"])
    build_full_results_dict(model_name="deit_small", verbose=True, all_search_steps=[1], all_methods=["cvxproj"])
    build_full_results_dict(model_name="vit_base", verbose=True, all_search_steps=[1], all_methods=["cvxproj"])
    x = 5