thomaspaniagua
QuadAttack release
71f183c
raw
history blame
No virus
6.22 kB
from typing import Any
import os
import torch
import ignite.distributed as idist
import torchvision
import torchvision.transforms as T
from torch.utils import data as torch_data
from .classification_wrapper import TopKClassificationWrapper
from torch.utils.data import Subset
from modelguidedattacks.data import get_dataset
from modelguidedattacks.cls_models.accuracy import get_correct_subset_for_models, DATASET_METADATA_DIR
from tqdm import tqdm
def get_gt_labels(dataset: TopKClassificationWrapper, train:bool, dataset_name:str):
training_str = "train" if train else "val"
save_name = os.path.join(DATASET_METADATA_DIR, f"{dataset_name}_labels_{training_str}.p")
if os.path.exists(save_name):
print ("Found labels cache")
return torch.load(save_name)
dataloader = torch_data.DataLoader(dataset, batch_size=128, shuffle=False, num_workers=4)
gt_labels = []
for batch in tqdm(dataloader):
gt_labels.extend(batch[1].tolist())
gt_labels = torch.tensor(gt_labels)
torch.save(gt_labels, save_name)
return gt_labels
def class_balanced_sampling(dataset, gt_labels: torch.Tensor,
correct_labels: list, total_samples=1000):
num_classes = len(dataset.classes)
correct_labels = torch.tensor(correct_labels)
correct_mask = torch.zeros((len(dataset), ), dtype=torch.bool)
correct_mask[correct_labels] = True
sampled_indices = 0
total_sampled_indices = 0
sampled_indices = [[] for i in range(num_classes)]
shuffled_inds = torch.randperm(len(dataset))
for sample_cnt, sample_i in enumerate(shuffled_inds):
if not correct_mask[sample_i]:
continue
sample_class = gt_labels[sample_i]
desired_samples_in_class = (total_sampled_indices // num_classes) + 1
if len(sampled_indices[sample_class]) < desired_samples_in_class:
sampled_indices[sample_class].append(sample_i.item())
total_sampled_indices += 1
if total_sampled_indices >= total_samples:
break
flattened_indices = []
for class_samples in sampled_indices:
flattened_indices.extend(class_samples)
return torch.tensor(flattened_indices)
def sample_attack_labels(dataset, gt_labels, k, sampler):
"""
dataset: Dataset we're generating attack labels for
gt_labels: List of gt idx for each sample in a dataset
k: attack size
sampler: ["random"]
"""
# Sample from uniform and argsort to simulate
# a batched randperm
attack_label_uniforms = torch.rand((len(gt_labels), len(dataset.classes)))
# We don't want to sample the gt class for any samples
batch_inds = torch.arange(len(gt_labels))
attack_label_uniforms[batch_inds, gt_labels] = -1.
attack_labels = attack_label_uniforms.argsort(dim=-1, descending=True)[:, :k]
return attack_labels
def setup_data(config: Any, rank):
"""Download datasets and create dataloaders
Parameters
----------
config: needs to contain `data_path`, `train_batch_size`, `eval_batch_size`, and `num_workers`
"""
dataset_train, dataset_eval = get_dataset(config.dataset)
train_subset = None
val_subset = None
attack_labels_train = None
attack_labels_val = None
if rank == 0:
gt_labels_train = get_gt_labels(dataset_train, True, config.dataset)
gt_labels_val = get_gt_labels(dataset_eval, False, config.dataset)
attack_labels_train = sample_attack_labels(dataset_train, gt_labels_train, k=config.k,
sampler=config.attack_sampling)
attack_labels_val = sample_attack_labels(dataset_eval, gt_labels_val, k=config.k,
sampler=config.attack_sampling)
device = "cuda" if torch.cuda.is_available() else "cpu"
correct_train_set = get_correct_subset_for_models(config.compare_models,
config.dataset, device,
train=True)
correct_eval_set = get_correct_subset_for_models(config.compare_models,
config.dataset, device,
train=False)
# Balanced sampling
train_subset = class_balanced_sampling(dataset_train, gt_labels_train,
correct_train_set)
val_subset = class_balanced_sampling(dataset_eval, gt_labels_val,
correct_eval_set)
if config.overfit:
rand_inds = torch.randperm(len(val_subset))[:16]
train_subset = train_subset[rand_inds]
val_subset = val_subset[rand_inds]
train_subset = idist.broadcast(train_subset, safe_mode=True)
val_subset = idist.broadcast(val_subset, safe_mode=True)
attack_labels_train = idist.broadcast(attack_labels_train, safe_mode=True)
attack_labels_val = idist.broadcast(attack_labels_val, safe_mode=True)
dataset_train = TopKClassificationWrapper(dataset_train, k=config.k,
attack_labels=attack_labels_train)
dataset_eval = TopKClassificationWrapper(dataset_eval, k=config.k,
attack_labels=attack_labels_val)
dataset_train = Subset(dataset_train, train_subset)
dataset_eval = Subset(dataset_eval, val_subset)
# if config.overfit:
# dataset_train = Subset(dataset_train, range(2))
# dataset_eval = dataset_train
# else:
# dataset_eval = Subset(dataset_eval, torch.randperm(len(dataset_eval))[:1000].tolist() )
dataloader_train = idist.auto_dataloader(
dataset_train,
batch_size=config.train_batch_size,
shuffle=not config.overfit,
num_workers=config.num_workers,
)
dataloader_eval = idist.auto_dataloader(
dataset_eval,
batch_size=config.eval_batch_size,
shuffle=True,
num_workers=config.num_workers,
)
return dataloader_train, dataloader_eval