Spaces:
Runtime error
Runtime error
import cv2 | |
import numpy as np | |
import torch | |
from pytorch_lightning import LightningModule | |
from pytorch_lightning.callbacks import Callback | |
from pytorch_lightning.loggers import WandbLogger | |
from pytorch_lightning.trainer import Trainer | |
from torch import Tensor | |
def unnormalize( | |
images: Tensor, | |
mean: tuple[float] = (0.5, 0.5, 0.5), | |
std: tuple[float] = (0.5, 0.5, 0.5), | |
) -> Tensor: | |
"""Reverts the normalization transformation applied before ViT. | |
Args: | |
images (Tensor): a batch of images | |
mean (tuple[int]): the means used for normalization - defaults to (0.5, 0.5, 0.5) | |
std (tuple[int]): the stds used for normalization - defaults to (0.5, 0.5, 0.5) | |
Returns: | |
the un-normalized batch of images | |
""" | |
unnormalized_images = images.clone() | |
for i, (m, s) in enumerate(zip(mean, std)): | |
unnormalized_images[:, i, :, :].mul_(s).add_(m) | |
return unnormalized_images | |
def smoothen(mask: Tensor, patch_size: int = 16) -> Tensor: | |
"""Smoothens a mask by downsampling it and re-upsampling it | |
with bi-linear interpolation. | |
Args: | |
mask (Tensor): a 2D float torch tensor with values in [0, 1] | |
patch_size (int): the patch size in pixels | |
Returns: | |
a smoothened mask at the pixel level | |
""" | |
device = mask.device | |
(h, w) = mask.shape | |
mask = cv2.resize( | |
mask.cpu().numpy(), | |
(h // patch_size, w // patch_size), | |
interpolation=cv2.INTER_NEAREST, | |
) | |
mask = cv2.resize(mask, (h, w), interpolation=cv2.INTER_LINEAR) | |
return torch.tensor(mask).to(device) | |
def draw_mask_on_image(image: Tensor, mask: Tensor) -> Tensor: | |
"""Overlays a dimming mask on the image. | |
Args: | |
image (Tensor): a float torch tensor with values in [0, 1] | |
mask (Tensor): a float torch tensor with values in [0, 1] | |
Returns: | |
the image with parts of it dimmed according to the mask | |
""" | |
masked_image = image * mask | |
return masked_image | |
def draw_heatmap_on_image( | |
image: Tensor, | |
mask: Tensor, | |
colormap: int = cv2.COLORMAP_JET, | |
) -> Tensor: | |
"""Overlays a heatmap on the image. | |
Args: | |
image (Tensor): a float torch tensor with values in [0, 1] | |
mask (Tensor): a float torch tensor with values in [0, 1] | |
colormap (int): the OpenCV colormap to be used | |
Returns: | |
the image with the heatmap overlaid | |
""" | |
# Save the device of the image | |
original_device = image.device | |
# Convert image & mask to numpy | |
image = image.permute(1, 2, 0).cpu().numpy() | |
mask = mask.cpu().numpy() | |
# Create heatmap | |
heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) | |
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) | |
heatmap = np.float32(heatmap) / 255 | |
# Overlay heatmap on image | |
masked_image = image + heatmap | |
masked_image = masked_image / np.max(masked_image) | |
return torch.tensor(masked_image).permute(2, 0, 1).to(original_device) | |
def _prepare_samples(images: Tensor, masks: Tensor) -> tuple[Tensor, list[float]]: | |
"""Prepares the samples for the masking/heatmap visualization. | |
Args: | |
images (Tensor): a float torch tensor with values in [0, 1] | |
masks (Tensor): a float torch tensor with values in [0, 1] | |
Returns | |
a tuple of image triplets (img, masked, heatmap) and their | |
corresponding masking percentages | |
""" | |
num_channels = images[0].shape[0] | |
# Smoothen masks | |
masks = [smoothen(m) for m in masks] | |
# Un-normalize images | |
if num_channels == 1: | |
images = [ | |
torch.repeat_interleave(img, 3, 0) | |
for img in unnormalize(images, mean=(0.5,), std=(0.5,)) | |
] | |
else: | |
images = [img for img in unnormalize(images)] | |
# Draw mask on sample images | |
images_with_mask = [ | |
draw_mask_on_image(image, mask) for image, mask in zip(images, masks) | |
] | |
# Draw heatmap on sample images | |
images_with_heatmap = [ | |
draw_heatmap_on_image(image, mask) for image, mask in zip(images, masks) | |
] | |
# Chunk to triplets (image, masked image, heatmap) | |
samples = torch.cat( | |
[ | |
torch.cat(images, dim=2), | |
torch.cat(images_with_mask, dim=2), | |
torch.cat(images_with_heatmap, dim=2), | |
], | |
dim=1, | |
).chunk(len(images), dim=-1) | |
# Compute masking percentages | |
masked_pixels_percentages = [ | |
100 * (1 - torch.stack(masks)[i].mean(-1).mean(-1).item()) | |
for i in range(len(masks)) | |
] | |
return samples, masked_pixels_percentages | |
def log_masks(images: Tensor, masks: Tensor, key: str, logger: WandbLogger): | |
"""Logs a set of images with their masks to WandB. | |
Args: | |
images (Tensor): a float torch tensor with values in [0, 1] | |
masks (Tensor): a float torch tensor with values in [0, 1] | |
key (str): the key to log the images with | |
logger (WandbLogger): the logger to log the images to | |
""" | |
samples, masked_pixels_percentages = _prepare_samples(images, masks) | |
# Log with wandb | |
logger.log_image( | |
key=key, | |
images=list(samples), | |
caption=[ | |
f"Masking: {masked_pixels_percentage:.2f}% " | |
for masked_pixels_percentage in masked_pixels_percentages | |
], | |
) | |
class DrawMaskCallback(Callback): | |
def __init__( | |
self, | |
samples: list[tuple[Tensor, Tensor]], | |
log_every_n_steps: int = 200, | |
key: str = "", | |
): | |
"""A callback that logs VisionDiffMask masks for the sample images to WandB. | |
Args: | |
samples (list[tuple[Tensor, Tensor]): a list of image, label pairs | |
log_every_n_steps (int): the interval in steps to log the masks to WandB | |
key (str): the key to log the images with (allows for multiple batches) | |
""" | |
self.images = torch.stack([img for img in samples[0]]) | |
self.labels = [label.item() for label in samples[1]] | |
self.log_every_n_steps = log_every_n_steps | |
self.key = key | |
def _log_masks(self, trainer: Trainer, pl_module: LightningModule): | |
# Predict mask | |
with torch.no_grad(): | |
pl_module.eval() | |
outputs = pl_module.get_mask(self.images) | |
pl_module.train() | |
# Unnest outputs | |
masks = outputs["mask"] | |
kl_divs = outputs["kl_div"] | |
pred_classes = outputs["pred_class"].cpu() | |
# Prepare masked samples for logging | |
samples, masked_pixels_percentages = _prepare_samples(self.images, masks) | |
# Log with wandb | |
trainer.logger.log_image( | |
key="DiffMask " + self.key, | |
images=list(samples), | |
caption=[ | |
f"Masking: {masked_pixels_percentage:.2f}% " | |
f"\n KL-divergence: {kl_div:.4f} " | |
f"\n Class: {pl_module.model.config.id2label[label]} " | |
f"\n Predicted Class: {pl_module.model.config.id2label[pred_class.item()]}" | |
for masked_pixels_percentage, kl_div, label, pred_class in zip( | |
masked_pixels_percentages, kl_divs, self.labels, pred_classes | |
) | |
], | |
) | |
def on_fit_start(self, trainer: Trainer, pl_module: LightningModule): | |
# Transfer sample images to correct device | |
self.images = self.images.to(pl_module.device) | |
# Log sample images | |
self._log_masks(trainer, pl_module) | |
def on_train_batch_end( | |
self, | |
trainer: Trainer, | |
pl_module: LightningModule, | |
outputs: dict, | |
batch: tuple[Tensor, Tensor], | |
batch_idx: int, | |
unused: int = 0, | |
): | |
# Log sample images every n steps | |
if batch_idx % self.log_every_n_steps == 0: | |
self._log_masks(trainer, pl_module) | |