Spaces:
Runtime error
Runtime error
# Copyright (c) EPFL VILAB. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# -------------------------------------------------------- | |
# Based on timm, DeiT, DINO, MoCo-v3, BEiT, MAE-priv and MAE code bases | |
# https://github.com/rwightman/pytorch-image-models/tree/master/timm | |
# https://github.com/facebookresearch/deit | |
# https://github.com/facebookresearch/dino | |
# https://github.com/facebookresearch/moco-v3 | |
# https://github.com/microsoft/unilm/tree/master/beit | |
# https://github.com/BUPT-PRIV/MAE-priv | |
# https://github.com/facebookresearch/mae | |
# -------------------------------------------------------- | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import rearrange | |
class MaskedCrossEntropyLoss(nn.Module): | |
"""Cross-entropy loss with masking | |
:param patch_size: Patch size | |
:param stride: Stride of task / modality | |
:param label_smoothing: Amount of smoothing in the loss (default is 0.0) | |
""" | |
def __init__(self, patch_size: int = 16, stride: int = 1, label_smoothing : float = 0.0): | |
super().__init__() | |
self.patch_size = patch_size | |
self.stride = stride | |
self.scale_factor = patch_size // stride | |
self.label_smoothing = label_smoothing | |
def forward(self, input, target, mask=None): | |
loss = F.cross_entropy(input, target, reduction='none', label_smoothing=self.label_smoothing) | |
if mask is not None: | |
if mask.sum() == 0: | |
return torch.tensor(0).to(loss.device) | |
H, W = input.shape[-2:] | |
nh, nw = H // self.scale_factor, W // self.scale_factor | |
# Resize mask and upsample | |
mask = rearrange(mask, "b (nh nw) -> b nh nw", nh=nh, nw=nw) | |
mask = F.interpolate(mask.unsqueeze(1).float(), size=(H, W), mode='nearest').squeeze(1) | |
loss = loss * mask | |
# Compute mean per sample | |
loss = loss.flatten(start_dim=1).sum(dim=1) / mask.flatten(start_dim=1).sum(dim=1) | |
loss = loss.nanmean() # Account for zero masks | |
else: | |
loss = loss.mean() # If this is ever nan, we want it to stop training | |
return loss | |
class MaskedMSELoss(nn.Module): | |
"""L1 loss with masking | |
:param patch_size: Patch size | |
:param stride: Stride of task / modality | |
:param norm_pix: Normalized pixel loss | |
""" | |
def __init__(self, patch_size: int = 16, stride: int = 1, norm_pix=False): | |
super().__init__() | |
self.patch_size = patch_size | |
self.stride = stride | |
self.scale_factor = patch_size // stride | |
self.norm_pix = norm_pix | |
def patchify(self, imgs, nh, nw): | |
p = self.scale_factor | |
x = rearrange(imgs, "b c (nh p1) (nw p2) -> b (nh nw) (p1 p2 c)", nh=nh, nw=nw, p1=p, p2=p) | |
return x | |
def unpatchify(self, x, nh, nw): | |
p = self.scale_factor | |
imgs = rearrange(x, "b (nh nw) (p1 p2 c) -> b c (nh p1) (nw p2)", nh=nh, nw=nw, p1=p, p2=p) | |
return imgs | |
def forward(self, input, target, mask=None): | |
H, W = input.shape[-2:] | |
nh, nw = H // self.scale_factor, W // self.scale_factor | |
if self.norm_pix: | |
target = self.patchify(target, nh, nw) | |
mean = target.mean(dim=-1, keepdim=True) | |
var = target.var(dim=-1, keepdim=True) | |
eps = 1e-6 | |
target = (target - mean) / torch.sqrt(var + eps) | |
target = self.unpatchify(target, nh, nw) | |
loss = F.mse_loss(input, target, reduction='none') | |
if mask is not None: | |
if mask.sum() == 0: | |
return torch.tensor(0).to(loss.device) | |
# Resize mask and upsample | |
mask = rearrange(mask, "b (nh nw) -> b nh nw", nh=nh, nw=nw) | |
mask = F.interpolate(mask.unsqueeze(1).float(), size=(H, W), mode='nearest').squeeze(1) | |
loss = loss.mean(dim=1) # B, C, H, W -> B, H, W | |
loss = loss * mask | |
# Compute mean per sample | |
loss = loss.flatten(start_dim=1).sum(dim=1) / mask.flatten(start_dim=1).sum(dim=1) | |
loss = loss.nanmean() # Account for zero masks | |
else: | |
loss = loss.mean() # If this is ever nan, we want it to stop training | |
return loss | |
class MaskedL1Loss(nn.Module): | |
"""L1 loss with masking | |
:param patch_size: Patch size | |
:param stride: Stride of task / modality | |
:param norm_pix: Normalized pixel loss | |
""" | |
def __init__(self, patch_size: int = 16, stride: int = 1, norm_pix=False): | |
super().__init__() | |
self.patch_size = patch_size | |
self.stride = stride | |
self.scale_factor = patch_size // stride | |
self.norm_pix = norm_pix | |
def patchify(self, imgs, nh, nw): | |
p = self.scale_factor | |
x = rearrange(imgs, "b c (nh p1) (nw p2) -> b (nh nw) (p1 p2 c)", nh=nh, nw=nw, p1=p, p2=p) | |
return x | |
def unpatchify(self, x, nh, nw): | |
p = self.scale_factor | |
imgs = rearrange(x, "b (nh nw) (p1 p2 c) -> b c (nh p1) (nw p2)", nh=nh, nw=nw, p1=p, p2=p) | |
return imgs | |
def forward(self, input, target, mask=None): | |
H, W = input.shape[-2:] | |
nh, nw = H // self.scale_factor, W // self.scale_factor | |
if self.norm_pix: | |
target = self.patchify(target, nh, nw) | |
mean = target.mean(dim=-1, keepdim=True) | |
var = target.var(dim=-1, keepdim=True) | |
eps = 1e-6 | |
target = (target - mean) / torch.sqrt(var + eps) | |
target = self.unpatchify(target, nh, nw) | |
loss = F.l1_loss(input, target, reduction='none') | |
if mask is not None: | |
if mask.sum() == 0: | |
return torch.tensor(0).to(loss.device) | |
# Resize mask and upsample | |
mask = rearrange(mask, "b (nh nw) -> b nh nw", nh=nh, nw=nw) | |
mask = F.interpolate(mask.unsqueeze(1).float(), size=(H, W), mode='nearest').squeeze(1) | |
loss = loss.mean(dim=1) # B, C, H, W -> B, H, W | |
loss = loss * mask | |
# Compute mean per sample | |
loss = loss.flatten(start_dim=1).sum(dim=1) / mask.flatten(start_dim=1).sum(dim=1) | |
loss = loss.nanmean() # Account for zero masks | |
else: | |
loss = loss.mean() # If this is ever nan, we want it to stop training | |
return loss | |