Spaces:
Runtime error
Runtime error
Add code
Browse files- .gitattributes +1 -0
- app.py +63 -0
- checkpoints/.gitkeep +1 -0
- checkpoints/diffmask.ckpt +3 -0
- code/attributions/__init__.py +2 -0
- code/attributions/attention_rollout.py +59 -0
- code/attributions/grad_cam.py +55 -0
- code/datamodules/__init__.py +3 -0
- code/datamodules/base.py +156 -0
- code/datamodules/image_classification.py +44 -0
- code/datamodules/transformations.py +41 -0
- code/datamodules/utils.py +133 -0
- code/datamodules/visual_qa.py +241 -0
- code/eval_base.py +102 -0
- code/main.py +215 -0
- code/models/__init__.py +2 -0
- code/models/classification.py +112 -0
- code/models/gates.py +261 -0
- code/models/interpretation.py +482 -0
- code/models/utils.py +64 -0
- code/train_base.py +123 -0
- code/utils/__init__.py +0 -0
- code/utils/distributions.py +64 -0
- code/utils/getters_setters.py +122 -0
- code/utils/metrics.py +67 -0
- code/utils/optimizer.py +151 -0
- code/utils/plot.py +252 -0
- requirements.txt +7 -0
.gitattributes
CHANGED
@@ -25,3 +25,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
checkpoints/diffmask.ckpt filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.insert(0, './code')
|
3 |
+
|
4 |
+
from datamodules.transformations import UnNest
|
5 |
+
from models.interpretation import ImageInterpretationNet
|
6 |
+
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
7 |
+
from utils.plot import smoothen, draw_mask_on_image, draw_heatmap_on_image
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
# Load Vision Transformer
|
14 |
+
hf_model = "tanlq/vit-base-patch16-224-in21k-finetuned-cifar10"
|
15 |
+
vit = ViTForImageClassification.from_pretrained(hf_model)
|
16 |
+
vit.eval()
|
17 |
+
|
18 |
+
# Load Feature Extractor
|
19 |
+
feature_extractor = ViTFeatureExtractor.from_pretrained(hf_model, return_tensors="pt")
|
20 |
+
feature_extractor = UnNest(feature_extractor)
|
21 |
+
|
22 |
+
# Load Vision DiffMask
|
23 |
+
diffmask = ImageInterpretationNet.load_from_checkpoint('checkpoints/diffmask.ckpt')
|
24 |
+
diffmask.set_vision_transformer(vit)
|
25 |
+
|
26 |
+
|
27 |
+
# Define mask plotting functions
|
28 |
+
def draw_mask(image, mask):
|
29 |
+
return draw_mask_on_image(image, smoothen(mask))\
|
30 |
+
.permute(1, 2, 0)\
|
31 |
+
.clip(0, 1)\
|
32 |
+
.numpy()
|
33 |
+
|
34 |
+
|
35 |
+
def draw_heatmap(image, mask):
|
36 |
+
return draw_heatmap_on_image(image, smoothen(mask))\
|
37 |
+
.permute(1, 2, 0)\
|
38 |
+
.clip(0, 1)\
|
39 |
+
.numpy()
|
40 |
+
|
41 |
+
|
42 |
+
# Define callable method for the demo
|
43 |
+
def get_mask(image):
|
44 |
+
if image is None:
|
45 |
+
return None
|
46 |
+
|
47 |
+
image = torch.from_numpy(image).permute(2, 0, 1).float() / 255
|
48 |
+
dm_image = feature_extractor(image).unsqueeze(0)
|
49 |
+
mask = diffmask.get_mask(dm_image)["mask"][0].detach()
|
50 |
+
|
51 |
+
masked_img = draw_mask(image, mask)
|
52 |
+
heatmap = draw_heatmap(image, mask)
|
53 |
+
return np.hstack((masked_img, heatmap))
|
54 |
+
|
55 |
+
|
56 |
+
# Launch demo interface
|
57 |
+
gr.Interface(
|
58 |
+
get_mask,
|
59 |
+
inputs=gr.inputs.Image(label="Input", shape=(224, 224), source="upload", type="numpy"),
|
60 |
+
outputs=[gr.outputs.Image(label="Output")],
|
61 |
+
title="Vision DiffMask Demo",
|
62 |
+
live=True,
|
63 |
+
).launch()
|
checkpoints/.gitkeep
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
checkpoints/diffmask.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:33ceff3adc10ffb86bdaa3c90380e7925e76e7b170ed42d1cc00ff33328fc77b
|
3 |
+
size 16610391
|
code/attributions/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .attention_rollout import attention_rollout
|
2 |
+
from .grad_cam import grad_cam
|
code/attributions/attention_rollout.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
from math import sqrt
|
5 |
+
from torch import Tensor
|
6 |
+
from transformers import ViTForImageClassification
|
7 |
+
|
8 |
+
|
9 |
+
@torch.no_grad()
|
10 |
+
def attention_rollout(
|
11 |
+
images: Tensor,
|
12 |
+
vit: ViTForImageClassification,
|
13 |
+
discard_ratio: float = 0.9,
|
14 |
+
head_fusion: str = "mean",
|
15 |
+
device: str = "cpu",
|
16 |
+
) -> Tensor:
|
17 |
+
"""Performs the Attention Rollout method on a batch of images (https://arxiv.org/pdf/2005.00928.pdf)."""
|
18 |
+
# Forward pass and save attention maps
|
19 |
+
attentions = vit(images, output_attentions=True).attentions
|
20 |
+
|
21 |
+
B, _, H, W = images.shape # Batch size, channels, height, width
|
22 |
+
P = attentions[0].size(-1) # Number of patches
|
23 |
+
|
24 |
+
mask = torch.eye(P).to(device)
|
25 |
+
# Iterate over layers
|
26 |
+
for j, attention in enumerate(attentions):
|
27 |
+
if head_fusion == "mean":
|
28 |
+
attention_heads_fused = attention.mean(axis=1)
|
29 |
+
elif head_fusion == "max":
|
30 |
+
attention_heads_fused = attention.max(axis=1)[0]
|
31 |
+
elif head_fusion == "min":
|
32 |
+
attention_heads_fused = attention.min(axis=1)[0]
|
33 |
+
else:
|
34 |
+
raise "Attention head fusion type Not supported"
|
35 |
+
|
36 |
+
# Drop the lowest attentions, but don't drop the class token
|
37 |
+
flat = attention_heads_fused.view(B, -1)
|
38 |
+
_, indices = flat.topk(int(flat.size(-1) * discard_ratio), -1, False)
|
39 |
+
indices = indices[indices != 0]
|
40 |
+
flat[0, indices] = 0
|
41 |
+
|
42 |
+
# I = torch.eye(P)
|
43 |
+
a = (attention_heads_fused + torch.eye(P).to(device)) / 2
|
44 |
+
a = a / a.sum(dim=-1).view(-1, P, 1)
|
45 |
+
|
46 |
+
mask = a @ mask
|
47 |
+
|
48 |
+
# Look at the total attention between the class token and the image patches
|
49 |
+
mask = mask[:, 0, 1:]
|
50 |
+
mask = mask / torch.max(mask)
|
51 |
+
|
52 |
+
N = int(sqrt(P))
|
53 |
+
S = int(H / N)
|
54 |
+
|
55 |
+
mask = mask.reshape(B, 1, N, N)
|
56 |
+
mask = F.interpolate(mask, scale_factor=S)
|
57 |
+
mask = mask.reshape(B, H, W)
|
58 |
+
|
59 |
+
return mask
|
code/attributions/grad_cam.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from pytorch_grad_cam import GradCAM
|
4 |
+
from torch import Tensor
|
5 |
+
from transformers import ViTForImageClassification
|
6 |
+
|
7 |
+
|
8 |
+
def grad_cam(images: Tensor, vit: ViTForImageClassification, use_cuda: bool = False) -> Tensor:
|
9 |
+
"""Performs the Grad-CAM method on a batch of images (https://arxiv.org/pdf/1610.02391.pdf)."""
|
10 |
+
|
11 |
+
# Wrap the ViT model to be compatible with GradCAM
|
12 |
+
vit = ViTWrapper(vit)
|
13 |
+
vit.eval()
|
14 |
+
|
15 |
+
# Create GradCAM object
|
16 |
+
cam = GradCAM(
|
17 |
+
model=vit,
|
18 |
+
target_layers=[vit.target_layer],
|
19 |
+
reshape_transform=_reshape_transform,
|
20 |
+
use_cuda=use_cuda,
|
21 |
+
)
|
22 |
+
|
23 |
+
# Compute GradCAM masks
|
24 |
+
grayscale_cam = cam(
|
25 |
+
input_tensor=images,
|
26 |
+
targets=None,
|
27 |
+
eigen_smooth=True,
|
28 |
+
aug_smooth=True,
|
29 |
+
)
|
30 |
+
|
31 |
+
return torch.from_numpy(grayscale_cam)
|
32 |
+
|
33 |
+
|
34 |
+
def _reshape_transform(tensor, height=14, width=14):
|
35 |
+
result = tensor[:, 1:, :].reshape(tensor.size(0), height, width, tensor.size(2))
|
36 |
+
|
37 |
+
# Bring the channels to the first dimension
|
38 |
+
result = result.transpose(2, 3).transpose(1, 2)
|
39 |
+
|
40 |
+
return result
|
41 |
+
|
42 |
+
|
43 |
+
class ViTWrapper(torch.nn.Module):
|
44 |
+
"""ViT Wrapper to use with Grad-CAM."""
|
45 |
+
|
46 |
+
def __init__(self, vit: ViTForImageClassification):
|
47 |
+
super().__init__()
|
48 |
+
self.vit = vit
|
49 |
+
|
50 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
51 |
+
return self.vit(x).logits
|
52 |
+
|
53 |
+
@property
|
54 |
+
def target_layer(self):
|
55 |
+
return self.vit.vit.encoder.layer[-2].layernorm_after
|
code/datamodules/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .base import ImageDataModule
|
2 |
+
from .image_classification import CIFAR10DataModule, MNISTDataModule
|
3 |
+
from .visual_qa import CIFAR10QADataModule, ToyQADataModule
|
code/datamodules/base.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .transformations import AddGaussianNoise
|
2 |
+
from abc import abstractmethod, ABCMeta
|
3 |
+
from argparse import ArgumentParser
|
4 |
+
from pytorch_lightning import LightningDataModule
|
5 |
+
from torch.utils.data import (
|
6 |
+
DataLoader,
|
7 |
+
Dataset,
|
8 |
+
default_collate,
|
9 |
+
RandomSampler,
|
10 |
+
SequentialSampler,
|
11 |
+
)
|
12 |
+
from torchvision import transforms
|
13 |
+
from typing import Optional
|
14 |
+
|
15 |
+
|
16 |
+
class ImageDataModule(LightningDataModule, metaclass=ABCMeta):
|
17 |
+
@staticmethod
|
18 |
+
def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
|
19 |
+
parser = parent_parser.add_argument_group("Data Modules")
|
20 |
+
parser.add_argument(
|
21 |
+
"--data_dir",
|
22 |
+
type=str,
|
23 |
+
default="data/",
|
24 |
+
help="The directory where the data is stored.",
|
25 |
+
)
|
26 |
+
parser.add_argument(
|
27 |
+
"--batch_size",
|
28 |
+
type=int,
|
29 |
+
default=32,
|
30 |
+
help="The batch size to use.",
|
31 |
+
)
|
32 |
+
parser.add_argument(
|
33 |
+
"--add_noise",
|
34 |
+
action="store_true",
|
35 |
+
help="Use gaussian noise augmentation.",
|
36 |
+
)
|
37 |
+
parser.add_argument(
|
38 |
+
"--add_rotation",
|
39 |
+
action="store_true",
|
40 |
+
help="Use rotation augmentation.",
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--add_blur",
|
44 |
+
action="store_true",
|
45 |
+
help="Use blur augmentation.",
|
46 |
+
)
|
47 |
+
parser.add_argument(
|
48 |
+
"--num_workers",
|
49 |
+
type=int,
|
50 |
+
default=4,
|
51 |
+
help="Number of workers to use for data loading.",
|
52 |
+
)
|
53 |
+
return parent_parser
|
54 |
+
|
55 |
+
# Declare variables that will be initialized later
|
56 |
+
train_data: Dataset
|
57 |
+
val_data: Dataset
|
58 |
+
test_data: Dataset
|
59 |
+
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
feature_extractor: Optional[callable] = None,
|
63 |
+
data_dir: str = "data/",
|
64 |
+
batch_size: int = 32,
|
65 |
+
add_noise: bool = False,
|
66 |
+
add_rotation: bool = False,
|
67 |
+
add_blur: bool = False,
|
68 |
+
num_workers: int = 4,
|
69 |
+
):
|
70 |
+
"""Abstract Pytorch Lightning DataModule for image datasets.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
feature_extractor (callable): feature extractor instance
|
74 |
+
data_dir (str): directory to store the dataset
|
75 |
+
batch_size (int): batch size for the train/val/test dataloaders
|
76 |
+
add_noise (bool): whether to add noise to the images
|
77 |
+
add_rotation (bool): whether to add random rotation to the images
|
78 |
+
add_blur (bool): whether to add blur to the images
|
79 |
+
num_workers (int): number of workers for train/val/test dataloaders
|
80 |
+
"""
|
81 |
+
super().__init__()
|
82 |
+
|
83 |
+
# Store hyperparameters
|
84 |
+
self.data_dir = data_dir
|
85 |
+
self.batch_size = batch_size
|
86 |
+
self.feature_extractor = feature_extractor
|
87 |
+
self.num_workers = num_workers
|
88 |
+
|
89 |
+
# Set the transforms
|
90 |
+
# If the feature_extractor is None, then we do not split the images into features
|
91 |
+
init_transforms = [feature_extractor] if feature_extractor else []
|
92 |
+
self.transform = transforms.Compose(init_transforms)
|
93 |
+
self._add_transforms(add_noise, add_rotation, add_blur)
|
94 |
+
|
95 |
+
# Set the collate function and the samplers
|
96 |
+
# These can be adapted in a child datamodule class to have a different behavior
|
97 |
+
self.collate_fn = default_collate
|
98 |
+
self.shuffled_sampler = RandomSampler
|
99 |
+
self.sequential_sampler = SequentialSampler
|
100 |
+
|
101 |
+
def _add_transforms(self, noise: bool, rotation: bool, blur: bool):
|
102 |
+
"""Add transforms to the module's transformations list.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
noise (bool): whether to add noise to the images
|
106 |
+
rotation (bool): whether to add random rotation to the images
|
107 |
+
blur (bool): whether to add blur to the images
|
108 |
+
"""
|
109 |
+
# TODO:
|
110 |
+
# - Which order to add the transforms in?
|
111 |
+
# - Applied in both train and test or just test?
|
112 |
+
# - Check what transforms are applied by the model
|
113 |
+
if noise:
|
114 |
+
self.transform.transforms.append(AddGaussianNoise(0.0, 1.0))
|
115 |
+
if rotation:
|
116 |
+
self.transform.transforms.append(transforms.RandomRotation(20))
|
117 |
+
if blur:
|
118 |
+
self.transform.transforms.append(transforms.GaussianBlur(3))
|
119 |
+
|
120 |
+
@abstractmethod
|
121 |
+
def prepare_data(self):
|
122 |
+
raise NotImplementedError()
|
123 |
+
|
124 |
+
@abstractmethod
|
125 |
+
def setup(self, stage: Optional[str] = None):
|
126 |
+
raise NotImplementedError()
|
127 |
+
|
128 |
+
# noinspection PyTypeChecker
|
129 |
+
def train_dataloader(self) -> DataLoader:
|
130 |
+
return DataLoader(
|
131 |
+
self.train_data,
|
132 |
+
batch_size=self.batch_size,
|
133 |
+
num_workers=self.num_workers,
|
134 |
+
collate_fn=self.collate_fn,
|
135 |
+
sampler=self.shuffled_sampler(self.train_data),
|
136 |
+
)
|
137 |
+
|
138 |
+
# noinspection PyTypeChecker
|
139 |
+
def val_dataloader(self) -> DataLoader:
|
140 |
+
return DataLoader(
|
141 |
+
self.val_data,
|
142 |
+
batch_size=self.batch_size,
|
143 |
+
num_workers=self.num_workers,
|
144 |
+
collate_fn=self.collate_fn,
|
145 |
+
sampler=self.sequential_sampler(self.val_data),
|
146 |
+
)
|
147 |
+
|
148 |
+
# noinspection PyTypeChecker
|
149 |
+
def test_dataloader(self) -> DataLoader:
|
150 |
+
return DataLoader(
|
151 |
+
self.test_data,
|
152 |
+
batch_size=self.batch_size,
|
153 |
+
num_workers=self.num_workers,
|
154 |
+
collate_fn=self.collate_fn,
|
155 |
+
sampler=self.sequential_sampler(self.test_data),
|
156 |
+
)
|
code/datamodules/image_classification.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base import ImageDataModule
|
2 |
+
from torch.utils.data import random_split
|
3 |
+
from torchvision.datasets import MNIST, CIFAR10
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
|
7 |
+
class MNISTDataModule(ImageDataModule):
|
8 |
+
"""Datamodule for the MNIST dataset."""
|
9 |
+
|
10 |
+
def prepare_data(self):
|
11 |
+
# Download MNIST
|
12 |
+
MNIST(self.data_dir, train=True, download=True)
|
13 |
+
MNIST(self.data_dir, train=False, download=True)
|
14 |
+
|
15 |
+
def setup(self, stage: Optional[str] = None):
|
16 |
+
# Set the training and validation data
|
17 |
+
if stage == "fit" or stage is None:
|
18 |
+
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
|
19 |
+
self.train_data, self.val_data = random_split(mnist_full, [55000, 5000])
|
20 |
+
|
21 |
+
# Set the test data
|
22 |
+
if stage == "test" or stage is None:
|
23 |
+
self.test_data = MNIST(self.data_dir, train=False, transform=self.transform)
|
24 |
+
|
25 |
+
|
26 |
+
class CIFAR10DataModule(ImageDataModule):
|
27 |
+
"""Datamodule for the CIFAR10 dataset."""
|
28 |
+
|
29 |
+
def prepare_data(self):
|
30 |
+
# Download CIFAR10
|
31 |
+
CIFAR10(self.data_dir, train=True, download=True)
|
32 |
+
CIFAR10(self.data_dir, train=False, download=True)
|
33 |
+
|
34 |
+
def setup(self, stage: Optional[str] = None):
|
35 |
+
# Set the training and validation data
|
36 |
+
if stage == "fit" or stage is None:
|
37 |
+
cifar10_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
|
38 |
+
self.train_data, self.val_data = random_split(cifar10_full, [45000, 5000])
|
39 |
+
|
40 |
+
# Set the test data
|
41 |
+
if stage == "test" or stage is None:
|
42 |
+
self.test_data = CIFAR10(
|
43 |
+
self.data_dir, train=False, transform=self.transform
|
44 |
+
)
|
code/datamodules/transformations.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import Tensor
|
2 |
+
from transformers.image_utils import ImageInput
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
class AddGaussianNoise:
|
8 |
+
"""Add Gaussian noise to an image.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
mean (float): mean of the Gaussian noise
|
12 |
+
std (float): standard deviation of the Gaussian noise
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, mean: float = 0.0, std: float = 1.0):
|
16 |
+
self.std = std
|
17 |
+
self.mean = mean
|
18 |
+
|
19 |
+
def __call__(self, tensor: Tensor) -> Tensor:
|
20 |
+
return tensor + torch.randn(tensor.size()) * self.std + self.mean
|
21 |
+
|
22 |
+
def __repr__(self) -> str:
|
23 |
+
return self.__class__.__name__ + "(mean={0}, std={1})".format(
|
24 |
+
self.mean, self.std
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
class UnNest:
|
29 |
+
"""Un-nest the output of a feature extractor"""
|
30 |
+
|
31 |
+
def __init__(self, feature_extractor: callable):
|
32 |
+
self.feature_extractor = feature_extractor
|
33 |
+
|
34 |
+
def __call__(self, x: ImageInput) -> Tensor:
|
35 |
+
# Pass the input through the feature extractor
|
36 |
+
x = self.feature_extractor(x)
|
37 |
+
# Un-nest the pixel_values tensor
|
38 |
+
x = torch.tensor(x["pixel_values"][0])
|
39 |
+
|
40 |
+
# HuggingFace models expect 3D tensors [C, H, W]
|
41 |
+
return x if len(x) == 3 else x.unsqueeze(0)
|
code/datamodules/utils.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .image_classification import CIFAR10DataModule, ImageDataModule, MNISTDataModule
|
2 |
+
from .transformations import UnNest
|
3 |
+
from .visual_qa import CIFAR10QADataModule, ToyQADataModule
|
4 |
+
from argparse import Namespace
|
5 |
+
from transformers import ConvNextFeatureExtractor, ViTFeatureExtractor
|
6 |
+
|
7 |
+
|
8 |
+
def get_configs(args: Namespace) -> tuple[dict, dict]:
|
9 |
+
"""Get the model and feature extractor configs from the command line args.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
args (Namespace): the argparse Namespace object
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
a tuple containing the model and feature extractor configs
|
16 |
+
"""
|
17 |
+
if args.dataset == "MNIST":
|
18 |
+
# We upsample the MNIST images to 112x112, with 1 channel (grayscale)
|
19 |
+
# and 10 classes (0-9). We normalize the image to have a mean of 0.5
|
20 |
+
# and a standard deviation of ±0.5.
|
21 |
+
model_cfg_args = {
|
22 |
+
"image_size": 112,
|
23 |
+
"num_channels": 1,
|
24 |
+
"num_labels": 10,
|
25 |
+
}
|
26 |
+
fe_cfg_args = {
|
27 |
+
"image_mean": [0.5],
|
28 |
+
"image_std": [0.5],
|
29 |
+
}
|
30 |
+
elif args.dataset.startswith("CIFAR10"):
|
31 |
+
if args.dataset not in ("CIFAR10", "CIFAR10_QA"):
|
32 |
+
raise Exception(f"Unknown CIFAR10 variant: {args.dataset}")
|
33 |
+
|
34 |
+
# We upsample the CIFAR10 images to 224x224, with 3 channels (RGB) and
|
35 |
+
# 10 classes (0-9) for the normal dataset, or (grid_size)^2 + 1 for the
|
36 |
+
# toy task. We normalize the image to have a mean of 0.5 and a standard
|
37 |
+
# deviation of ±0.5.
|
38 |
+
model_cfg_args = {
|
39 |
+
"image_size": 224, # fixed to 224 because pretrained models have that size
|
40 |
+
"num_channels": 3,
|
41 |
+
"num_labels": (args.grid_size**2) + 1
|
42 |
+
if args.dataset == "CIFAR10_QA"
|
43 |
+
else 10,
|
44 |
+
}
|
45 |
+
fe_cfg_args = {
|
46 |
+
"image_mean": [0.5, 0.5, 0.5],
|
47 |
+
"image_std": [0.5, 0.5, 0.5],
|
48 |
+
}
|
49 |
+
elif args.dataset == "toy":
|
50 |
+
# We use an image size so that each patch contains a single color, with
|
51 |
+
# 3 channels (RGB) and (grid_size)^2 + 1 classes. We normalize the image
|
52 |
+
# to have a mean of 0.5 and a standard deviation of ±0.5.
|
53 |
+
model_cfg_args = {
|
54 |
+
"image_size": args.grid_size * 16,
|
55 |
+
"num_channels": 3,
|
56 |
+
"num_labels": (args.grid_size**2) + 1,
|
57 |
+
}
|
58 |
+
fe_cfg_args = {
|
59 |
+
"image_mean": [0.5, 0.5, 0.5],
|
60 |
+
"image_std": [0.5, 0.5, 0.5],
|
61 |
+
}
|
62 |
+
else:
|
63 |
+
raise Exception(f"Unknown dataset: {args.dataset}")
|
64 |
+
|
65 |
+
# Set the feature extractor's size attribute to be the same as the model's image size
|
66 |
+
fe_cfg_args["size"] = model_cfg_args["image_size"]
|
67 |
+
# Set the tensors' return type to PyTorch tensors
|
68 |
+
fe_cfg_args["return_tensors"] = "pt"
|
69 |
+
|
70 |
+
return model_cfg_args, fe_cfg_args
|
71 |
+
|
72 |
+
|
73 |
+
def datamodule_factory(args: Namespace) -> ImageDataModule:
|
74 |
+
"""A factory method for creating a datamodule based on the command line args.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
args (Namespace): the argparse Namespace object
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
an ImageDataModule instance
|
81 |
+
"""
|
82 |
+
# Get the model and feature extractor configs
|
83 |
+
model_cfg_args, fe_cfg_args = get_configs(args)
|
84 |
+
|
85 |
+
# Set the feature extractor class based on the provided base model name
|
86 |
+
if args.base_model == "ViT":
|
87 |
+
fe_class = ViTFeatureExtractor
|
88 |
+
elif args.base_model == "ConvNeXt":
|
89 |
+
fe_class = ConvNextFeatureExtractor
|
90 |
+
else:
|
91 |
+
raise Exception(f"Unknown base model: {args.base_model}")
|
92 |
+
|
93 |
+
# Create the feature extractor instance
|
94 |
+
if args.from_pretrained:
|
95 |
+
feature_extractor = fe_class.from_pretrained(
|
96 |
+
args.from_pretrained, **fe_cfg_args
|
97 |
+
)
|
98 |
+
else:
|
99 |
+
feature_extractor = fe_class(**fe_cfg_args)
|
100 |
+
|
101 |
+
# Un-nest the feature extractor's output
|
102 |
+
feature_extractor = UnNest(feature_extractor)
|
103 |
+
|
104 |
+
# Define the datamodule's configuration
|
105 |
+
dm_cfg = {
|
106 |
+
"feature_extractor": feature_extractor,
|
107 |
+
"batch_size": args.batch_size,
|
108 |
+
"add_noise": args.add_noise,
|
109 |
+
"add_rotation": args.add_rotation,
|
110 |
+
"add_blur": args.add_blur,
|
111 |
+
"num_workers": args.num_workers,
|
112 |
+
}
|
113 |
+
|
114 |
+
# Determine the dataset class based on the provided dataset name
|
115 |
+
if args.dataset.startswith("CIFAR10"):
|
116 |
+
if args.dataset == "CIFAR10":
|
117 |
+
dm_class = CIFAR10DataModule
|
118 |
+
elif args.dataset == "CIFAR10_QA":
|
119 |
+
dm_cfg["class_idx"] = args.class_idx
|
120 |
+
dm_cfg["grid_size"] = args.grid_size
|
121 |
+
dm_class = CIFAR10QADataModule
|
122 |
+
else:
|
123 |
+
raise Exception(f"Unknown CIFAR10 variant: {args.dataset}")
|
124 |
+
elif args.dataset == "MNIST":
|
125 |
+
dm_class = MNISTDataModule
|
126 |
+
elif args.dataset == "toy":
|
127 |
+
dm_cfg["class_idx"] = args.class_idx
|
128 |
+
dm_cfg["grid_size"] = args.grid_size
|
129 |
+
dm_class = ToyQADataModule
|
130 |
+
else:
|
131 |
+
raise Exception(f"Unknown dataset: {args.dataset}")
|
132 |
+
|
133 |
+
return dm_class(**dm_cfg)
|
code/datamodules/visual_qa.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .image_classification import CIFAR10DataModule
|
2 |
+
from argparse import ArgumentParser
|
3 |
+
from functools import partial
|
4 |
+
from torch import LongTensor
|
5 |
+
from torch.utils.data import default_collate, random_split, Sampler
|
6 |
+
from torchvision import transforms
|
7 |
+
from torchvision.datasets import VisionDataset
|
8 |
+
from typing import Iterator, Optional
|
9 |
+
|
10 |
+
import itertools
|
11 |
+
import random
|
12 |
+
import torch
|
13 |
+
|
14 |
+
|
15 |
+
class CIFAR10QADataModule(CIFAR10DataModule):
|
16 |
+
@staticmethod
|
17 |
+
def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
|
18 |
+
parser = parent_parser.add_argument_group("Visual QA")
|
19 |
+
parser.add_argument(
|
20 |
+
"--class_idx",
|
21 |
+
type=int,
|
22 |
+
default=3,
|
23 |
+
help="The class (index) to count.",
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"--grid_size",
|
27 |
+
type=int,
|
28 |
+
default=3,
|
29 |
+
help="The number of images per row in the grid.",
|
30 |
+
)
|
31 |
+
return parent_parser
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
class_idx: int,
|
36 |
+
grid_size: int = 3,
|
37 |
+
feature_extractor: callable = None,
|
38 |
+
data_dir: str = "data/",
|
39 |
+
batch_size: int = 32,
|
40 |
+
add_noise: bool = False,
|
41 |
+
add_rotation: bool = False,
|
42 |
+
add_blur: bool = False,
|
43 |
+
num_workers: int = 4,
|
44 |
+
):
|
45 |
+
"""A datamodule for a modified CIFAR10 dataset that is used for Question Answering.
|
46 |
+
More specifically, the task is to count the number of images of a certain class in a grid.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
class_idx (int): the class (index) to count
|
50 |
+
grid_size (int): the number of images per row in the grid
|
51 |
+
feature_extractor (callable): a callable feature extractor instance
|
52 |
+
data_dir (str): the directory to store the dataset
|
53 |
+
batch_size (int): the batch size for the train/val/test dataloaders
|
54 |
+
add_noise (bool): whether to add noise to the images
|
55 |
+
add_rotation (bool): whether to add rotation augmentation
|
56 |
+
add_blur (bool): whether to add blur augmentation
|
57 |
+
num_workers (int): the number of workers to use for data loading
|
58 |
+
"""
|
59 |
+
super().__init__(
|
60 |
+
feature_extractor,
|
61 |
+
data_dir,
|
62 |
+
(grid_size**2) * batch_size,
|
63 |
+
add_noise,
|
64 |
+
add_rotation,
|
65 |
+
add_blur,
|
66 |
+
num_workers,
|
67 |
+
)
|
68 |
+
|
69 |
+
# Store hyperparameters
|
70 |
+
self.class_idx = class_idx
|
71 |
+
self.grid_size = grid_size
|
72 |
+
|
73 |
+
# Save the existing transformations to be applied after creating the grid
|
74 |
+
self.post_transform = self.transform
|
75 |
+
# Set the pre-batch transformation to be the conversion from PIL to tensor
|
76 |
+
self.transform = transforms.PILToTensor()
|
77 |
+
|
78 |
+
# Specify the custom collate function and samplers
|
79 |
+
self.collate_fn = self.custom_collate_fn
|
80 |
+
self.shuffled_sampler = partial(
|
81 |
+
FairGridSampler,
|
82 |
+
class_idx=class_idx,
|
83 |
+
grid_size=grid_size,
|
84 |
+
shuffle=True,
|
85 |
+
)
|
86 |
+
self.sequential_sampler = partial(
|
87 |
+
FairGridSampler,
|
88 |
+
class_idx=class_idx,
|
89 |
+
grid_size=grid_size,
|
90 |
+
shuffle=False,
|
91 |
+
)
|
92 |
+
|
93 |
+
def custom_collate_fn(self, batch):
|
94 |
+
# Split the batch into groups of grid_size**2
|
95 |
+
idx = range(len(batch))
|
96 |
+
grids = zip(*(iter(idx),) * (self.grid_size**2))
|
97 |
+
|
98 |
+
new_batch = []
|
99 |
+
for grid in grids:
|
100 |
+
# Create a grid of images from the indices in the batch
|
101 |
+
img = torch.hstack(
|
102 |
+
[
|
103 |
+
torch.dstack(
|
104 |
+
[batch[i][0] for i in grid[idx : idx + self.grid_size]]
|
105 |
+
)
|
106 |
+
for idx in range(
|
107 |
+
0, self.grid_size**2 - self.grid_size + 1, self.grid_size
|
108 |
+
)
|
109 |
+
]
|
110 |
+
)
|
111 |
+
# Apply the post transformations to the grid
|
112 |
+
img = self.post_transform(img)
|
113 |
+
# Define the target as the number of images that have the class_idx
|
114 |
+
targets = [batch[i][1] for i in grid]
|
115 |
+
target = targets.count(self.class_idx)
|
116 |
+
# Append grid and target to the batch
|
117 |
+
new_batch += [(img, target)]
|
118 |
+
|
119 |
+
return default_collate(new_batch)
|
120 |
+
|
121 |
+
|
122 |
+
class ToyQADataModule(CIFAR10QADataModule):
|
123 |
+
"""A datamodule for the toy dataset as described in the paper."""
|
124 |
+
|
125 |
+
def prepare_data(self):
|
126 |
+
# No need to download anything for the toy task
|
127 |
+
pass
|
128 |
+
|
129 |
+
def setup(self, stage: Optional[str] = None):
|
130 |
+
img_size = 16
|
131 |
+
|
132 |
+
samples = []
|
133 |
+
# Generate 6000 samples based on 6 different colors
|
134 |
+
for r, g, b in itertools.product((0, 1), (0, 1), (0, 1)):
|
135 |
+
if r == g == b:
|
136 |
+
# We do not want black/white patches
|
137 |
+
continue
|
138 |
+
|
139 |
+
for _ in range(1000):
|
140 |
+
patch = torch.vstack(
|
141 |
+
[
|
142 |
+
r * torch.ones(1, img_size, img_size),
|
143 |
+
g * torch.ones(1, img_size, img_size),
|
144 |
+
b * torch.ones(1, img_size, img_size),
|
145 |
+
]
|
146 |
+
)
|
147 |
+
|
148 |
+
# Assign a unique id to each color
|
149 |
+
target = int(f"{r}{g}{b}", 2) - 1
|
150 |
+
# Append the patch and target to the samples
|
151 |
+
samples += [(patch, target)]
|
152 |
+
|
153 |
+
# Split the data to 90% train, 5% validation and 5% test
|
154 |
+
train_size = int(len(samples) * 0.9)
|
155 |
+
val_size = (len(samples) - train_size) // 2
|
156 |
+
test_size = len(samples) - train_size - val_size
|
157 |
+
self.train_data, self.val_data, self.test_data = random_split(
|
158 |
+
samples,
|
159 |
+
[
|
160 |
+
train_size,
|
161 |
+
val_size,
|
162 |
+
test_size,
|
163 |
+
],
|
164 |
+
)
|
165 |
+
|
166 |
+
|
167 |
+
class FairGridSampler(Sampler[int]):
|
168 |
+
def __init__(
|
169 |
+
self,
|
170 |
+
dataset: VisionDataset,
|
171 |
+
class_idx: int,
|
172 |
+
grid_size: int,
|
173 |
+
shuffle: bool = False,
|
174 |
+
):
|
175 |
+
"""A sampler that returns a grid of images from the dataset, with a uniformly random
|
176 |
+
amount of appearances for a specific class of interest.
|
177 |
+
|
178 |
+
Args:
|
179 |
+
dataset (VisionDataset): the dataset to sample from
|
180 |
+
class_idx(int): the class (index) to treat as the class of interest
|
181 |
+
grid_size (int): the number of images per row in the grid
|
182 |
+
shuffle (bool): whether to shuffle the dataset before sampling
|
183 |
+
"""
|
184 |
+
super().__init__(dataset)
|
185 |
+
|
186 |
+
# Save the hyperparameters
|
187 |
+
self.dataset = dataset
|
188 |
+
self.grid_size = grid_size
|
189 |
+
self.n_images = grid_size**2
|
190 |
+
|
191 |
+
# Get the indices of the class of interest
|
192 |
+
self.class_indices = LongTensor(
|
193 |
+
[i for i, x in enumerate(dataset) if x[1] == class_idx]
|
194 |
+
)
|
195 |
+
# Get the indices of all other classes
|
196 |
+
self.other_indices = LongTensor(
|
197 |
+
[i for i, x in enumerate(dataset) if x[1] != class_idx]
|
198 |
+
)
|
199 |
+
|
200 |
+
# Fix the seed if shuffle is False
|
201 |
+
self.seed = None if shuffle else self._get_seed()
|
202 |
+
|
203 |
+
@staticmethod
|
204 |
+
def _get_seed() -> int:
|
205 |
+
"""Utility function for generating a random seed."""
|
206 |
+
return int(torch.empty((), dtype=torch.int64).random_().item())
|
207 |
+
|
208 |
+
def __iter__(self) -> Iterator[int]:
|
209 |
+
# Create a torch Generator object
|
210 |
+
seed = self.seed if self.seed is not None else self._get_seed()
|
211 |
+
gen = torch.Generator()
|
212 |
+
gen.manual_seed(seed)
|
213 |
+
|
214 |
+
# Sample the batches
|
215 |
+
for _ in range(len(self.dataset) // self.n_images):
|
216 |
+
# Pick the number of instances for the class of interest
|
217 |
+
n_samples = torch.randint(self.n_images + 1, (), generator=gen).item()
|
218 |
+
|
219 |
+
# Sample the indices from the class of interest
|
220 |
+
idx_from_class = torch.randperm(
|
221 |
+
len(self.class_indices),
|
222 |
+
generator=gen,
|
223 |
+
)[:n_samples]
|
224 |
+
# Sample the indices from the other classes
|
225 |
+
idx_from_other = torch.randperm(
|
226 |
+
len(self.other_indices),
|
227 |
+
generator=gen,
|
228 |
+
)[: self.n_images - n_samples]
|
229 |
+
|
230 |
+
# Concatenate the corresponding lists of patches to form a grid
|
231 |
+
grid = (
|
232 |
+
self.class_indices[idx_from_class].tolist()
|
233 |
+
+ self.other_indices[idx_from_other].tolist()
|
234 |
+
)
|
235 |
+
|
236 |
+
# Shuffle the order of the patches within the grid
|
237 |
+
random.shuffle(grid)
|
238 |
+
yield from grid
|
239 |
+
|
240 |
+
def __len__(self) -> int:
|
241 |
+
return len(self.dataset)
|
code/eval_base.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datamodules import CIFAR10QADataModule, ImageDataModule
|
2 |
+
from datamodules.utils import datamodule_factory
|
3 |
+
from models import ImageClassificationNet
|
4 |
+
from models.utils import model_factory
|
5 |
+
from pytorch_lightning.loggers import WandbLogger
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
|
10 |
+
|
11 |
+
def main(args: argparse.Namespace):
|
12 |
+
# Seed
|
13 |
+
pl.seed_everything(args.seed)
|
14 |
+
|
15 |
+
# Create base model
|
16 |
+
base = model_factory(args, own_config=True)
|
17 |
+
|
18 |
+
# Load datamodule
|
19 |
+
dm = datamodule_factory(args)
|
20 |
+
|
21 |
+
# Load the model from the specified checkpoint
|
22 |
+
model = ImageClassificationNet.load_from_checkpoint(
|
23 |
+
args.checkpoint,
|
24 |
+
model=base,
|
25 |
+
num_train_steps=0,
|
26 |
+
)
|
27 |
+
|
28 |
+
# Create wandb logger
|
29 |
+
wandb_logger = WandbLogger(
|
30 |
+
name=f"{args.dataset}_eval_{args.base_model} ({args.from_pretrained})",
|
31 |
+
project="Patch-DiffMask",
|
32 |
+
)
|
33 |
+
|
34 |
+
# Create trainer
|
35 |
+
trainer = pl.Trainer(
|
36 |
+
accelerator="auto",
|
37 |
+
logger=wandb_logger,
|
38 |
+
max_epochs=1,
|
39 |
+
enable_progress_bar=args.enable_progress_bar,
|
40 |
+
)
|
41 |
+
|
42 |
+
# Evaluate the model
|
43 |
+
trainer.test(model, dm)
|
44 |
+
|
45 |
+
# Save the HuggingFace model to be used with --from_pretrained
|
46 |
+
save_dir = f"checkpoints/{args.base_model}_{args.dataset}"
|
47 |
+
model.model.save_pretrained(save_dir)
|
48 |
+
dm.feature_extractor.save_pretrained(save_dir)
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
parser = argparse.ArgumentParser()
|
53 |
+
|
54 |
+
parser.add_argument(
|
55 |
+
"--checkpoint",
|
56 |
+
type=str,
|
57 |
+
required=True,
|
58 |
+
help="Checkpoint to resume the training from.",
|
59 |
+
)
|
60 |
+
|
61 |
+
# Trainer
|
62 |
+
parser.add_argument(
|
63 |
+
"--enable_progress_bar",
|
64 |
+
action="store_true",
|
65 |
+
help="Whether to show progress bar during training. NOT recommended when logging to files.",
|
66 |
+
)
|
67 |
+
parser.add_argument(
|
68 |
+
"--seed",
|
69 |
+
type=int,
|
70 |
+
default=123,
|
71 |
+
help="Random seed for reproducibility.",
|
72 |
+
)
|
73 |
+
|
74 |
+
# Base (classification) model
|
75 |
+
parser.add_argument(
|
76 |
+
"--base_model",
|
77 |
+
type=str,
|
78 |
+
default="ViT",
|
79 |
+
choices=["ViT", "ConvNeXt"],
|
80 |
+
help="Base model architecture to train.",
|
81 |
+
)
|
82 |
+
parser.add_argument(
|
83 |
+
"--from_pretrained",
|
84 |
+
type=str,
|
85 |
+
# default="tanlq/vit-base-patch16-224-in21k-finetuned-cifar10",
|
86 |
+
help="The name of the pretrained HF model to fine-tune from.",
|
87 |
+
)
|
88 |
+
|
89 |
+
# Datamodule
|
90 |
+
ImageDataModule.add_model_specific_args(parser)
|
91 |
+
CIFAR10QADataModule.add_model_specific_args(parser)
|
92 |
+
parser.add_argument(
|
93 |
+
"--dataset",
|
94 |
+
type=str,
|
95 |
+
default="toy",
|
96 |
+
choices=["MNIST", "CIFAR10", "CIFAR10_QA", "toy"],
|
97 |
+
help="The dataset to use.",
|
98 |
+
)
|
99 |
+
|
100 |
+
args = parser.parse_args()
|
101 |
+
|
102 |
+
main(args)
|
code/main.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import ArgumentParser, Namespace
|
2 |
+
from attributions import attention_rollout, grad_cam
|
3 |
+
from datamodules import CIFAR10QADataModule, ImageDataModule
|
4 |
+
from datamodules.utils import datamodule_factory
|
5 |
+
from functools import partial
|
6 |
+
from models import ImageInterpretationNet
|
7 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
8 |
+
from pytorch_lightning.loggers import WandbLogger
|
9 |
+
from transformers import ViTForImageClassification
|
10 |
+
from utils.plot import DrawMaskCallback, log_masks
|
11 |
+
|
12 |
+
import pytorch_lightning as pl
|
13 |
+
|
14 |
+
|
15 |
+
def get_experiment_name(args: Namespace):
|
16 |
+
"""Create a name for the experiment based on the command line arguments."""
|
17 |
+
# Convert to dictionary
|
18 |
+
args = vars(args)
|
19 |
+
|
20 |
+
# Create a list with non-experiment arguments
|
21 |
+
non_experiment_args = [
|
22 |
+
"add_blur",
|
23 |
+
"add_noise",
|
24 |
+
"add_rotation",
|
25 |
+
"base_model",
|
26 |
+
"batch_size",
|
27 |
+
"class_idx",
|
28 |
+
"data_dir",
|
29 |
+
"enable_progress_bar",
|
30 |
+
"from_pretrained",
|
31 |
+
"log_every_n_steps",
|
32 |
+
"num_epochs",
|
33 |
+
"num_workers",
|
34 |
+
"sample_images",
|
35 |
+
"seed",
|
36 |
+
]
|
37 |
+
|
38 |
+
# Create experiment name from experiment arguments
|
39 |
+
return "-".join(
|
40 |
+
[
|
41 |
+
f"{name}={value}"
|
42 |
+
for name, value in sorted(args.items())
|
43 |
+
if name not in non_experiment_args
|
44 |
+
]
|
45 |
+
)
|
46 |
+
|
47 |
+
|
48 |
+
def setup_sample_image_logs(
|
49 |
+
dm: ImageDataModule,
|
50 |
+
args: Namespace,
|
51 |
+
logger: WandbLogger,
|
52 |
+
n_panels: int = 2, # TODO: change?
|
53 |
+
):
|
54 |
+
"""Setup the log callbacks for sampling and plotting images."""
|
55 |
+
images_per_panel = args.sample_images
|
56 |
+
|
57 |
+
# Sample images
|
58 |
+
sample_images = []
|
59 |
+
iter_loader = iter(dm.val_dataloader())
|
60 |
+
for panel in range(n_panels):
|
61 |
+
X, Y = next(iter_loader)
|
62 |
+
sample_images += [(X[:images_per_panel], Y[:images_per_panel])]
|
63 |
+
|
64 |
+
# Define mask callback
|
65 |
+
mask_cb = partial(DrawMaskCallback, log_every_n_steps=args.log_every_n_steps)
|
66 |
+
|
67 |
+
callbacks = []
|
68 |
+
for panel in range(n_panels):
|
69 |
+
# Initialize ViT model
|
70 |
+
vit = ViTForImageClassification.from_pretrained(args.from_pretrained)
|
71 |
+
|
72 |
+
# Extract samples for current panel
|
73 |
+
samples = sample_images[panel]
|
74 |
+
X, _ = samples
|
75 |
+
|
76 |
+
# Log GradCAM
|
77 |
+
gradcam_masks = grad_cam(X, vit)
|
78 |
+
log_masks(X, gradcam_masks, f"GradCAM {panel}", logger)
|
79 |
+
|
80 |
+
# Log Attention Rollout
|
81 |
+
rollout_masks = attention_rollout(X, vit)
|
82 |
+
log_masks(X, rollout_masks, f"Attention Rollout {panel}", logger)
|
83 |
+
|
84 |
+
# Create mask callback
|
85 |
+
callbacks += [mask_cb(samples, key=f"{panel}")]
|
86 |
+
|
87 |
+
return callbacks
|
88 |
+
|
89 |
+
|
90 |
+
def main(args: Namespace):
|
91 |
+
# Seed
|
92 |
+
pl.seed_everything(args.seed)
|
93 |
+
|
94 |
+
# Load pre-trained Transformer
|
95 |
+
model = ViTForImageClassification.from_pretrained(args.from_pretrained)
|
96 |
+
|
97 |
+
# Load datamodule
|
98 |
+
dm = datamodule_factory(args)
|
99 |
+
|
100 |
+
# Setup datamodule to sample images for the mask callback
|
101 |
+
dm.prepare_data()
|
102 |
+
dm.setup("fit")
|
103 |
+
|
104 |
+
# Create Vision DiffMask for the model
|
105 |
+
diffmask = ImageInterpretationNet(
|
106 |
+
model_cfg=model.config,
|
107 |
+
alpha=args.alpha,
|
108 |
+
lr=args.lr,
|
109 |
+
eps=args.eps,
|
110 |
+
lr_placeholder=args.lr_placeholder,
|
111 |
+
lr_alpha=args.lr_alpha,
|
112 |
+
mul_activation=args.mul_activation,
|
113 |
+
add_activation=args.add_activation,
|
114 |
+
placeholder=not args.no_placeholder,
|
115 |
+
weighted_layer_pred=args.weighted_layer_distribution,
|
116 |
+
)
|
117 |
+
diffmask.set_vision_transformer(model)
|
118 |
+
|
119 |
+
# Create wandb logger instance
|
120 |
+
wandb_logger = WandbLogger(
|
121 |
+
name=get_experiment_name(args),
|
122 |
+
project="Patch-DiffMask",
|
123 |
+
)
|
124 |
+
|
125 |
+
# Create checkpoint callback
|
126 |
+
ckpt_cb = ModelCheckpoint(
|
127 |
+
save_top_k=-1,
|
128 |
+
dirpath=f"checkpoints/{wandb_logger.version}",
|
129 |
+
every_n_train_steps=args.log_every_n_steps,
|
130 |
+
)
|
131 |
+
|
132 |
+
# Create mask callbacks
|
133 |
+
mask_cbs = setup_sample_image_logs(dm, args, wandb_logger)
|
134 |
+
|
135 |
+
# Create trainer
|
136 |
+
trainer = pl.Trainer(
|
137 |
+
accelerator="auto",
|
138 |
+
callbacks=[ckpt_cb, *mask_cbs],
|
139 |
+
enable_progress_bar=args.enable_progress_bar,
|
140 |
+
logger=wandb_logger,
|
141 |
+
max_epochs=args.num_epochs,
|
142 |
+
)
|
143 |
+
|
144 |
+
# Train the model
|
145 |
+
trainer.fit(diffmask, dm)
|
146 |
+
|
147 |
+
|
148 |
+
if __name__ == "__main__":
|
149 |
+
parser = ArgumentParser()
|
150 |
+
|
151 |
+
# Trainer
|
152 |
+
parser.add_argument(
|
153 |
+
"--enable_progress_bar",
|
154 |
+
action="store_true",
|
155 |
+
help="Whether to enable the progress bar (NOT recommended when logging to file).",
|
156 |
+
)
|
157 |
+
parser.add_argument(
|
158 |
+
"--num_epochs",
|
159 |
+
type=int,
|
160 |
+
default=5,
|
161 |
+
help="Number of epochs to train.",
|
162 |
+
)
|
163 |
+
parser.add_argument(
|
164 |
+
"--seed",
|
165 |
+
type=int,
|
166 |
+
default=123,
|
167 |
+
help="Random seed for reproducibility.",
|
168 |
+
)
|
169 |
+
|
170 |
+
# Logging
|
171 |
+
parser.add_argument(
|
172 |
+
"--sample_images",
|
173 |
+
type=int,
|
174 |
+
default=8,
|
175 |
+
help="Number of images to sample for the mask callback.",
|
176 |
+
)
|
177 |
+
parser.add_argument(
|
178 |
+
"--log_every_n_steps",
|
179 |
+
type=int,
|
180 |
+
default=200,
|
181 |
+
help="Number of steps between logging media & checkpoints.",
|
182 |
+
)
|
183 |
+
|
184 |
+
# Base (classification) model
|
185 |
+
parser.add_argument(
|
186 |
+
"--base_model",
|
187 |
+
type=str,
|
188 |
+
default="ViT",
|
189 |
+
choices=["ViT"],
|
190 |
+
help="Base model architecture to train.",
|
191 |
+
)
|
192 |
+
parser.add_argument(
|
193 |
+
"--from_pretrained",
|
194 |
+
type=str,
|
195 |
+
default="tanlq/vit-base-patch16-224-in21k-finetuned-cifar10",
|
196 |
+
help="The name of the pretrained HF model to load.",
|
197 |
+
)
|
198 |
+
|
199 |
+
# Interpretation model
|
200 |
+
ImageInterpretationNet.add_model_specific_args(parser)
|
201 |
+
|
202 |
+
# Datamodule
|
203 |
+
ImageDataModule.add_model_specific_args(parser)
|
204 |
+
CIFAR10QADataModule.add_model_specific_args(parser)
|
205 |
+
parser.add_argument(
|
206 |
+
"--dataset",
|
207 |
+
type=str,
|
208 |
+
default="CIFAR10",
|
209 |
+
choices=["MNIST", "CIFAR10", "CIFAR10_QA", "toy"],
|
210 |
+
help="The dataset to use.",
|
211 |
+
)
|
212 |
+
|
213 |
+
args = parser.parse_args()
|
214 |
+
|
215 |
+
main(args)
|
code/models/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .classification import ImageClassificationNet
|
2 |
+
from .interpretation import ImageInterpretationNet
|
code/models/classification.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Parts of this file have been adapted from
|
3 |
+
https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial15/Vision_Transformer.html
|
4 |
+
"""
|
5 |
+
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from argparse import ArgumentParser
|
10 |
+
from torch import Tensor
|
11 |
+
from torch.optim import AdamW, Optimizer, RAdam
|
12 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
13 |
+
from transformers import get_scheduler, PreTrainedModel
|
14 |
+
|
15 |
+
|
16 |
+
class ImageClassificationNet(pl.LightningModule):
|
17 |
+
@staticmethod
|
18 |
+
def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
|
19 |
+
parser = parent_parser.add_argument_group("Classification Model")
|
20 |
+
parser.add_argument(
|
21 |
+
"--optimizer",
|
22 |
+
type=str,
|
23 |
+
default="AdamW",
|
24 |
+
choices=["AdamW", "RAdam"],
|
25 |
+
help="The optimizer to use to train the model.",
|
26 |
+
)
|
27 |
+
parser.add_argument(
|
28 |
+
"--weight_decay",
|
29 |
+
type=float,
|
30 |
+
default=1e-2,
|
31 |
+
help="The optimizer's weight decay.",
|
32 |
+
)
|
33 |
+
parser.add_argument(
|
34 |
+
"--lr",
|
35 |
+
type=float,
|
36 |
+
default=5e-5,
|
37 |
+
help="The initial learning rate for the model.",
|
38 |
+
)
|
39 |
+
return parent_parser
|
40 |
+
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
model: PreTrainedModel,
|
44 |
+
num_train_steps: int,
|
45 |
+
optimizer: str = "AdamW",
|
46 |
+
weight_decay: float = 1e-2,
|
47 |
+
lr: float = 5e-5,
|
48 |
+
):
|
49 |
+
"""A PyTorch Lightning Module for a HuggingFace model used for image classification.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
model (PreTrainedModel): a pretrained model for image classification
|
53 |
+
num_train_steps (int): number of training steps
|
54 |
+
optimizer (str): optimizer to use
|
55 |
+
weight_decay (float): weight decay for optimizer
|
56 |
+
lr (float): the learning rate used for training
|
57 |
+
"""
|
58 |
+
super().__init__()
|
59 |
+
|
60 |
+
# Save the hyperparameters and the model
|
61 |
+
self.save_hyperparameters(ignore=["model"])
|
62 |
+
self.model = model
|
63 |
+
|
64 |
+
def forward(self, x: Tensor) -> Tensor:
|
65 |
+
return self.model(x).logits
|
66 |
+
|
67 |
+
def configure_optimizers(self) -> tuple[list[Optimizer], list[_LRScheduler]]:
|
68 |
+
# Set the optimizer class based on the hyperparameter
|
69 |
+
if self.hparams.optimizer == "AdamW":
|
70 |
+
optim_class = AdamW
|
71 |
+
elif self.hparams.optimizer == "RAdam":
|
72 |
+
optim_class = RAdam
|
73 |
+
else:
|
74 |
+
raise Exception(f"Unknown optimizer {self.hparams.optimizer}")
|
75 |
+
|
76 |
+
# Create the optimizer and the learning rate scheduler
|
77 |
+
optimizer = optim_class(
|
78 |
+
self.parameters(),
|
79 |
+
weight_decay=self.hparams.weight_decay,
|
80 |
+
lr=self.hparams.lr,
|
81 |
+
)
|
82 |
+
lr_scheduler = get_scheduler(
|
83 |
+
name="linear",
|
84 |
+
optimizer=optimizer,
|
85 |
+
num_warmup_steps=0,
|
86 |
+
num_training_steps=self.hparams.num_train_steps,
|
87 |
+
)
|
88 |
+
|
89 |
+
return [optimizer], [lr_scheduler]
|
90 |
+
|
91 |
+
def _calculate_loss(self, batch: tuple[Tensor, Tensor], mode: str) -> Tensor:
|
92 |
+
imgs, labels = batch
|
93 |
+
|
94 |
+
preds = self.model(imgs).logits
|
95 |
+
loss = F.cross_entropy(preds, labels)
|
96 |
+
acc = (preds.argmax(dim=-1) == labels).float().mean()
|
97 |
+
|
98 |
+
self.log(f"{mode}_loss", loss)
|
99 |
+
self.log(f"{mode}_acc", acc)
|
100 |
+
|
101 |
+
return loss
|
102 |
+
|
103 |
+
def training_step(self, batch: tuple[Tensor, Tensor], _: Tensor) -> Tensor:
|
104 |
+
loss = self._calculate_loss(batch, mode="train")
|
105 |
+
|
106 |
+
return loss
|
107 |
+
|
108 |
+
def validation_step(self, batch: tuple[Tensor, Tensor], _: Tensor):
|
109 |
+
self._calculate_loss(batch, mode="val")
|
110 |
+
|
111 |
+
def test_step(self, batch: tuple[Tensor, Tensor], _: Tensor):
|
112 |
+
self._calculate_loss(batch, mode="test")
|
code/models/gates.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Parts of this file have been adapted from
|
3 |
+
https://github.com/nicola-decao/diffmask/blob/master/diffmask/models/gates.py
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from torch import Tensor
|
10 |
+
from typing import Optional
|
11 |
+
from utils.distributions import RectifiedStreched, BinaryConcrete
|
12 |
+
|
13 |
+
|
14 |
+
class MLPGate(nn.Module):
|
15 |
+
def __init__(self, input_size: int, hidden_size: int, bias: bool = True):
|
16 |
+
"""
|
17 |
+
This is an MLP with the following structure;
|
18 |
+
Linear(input_size, hidden_size), Tanh(), Linear(hidden_size, 1)
|
19 |
+
The bias of the last layer is set to 5.0 to start with high probability
|
20 |
+
of keeping states (fundamental for good convergence as the initialized
|
21 |
+
DiffMask has not learned what to mask yet).
|
22 |
+
|
23 |
+
Args:
|
24 |
+
input_size (int): the number of input features
|
25 |
+
hidden_size (int): the number of hidden units
|
26 |
+
bias (bool): whether to use a bias term
|
27 |
+
"""
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.f = nn.Sequential(
|
31 |
+
nn.utils.weight_norm(nn.Linear(input_size, hidden_size)),
|
32 |
+
nn.Tanh(),
|
33 |
+
nn.utils.weight_norm(nn.Linear(hidden_size, 1, bias=bias)),
|
34 |
+
)
|
35 |
+
|
36 |
+
if bias:
|
37 |
+
self.f[-1].bias.data[:] = 5.0
|
38 |
+
|
39 |
+
def forward(self, *args: Tensor) -> Tensor:
|
40 |
+
return self.f(torch.cat(args, -1))
|
41 |
+
|
42 |
+
|
43 |
+
class MLPMaxGate(nn.Module):
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
input_size: int,
|
47 |
+
hidden_size: int,
|
48 |
+
mul_activation: float = 10.0,
|
49 |
+
add_activation: float = 5.0,
|
50 |
+
bias: bool = True,
|
51 |
+
):
|
52 |
+
"""
|
53 |
+
This is an MLP with the following structure;
|
54 |
+
Linear(input_size, hidden_size), Tanh(), Linear(hidden_size, 1)
|
55 |
+
The bias of the last layer is set to 5.0 to start with high probability
|
56 |
+
of keeping states (fundamental for good convergence as the initialized
|
57 |
+
DiffMask has not learned what to mask yet).
|
58 |
+
It also uses a scaler for the output of the activation function.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
input_size (int): the number of input features
|
62 |
+
hidden_size (int): the number of hidden units
|
63 |
+
mul_activation (float): the scaler for the output of the activation function
|
64 |
+
add_activation (float): the offset for the output of the activation function
|
65 |
+
bias (bool): whether to use a bias term
|
66 |
+
"""
|
67 |
+
super().__init__()
|
68 |
+
|
69 |
+
self.f = nn.Sequential(
|
70 |
+
nn.utils.weight_norm(nn.Linear(input_size, hidden_size)),
|
71 |
+
nn.Tanh(),
|
72 |
+
nn.utils.weight_norm(nn.Linear(hidden_size, 1, bias=bias)),
|
73 |
+
nn.Tanh(),
|
74 |
+
)
|
75 |
+
self.add_activation = nn.Parameter(torch.tensor(add_activation))
|
76 |
+
self.mul_activation = mul_activation
|
77 |
+
|
78 |
+
def forward(self, *args: Tensor) -> Tensor:
|
79 |
+
return self.f(torch.cat(args, -1)) * self.mul_activation + self.add_activation
|
80 |
+
|
81 |
+
|
82 |
+
class DiffMaskGateInput(nn.Module):
|
83 |
+
def __init__(
|
84 |
+
self,
|
85 |
+
hidden_size: int,
|
86 |
+
hidden_attention: int,
|
87 |
+
num_hidden_layers: int,
|
88 |
+
max_position_embeddings: int,
|
89 |
+
gate_fn: nn.Module = MLPMaxGate,
|
90 |
+
mul_activation: float = 10.0,
|
91 |
+
add_activation: float = 5.0,
|
92 |
+
gate_bias: bool = True,
|
93 |
+
placeholder: bool = False,
|
94 |
+
init_vector: Tensor = None,
|
95 |
+
):
|
96 |
+
"""This is a DiffMask module that masks the input of the first layer.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
hidden_size (int): the size of the hidden representations
|
100 |
+
hidden_attention (int) the amount of units in the gate's hidden (bottleneck) layer
|
101 |
+
num_hidden_layers (int): the number of hidden layers (and thus gates to use)
|
102 |
+
max_position_embeddings (int): the amount of placeholder embeddings to learn for the masked positions
|
103 |
+
gate_fn (nn.Module): the PyTorch module to use as a gate
|
104 |
+
mul_activation (float): the scaler for the output of the activation function
|
105 |
+
add_activation (float): the offset for the output of the activation function
|
106 |
+
gate_bias (bool): whether to use a bias term
|
107 |
+
placeholder (bool): whether to use placeholder embeddings or a zero vector
|
108 |
+
init_vector (Tensor): the initial vector to use for the placeholder embeddings
|
109 |
+
"""
|
110 |
+
super().__init__()
|
111 |
+
|
112 |
+
# Create a ModuleList with the gates
|
113 |
+
self.g_hat = nn.ModuleList(
|
114 |
+
[
|
115 |
+
gate_fn(
|
116 |
+
hidden_size * 2,
|
117 |
+
hidden_attention,
|
118 |
+
mul_activation,
|
119 |
+
add_activation,
|
120 |
+
gate_bias,
|
121 |
+
)
|
122 |
+
for _ in range(num_hidden_layers)
|
123 |
+
]
|
124 |
+
)
|
125 |
+
|
126 |
+
if placeholder:
|
127 |
+
# Use a placeholder embedding for the masked positions
|
128 |
+
self.placeholder = nn.Parameter(
|
129 |
+
nn.init.xavier_normal_(
|
130 |
+
torch.empty((1, max_position_embeddings, hidden_size))
|
131 |
+
)
|
132 |
+
if init_vector is None
|
133 |
+
else init_vector.view(1, 1, hidden_size).repeat(
|
134 |
+
1, max_position_embeddings, 1
|
135 |
+
)
|
136 |
+
)
|
137 |
+
else:
|
138 |
+
# Use a zero vector for the masked positions
|
139 |
+
self.register_buffer(
|
140 |
+
"placeholder",
|
141 |
+
torch.zeros((1, 1, hidden_size)),
|
142 |
+
)
|
143 |
+
|
144 |
+
def forward(
|
145 |
+
self, hidden_states: tuple[Tensor], layer_pred: Optional[int]
|
146 |
+
) -> tuple[tuple[Tensor], Tensor, Tensor, Tensor, Tensor]:
|
147 |
+
# Concatenate the output of all the gates
|
148 |
+
logits = torch.cat(
|
149 |
+
[
|
150 |
+
self.g_hat[i](hidden_states[0], hidden_states[i])
|
151 |
+
for i in range(
|
152 |
+
(layer_pred + 1) if layer_pred is not None else len(hidden_states)
|
153 |
+
)
|
154 |
+
],
|
155 |
+
-1,
|
156 |
+
)
|
157 |
+
|
158 |
+
# Define a Hard Concrete distribution
|
159 |
+
dist = RectifiedStreched(
|
160 |
+
BinaryConcrete(torch.full_like(logits, 0.2), logits),
|
161 |
+
l=-0.2,
|
162 |
+
r=1.0,
|
163 |
+
)
|
164 |
+
|
165 |
+
# Calculate the expectation for the full gate probabilities
|
166 |
+
# These act as votes for the masked positions
|
167 |
+
gates_full = dist.rsample().cumprod(-1)
|
168 |
+
expected_L0_full = dist.log_expected_L0().cumsum(-1)
|
169 |
+
|
170 |
+
# Extract the probabilities from the last layer, which acts
|
171 |
+
# as an aggregation of the votes per position
|
172 |
+
gates = gates_full[..., -1]
|
173 |
+
expected_L0 = expected_L0_full[..., -1]
|
174 |
+
|
175 |
+
return (
|
176 |
+
hidden_states[0] * gates.unsqueeze(-1)
|
177 |
+
+ self.placeholder[:, : hidden_states[0].shape[-2]]
|
178 |
+
* (1 - gates).unsqueeze(-1),
|
179 |
+
gates,
|
180 |
+
expected_L0,
|
181 |
+
gates_full,
|
182 |
+
expected_L0_full,
|
183 |
+
)
|
184 |
+
|
185 |
+
|
186 |
+
# class DiffMaskGateHidden(nn.Module):
|
187 |
+
# def __init__(
|
188 |
+
# self,
|
189 |
+
# hidden_size: int,
|
190 |
+
# hidden_attention: int,
|
191 |
+
# num_hidden_layers: int,
|
192 |
+
# max_position_embeddings: int,
|
193 |
+
# gate_fn: nn.Module = MLPMaxGate,
|
194 |
+
# gate_bias: bool = True,
|
195 |
+
# placeholder: bool = False,
|
196 |
+
# init_vector: Tensor = None,
|
197 |
+
# ):
|
198 |
+
# super().__init__()
|
199 |
+
#
|
200 |
+
# self.g_hat = nn.ModuleList(
|
201 |
+
# [
|
202 |
+
# gate_fn(hidden_size, hidden_attention, bias=gate_bias)
|
203 |
+
# for _ in range(num_hidden_layers)
|
204 |
+
# ]
|
205 |
+
# )
|
206 |
+
#
|
207 |
+
# if placeholder:
|
208 |
+
# self.placeholder = nn.ParameterList(
|
209 |
+
# [
|
210 |
+
# nn.Parameter(
|
211 |
+
# nn.init.xavier_normal_(
|
212 |
+
# torch.empty((1, max_position_embeddings, hidden_size))
|
213 |
+
# )
|
214 |
+
# if init_vector is None
|
215 |
+
# else init_vector.view(1, 1, hidden_size).repeat(
|
216 |
+
# 1, max_position_embeddings, 1
|
217 |
+
# )
|
218 |
+
# )
|
219 |
+
# for _ in range(num_hidden_layers)
|
220 |
+
# ]
|
221 |
+
# )
|
222 |
+
# else:
|
223 |
+
# self.register_buffer(
|
224 |
+
# "placeholder",
|
225 |
+
# torch.zeros((num_hidden_layers, 1, 1, hidden_size)),
|
226 |
+
# )
|
227 |
+
#
|
228 |
+
# def forward(
|
229 |
+
# self, hidden_states: tuple[Tensor], layer_pred: Optional[int]
|
230 |
+
# ) -> tuple[tuple[Tensor], Tensor, Tensor, Tensor, Tensor]:
|
231 |
+
# if layer_pred is not None:
|
232 |
+
# logits = self.g_hat[layer_pred](hidden_states[layer_pred])
|
233 |
+
# else:
|
234 |
+
# logits = torch.cat(
|
235 |
+
# [self.g_hat[i](hidden_states[i]) for i in range(len(hidden_states))], -1
|
236 |
+
# )
|
237 |
+
#
|
238 |
+
# dist = RectifiedStreched(
|
239 |
+
# BinaryConcrete(torch.full_like(logits, 0.2), logits),
|
240 |
+
# l=-0.2,
|
241 |
+
# r=1.0,
|
242 |
+
# )
|
243 |
+
#
|
244 |
+
# gates_full = dist.rsample()
|
245 |
+
# expected_L0_full = dist.log_expected_L0()
|
246 |
+
#
|
247 |
+
# gates = gates_full if layer_pred is not None else gates_full[..., :1]
|
248 |
+
# expected_L0 = (
|
249 |
+
# expected_L0_full if layer_pred is not None else expected_L0_full[..., :1]
|
250 |
+
# )
|
251 |
+
#
|
252 |
+
# layer_pred = layer_pred or 0 # equiv to "layer_pred if layer_pred else 0"
|
253 |
+
# return (
|
254 |
+
# hidden_states[layer_pred] * gates
|
255 |
+
# + self.placeholder[layer_pred][:, : hidden_states[layer_pred].shape[-2]]
|
256 |
+
# * (1 - gates),
|
257 |
+
# gates.squeeze(-1),
|
258 |
+
# expected_L0.squeeze(-1),
|
259 |
+
# gates_full,
|
260 |
+
# expected_L0_full,
|
261 |
+
# )
|
code/models/interpretation.py
ADDED
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .gates import DiffMaskGateInput
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
from math import sqrt
|
8 |
+
from pytorch_lightning.core.optimizer import LightningOptimizer
|
9 |
+
from torch import Tensor
|
10 |
+
from torch.optim import Optimizer
|
11 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
12 |
+
from transformers import (
|
13 |
+
get_constant_schedule_with_warmup,
|
14 |
+
get_constant_schedule,
|
15 |
+
ViTForImageClassification,
|
16 |
+
)
|
17 |
+
from transformers.models.vit.configuration_vit import ViTConfig
|
18 |
+
from typing import Optional, Union
|
19 |
+
from utils.getters_setters import vit_getter, vit_setter
|
20 |
+
from utils.metrics import accuracy_precision_recall_f1
|
21 |
+
from utils.optimizer import LookaheadAdam
|
22 |
+
|
23 |
+
|
24 |
+
class ImageInterpretationNet(pl.LightningModule):
|
25 |
+
@staticmethod
|
26 |
+
def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
|
27 |
+
parser = parent_parser.add_argument_group("Vision DiffMask")
|
28 |
+
parser.add_argument(
|
29 |
+
"--alpha",
|
30 |
+
type=float,
|
31 |
+
default=20.0,
|
32 |
+
help="Initial value for the Lagrangian",
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"--lr",
|
36 |
+
type=float,
|
37 |
+
default=2e-5,
|
38 |
+
help="Learning rate for DiffMask.",
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--eps",
|
42 |
+
type=float,
|
43 |
+
default=0.1,
|
44 |
+
help="KL divergence tolerance.",
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"--no_placeholder",
|
48 |
+
action="store_true",
|
49 |
+
help="Whether to not use placeholder",
|
50 |
+
)
|
51 |
+
parser.add_argument(
|
52 |
+
"--lr_placeholder",
|
53 |
+
type=float,
|
54 |
+
default=1e-3,
|
55 |
+
help="Learning for mask vectors.",
|
56 |
+
)
|
57 |
+
parser.add_argument(
|
58 |
+
"--lr_alpha",
|
59 |
+
type=float,
|
60 |
+
default=0.3,
|
61 |
+
help="Learning rate for lagrangian optimizer.",
|
62 |
+
)
|
63 |
+
parser.add_argument(
|
64 |
+
"--mul_activation",
|
65 |
+
type=float,
|
66 |
+
default=15.0,
|
67 |
+
help="Value to multiply gate activations.",
|
68 |
+
)
|
69 |
+
parser.add_argument(
|
70 |
+
"--add_activation",
|
71 |
+
type=float,
|
72 |
+
default=8.0,
|
73 |
+
help="Value to add to gate activations.",
|
74 |
+
)
|
75 |
+
parser.add_argument(
|
76 |
+
"--weighted_layer_distribution",
|
77 |
+
action="store_true",
|
78 |
+
help="Whether to use a weighted distribution when picking a layer in DiffMask forward.",
|
79 |
+
)
|
80 |
+
return parent_parser
|
81 |
+
|
82 |
+
# Declare variables that will be initialized later
|
83 |
+
model: ViTForImageClassification
|
84 |
+
|
85 |
+
def __init__(
|
86 |
+
self,
|
87 |
+
model_cfg: ViTConfig,
|
88 |
+
alpha: float = 1,
|
89 |
+
lr: float = 3e-4,
|
90 |
+
eps: float = 0.1,
|
91 |
+
eps_valid: float = 0.8,
|
92 |
+
acc_valid: float = 0.75,
|
93 |
+
lr_placeholder: float = 1e-3,
|
94 |
+
lr_alpha: float = 0.3,
|
95 |
+
mul_activation: float = 10.0,
|
96 |
+
add_activation: float = 5.0,
|
97 |
+
placeholder: bool = True,
|
98 |
+
weighted_layer_pred: bool = False,
|
99 |
+
):
|
100 |
+
"""A PyTorch Lightning Module for the VisionDiffMask model on the Vision Transformer.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
model_cfg (ViTConfig): the configuration of the Vision Transformer model
|
104 |
+
alpha (float): the initial value for the Lagrangian
|
105 |
+
lr (float): the learning rate for the DiffMask gates
|
106 |
+
eps (float): the tolerance for the KL divergence
|
107 |
+
eps_valid (float): the tolerance for the KL divergence in the validation step
|
108 |
+
acc_valid (float): the accuracy threshold for the validation step
|
109 |
+
lr_placeholder (float): the learning rate for the learnable masking embeddings
|
110 |
+
lr_alpha (float): the learning rate for the Lagrangian
|
111 |
+
mul_activation (float): the value to multiply the gate activations by
|
112 |
+
add_activation (float): the value to add to the gate activations
|
113 |
+
placeholder (bool): whether to use placeholder embeddings or a zero vector
|
114 |
+
weighted_layer_pred (bool): whether to use a weighted distribution when picking a layer
|
115 |
+
"""
|
116 |
+
super().__init__()
|
117 |
+
|
118 |
+
# Save the hyperparameters
|
119 |
+
self.save_hyperparameters()
|
120 |
+
|
121 |
+
# Create DiffMask instance
|
122 |
+
self.gate = DiffMaskGateInput(
|
123 |
+
hidden_size=model_cfg.hidden_size,
|
124 |
+
hidden_attention=model_cfg.hidden_size // 4,
|
125 |
+
num_hidden_layers=model_cfg.num_hidden_layers + 2,
|
126 |
+
max_position_embeddings=1,
|
127 |
+
mul_activation=mul_activation,
|
128 |
+
add_activation=add_activation,
|
129 |
+
placeholder=placeholder,
|
130 |
+
)
|
131 |
+
|
132 |
+
# Create the Lagrangian values for the dual optimization
|
133 |
+
self.alpha = torch.nn.ParameterList(
|
134 |
+
[
|
135 |
+
torch.nn.Parameter(torch.ones(()) * alpha)
|
136 |
+
for _ in range(model_cfg.num_hidden_layers + 2)
|
137 |
+
]
|
138 |
+
)
|
139 |
+
|
140 |
+
# Register buffers for running metrics
|
141 |
+
self.register_buffer(
|
142 |
+
"running_acc", torch.ones((model_cfg.num_hidden_layers + 2,))
|
143 |
+
)
|
144 |
+
self.register_buffer(
|
145 |
+
"running_l0", torch.ones((model_cfg.num_hidden_layers + 2,))
|
146 |
+
)
|
147 |
+
self.register_buffer(
|
148 |
+
"running_steps", torch.zeros((model_cfg.num_hidden_layers + 2,))
|
149 |
+
)
|
150 |
+
|
151 |
+
def set_vision_transformer(self, model: ViTForImageClassification):
|
152 |
+
"""Set the Vision Transformer model to be used with this module."""
|
153 |
+
# Save the model instance as a class attribute
|
154 |
+
self.model = model
|
155 |
+
# Freeze the model's parameters
|
156 |
+
for param in self.model.parameters():
|
157 |
+
param.requires_grad = False
|
158 |
+
|
159 |
+
def forward_explainer(
|
160 |
+
self, x: Tensor, attribution: bool = False
|
161 |
+
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, int, int]:
|
162 |
+
"""Performs a forward pass through the explainer (VisionDiffMask) model."""
|
163 |
+
# Get the original logits and hidden states from the model
|
164 |
+
logits_orig, hidden_states = vit_getter(self.model, x)
|
165 |
+
|
166 |
+
# Add [CLS] token to deal with shape mismatch in self.gate() call
|
167 |
+
patch_embeddings = hidden_states[0]
|
168 |
+
batch_size = len(patch_embeddings)
|
169 |
+
cls_tokens = self.model.vit.embeddings.cls_token.expand(batch_size, -1, -1)
|
170 |
+
hidden_states[0] = torch.cat((cls_tokens, patch_embeddings), dim=1)
|
171 |
+
|
172 |
+
# Select the layer to generate the mask from in this pass
|
173 |
+
n_hidden = len(hidden_states)
|
174 |
+
if self.hparams.weighted_layer_pred:
|
175 |
+
# If weighted layer prediction is enabled, use a weighted distribution
|
176 |
+
# instead of uniformly picking a layer after a number of steps
|
177 |
+
low_weight = (
|
178 |
+
lambda i: self.running_acc[i] > 0.75
|
179 |
+
and self.running_l0[i] < 0.1
|
180 |
+
and self.running_steps[i] > 100
|
181 |
+
)
|
182 |
+
layers = torch.tensor(list(range(n_hidden)))
|
183 |
+
p = torch.tensor([0.1 if low_weight(i) else 1 for i in range(n_hidden)])
|
184 |
+
p = p / p.sum()
|
185 |
+
idx = p.multinomial(num_samples=1)
|
186 |
+
layer_pred = layers[idx].item()
|
187 |
+
else:
|
188 |
+
layer_pred = torch.randint(n_hidden, ()).item()
|
189 |
+
|
190 |
+
# Set the layer to drop to 0, since we are only interested in masking the input
|
191 |
+
layer_drop = 0
|
192 |
+
|
193 |
+
(
|
194 |
+
new_hidden_state,
|
195 |
+
gates,
|
196 |
+
expected_L0,
|
197 |
+
gates_full,
|
198 |
+
expected_L0_full,
|
199 |
+
) = self.gate(
|
200 |
+
hidden_states=hidden_states,
|
201 |
+
layer_pred=None
|
202 |
+
if attribution
|
203 |
+
else layer_pred, # if attribution, we get all the hidden states
|
204 |
+
)
|
205 |
+
|
206 |
+
# Create the list of the new hidden states for the new forward pass
|
207 |
+
new_hidden_states = (
|
208 |
+
[None] * layer_drop
|
209 |
+
+ [new_hidden_state]
|
210 |
+
+ [None] * (n_hidden - layer_drop - 1)
|
211 |
+
)
|
212 |
+
|
213 |
+
# Get the new logits from the masked input
|
214 |
+
logits, _ = vit_setter(self.model, x, new_hidden_states)
|
215 |
+
|
216 |
+
return (
|
217 |
+
logits,
|
218 |
+
logits_orig,
|
219 |
+
gates,
|
220 |
+
expected_L0,
|
221 |
+
gates_full,
|
222 |
+
expected_L0_full,
|
223 |
+
layer_drop,
|
224 |
+
layer_pred,
|
225 |
+
)
|
226 |
+
|
227 |
+
def get_mask(self, x: Tensor,
|
228 |
+
idx: int = -1,
|
229 |
+
aggregated_mask: bool = True,
|
230 |
+
) -> dict[str, Tensor]:
|
231 |
+
"""
|
232 |
+
Generates a mask for the given input.
|
233 |
+
Args:
|
234 |
+
x: the input to generate the mask for
|
235 |
+
idx: the index of the layer to generate the mask from
|
236 |
+
aggregated_mask: whether to use an aggregative mask from each layer
|
237 |
+
Returns:
|
238 |
+
a dictionary containing the mask, kl divergence and the predicted class
|
239 |
+
"""
|
240 |
+
|
241 |
+
# Pass from forward explainer with attribution=True
|
242 |
+
(
|
243 |
+
logits,
|
244 |
+
logits_orig,
|
245 |
+
gates,
|
246 |
+
expected_L0,
|
247 |
+
gates_full,
|
248 |
+
expected_L0_full,
|
249 |
+
layer_drop,
|
250 |
+
layer_pred,
|
251 |
+
) = self.forward_explainer(x, attribution=True)
|
252 |
+
|
253 |
+
# Calculate KL-divergence
|
254 |
+
kl_div = torch.distributions.kl_divergence(
|
255 |
+
torch.distributions.Categorical(logits=logits_orig),
|
256 |
+
torch.distributions.Categorical(logits=logits),
|
257 |
+
)
|
258 |
+
|
259 |
+
# Get predicted class
|
260 |
+
pred_class = logits.argmax(-1)
|
261 |
+
|
262 |
+
# Calculate mask
|
263 |
+
if aggregated_mask:
|
264 |
+
mask = expected_L0_full[:, :, idx].exp()
|
265 |
+
else:
|
266 |
+
mask = gates_full[:, :, idx]
|
267 |
+
|
268 |
+
mask = mask[:, 1:]
|
269 |
+
|
270 |
+
C, H, W = x.shape[1:] # channels, height, width
|
271 |
+
B, P = mask.shape # batch, patches
|
272 |
+
N = int(sqrt(P)) # patches per side
|
273 |
+
S = int(H / N) # patch size
|
274 |
+
|
275 |
+
# Reshape mask to match input shape
|
276 |
+
mask = mask.reshape(B, 1, N, N)
|
277 |
+
mask = F.interpolate(mask, scale_factor=S)
|
278 |
+
mask = mask.reshape(B, H, W)
|
279 |
+
|
280 |
+
return {"mask": mask, "kl_div": kl_div, "pred_class": pred_class}
|
281 |
+
|
282 |
+
def forward(self, x: Tensor) -> Tensor:
|
283 |
+
return self.model(x).logits
|
284 |
+
|
285 |
+
def training_step(self, batch: tuple[Tensor, Tensor], *args, **kwargs) -> dict:
|
286 |
+
# Unpack the batch
|
287 |
+
x, y = batch
|
288 |
+
|
289 |
+
# Pass the batch through the explainer (VisionDiffMask) model
|
290 |
+
(
|
291 |
+
logits,
|
292 |
+
logits_orig,
|
293 |
+
gates,
|
294 |
+
expected_L0,
|
295 |
+
gates_full,
|
296 |
+
expected_L0_full,
|
297 |
+
layer_drop,
|
298 |
+
layer_pred,
|
299 |
+
) = self.forward_explainer(x)
|
300 |
+
|
301 |
+
# Calculate the KL-divergence loss term
|
302 |
+
loss_c = (
|
303 |
+
torch.distributions.kl_divergence(
|
304 |
+
torch.distributions.Categorical(logits=logits_orig),
|
305 |
+
torch.distributions.Categorical(logits=logits),
|
306 |
+
)
|
307 |
+
- self.hparams.eps
|
308 |
+
)
|
309 |
+
|
310 |
+
# Calculate the L0 loss term
|
311 |
+
loss_g = expected_L0.mean(-1)
|
312 |
+
|
313 |
+
# Calculate the full loss term
|
314 |
+
loss = self.alpha[layer_pred] * loss_c + loss_g
|
315 |
+
|
316 |
+
# Calculate the accuracy
|
317 |
+
acc, _, _, _ = accuracy_precision_recall_f1(
|
318 |
+
logits.argmax(-1), logits_orig.argmax(-1), average=True
|
319 |
+
)
|
320 |
+
|
321 |
+
# Calculate the average L0 loss
|
322 |
+
l0 = expected_L0.exp().mean(-1)
|
323 |
+
|
324 |
+
outputs_dict = {
|
325 |
+
"loss_c": loss_c.mean(-1),
|
326 |
+
"loss_g": loss_g.mean(-1),
|
327 |
+
"alpha": self.alpha[layer_pred].mean(-1),
|
328 |
+
"acc": acc,
|
329 |
+
"l0": l0.mean(-1),
|
330 |
+
"layer_pred": layer_pred,
|
331 |
+
"r_acc": self.running_acc[layer_pred],
|
332 |
+
"r_l0": self.running_l0[layer_pred],
|
333 |
+
"r_steps": self.running_steps[layer_pred],
|
334 |
+
"debug_loss": loss.mean(-1),
|
335 |
+
}
|
336 |
+
|
337 |
+
outputs_dict = {
|
338 |
+
"loss": loss.mean(-1),
|
339 |
+
**outputs_dict,
|
340 |
+
"log": outputs_dict,
|
341 |
+
"progress_bar": outputs_dict,
|
342 |
+
}
|
343 |
+
|
344 |
+
self.log(
|
345 |
+
"loss", outputs_dict["loss"], on_step=True, on_epoch=True, prog_bar=True
|
346 |
+
)
|
347 |
+
self.log(
|
348 |
+
"loss_c", outputs_dict["loss_c"], on_step=True, on_epoch=True, prog_bar=True
|
349 |
+
)
|
350 |
+
self.log(
|
351 |
+
"loss_g", outputs_dict["loss_g"], on_step=True, on_epoch=True, prog_bar=True
|
352 |
+
)
|
353 |
+
self.log("acc", outputs_dict["acc"], on_step=True, on_epoch=True, prog_bar=True)
|
354 |
+
self.log("l0", outputs_dict["l0"], on_step=True, on_epoch=True, prog_bar=True)
|
355 |
+
self.log(
|
356 |
+
"alpha", outputs_dict["alpha"], on_step=True, on_epoch=True, prog_bar=True
|
357 |
+
)
|
358 |
+
|
359 |
+
outputs_dict = {
|
360 |
+
"{}{}".format("" if self.training else "val_", k): v
|
361 |
+
for k, v in outputs_dict.items()
|
362 |
+
}
|
363 |
+
|
364 |
+
if self.training:
|
365 |
+
self.running_acc[layer_pred] = (
|
366 |
+
self.running_acc[layer_pred] * 0.9 + acc * 0.1
|
367 |
+
)
|
368 |
+
self.running_l0[layer_pred] = (
|
369 |
+
self.running_l0[layer_pred] * 0.9 + l0.mean(-1) * 0.1
|
370 |
+
)
|
371 |
+
self.running_steps[layer_pred] += 1
|
372 |
+
|
373 |
+
return outputs_dict
|
374 |
+
|
375 |
+
def validation_epoch_end(self, outputs: list[dict]):
|
376 |
+
outputs_dict = {
|
377 |
+
k: [e[k] for e in outputs if k in e]
|
378 |
+
for k in ("val_loss_c", "val_loss_g", "val_acc", "val_l0")
|
379 |
+
}
|
380 |
+
|
381 |
+
outputs_dict = {k: sum(v) / len(v) for k, v in outputs_dict.items()}
|
382 |
+
|
383 |
+
outputs_dict["val_loss_c"] += self.hparams.eps
|
384 |
+
|
385 |
+
outputs_dict = {
|
386 |
+
"val_loss": outputs_dict["val_l0"]
|
387 |
+
if outputs_dict["val_loss_c"] <= self.hparams.eps_valid
|
388 |
+
and outputs_dict["val_acc"] >= self.hparams.acc_valid
|
389 |
+
else torch.full_like(outputs_dict["val_l0"], float("inf")),
|
390 |
+
**outputs_dict,
|
391 |
+
"log": outputs_dict,
|
392 |
+
}
|
393 |
+
|
394 |
+
return outputs_dict
|
395 |
+
|
396 |
+
def configure_optimizers(self) -> tuple[list[Optimizer], list[_LRScheduler]]:
|
397 |
+
optimizers = [
|
398 |
+
LookaheadAdam(
|
399 |
+
params=[
|
400 |
+
{
|
401 |
+
"params": self.gate.g_hat.parameters(),
|
402 |
+
"lr": self.hparams.lr,
|
403 |
+
},
|
404 |
+
{
|
405 |
+
"params": self.gate.placeholder.parameters()
|
406 |
+
if isinstance(self.gate.placeholder, torch.nn.ParameterList)
|
407 |
+
else [self.gate.placeholder],
|
408 |
+
"lr": self.hparams.lr_placeholder,
|
409 |
+
},
|
410 |
+
],
|
411 |
+
# centered=True, # this is for LookaheadRMSprop
|
412 |
+
),
|
413 |
+
LookaheadAdam(
|
414 |
+
params=[self.alpha]
|
415 |
+
if isinstance(self.alpha, torch.Tensor)
|
416 |
+
else self.alpha.parameters(),
|
417 |
+
lr=self.hparams.lr_alpha,
|
418 |
+
),
|
419 |
+
]
|
420 |
+
|
421 |
+
schedulers = [
|
422 |
+
{
|
423 |
+
"scheduler": get_constant_schedule_with_warmup(optimizers[0], 12 * 100),
|
424 |
+
"interval": "step",
|
425 |
+
},
|
426 |
+
get_constant_schedule(optimizers[1]),
|
427 |
+
]
|
428 |
+
return optimizers, schedulers
|
429 |
+
|
430 |
+
def optimizer_step(
|
431 |
+
self,
|
432 |
+
epoch: int,
|
433 |
+
batch_idx: int,
|
434 |
+
optimizer: Union[Optimizer, LightningOptimizer],
|
435 |
+
optimizer_idx: int = 0,
|
436 |
+
optimizer_closure: Optional[callable] = None,
|
437 |
+
on_tpu: bool = False,
|
438 |
+
using_native_amp: bool = False,
|
439 |
+
using_lbfgs: bool = False,
|
440 |
+
):
|
441 |
+
# Optimizer 0: Minimize loss w.r.t. DiffMask's parameters
|
442 |
+
if optimizer_idx == 0:
|
443 |
+
# Gradient ascent on the model's parameters
|
444 |
+
optimizer.step(closure=optimizer_closure)
|
445 |
+
optimizer.zero_grad()
|
446 |
+
for g in optimizer.param_groups:
|
447 |
+
for p in g["params"]:
|
448 |
+
p.grad = None
|
449 |
+
|
450 |
+
# Optimizer 1: Maximize loss w.r.t. the Langrangian
|
451 |
+
elif optimizer_idx == 1:
|
452 |
+
# Reverse the sign of the Langrangian's gradients
|
453 |
+
for i in range(len(self.alpha)):
|
454 |
+
if self.alpha[i].grad:
|
455 |
+
self.alpha[i].grad *= -1
|
456 |
+
|
457 |
+
# Gradient ascent on the Langrangian
|
458 |
+
optimizer.step(closure=optimizer_closure)
|
459 |
+
optimizer.zero_grad()
|
460 |
+
for g in optimizer.param_groups:
|
461 |
+
for p in g["params"]:
|
462 |
+
p.grad = None
|
463 |
+
|
464 |
+
# Clip the Lagrangian's values
|
465 |
+
for i in range(len(self.alpha)):
|
466 |
+
self.alpha[i].data = torch.where(
|
467 |
+
self.alpha[i].data < 0,
|
468 |
+
torch.full_like(self.alpha[i].data, 0),
|
469 |
+
self.alpha[i].data,
|
470 |
+
)
|
471 |
+
self.alpha[i].data = torch.where(
|
472 |
+
self.alpha[i].data > 200,
|
473 |
+
torch.full_like(self.alpha[i].data, 200),
|
474 |
+
self.alpha[i].data,
|
475 |
+
)
|
476 |
+
|
477 |
+
def on_save_checkpoint(self, ckpt: dict):
|
478 |
+
# Remove VIT from checkpoint as we can load it dynamically
|
479 |
+
keys = list(ckpt["state_dict"].keys())
|
480 |
+
for key in keys:
|
481 |
+
if key.startswith("model."):
|
482 |
+
del ckpt["state_dict"][key]
|
code/models/utils.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datamodules.utils import get_configs
|
2 |
+
from transformers import (
|
3 |
+
ConvNextConfig,
|
4 |
+
ConvNextForImageClassification,
|
5 |
+
PreTrainedModel,
|
6 |
+
ViTConfig,
|
7 |
+
ViTForImageClassification,
|
8 |
+
)
|
9 |
+
|
10 |
+
import argparse
|
11 |
+
import torch
|
12 |
+
|
13 |
+
|
14 |
+
def set_clf_head(base: PreTrainedModel, num_classes: int):
|
15 |
+
"""Set the classification head of the model in case of an output mismatch.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
base (PreTrainedModel): the model to modify
|
19 |
+
num_classes (int): the number of classes to use for the output layer
|
20 |
+
"""
|
21 |
+
if base.classifier.out_features != num_classes:
|
22 |
+
in_features = base.classifier.in_features
|
23 |
+
base.classifier = torch.nn.Linear(in_features, num_classes)
|
24 |
+
|
25 |
+
|
26 |
+
def model_factory(
|
27 |
+
args: argparse.Namespace,
|
28 |
+
own_config: bool = False,
|
29 |
+
) -> PreTrainedModel:
|
30 |
+
"""A factory method for creating a HuggingFace model based on the command line args.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
args (Namespace): the argparse Namespace object
|
34 |
+
own_config (bool): whether to create our own model config instead of a pretrained one;
|
35 |
+
this is recommended when the model was pre-trained on another task with a different
|
36 |
+
amount of classes for its classifier head
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
a PreTrainedModel instance
|
40 |
+
"""
|
41 |
+
if args.base_model == "ViT":
|
42 |
+
# Create a new Vision Transformer
|
43 |
+
config_class = ViTConfig
|
44 |
+
base_class = ViTForImageClassification
|
45 |
+
elif args.base_model == "ConvNeXt":
|
46 |
+
# Create a new ConvNext model
|
47 |
+
config_class = ConvNextConfig
|
48 |
+
base_class = ConvNextForImageClassification
|
49 |
+
else:
|
50 |
+
raise Exception(f"Unknown base model: {args.base_model}")
|
51 |
+
|
52 |
+
# Get the model config
|
53 |
+
model_cfg_args, _ = get_configs(args)
|
54 |
+
if not own_config and args.from_pretrained:
|
55 |
+
# Create a model from a pretrained model
|
56 |
+
base = base_class.from_pretrained(args.from_pretrained)
|
57 |
+
# Set the classifier head if needed
|
58 |
+
set_clf_head(base, model_cfg_args["num_labels"])
|
59 |
+
else:
|
60 |
+
# Create a model based on the config
|
61 |
+
config = config_class(**model_cfg_args)
|
62 |
+
base = base_class(config)
|
63 |
+
|
64 |
+
return base
|
code/train_base.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import pytorch_lightning as pl
|
3 |
+
|
4 |
+
from datamodules import CIFAR10QADataModule, ImageDataModule
|
5 |
+
from datamodules.utils import datamodule_factory
|
6 |
+
from models import ImageClassificationNet
|
7 |
+
from models.utils import model_factory
|
8 |
+
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
9 |
+
from pytorch_lightning.loggers import WandbLogger
|
10 |
+
|
11 |
+
|
12 |
+
def main(args: argparse.Namespace):
|
13 |
+
# Seed
|
14 |
+
pl.seed_everything(args.seed)
|
15 |
+
|
16 |
+
# Create base model
|
17 |
+
base = model_factory(args)
|
18 |
+
|
19 |
+
# Load datamodule
|
20 |
+
dm = datamodule_factory(args)
|
21 |
+
dm.prepare_data()
|
22 |
+
dm.setup("fit")
|
23 |
+
|
24 |
+
if args.checkpoint:
|
25 |
+
# Load the model from the specified checkpoint
|
26 |
+
model = ImageClassificationNet.load_from_checkpoint(args.checkpoint, model=base)
|
27 |
+
else:
|
28 |
+
# Create a new instance of the classification model
|
29 |
+
model = ImageClassificationNet(
|
30 |
+
model=base,
|
31 |
+
num_train_steps=args.num_epochs * len(dm.train_dataloader()),
|
32 |
+
optimizer=args.optimizer,
|
33 |
+
weight_decay=args.weight_decay,
|
34 |
+
lr=args.lr,
|
35 |
+
)
|
36 |
+
|
37 |
+
# Create wandb logger
|
38 |
+
wandb_logger = WandbLogger(
|
39 |
+
name=f"{args.dataset}_training_{args.base_model} ({args.from_pretrained})",
|
40 |
+
project="Patch-DiffMask",
|
41 |
+
)
|
42 |
+
|
43 |
+
# Create checkpoint callback
|
44 |
+
ckpt_cb = ModelCheckpoint(dirpath=f"checkpoints/{wandb_logger.version}")
|
45 |
+
# Create early stopping callback
|
46 |
+
es_cb = EarlyStopping(monitor="val_acc", mode="max", patience=5)
|
47 |
+
|
48 |
+
# Create trainer
|
49 |
+
trainer = pl.Trainer(
|
50 |
+
accelerator="auto",
|
51 |
+
callbacks=[ckpt_cb, es_cb],
|
52 |
+
logger=wandb_logger,
|
53 |
+
max_epochs=args.num_epochs,
|
54 |
+
enable_progress_bar=args.enable_progress_bar,
|
55 |
+
)
|
56 |
+
|
57 |
+
trainer_args = {}
|
58 |
+
if args.checkpoint:
|
59 |
+
# Resume trainer from checkpoint
|
60 |
+
trainer_args["ckpt_path"] = args.checkpoint
|
61 |
+
|
62 |
+
# Train the model
|
63 |
+
trainer.fit(model, dm, **trainer_args)
|
64 |
+
|
65 |
+
|
66 |
+
if __name__ == "__main__":
|
67 |
+
parser = argparse.ArgumentParser()
|
68 |
+
|
69 |
+
parser.add_argument(
|
70 |
+
"--checkpoint",
|
71 |
+
type=str,
|
72 |
+
help="Checkpoint to resume the training from.",
|
73 |
+
)
|
74 |
+
|
75 |
+
# Trainer
|
76 |
+
parser.add_argument(
|
77 |
+
"--enable_progress_bar",
|
78 |
+
action="store_true",
|
79 |
+
help="Whether to show progress bar during training. NOT recommended when logging to files.",
|
80 |
+
)
|
81 |
+
parser.add_argument(
|
82 |
+
"--num_epochs",
|
83 |
+
type=int,
|
84 |
+
default=5,
|
85 |
+
help="Number of epochs to train.",
|
86 |
+
)
|
87 |
+
parser.add_argument(
|
88 |
+
"--seed",
|
89 |
+
type=int,
|
90 |
+
default=123,
|
91 |
+
help="Random seed for reproducibility.",
|
92 |
+
)
|
93 |
+
|
94 |
+
# Base (classification) model
|
95 |
+
ImageClassificationNet.add_model_specific_args(parser)
|
96 |
+
parser.add_argument(
|
97 |
+
"--base_model",
|
98 |
+
type=str,
|
99 |
+
default="ViT",
|
100 |
+
choices=["ViT", "ConvNeXt"],
|
101 |
+
help="Base model architecture to train.",
|
102 |
+
)
|
103 |
+
parser.add_argument(
|
104 |
+
"--from_pretrained",
|
105 |
+
type=str,
|
106 |
+
# default="tanlq/vit-base-patch16-224-in21k-finetuned-cifar10",
|
107 |
+
help="The name of the pretrained HF model to fine-tune from.",
|
108 |
+
)
|
109 |
+
|
110 |
+
# Datamodule
|
111 |
+
ImageDataModule.add_model_specific_args(parser)
|
112 |
+
CIFAR10QADataModule.add_model_specific_args(parser)
|
113 |
+
parser.add_argument(
|
114 |
+
"--dataset",
|
115 |
+
type=str,
|
116 |
+
default="toy",
|
117 |
+
choices=["MNIST", "CIFAR10", "CIFAR10_QA", "toy"],
|
118 |
+
help="The dataset to use.",
|
119 |
+
)
|
120 |
+
|
121 |
+
args = parser.parse_args()
|
122 |
+
|
123 |
+
main(args)
|
code/utils/__init__.py
ADDED
File without changes
|
code/utils/distributions.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
File copied from
|
3 |
+
https://github.com/nicola-decao/diffmask/blob/master/diffmask/models/distributions.py
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.distributions as distr
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from torch import Tensor
|
11 |
+
|
12 |
+
|
13 |
+
class BinaryConcrete(distr.relaxed_bernoulli.RelaxedBernoulli):
|
14 |
+
def __init__(self, temperature: Tensor, logits: Tensor):
|
15 |
+
super().__init__(temperature=temperature, logits=logits)
|
16 |
+
self.device = self.temperature.device
|
17 |
+
|
18 |
+
def cdf(self, value: Tensor) -> Tensor:
|
19 |
+
return torch.sigmoid(
|
20 |
+
(torch.log(value) - torch.log(1.0 - value)) * self.temperature - self.logits
|
21 |
+
)
|
22 |
+
|
23 |
+
def log_prob(self, value: Tensor) -> Tensor:
|
24 |
+
return torch.where(
|
25 |
+
(value > 0) & (value < 1),
|
26 |
+
super().log_prob(value),
|
27 |
+
torch.full_like(value, -float("inf")),
|
28 |
+
)
|
29 |
+
|
30 |
+
def log_expected_L0(self, value: Tensor) -> Tensor:
|
31 |
+
return -F.softplus(
|
32 |
+
(torch.log(value) - torch.log(1 - value)) * self.temperature - self.logits
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
class Streched(distr.TransformedDistribution):
|
37 |
+
def __init__(self, base_dist, l: float = -0.1, r: float = 1.1):
|
38 |
+
super().__init__(base_dist, distr.AffineTransform(loc=l, scale=r - l))
|
39 |
+
|
40 |
+
def log_expected_L0(self) -> Tensor:
|
41 |
+
value = torch.tensor(0.0, device=self.base_dist.device)
|
42 |
+
for transform in self.transforms[::-1]:
|
43 |
+
value = transform.inv(value)
|
44 |
+
if self._validate_args:
|
45 |
+
self.base_dist._validate_sample(value)
|
46 |
+
value = self.base_dist.log_expected_L0(value)
|
47 |
+
value = self._monotonize_cdf(value)
|
48 |
+
return value
|
49 |
+
|
50 |
+
def expected_L0(self) -> Tensor:
|
51 |
+
return self.log_expected_L0().exp()
|
52 |
+
|
53 |
+
|
54 |
+
class RectifiedStreched(Streched):
|
55 |
+
def __init__(self, *args, **kwargs):
|
56 |
+
super().__init__(*args, **kwargs)
|
57 |
+
|
58 |
+
@torch.no_grad()
|
59 |
+
def sample(self, sample_shape: torch.Size = torch.Size([])) -> Tensor:
|
60 |
+
return self.rsample(sample_shape)
|
61 |
+
|
62 |
+
def rsample(self, sample_shape: torch.Size = torch.Size([])) -> Tensor:
|
63 |
+
x = super().rsample(sample_shape)
|
64 |
+
return x.clamp(0, 1)
|
code/utils/getters_setters.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import Tensor
|
2 |
+
from torch.nn import Module
|
3 |
+
from torch.utils.hooks import RemovableHandle
|
4 |
+
from transformers import ViTForImageClassification
|
5 |
+
from typing import Optional, Union
|
6 |
+
|
7 |
+
|
8 |
+
def _add_hooks(
|
9 |
+
model: ViTForImageClassification, get_hook: callable
|
10 |
+
) -> list[RemovableHandle]:
|
11 |
+
"""Adds a list of hooks to the model according to the get_hook function provided.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
model (ViTForImageClassification): the ViT instance to add hooks to
|
15 |
+
get_hook (callable): a function that takes an index and returns a hook
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
a list of RemovableHandle instances
|
19 |
+
"""
|
20 |
+
return (
|
21 |
+
[model.vit.embeddings.patch_embeddings.register_forward_hook(get_hook(0))]
|
22 |
+
+ [
|
23 |
+
layer.register_forward_pre_hook(get_hook(i + 1))
|
24 |
+
for i, layer in enumerate(model.vit.encoder.layer)
|
25 |
+
]
|
26 |
+
+ [
|
27 |
+
model.vit.encoder.layer[-1].register_forward_hook(
|
28 |
+
get_hook(len(model.vit.encoder.layer) + 1)
|
29 |
+
)
|
30 |
+
]
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
def vit_getter(
|
35 |
+
model: ViTForImageClassification, x: Tensor
|
36 |
+
) -> tuple[Tensor, list[Tensor]]:
|
37 |
+
"""A function that returns the logits and hidden states of the model.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
model (ViTForImageClassification): the ViT instance to use for the forward pass
|
41 |
+
x (Tensor): the input to the model
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
a tuple of the model's logits and hidden states
|
45 |
+
"""
|
46 |
+
hidden_states_ = []
|
47 |
+
|
48 |
+
def get_hook(i: int) -> callable:
|
49 |
+
def hook(_: Module, inputs: tuple, outputs: Optional[tuple] = None):
|
50 |
+
if i == 0:
|
51 |
+
hidden_states_.append(outputs)
|
52 |
+
elif 1 <= i <= len(model.vit.encoder.layer):
|
53 |
+
hidden_states_.append(inputs[0])
|
54 |
+
elif i == len(model.vit.encoder.layer) + 1:
|
55 |
+
hidden_states_.append(outputs[0])
|
56 |
+
|
57 |
+
return hook
|
58 |
+
|
59 |
+
handles = _add_hooks(model, get_hook)
|
60 |
+
try:
|
61 |
+
logits = model(x).logits
|
62 |
+
finally:
|
63 |
+
for handle in handles:
|
64 |
+
handle.remove()
|
65 |
+
|
66 |
+
return logits, hidden_states_
|
67 |
+
|
68 |
+
|
69 |
+
def vit_setter(
|
70 |
+
model: ViTForImageClassification, x: Tensor, hidden_states: list[Optional[Tensor]]
|
71 |
+
) -> tuple[Tensor, list[Tensor]]:
|
72 |
+
"""A function that sets some of the model's hidden states and returns its (new) logits
|
73 |
+
and hidden states after another forward pass.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
model (ViTForImageClassification): the ViT instance to use for the forward pass
|
77 |
+
x (Tensor): the input to the model
|
78 |
+
hidden_states (list[Optional[Tensor]]): a list, with each element corresponding to
|
79 |
+
a hidden state to set or None to calculate anew for that index
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
a tuple of the model's logits and (new) hidden states
|
83 |
+
"""
|
84 |
+
hidden_states_ = []
|
85 |
+
|
86 |
+
def get_hook(i: int) -> callable:
|
87 |
+
def hook(
|
88 |
+
_: Module, inputs: tuple, outputs: Optional[tuple] = None
|
89 |
+
) -> Optional[Union[tuple, Tensor]]:
|
90 |
+
if i == 0:
|
91 |
+
if hidden_states[i] is not None:
|
92 |
+
# print(hidden_states[i].shape)
|
93 |
+
hidden_states_.append(hidden_states[i][:, 1:])
|
94 |
+
return hidden_states_[-1]
|
95 |
+
else:
|
96 |
+
hidden_states_.append(outputs)
|
97 |
+
|
98 |
+
elif 1 <= i <= len(model.vit.encoder.layer):
|
99 |
+
if hidden_states[i] is not None:
|
100 |
+
hidden_states_.append(hidden_states[i])
|
101 |
+
return (hidden_states[i],) + inputs[1:]
|
102 |
+
else:
|
103 |
+
hidden_states_.append(inputs[0])
|
104 |
+
|
105 |
+
elif i == len(model.vit.encoder.layer) + 1:
|
106 |
+
if hidden_states[i] is not None:
|
107 |
+
hidden_states_.append(hidden_states[i])
|
108 |
+
return (hidden_states[i],) + outputs[1:]
|
109 |
+
else:
|
110 |
+
hidden_states_.append(outputs[0])
|
111 |
+
|
112 |
+
return hook
|
113 |
+
|
114 |
+
handles = _add_hooks(model, get_hook)
|
115 |
+
|
116 |
+
try:
|
117 |
+
logits = model(x).logits
|
118 |
+
finally:
|
119 |
+
for handle in handles:
|
120 |
+
handle.remove()
|
121 |
+
|
122 |
+
return logits, hidden_states_
|
code/utils/metrics.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
File copied from
|
3 |
+
https://github.com/nicola-decao/diffmask/blob/master/diffmask/utils/util.py
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from torch import Tensor
|
9 |
+
|
10 |
+
|
11 |
+
def accuracy_precision_recall_f1(
|
12 |
+
y_pred: Tensor, y_true: Tensor, average: bool = True
|
13 |
+
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
|
14 |
+
"""Calculates the accuracy, precision, recall and f1 score given the predicted and true labels.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
y_pred (Tensor): predicted labels
|
18 |
+
y_true (Tensor): true labels
|
19 |
+
average (bool): whether to average the scores or not
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
a tuple of the accuracy, precision, recall and f1 score
|
23 |
+
"""
|
24 |
+
M = confusion_matrix(y_pred, y_true)
|
25 |
+
|
26 |
+
tp = M.diagonal(dim1=-2, dim2=-1).float()
|
27 |
+
|
28 |
+
precision_den = M.sum(-2)
|
29 |
+
precision = torch.where(
|
30 |
+
precision_den == 0, torch.zeros_like(tp), tp / precision_den
|
31 |
+
)
|
32 |
+
|
33 |
+
recall_den = M.sum(-1)
|
34 |
+
recall = torch.where(recall_den == 0, torch.ones_like(tp), tp / recall_den)
|
35 |
+
|
36 |
+
f1_den = precision + recall
|
37 |
+
f1 = torch.where(
|
38 |
+
f1_den == 0, torch.zeros_like(tp), 2 * (precision * recall) / f1_den
|
39 |
+
)
|
40 |
+
|
41 |
+
# noinspection PyTypeChecker
|
42 |
+
return ((y_pred == y_true).float().mean(-1),) + (
|
43 |
+
tuple(e.mean(-1) for e in (precision, recall, f1))
|
44 |
+
if average
|
45 |
+
else (precision, recall, f1)
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
def confusion_matrix(y_pred: Tensor, y_true: Tensor) -> Tensor:
|
50 |
+
"""Creates a confusion matrix given the predicted and true labels."""
|
51 |
+
device = y_pred.device
|
52 |
+
labels = max(y_pred.max().item() + 1, y_true.max().item() + 1)
|
53 |
+
|
54 |
+
return (
|
55 |
+
(
|
56 |
+
torch.stack((y_true, y_pred), -1).unsqueeze(-2).unsqueeze(-2)
|
57 |
+
== torch.stack(
|
58 |
+
(
|
59 |
+
torch.arange(labels, device=device).unsqueeze(-1).repeat(1, labels),
|
60 |
+
torch.arange(labels, device=device).unsqueeze(-2).repeat(labels, 1),
|
61 |
+
),
|
62 |
+
-1,
|
63 |
+
)
|
64 |
+
)
|
65 |
+
.all(-1)
|
66 |
+
.sum(-3)
|
67 |
+
)
|
code/utils/optimizer.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
File copied from
|
3 |
+
https://github.com/nicola-decao/diffmask/blob/master/diffmask/optim/lookahead.py
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.optim as optim
|
8 |
+
|
9 |
+
from collections import defaultdict
|
10 |
+
from torch import Tensor
|
11 |
+
from torch.optim.optimizer import Optimizer
|
12 |
+
from typing import Iterable, Optional, Union
|
13 |
+
|
14 |
+
|
15 |
+
_params_type = Union[Iterable[Tensor], Iterable[dict]]
|
16 |
+
|
17 |
+
|
18 |
+
class Lookahead(Optimizer):
|
19 |
+
"""Lookahead optimizer: https://arxiv.org/abs/1907.08610"""
|
20 |
+
|
21 |
+
# noinspection PyMissingConstructor
|
22 |
+
def __init__(self, base_optimizer: Optimizer, alpha: float = 0.5, k: int = 6):
|
23 |
+
if not 0.0 <= alpha <= 1.0:
|
24 |
+
raise ValueError(f"Invalid slow update rate: {alpha}")
|
25 |
+
if not 1 <= k:
|
26 |
+
raise ValueError(f"Invalid lookahead steps: {k}")
|
27 |
+
defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
|
28 |
+
self.base_optimizer = base_optimizer
|
29 |
+
self.param_groups = self.base_optimizer.param_groups
|
30 |
+
self.defaults = base_optimizer.defaults
|
31 |
+
self.defaults.update(defaults)
|
32 |
+
self.state = defaultdict(dict)
|
33 |
+
# manually add our defaults to the param groups
|
34 |
+
for name, default in defaults.items():
|
35 |
+
for group in self.param_groups:
|
36 |
+
group.setdefault(name, default)
|
37 |
+
|
38 |
+
def update_slow(self, group: dict):
|
39 |
+
for fast_p in group["params"]:
|
40 |
+
if fast_p.grad is None:
|
41 |
+
continue
|
42 |
+
param_state = self.state[fast_p]
|
43 |
+
if "slow_buffer" not in param_state:
|
44 |
+
param_state["slow_buffer"] = torch.empty_like(fast_p.data)
|
45 |
+
param_state["slow_buffer"].copy_(fast_p.data)
|
46 |
+
slow = param_state["slow_buffer"]
|
47 |
+
slow.add_(fast_p.data - slow, alpha=group["lookahead_alpha"])
|
48 |
+
fast_p.data.copy_(slow)
|
49 |
+
|
50 |
+
def sync_lookahead(self):
|
51 |
+
for group in self.param_groups:
|
52 |
+
self.update_slow(group)
|
53 |
+
|
54 |
+
def step(self, closure: Optional[callable] = None) -> Optional[float]:
|
55 |
+
# print(self.k)
|
56 |
+
# assert id(self.param_groups) == id(self.base_optimizer.param_groups)
|
57 |
+
loss = self.base_optimizer.step(closure)
|
58 |
+
for group in self.param_groups:
|
59 |
+
group["lookahead_step"] += 1
|
60 |
+
if group["lookahead_step"] % group["lookahead_k"] == 0:
|
61 |
+
self.update_slow(group)
|
62 |
+
return loss
|
63 |
+
|
64 |
+
def state_dict(self) -> dict:
|
65 |
+
fast_state_dict = self.base_optimizer.state_dict()
|
66 |
+
slow_state = {
|
67 |
+
(id(k) if isinstance(k, torch.Tensor) else k): v
|
68 |
+
for k, v in self.state.items()
|
69 |
+
}
|
70 |
+
fast_state = fast_state_dict["state"]
|
71 |
+
param_groups = fast_state_dict["param_groups"]
|
72 |
+
return {
|
73 |
+
"state": fast_state,
|
74 |
+
"slow_state": slow_state,
|
75 |
+
"param_groups": param_groups,
|
76 |
+
}
|
77 |
+
|
78 |
+
def load_state_dict(self, state_dict: dict):
|
79 |
+
fast_state_dict = {
|
80 |
+
"state": state_dict["state"],
|
81 |
+
"param_groups": state_dict["param_groups"],
|
82 |
+
}
|
83 |
+
self.base_optimizer.load_state_dict(fast_state_dict)
|
84 |
+
|
85 |
+
# We want to restore the slow state, but share param_groups reference
|
86 |
+
# with base_optimizer. This is a bit redundant but least code
|
87 |
+
slow_state_new = False
|
88 |
+
if "slow_state" not in state_dict:
|
89 |
+
print("Loading state_dict from optimizer without Lookahead applied.")
|
90 |
+
state_dict["slow_state"] = defaultdict(dict)
|
91 |
+
slow_state_new = True
|
92 |
+
slow_state_dict = {
|
93 |
+
"state": state_dict["slow_state"],
|
94 |
+
"param_groups": state_dict[
|
95 |
+
"param_groups"
|
96 |
+
], # this is pointless but saves code
|
97 |
+
}
|
98 |
+
super(Lookahead, self).load_state_dict(slow_state_dict)
|
99 |
+
self.param_groups = (
|
100 |
+
self.base_optimizer.param_groups
|
101 |
+
) # make both ref same container
|
102 |
+
if slow_state_new:
|
103 |
+
# reapply defaults to catch missing lookahead specific ones
|
104 |
+
for name, default in self.defaults.items():
|
105 |
+
for group in self.param_groups:
|
106 |
+
group.setdefault(name, default)
|
107 |
+
|
108 |
+
|
109 |
+
def LookaheadAdam(
|
110 |
+
params: _params_type,
|
111 |
+
lr: float = 1e-3,
|
112 |
+
betas: tuple[float, float] = (0.9, 0.999),
|
113 |
+
eps: float = 1e-08,
|
114 |
+
weight_decay: float = 0,
|
115 |
+
amsgrad: bool = False,
|
116 |
+
lalpha: float = 0.5,
|
117 |
+
k: int = 6,
|
118 |
+
):
|
119 |
+
return Lookahead(
|
120 |
+
torch.optim.Adam(params, lr, betas, eps, weight_decay, amsgrad), lalpha, k
|
121 |
+
)
|
122 |
+
|
123 |
+
|
124 |
+
def LookaheadRAdam(
|
125 |
+
params: _params_type,
|
126 |
+
lr: float = 1e-3,
|
127 |
+
betas: tuple[float, float] = (0.9, 0.999),
|
128 |
+
eps: float = 1e-8,
|
129 |
+
weight_decay: float = 0,
|
130 |
+
lalpha: float = 0.5,
|
131 |
+
k: int = 6,
|
132 |
+
):
|
133 |
+
return Lookahead(optim.RAdam(params, lr, betas, eps, weight_decay), lalpha, k)
|
134 |
+
|
135 |
+
|
136 |
+
def LookaheadRMSprop(
|
137 |
+
params: _params_type,
|
138 |
+
lr: float = 1e-2,
|
139 |
+
alpha: float = 0.99,
|
140 |
+
eps: float = 1e-08,
|
141 |
+
weight_decay: float = 0,
|
142 |
+
momentum: float = 0,
|
143 |
+
centered: bool = False,
|
144 |
+
lalpha: float = 0.5,
|
145 |
+
k: int = 6,
|
146 |
+
):
|
147 |
+
return Lookahead(
|
148 |
+
torch.optim.RMSprop(params, lr, alpha, eps, weight_decay, momentum, centered),
|
149 |
+
lalpha,
|
150 |
+
k,
|
151 |
+
)
|
code/utils/plot.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from pytorch_lightning import LightningModule
|
6 |
+
from pytorch_lightning.callbacks import Callback
|
7 |
+
from pytorch_lightning.loggers import WandbLogger
|
8 |
+
from pytorch_lightning.trainer import Trainer
|
9 |
+
from torch import Tensor
|
10 |
+
|
11 |
+
|
12 |
+
@torch.no_grad()
|
13 |
+
def unnormalize(
|
14 |
+
images: Tensor,
|
15 |
+
mean: tuple[float] = (0.5, 0.5, 0.5),
|
16 |
+
std: tuple[float] = (0.5, 0.5, 0.5),
|
17 |
+
) -> Tensor:
|
18 |
+
"""Reverts the normalization transformation applied before ViT.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
images (Tensor): a batch of images
|
22 |
+
mean (tuple[int]): the means used for normalization - defaults to (0.5, 0.5, 0.5)
|
23 |
+
std (tuple[int]): the stds used for normalization - defaults to (0.5, 0.5, 0.5)
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
the un-normalized batch of images
|
27 |
+
"""
|
28 |
+
unnormalized_images = images.clone()
|
29 |
+
for i, (m, s) in enumerate(zip(mean, std)):
|
30 |
+
unnormalized_images[:, i, :, :].mul_(s).add_(m)
|
31 |
+
|
32 |
+
return unnormalized_images
|
33 |
+
|
34 |
+
|
35 |
+
@torch.no_grad()
|
36 |
+
def smoothen(mask: Tensor, patch_size: int = 16) -> Tensor:
|
37 |
+
"""Smoothens a mask by downsampling it and re-upsampling it
|
38 |
+
with bi-linear interpolation.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
mask (Tensor): a 2D float torch tensor with values in [0, 1]
|
42 |
+
patch_size (int): the patch size in pixels
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
a smoothened mask at the pixel level
|
46 |
+
"""
|
47 |
+
device = mask.device
|
48 |
+
(h, w) = mask.shape
|
49 |
+
mask = cv2.resize(
|
50 |
+
mask.cpu().numpy(),
|
51 |
+
(h // patch_size, w // patch_size),
|
52 |
+
interpolation=cv2.INTER_NEAREST,
|
53 |
+
)
|
54 |
+
mask = cv2.resize(mask, (h, w), interpolation=cv2.INTER_LINEAR)
|
55 |
+
return torch.tensor(mask).to(device)
|
56 |
+
|
57 |
+
|
58 |
+
@torch.no_grad()
|
59 |
+
def draw_mask_on_image(image: Tensor, mask: Tensor) -> Tensor:
|
60 |
+
"""Overlays a dimming mask on the image.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
image (Tensor): a float torch tensor with values in [0, 1]
|
64 |
+
mask (Tensor): a float torch tensor with values in [0, 1]
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
the image with parts of it dimmed according to the mask
|
68 |
+
"""
|
69 |
+
masked_image = image * mask
|
70 |
+
|
71 |
+
return masked_image
|
72 |
+
|
73 |
+
|
74 |
+
@torch.no_grad()
|
75 |
+
def draw_heatmap_on_image(
|
76 |
+
image: Tensor,
|
77 |
+
mask: Tensor,
|
78 |
+
colormap: int = cv2.COLORMAP_JET,
|
79 |
+
) -> Tensor:
|
80 |
+
"""Overlays a heatmap on the image.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
image (Tensor): a float torch tensor with values in [0, 1]
|
84 |
+
mask (Tensor): a float torch tensor with values in [0, 1]
|
85 |
+
colormap (int): the OpenCV colormap to be used
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
the image with the heatmap overlaid
|
89 |
+
"""
|
90 |
+
# Save the device of the image
|
91 |
+
original_device = image.device
|
92 |
+
|
93 |
+
# Convert image & mask to numpy
|
94 |
+
image = image.permute(1, 2, 0).cpu().numpy()
|
95 |
+
mask = mask.cpu().numpy()
|
96 |
+
|
97 |
+
# Create heatmap
|
98 |
+
heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
|
99 |
+
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
100 |
+
heatmap = np.float32(heatmap) / 255
|
101 |
+
|
102 |
+
# Overlay heatmap on image
|
103 |
+
masked_image = image + heatmap
|
104 |
+
masked_image = masked_image / np.max(masked_image)
|
105 |
+
|
106 |
+
return torch.tensor(masked_image).permute(2, 0, 1).to(original_device)
|
107 |
+
|
108 |
+
|
109 |
+
def _prepare_samples(images: Tensor, masks: Tensor) -> tuple[Tensor, list[float]]:
|
110 |
+
"""Prepares the samples for the masking/heatmap visualization.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
images (Tensor): a float torch tensor with values in [0, 1]
|
114 |
+
masks (Tensor): a float torch tensor with values in [0, 1]
|
115 |
+
|
116 |
+
Returns
|
117 |
+
a tuple of image triplets (img, masked, heatmap) and their
|
118 |
+
corresponding masking percentages
|
119 |
+
"""
|
120 |
+
num_channels = images[0].shape[0]
|
121 |
+
|
122 |
+
# Smoothen masks
|
123 |
+
masks = [smoothen(m) for m in masks]
|
124 |
+
|
125 |
+
# Un-normalize images
|
126 |
+
if num_channels == 1:
|
127 |
+
images = [
|
128 |
+
torch.repeat_interleave(img, 3, 0)
|
129 |
+
for img in unnormalize(images, mean=(0.5,), std=(0.5,))
|
130 |
+
]
|
131 |
+
else:
|
132 |
+
images = [img for img in unnormalize(images)]
|
133 |
+
|
134 |
+
# Draw mask on sample images
|
135 |
+
images_with_mask = [
|
136 |
+
draw_mask_on_image(image, mask) for image, mask in zip(images, masks)
|
137 |
+
]
|
138 |
+
|
139 |
+
# Draw heatmap on sample images
|
140 |
+
images_with_heatmap = [
|
141 |
+
draw_heatmap_on_image(image, mask) for image, mask in zip(images, masks)
|
142 |
+
]
|
143 |
+
|
144 |
+
# Chunk to triplets (image, masked image, heatmap)
|
145 |
+
samples = torch.cat(
|
146 |
+
[
|
147 |
+
torch.cat(images, dim=2),
|
148 |
+
torch.cat(images_with_mask, dim=2),
|
149 |
+
torch.cat(images_with_heatmap, dim=2),
|
150 |
+
],
|
151 |
+
dim=1,
|
152 |
+
).chunk(len(images), dim=-1)
|
153 |
+
|
154 |
+
# Compute masking percentages
|
155 |
+
masked_pixels_percentages = [
|
156 |
+
100 * (1 - torch.stack(masks)[i].mean(-1).mean(-1).item())
|
157 |
+
for i in range(len(masks))
|
158 |
+
]
|
159 |
+
|
160 |
+
return samples, masked_pixels_percentages
|
161 |
+
|
162 |
+
|
163 |
+
def log_masks(images: Tensor, masks: Tensor, key: str, logger: WandbLogger):
|
164 |
+
"""Logs a set of images with their masks to WandB.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
images (Tensor): a float torch tensor with values in [0, 1]
|
168 |
+
masks (Tensor): a float torch tensor with values in [0, 1]
|
169 |
+
key (str): the key to log the images with
|
170 |
+
logger (WandbLogger): the logger to log the images to
|
171 |
+
"""
|
172 |
+
samples, masked_pixels_percentages = _prepare_samples(images, masks)
|
173 |
+
|
174 |
+
# Log with wandb
|
175 |
+
logger.log_image(
|
176 |
+
key=key,
|
177 |
+
images=list(samples),
|
178 |
+
caption=[
|
179 |
+
f"Masking: {masked_pixels_percentage:.2f}% "
|
180 |
+
for masked_pixels_percentage in masked_pixels_percentages
|
181 |
+
],
|
182 |
+
)
|
183 |
+
|
184 |
+
|
185 |
+
class DrawMaskCallback(Callback):
|
186 |
+
def __init__(
|
187 |
+
self,
|
188 |
+
samples: list[tuple[Tensor, Tensor]],
|
189 |
+
log_every_n_steps: int = 200,
|
190 |
+
key: str = "",
|
191 |
+
):
|
192 |
+
"""A callback that logs VisionDiffMask masks for the sample images to WandB.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
samples (list[tuple[Tensor, Tensor]): a list of image, label pairs
|
196 |
+
log_every_n_steps (int): the interval in steps to log the masks to WandB
|
197 |
+
key (str): the key to log the images with (allows for multiple batches)
|
198 |
+
"""
|
199 |
+
self.images = torch.stack([img for img in samples[0]])
|
200 |
+
self.labels = [label.item() for label in samples[1]]
|
201 |
+
self.log_every_n_steps = log_every_n_steps
|
202 |
+
self.key = key
|
203 |
+
|
204 |
+
def _log_masks(self, trainer: Trainer, pl_module: LightningModule):
|
205 |
+
# Predict mask
|
206 |
+
with torch.no_grad():
|
207 |
+
pl_module.eval()
|
208 |
+
outputs = pl_module.get_mask(self.images)
|
209 |
+
pl_module.train()
|
210 |
+
|
211 |
+
# Unnest outputs
|
212 |
+
masks = outputs["mask"]
|
213 |
+
kl_divs = outputs["kl_div"]
|
214 |
+
pred_classes = outputs["pred_class"].cpu()
|
215 |
+
|
216 |
+
# Prepare masked samples for logging
|
217 |
+
samples, masked_pixels_percentages = _prepare_samples(self.images, masks)
|
218 |
+
|
219 |
+
# Log with wandb
|
220 |
+
trainer.logger.log_image(
|
221 |
+
key="DiffMask " + self.key,
|
222 |
+
images=list(samples),
|
223 |
+
caption=[
|
224 |
+
f"Masking: {masked_pixels_percentage:.2f}% "
|
225 |
+
f"\n KL-divergence: {kl_div:.4f} "
|
226 |
+
f"\n Class: {pl_module.model.config.id2label[label]} "
|
227 |
+
f"\n Predicted Class: {pl_module.model.config.id2label[pred_class.item()]}"
|
228 |
+
for masked_pixels_percentage, kl_div, label, pred_class in zip(
|
229 |
+
masked_pixels_percentages, kl_divs, self.labels, pred_classes
|
230 |
+
)
|
231 |
+
],
|
232 |
+
)
|
233 |
+
|
234 |
+
def on_fit_start(self, trainer: Trainer, pl_module: LightningModule):
|
235 |
+
# Transfer sample images to correct device
|
236 |
+
self.images = self.images.to(pl_module.device)
|
237 |
+
|
238 |
+
# Log sample images
|
239 |
+
self._log_masks(trainer, pl_module)
|
240 |
+
|
241 |
+
def on_train_batch_end(
|
242 |
+
self,
|
243 |
+
trainer: Trainer,
|
244 |
+
pl_module: LightningModule,
|
245 |
+
outputs: dict,
|
246 |
+
batch: tuple[Tensor, Tensor],
|
247 |
+
batch_idx: int,
|
248 |
+
unused: int = 0,
|
249 |
+
):
|
250 |
+
# Log sample images every n steps
|
251 |
+
if batch_idx % self.log_every_n_steps == 0:
|
252 |
+
self._log_masks(trainer, pl_module)
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
opencv-python
|
3 |
+
pytorch_lightning
|
4 |
+
torch
|
5 |
+
torchvision
|
6 |
+
transformers
|
7 |
+
|