thomaspaniagua
QuadAttack release
71f183c
raw
history blame
No virus
9.23 kB
import os
import torch
from pathlib import Path
import numpy as np
import copy
from collections import OrderedDict
config_parameter_keys = ["loss", "unguided_lr", "model", "k", "binary_search_steps",
"unguided_iterations", "topk_loss_coef_upper", "seed",
"opt_warmup_its", "cvx_proj_margin",
"topk_loss_coef_upper", "binary_search_steps"]
def config_to_dict(config):
result_keys = config_parameter_keys
result_dict = {
key: getattr(config, key) for key in result_keys
}
if hasattr(config, "cvx_proj_margin"):
result_dict["cvx_proj_margin"] = config.cvx_proj_margin
else:
result_dict["cvx_proj_margin"] = 0.2
return result_dict
def load_all_results(load_min_max=False):
if not os.path.isdir("results_rebuttal"):
return []
all_result_files = Path('results_rebuttal').rglob('*.save')
results_list = []
for result_file in all_result_files:
result = torch.load(result_file)
config = result["config"]
result_dict = config_to_dict(config)
result_dict["ASR"] = result["ASR"]
result_dict["L1"] = result["L1 Energy"]
result_dict["L2"] = result["L2 Energy"]
result_dict["L_inf"] = result["L_inf Energy"]
if "L2 Energy Max" in result and load_min_max:
result_dict["L1 Max"] = result["L1 Energy Max"]
result_dict["L2 Max"] = result["L2 Energy Max"]
result_dict["L_inf Max"] = result["L_inf Energy Max"]
result_dict["L1 Min"] = result["L1 Energy Min"]
result_dict["L2 Min"] = result["L2 Energy Min"]
result_dict["L_inf Min"] = result["L_inf Energy Min"]
results_list.append(result_dict)
return results_list
def close(target, eps=1e-5):
return lambda x: np.allclose(x, target, atol=eps)
def eq(target):
if isinstance(target, float):
return close(target)
else:
return lambda x: x == target
def gte(target):
return lambda x: float(x) >= target
def lte(target):
return lambda x: float(x) <= target
def in_set(target):
return lambda x: x in target
def filter_from_config(config):
config_dict = config_to_dict(config)
filter = {
key: eq(val) for (key, val) in config_dict.items()
}
return filter
def filter_results(filter, results_list, only_with_minmax=False):
filtered_results = []
for result in results_list:
pass_filter = True
for key, val in result.items():
if key not in filter:
continue
if not filter[key](val):
pass_filter = False
break
if only_with_minmax and "L2 Max" not in result:
continue
if pass_filter:
filtered_results.append(result)
return filtered_results
def resolve_nonunique_filter(filter, results_list, include_failed=False):
filtered_results = filter_results(filter, results_list)
unique_parameters = []
# Find unique parameter sets for results
for result in filtered_results:
result_parameters = {param_key:result[param_key] for param_key in config_parameter_keys}
# Round to avoid floating pt imprecision from messing with set uniqueness checks
for key in result_parameters.keys():
if isinstance(result_parameters[key], float):
result_parameters[key] = round(result_parameters[key], 5)
del result_parameters["seed"]
unique_parameters.append(result_parameters)
# Only keep unique dicts
unique_parameters = [dict(y) for y in set(tuple(x.items()) for x in unique_parameters)]
best_metric = -np.Infinity
best_param_set = None
best_result_list = None
for param_set in unique_parameters:
# Perform another search
unique_filter = {
param_name: eq(param_value) for param_name, param_value in param_set.items()
}
filtered_results = filter_results(unique_filter, results_list)
assert len(filtered_results) == 5
asrs = [result["ASR"] for result in filtered_results]
l2_energies = [result["L2"] for result in filtered_results]
mean_asr = np.mean(np.array(asrs)[np.isfinite(asrs)])
mean_l2 = np.mean(np.array(l2_energies)[np.isfinite(l2_energies)])
# Arbitrary point in tradeoff curve
result_goodness = -mean_l2 + mean_asr * 100
if (mean_asr > 0 and mean_asr < 0.025) and not include_failed:
# Irrelevant result and associated energies
continue
if result_goodness > best_metric or (include_failed and best_param_set is None):
best_param_set = param_set
best_result_list = filtered_results
best_metric = result_goodness
return best_param_set, best_result_list
def get_combined_results(filtered_results):
combined_results = {}
for result in filtered_results:
for key in result:
if key not in combined_results:
combined_results[key] = []
combined_results[key].append(result[key])
unique_runs = len(np.unique(combined_results["seed"]))
# assert len(combined_results["seed"]) == unique_runs
for key, val in list(combined_results.items()):
if key in ["ASR", "L1", "L2", "L_inf"]:
val = np.array(val)
combined_results[f"{key}_mean"] = np.mean(val[np.isfinite(val)])
combined_results[f"{key}_median"] = np.median(val[np.isfinite(val)])
# Coupled results
best_asr_idx = np.argmax(combined_results["ASR"])
best_asr = combined_results["ASR"][best_asr_idx]
best_l1 = combined_results["L1"][best_asr_idx]
best_l2 = combined_results["L2"][best_asr_idx]
best_linf = combined_results["L_inf"][best_asr_idx]
combined_results["ASR_best"] = best_asr
combined_results["L1_best"] = best_l1
combined_results["L2_best"] = best_l2
combined_results["L_inf_best"] = best_linf
worst_asr_idx = np.argmin(combined_results["ASR"])
worst_asr = combined_results["ASR"][worst_asr_idx]
worst_l1 = combined_results["L1"][worst_asr_idx]
worst_l2 = combined_results["L2"][worst_asr_idx]
worst_linf = combined_results["L_inf"][worst_asr_idx]
combined_results["ASR_worst"] = worst_asr
combined_results["L1_worst"] = worst_l1
combined_results["L2_worst"] = worst_l2
combined_results["L_inf_worst"] = worst_linf
return combined_results
def build_full_results_dict(model_name="resnet50", verbose=False,
all_k=[20, 15, 10, 5, 1],
all_num_iter=[60, 30],
all_search_steps=[1, 9],
all_methods=["cwk", "ad", "cvxproj"]):
if verbose:
print ("-" * 100)
print ("Results for", model_name)
results_list = load_all_results()
results = OrderedDict()
for k in all_k:
results[k] = OrderedDict()
for num_binary_search_steps in all_search_steps:
results[k][num_binary_search_steps] = OrderedDict()
for num_iter in all_num_iter:
results[k][num_binary_search_steps][num_iter] = OrderedDict()
for method_name in all_methods:
filter = {
"loss": eq(method_name),
"model": eq(model_name),
"k": eq(k),
"unguided_iterations": eq(num_iter),
"binary_search_steps": eq(num_binary_search_steps)
}
best_param_set, filtered_results = resolve_nonunique_filter(filter, results_list)
if verbose and best_param_set is not None:
print (f"K={k} Lr={best_param_set['unguided_lr']} and loss_coef={best_param_set['topk_loss_coef_upper']} ")
if best_param_set is None:
continue
assert len(filtered_results) == 5
combined_results = get_combined_results(filtered_results)
for key in list(combined_results):
if "L1" not in key and "L2" not in key and "L_inf" not in key and "ASR" not in key:
del combined_results[key]
for key in list(combined_results):
if "mean" not in key and "worst" not in key and "best" not in key:
del combined_results[key]
results[k][num_binary_search_steps][num_iter][method_name] = combined_results
return results
if __name__ == "__main__":
build_full_results_dict(model_name="resnet50", verbose=True, all_search_steps=[1], all_methods=["cvxproj"])
build_full_results_dict(model_name="densenet121", verbose=True, all_search_steps=[1], all_methods=["cvxproj"])
build_full_results_dict(model_name="deit_small", verbose=True, all_search_steps=[1], all_methods=["cvxproj"])
build_full_results_dict(model_name="vit_base", verbose=True, all_search_steps=[1], all_methods=["cvxproj"])
x = 5