din0s's picture
Add code
d4ab5ac unverified
raw
history blame
5.92 kB
from argparse import ArgumentParser, Namespace
from attributions import attention_rollout, grad_cam
from datamodules import CIFAR10QADataModule, ImageDataModule
from datamodules.utils import datamodule_factory
from functools import partial
from models import ImageInterpretationNet
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from transformers import ViTForImageClassification
from utils.plot import DrawMaskCallback, log_masks
import pytorch_lightning as pl
def get_experiment_name(args: Namespace):
"""Create a name for the experiment based on the command line arguments."""
# Convert to dictionary
args = vars(args)
# Create a list with non-experiment arguments
non_experiment_args = [
"add_blur",
"add_noise",
"add_rotation",
"base_model",
"batch_size",
"class_idx",
"data_dir",
"enable_progress_bar",
"from_pretrained",
"log_every_n_steps",
"num_epochs",
"num_workers",
"sample_images",
"seed",
]
# Create experiment name from experiment arguments
return "-".join(
[
f"{name}={value}"
for name, value in sorted(args.items())
if name not in non_experiment_args
]
)
def setup_sample_image_logs(
dm: ImageDataModule,
args: Namespace,
logger: WandbLogger,
n_panels: int = 2, # TODO: change?
):
"""Setup the log callbacks for sampling and plotting images."""
images_per_panel = args.sample_images
# Sample images
sample_images = []
iter_loader = iter(dm.val_dataloader())
for panel in range(n_panels):
X, Y = next(iter_loader)
sample_images += [(X[:images_per_panel], Y[:images_per_panel])]
# Define mask callback
mask_cb = partial(DrawMaskCallback, log_every_n_steps=args.log_every_n_steps)
callbacks = []
for panel in range(n_panels):
# Initialize ViT model
vit = ViTForImageClassification.from_pretrained(args.from_pretrained)
# Extract samples for current panel
samples = sample_images[panel]
X, _ = samples
# Log GradCAM
gradcam_masks = grad_cam(X, vit)
log_masks(X, gradcam_masks, f"GradCAM {panel}", logger)
# Log Attention Rollout
rollout_masks = attention_rollout(X, vit)
log_masks(X, rollout_masks, f"Attention Rollout {panel}", logger)
# Create mask callback
callbacks += [mask_cb(samples, key=f"{panel}")]
return callbacks
def main(args: Namespace):
# Seed
pl.seed_everything(args.seed)
# Load pre-trained Transformer
model = ViTForImageClassification.from_pretrained(args.from_pretrained)
# Load datamodule
dm = datamodule_factory(args)
# Setup datamodule to sample images for the mask callback
dm.prepare_data()
dm.setup("fit")
# Create Vision DiffMask for the model
diffmask = ImageInterpretationNet(
model_cfg=model.config,
alpha=args.alpha,
lr=args.lr,
eps=args.eps,
lr_placeholder=args.lr_placeholder,
lr_alpha=args.lr_alpha,
mul_activation=args.mul_activation,
add_activation=args.add_activation,
placeholder=not args.no_placeholder,
weighted_layer_pred=args.weighted_layer_distribution,
)
diffmask.set_vision_transformer(model)
# Create wandb logger instance
wandb_logger = WandbLogger(
name=get_experiment_name(args),
project="Patch-DiffMask",
)
# Create checkpoint callback
ckpt_cb = ModelCheckpoint(
save_top_k=-1,
dirpath=f"checkpoints/{wandb_logger.version}",
every_n_train_steps=args.log_every_n_steps,
)
# Create mask callbacks
mask_cbs = setup_sample_image_logs(dm, args, wandb_logger)
# Create trainer
trainer = pl.Trainer(
accelerator="auto",
callbacks=[ckpt_cb, *mask_cbs],
enable_progress_bar=args.enable_progress_bar,
logger=wandb_logger,
max_epochs=args.num_epochs,
)
# Train the model
trainer.fit(diffmask, dm)
if __name__ == "__main__":
parser = ArgumentParser()
# Trainer
parser.add_argument(
"--enable_progress_bar",
action="store_true",
help="Whether to enable the progress bar (NOT recommended when logging to file).",
)
parser.add_argument(
"--num_epochs",
type=int,
default=5,
help="Number of epochs to train.",
)
parser.add_argument(
"--seed",
type=int,
default=123,
help="Random seed for reproducibility.",
)
# Logging
parser.add_argument(
"--sample_images",
type=int,
default=8,
help="Number of images to sample for the mask callback.",
)
parser.add_argument(
"--log_every_n_steps",
type=int,
default=200,
help="Number of steps between logging media & checkpoints.",
)
# Base (classification) model
parser.add_argument(
"--base_model",
type=str,
default="ViT",
choices=["ViT"],
help="Base model architecture to train.",
)
parser.add_argument(
"--from_pretrained",
type=str,
default="tanlq/vit-base-patch16-224-in21k-finetuned-cifar10",
help="The name of the pretrained HF model to load.",
)
# Interpretation model
ImageInterpretationNet.add_model_specific_args(parser)
# Datamodule
ImageDataModule.add_model_specific_args(parser)
CIFAR10QADataModule.add_model_specific_args(parser)
parser.add_argument(
"--dataset",
type=str,
default="CIFAR10",
choices=["MNIST", "CIFAR10", "CIFAR10_QA", "toy"],
help="The dataset to use.",
)
args = parser.parse_args()
main(args)