Spaces:
Runtime error
Runtime error
File size: 7,876 Bytes
d4ab5ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 |
import cv2
import numpy as np
import torch
from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.trainer import Trainer
from torch import Tensor
@torch.no_grad()
def unnormalize(
images: Tensor,
mean: tuple[float] = (0.5, 0.5, 0.5),
std: tuple[float] = (0.5, 0.5, 0.5),
) -> Tensor:
"""Reverts the normalization transformation applied before ViT.
Args:
images (Tensor): a batch of images
mean (tuple[int]): the means used for normalization - defaults to (0.5, 0.5, 0.5)
std (tuple[int]): the stds used for normalization - defaults to (0.5, 0.5, 0.5)
Returns:
the un-normalized batch of images
"""
unnormalized_images = images.clone()
for i, (m, s) in enumerate(zip(mean, std)):
unnormalized_images[:, i, :, :].mul_(s).add_(m)
return unnormalized_images
@torch.no_grad()
def smoothen(mask: Tensor, patch_size: int = 16) -> Tensor:
"""Smoothens a mask by downsampling it and re-upsampling it
with bi-linear interpolation.
Args:
mask (Tensor): a 2D float torch tensor with values in [0, 1]
patch_size (int): the patch size in pixels
Returns:
a smoothened mask at the pixel level
"""
device = mask.device
(h, w) = mask.shape
mask = cv2.resize(
mask.cpu().numpy(),
(h // patch_size, w // patch_size),
interpolation=cv2.INTER_NEAREST,
)
mask = cv2.resize(mask, (h, w), interpolation=cv2.INTER_LINEAR)
return torch.tensor(mask).to(device)
@torch.no_grad()
def draw_mask_on_image(image: Tensor, mask: Tensor) -> Tensor:
"""Overlays a dimming mask on the image.
Args:
image (Tensor): a float torch tensor with values in [0, 1]
mask (Tensor): a float torch tensor with values in [0, 1]
Returns:
the image with parts of it dimmed according to the mask
"""
masked_image = image * mask
return masked_image
@torch.no_grad()
def draw_heatmap_on_image(
image: Tensor,
mask: Tensor,
colormap: int = cv2.COLORMAP_JET,
) -> Tensor:
"""Overlays a heatmap on the image.
Args:
image (Tensor): a float torch tensor with values in [0, 1]
mask (Tensor): a float torch tensor with values in [0, 1]
colormap (int): the OpenCV colormap to be used
Returns:
the image with the heatmap overlaid
"""
# Save the device of the image
original_device = image.device
# Convert image & mask to numpy
image = image.permute(1, 2, 0).cpu().numpy()
mask = mask.cpu().numpy()
# Create heatmap
heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
heatmap = np.float32(heatmap) / 255
# Overlay heatmap on image
masked_image = image + heatmap
masked_image = masked_image / np.max(masked_image)
return torch.tensor(masked_image).permute(2, 0, 1).to(original_device)
def _prepare_samples(images: Tensor, masks: Tensor) -> tuple[Tensor, list[float]]:
"""Prepares the samples for the masking/heatmap visualization.
Args:
images (Tensor): a float torch tensor with values in [0, 1]
masks (Tensor): a float torch tensor with values in [0, 1]
Returns
a tuple of image triplets (img, masked, heatmap) and their
corresponding masking percentages
"""
num_channels = images[0].shape[0]
# Smoothen masks
masks = [smoothen(m) for m in masks]
# Un-normalize images
if num_channels == 1:
images = [
torch.repeat_interleave(img, 3, 0)
for img in unnormalize(images, mean=(0.5,), std=(0.5,))
]
else:
images = [img for img in unnormalize(images)]
# Draw mask on sample images
images_with_mask = [
draw_mask_on_image(image, mask) for image, mask in zip(images, masks)
]
# Draw heatmap on sample images
images_with_heatmap = [
draw_heatmap_on_image(image, mask) for image, mask in zip(images, masks)
]
# Chunk to triplets (image, masked image, heatmap)
samples = torch.cat(
[
torch.cat(images, dim=2),
torch.cat(images_with_mask, dim=2),
torch.cat(images_with_heatmap, dim=2),
],
dim=1,
).chunk(len(images), dim=-1)
# Compute masking percentages
masked_pixels_percentages = [
100 * (1 - torch.stack(masks)[i].mean(-1).mean(-1).item())
for i in range(len(masks))
]
return samples, masked_pixels_percentages
def log_masks(images: Tensor, masks: Tensor, key: str, logger: WandbLogger):
"""Logs a set of images with their masks to WandB.
Args:
images (Tensor): a float torch tensor with values in [0, 1]
masks (Tensor): a float torch tensor with values in [0, 1]
key (str): the key to log the images with
logger (WandbLogger): the logger to log the images to
"""
samples, masked_pixels_percentages = _prepare_samples(images, masks)
# Log with wandb
logger.log_image(
key=key,
images=list(samples),
caption=[
f"Masking: {masked_pixels_percentage:.2f}% "
for masked_pixels_percentage in masked_pixels_percentages
],
)
class DrawMaskCallback(Callback):
def __init__(
self,
samples: list[tuple[Tensor, Tensor]],
log_every_n_steps: int = 200,
key: str = "",
):
"""A callback that logs VisionDiffMask masks for the sample images to WandB.
Args:
samples (list[tuple[Tensor, Tensor]): a list of image, label pairs
log_every_n_steps (int): the interval in steps to log the masks to WandB
key (str): the key to log the images with (allows for multiple batches)
"""
self.images = torch.stack([img for img in samples[0]])
self.labels = [label.item() for label in samples[1]]
self.log_every_n_steps = log_every_n_steps
self.key = key
def _log_masks(self, trainer: Trainer, pl_module: LightningModule):
# Predict mask
with torch.no_grad():
pl_module.eval()
outputs = pl_module.get_mask(self.images)
pl_module.train()
# Unnest outputs
masks = outputs["mask"]
kl_divs = outputs["kl_div"]
pred_classes = outputs["pred_class"].cpu()
# Prepare masked samples for logging
samples, masked_pixels_percentages = _prepare_samples(self.images, masks)
# Log with wandb
trainer.logger.log_image(
key="DiffMask " + self.key,
images=list(samples),
caption=[
f"Masking: {masked_pixels_percentage:.2f}% "
f"\n KL-divergence: {kl_div:.4f} "
f"\n Class: {pl_module.model.config.id2label[label]} "
f"\n Predicted Class: {pl_module.model.config.id2label[pred_class.item()]}"
for masked_pixels_percentage, kl_div, label, pred_class in zip(
masked_pixels_percentages, kl_divs, self.labels, pred_classes
)
],
)
def on_fit_start(self, trainer: Trainer, pl_module: LightningModule):
# Transfer sample images to correct device
self.images = self.images.to(pl_module.device)
# Log sample images
self._log_masks(trainer, pl_module)
def on_train_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: dict,
batch: tuple[Tensor, Tensor],
batch_idx: int,
unused: int = 0,
):
# Log sample images every n steps
if batch_idx % self.log_every_n_steps == 0:
self._log_masks(trainer, pl_module)
|