Spaces:
Runtime error
Runtime error
from argparse import ArgumentParser, Namespace | |
from attributions import attention_rollout, grad_cam | |
from datamodules import CIFAR10QADataModule, ImageDataModule | |
from datamodules.utils import datamodule_factory | |
from functools import partial | |
from models import ImageInterpretationNet | |
from pytorch_lightning.callbacks import ModelCheckpoint | |
from pytorch_lightning.loggers import WandbLogger | |
from transformers import ViTForImageClassification | |
from utils.plot import DrawMaskCallback, log_masks | |
import pytorch_lightning as pl | |
def get_experiment_name(args: Namespace): | |
"""Create a name for the experiment based on the command line arguments.""" | |
# Convert to dictionary | |
args = vars(args) | |
# Create a list with non-experiment arguments | |
non_experiment_args = [ | |
"add_blur", | |
"add_noise", | |
"add_rotation", | |
"base_model", | |
"batch_size", | |
"class_idx", | |
"data_dir", | |
"enable_progress_bar", | |
"from_pretrained", | |
"log_every_n_steps", | |
"num_epochs", | |
"num_workers", | |
"sample_images", | |
"seed", | |
] | |
# Create experiment name from experiment arguments | |
return "-".join( | |
[ | |
f"{name}={value}" | |
for name, value in sorted(args.items()) | |
if name not in non_experiment_args | |
] | |
) | |
def setup_sample_image_logs( | |
dm: ImageDataModule, | |
args: Namespace, | |
logger: WandbLogger, | |
n_panels: int = 2, # TODO: change? | |
): | |
"""Setup the log callbacks for sampling and plotting images.""" | |
images_per_panel = args.sample_images | |
# Sample images | |
sample_images = [] | |
iter_loader = iter(dm.val_dataloader()) | |
for panel in range(n_panels): | |
X, Y = next(iter_loader) | |
sample_images += [(X[:images_per_panel], Y[:images_per_panel])] | |
# Define mask callback | |
mask_cb = partial(DrawMaskCallback, log_every_n_steps=args.log_every_n_steps) | |
callbacks = [] | |
for panel in range(n_panels): | |
# Initialize ViT model | |
vit = ViTForImageClassification.from_pretrained(args.from_pretrained) | |
# Extract samples for current panel | |
samples = sample_images[panel] | |
X, _ = samples | |
# Log GradCAM | |
gradcam_masks = grad_cam(X, vit) | |
log_masks(X, gradcam_masks, f"GradCAM {panel}", logger) | |
# Log Attention Rollout | |
rollout_masks = attention_rollout(X, vit) | |
log_masks(X, rollout_masks, f"Attention Rollout {panel}", logger) | |
# Create mask callback | |
callbacks += [mask_cb(samples, key=f"{panel}")] | |
return callbacks | |
def main(args: Namespace): | |
# Seed | |
pl.seed_everything(args.seed) | |
# Load pre-trained Transformer | |
model = ViTForImageClassification.from_pretrained(args.from_pretrained) | |
# Load datamodule | |
dm = datamodule_factory(args) | |
# Setup datamodule to sample images for the mask callback | |
dm.prepare_data() | |
dm.setup("fit") | |
# Create Vision DiffMask for the model | |
diffmask = ImageInterpretationNet( | |
model_cfg=model.config, | |
alpha=args.alpha, | |
lr=args.lr, | |
eps=args.eps, | |
lr_placeholder=args.lr_placeholder, | |
lr_alpha=args.lr_alpha, | |
mul_activation=args.mul_activation, | |
add_activation=args.add_activation, | |
placeholder=not args.no_placeholder, | |
weighted_layer_pred=args.weighted_layer_distribution, | |
) | |
diffmask.set_vision_transformer(model) | |
# Create wandb logger instance | |
wandb_logger = WandbLogger( | |
name=get_experiment_name(args), | |
project="Patch-DiffMask", | |
) | |
# Create checkpoint callback | |
ckpt_cb = ModelCheckpoint( | |
save_top_k=-1, | |
dirpath=f"checkpoints/{wandb_logger.version}", | |
every_n_train_steps=args.log_every_n_steps, | |
) | |
# Create mask callbacks | |
mask_cbs = setup_sample_image_logs(dm, args, wandb_logger) | |
# Create trainer | |
trainer = pl.Trainer( | |
accelerator="auto", | |
callbacks=[ckpt_cb, *mask_cbs], | |
enable_progress_bar=args.enable_progress_bar, | |
logger=wandb_logger, | |
max_epochs=args.num_epochs, | |
) | |
# Train the model | |
trainer.fit(diffmask, dm) | |
if __name__ == "__main__": | |
parser = ArgumentParser() | |
# Trainer | |
parser.add_argument( | |
"--enable_progress_bar", | |
action="store_true", | |
help="Whether to enable the progress bar (NOT recommended when logging to file).", | |
) | |
parser.add_argument( | |
"--num_epochs", | |
type=int, | |
default=5, | |
help="Number of epochs to train.", | |
) | |
parser.add_argument( | |
"--seed", | |
type=int, | |
default=123, | |
help="Random seed for reproducibility.", | |
) | |
# Logging | |
parser.add_argument( | |
"--sample_images", | |
type=int, | |
default=8, | |
help="Number of images to sample for the mask callback.", | |
) | |
parser.add_argument( | |
"--log_every_n_steps", | |
type=int, | |
default=200, | |
help="Number of steps between logging media & checkpoints.", | |
) | |
# Base (classification) model | |
parser.add_argument( | |
"--base_model", | |
type=str, | |
default="ViT", | |
choices=["ViT"], | |
help="Base model architecture to train.", | |
) | |
parser.add_argument( | |
"--from_pretrained", | |
type=str, | |
default="tanlq/vit-base-patch16-224-in21k-finetuned-cifar10", | |
help="The name of the pretrained HF model to load.", | |
) | |
# Interpretation model | |
ImageInterpretationNet.add_model_specific_args(parser) | |
# Datamodule | |
ImageDataModule.add_model_specific_args(parser) | |
CIFAR10QADataModule.add_model_specific_args(parser) | |
parser.add_argument( | |
"--dataset", | |
type=str, | |
default="CIFAR10", | |
choices=["MNIST", "CIFAR10", "CIFAR10_QA", "toy"], | |
help="The dataset to use.", | |
) | |
args = parser.parse_args() | |
main(args) | |