Spaces:
Runtime error
Runtime error
File size: 5,923 Bytes
d4ab5ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
from argparse import ArgumentParser, Namespace
from attributions import attention_rollout, grad_cam
from datamodules import CIFAR10QADataModule, ImageDataModule
from datamodules.utils import datamodule_factory
from functools import partial
from models import ImageInterpretationNet
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from transformers import ViTForImageClassification
from utils.plot import DrawMaskCallback, log_masks
import pytorch_lightning as pl
def get_experiment_name(args: Namespace):
"""Create a name for the experiment based on the command line arguments."""
# Convert to dictionary
args = vars(args)
# Create a list with non-experiment arguments
non_experiment_args = [
"add_blur",
"add_noise",
"add_rotation",
"base_model",
"batch_size",
"class_idx",
"data_dir",
"enable_progress_bar",
"from_pretrained",
"log_every_n_steps",
"num_epochs",
"num_workers",
"sample_images",
"seed",
]
# Create experiment name from experiment arguments
return "-".join(
[
f"{name}={value}"
for name, value in sorted(args.items())
if name not in non_experiment_args
]
)
def setup_sample_image_logs(
dm: ImageDataModule,
args: Namespace,
logger: WandbLogger,
n_panels: int = 2, # TODO: change?
):
"""Setup the log callbacks for sampling and plotting images."""
images_per_panel = args.sample_images
# Sample images
sample_images = []
iter_loader = iter(dm.val_dataloader())
for panel in range(n_panels):
X, Y = next(iter_loader)
sample_images += [(X[:images_per_panel], Y[:images_per_panel])]
# Define mask callback
mask_cb = partial(DrawMaskCallback, log_every_n_steps=args.log_every_n_steps)
callbacks = []
for panel in range(n_panels):
# Initialize ViT model
vit = ViTForImageClassification.from_pretrained(args.from_pretrained)
# Extract samples for current panel
samples = sample_images[panel]
X, _ = samples
# Log GradCAM
gradcam_masks = grad_cam(X, vit)
log_masks(X, gradcam_masks, f"GradCAM {panel}", logger)
# Log Attention Rollout
rollout_masks = attention_rollout(X, vit)
log_masks(X, rollout_masks, f"Attention Rollout {panel}", logger)
# Create mask callback
callbacks += [mask_cb(samples, key=f"{panel}")]
return callbacks
def main(args: Namespace):
# Seed
pl.seed_everything(args.seed)
# Load pre-trained Transformer
model = ViTForImageClassification.from_pretrained(args.from_pretrained)
# Load datamodule
dm = datamodule_factory(args)
# Setup datamodule to sample images for the mask callback
dm.prepare_data()
dm.setup("fit")
# Create Vision DiffMask for the model
diffmask = ImageInterpretationNet(
model_cfg=model.config,
alpha=args.alpha,
lr=args.lr,
eps=args.eps,
lr_placeholder=args.lr_placeholder,
lr_alpha=args.lr_alpha,
mul_activation=args.mul_activation,
add_activation=args.add_activation,
placeholder=not args.no_placeholder,
weighted_layer_pred=args.weighted_layer_distribution,
)
diffmask.set_vision_transformer(model)
# Create wandb logger instance
wandb_logger = WandbLogger(
name=get_experiment_name(args),
project="Patch-DiffMask",
)
# Create checkpoint callback
ckpt_cb = ModelCheckpoint(
save_top_k=-1,
dirpath=f"checkpoints/{wandb_logger.version}",
every_n_train_steps=args.log_every_n_steps,
)
# Create mask callbacks
mask_cbs = setup_sample_image_logs(dm, args, wandb_logger)
# Create trainer
trainer = pl.Trainer(
accelerator="auto",
callbacks=[ckpt_cb, *mask_cbs],
enable_progress_bar=args.enable_progress_bar,
logger=wandb_logger,
max_epochs=args.num_epochs,
)
# Train the model
trainer.fit(diffmask, dm)
if __name__ == "__main__":
parser = ArgumentParser()
# Trainer
parser.add_argument(
"--enable_progress_bar",
action="store_true",
help="Whether to enable the progress bar (NOT recommended when logging to file).",
)
parser.add_argument(
"--num_epochs",
type=int,
default=5,
help="Number of epochs to train.",
)
parser.add_argument(
"--seed",
type=int,
default=123,
help="Random seed for reproducibility.",
)
# Logging
parser.add_argument(
"--sample_images",
type=int,
default=8,
help="Number of images to sample for the mask callback.",
)
parser.add_argument(
"--log_every_n_steps",
type=int,
default=200,
help="Number of steps between logging media & checkpoints.",
)
# Base (classification) model
parser.add_argument(
"--base_model",
type=str,
default="ViT",
choices=["ViT"],
help="Base model architecture to train.",
)
parser.add_argument(
"--from_pretrained",
type=str,
default="tanlq/vit-base-patch16-224-in21k-finetuned-cifar10",
help="The name of the pretrained HF model to load.",
)
# Interpretation model
ImageInterpretationNet.add_model_specific_args(parser)
# Datamodule
ImageDataModule.add_model_specific_args(parser)
CIFAR10QADataModule.add_model_specific_args(parser)
parser.add_argument(
"--dataset",
type=str,
default="CIFAR10",
choices=["MNIST", "CIFAR10", "CIFAR10_QA", "toy"],
help="The dataset to use.",
)
args = parser.parse_args()
main(args)
|