Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
"""Some utilities for backbones, in particular for windowing""" | |
from typing import Tuple | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
def window_partition(x, window_size): | |
""" | |
Partition into non-overlapping windows with padding if needed. | |
Args: | |
x (tensor): input tokens with [B, H, W, C]. | |
window_size (int): window size. | |
Returns: | |
windows: windows after partition with [B * num_windows, window_size, window_size, C]. | |
(Hp, Wp): padded height and width before partition | |
""" | |
B, H, W, C = x.shape | |
pad_h = (window_size - H % window_size) % window_size | |
pad_w = (window_size - W % window_size) % window_size | |
if pad_h > 0 or pad_w > 0: | |
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) | |
Hp, Wp = H + pad_h, W + pad_w | |
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) | |
windows = ( | |
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) | |
) | |
return windows, (Hp, Wp) | |
def window_unpartition(windows, window_size, pad_hw, hw): | |
""" | |
Window unpartition into original sequences and removing padding. | |
Args: | |
x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. | |
window_size (int): window size. | |
pad_hw (Tuple): padded height and width (Hp, Wp). | |
hw (Tuple): original height and width (H, W) before padding. | |
Returns: | |
x: unpartitioned sequences with [B, H, W, C]. | |
""" | |
Hp, Wp = pad_hw | |
H, W = hw | |
B = windows.shape[0] // (Hp * Wp // window_size // window_size) | |
x = windows.view( | |
B, Hp // window_size, Wp // window_size, window_size, window_size, -1 | |
) | |
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) | |
if Hp > H or Wp > W: | |
x = x[:, :H, :W, :].contiguous() | |
return x | |
class PatchEmbed(nn.Module): | |
""" | |
Image to Patch Embedding. | |
""" | |
def __init__( | |
self, | |
kernel_size: Tuple[int, ...] = (7, 7), | |
stride: Tuple[int, ...] = (4, 4), | |
padding: Tuple[int, ...] = (3, 3), | |
in_chans: int = 3, | |
embed_dim: int = 768, | |
): | |
""" | |
Args: | |
kernel_size (Tuple): kernel size of the projection layer. | |
stride (Tuple): stride of the projection layer. | |
padding (Tuple): padding size of the projection layer. | |
in_chans (int): Number of input image channels. | |
embed_dim (int): embed_dim (int): Patch embedding dimension. | |
""" | |
super().__init__() | |
self.proj = nn.Conv2d( | |
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.proj(x) | |
# B C H W -> B H W C | |
x = x.permute(0, 2, 3, 1) | |
return x | |
def get_abs_pos(abs_pos, has_cls_token, hw): | |
""" | |
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token | |
dimension for the original embeddings. | |
Args: | |
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). | |
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. | |
hw (Tuple): size of input image tokens. | |
Returns: | |
Absolute positional embeddings after processing with shape (1, H, W, C) | |
""" | |
h, w = hw | |
if has_cls_token: | |
abs_pos = abs_pos[:, 1:] | |
xy_num = abs_pos.shape[1] | |
size = int(math.sqrt(xy_num)) | |
assert size * size == xy_num | |
if size != h or size != w: | |
new_abs_pos = F.interpolate( | |
abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2), | |
size=(h, w), | |
mode="bicubic", | |
align_corners=False, | |
) | |
return new_abs_pos.permute(0, 2, 3, 1) | |
else: | |
return abs_pos.reshape(1, h, w, -1) | |