Spaces:
Paused
Paused
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
import math | |
import numpy as np | |
from scipy import interpolate | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
__all__ = [ | |
"window_partition", | |
"window_unpartition", | |
"add_decomposed_rel_pos", | |
"get_abs_pos", | |
"PatchEmbed", | |
"VisionRotaryEmbeddingFast", | |
] | |
def window_partition(x, window_size): | |
""" | |
Partition into non-overlapping windows with padding if needed. | |
Args: | |
x (tensor): input tokens with [B, H, W, C]. | |
window_size (int): window size. | |
Returns: | |
windows: windows after partition with [B * num_windows, window_size, window_size, C]. | |
(Hp, Wp): padded height and width before partition | |
""" | |
B, H, W, C = x.shape | |
pad_h = (window_size - H % window_size) % window_size | |
pad_w = (window_size - W % window_size) % window_size | |
if pad_h > 0 or pad_w > 0: | |
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) | |
Hp, Wp = H + pad_h, W + pad_w | |
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) | |
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) | |
return windows, (Hp, Wp) | |
def window_unpartition(windows, window_size, pad_hw, hw): | |
""" | |
Window unpartition into original sequences and removing padding. | |
Args: | |
x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. | |
window_size (int): window size. | |
pad_hw (Tuple): padded height and width (Hp, Wp). | |
hw (Tuple): original height and width (H, W) before padding. | |
Returns: | |
x: unpartitioned sequences with [B, H, W, C]. | |
""" | |
Hp, Wp = pad_hw | |
H, W = hw | |
B = windows.shape[0] // (Hp * Wp // window_size // window_size) | |
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) | |
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) | |
if Hp > H or Wp > W: | |
x = x[:, :H, :W, :].contiguous() | |
return x | |
def get_rel_pos(q_size, k_size, rel_pos): | |
""" | |
Get relative positional embeddings according to the relative positions of | |
query and key sizes. | |
Args: | |
q_size (int): size of query q. | |
k_size (int): size of key k. | |
rel_pos (Tensor): relative position embeddings (L, C). | |
Returns: | |
Extracted positional embeddings according to relative positions. | |
""" | |
max_rel_dist = int(2 * max(q_size, k_size) - 1) | |
use_log_interpolation = True | |
# Interpolate rel pos if needed. | |
if rel_pos.shape[0] != max_rel_dist: | |
if not use_log_interpolation: | |
# Interpolate rel pos. | |
rel_pos_resized = F.interpolate( | |
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), | |
size=max_rel_dist, | |
mode="linear", | |
) | |
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) | |
else: | |
src_size = rel_pos.shape[0] | |
dst_size = max_rel_dist | |
# q = 1.13492 | |
q = 1.0903078 | |
dis = [] | |
cur = 1 | |
for i in range(src_size // 2): | |
dis.append(cur) | |
cur += q ** (i + 1) | |
r_ids = [-_ for _ in reversed(dis)] | |
x = r_ids + [0] + dis | |
t = dst_size // 2.0 | |
dx = np.arange(-t, t + 0.1, 1.0) | |
# print("x = %s" % str(x)) | |
# print("dx = %s" % str(dx)) | |
all_rel_pos_bias = [] | |
for i in range(rel_pos.shape[1]): | |
z = rel_pos[:, i].view(src_size).cpu().float().numpy() | |
f = interpolate.interp1d(x, z, kind='cubic', fill_value="extrapolate") | |
all_rel_pos_bias.append( | |
torch.Tensor(f(dx)).contiguous().view(-1, 1).to(rel_pos.device)) | |
rel_pos_resized = torch.cat(all_rel_pos_bias, dim=-1) | |
else: | |
rel_pos_resized = rel_pos | |
# Scale the coords with short length if shapes for q and k are different. | |
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) | |
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) | |
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) | |
return rel_pos_resized[relative_coords.long()] | |
def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size): | |
""" | |
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. | |
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 | |
Args: | |
attn (Tensor): attention map. | |
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). | |
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. | |
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. | |
q_size (Tuple): spatial sequence size of query q with (q_h, q_w). | |
k_size (Tuple): spatial sequence size of key k with (k_h, k_w). | |
Returns: | |
attn (Tensor): attention map with added relative positional embeddings. | |
""" | |
q_h, q_w = q_size | |
k_h, k_w = k_size | |
Rh = get_rel_pos(q_h, k_h, rel_pos_h) | |
Rw = get_rel_pos(q_w, k_w, rel_pos_w) | |
B, _, dim = q.shape | |
r_q = q.reshape(B, q_h, q_w, dim) | |
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) | |
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) | |
attn = ( | |
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] | |
).view(B, q_h * q_w, k_h * k_w) | |
return attn | |
def get_abs_pos(abs_pos, has_cls_token, hw): | |
""" | |
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token | |
dimension for the original embeddings. | |
Args: | |
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). | |
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. | |
hw (Tuple): size of input image tokens. | |
Returns: | |
Absolute positional embeddings after processing with shape (1, H, W, C) | |
""" | |
h, w = hw | |
if has_cls_token: | |
abs_pos = abs_pos[:, 1:] | |
xy_num = abs_pos.shape[1] | |
size = int(math.sqrt(xy_num)) | |
assert size * size == xy_num | |
if size != h or size != w: | |
new_abs_pos = F.interpolate( | |
abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2), | |
size=(h, w), | |
mode="bicubic", | |
align_corners=False, | |
) | |
return new_abs_pos.permute(0, 2, 3, 1) | |
else: | |
return abs_pos.reshape(1, h, w, -1) | |
class PatchEmbed(nn.Module): | |
""" | |
Image to Patch Embedding. | |
""" | |
def __init__( | |
self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768 | |
): | |
""" | |
Args: | |
kernel_size (Tuple): kernel size of the projection layer. | |
stride (Tuple): stride of the projection layer. | |
padding (Tuple): padding size of the projection layer. | |
in_chans (int): Number of input image channels. | |
embed_dim (int): embed_dim (int): Patch embedding dimension. | |
""" | |
super().__init__() | |
self.proj = nn.Conv2d( | |
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding | |
) | |
def forward(self, x): | |
x = self.proj(x) | |
# B C H W -> B H W C | |
x = x.permute(0, 2, 3, 1) | |
return x | |
from math import pi | |
import torch | |
from torch import nn | |
from einops import rearrange, repeat | |
def broadcat(tensors, dim = -1): | |
num_tensors = len(tensors) | |
shape_lens = set(list(map(lambda t: len(t.shape), tensors))) | |
assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' | |
shape_len = list(shape_lens)[0] | |
dim = (dim + shape_len) if dim < 0 else dim | |
dims = list(zip(*map(lambda t: list(t.shape), tensors))) | |
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] | |
assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation' | |
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) | |
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) | |
expanded_dims.insert(dim, (dim, dims[dim])) | |
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) | |
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) | |
return torch.cat(tensors, dim = dim) | |
def rotate_half(x): | |
x = rearrange(x, '... (d r) -> ... d r', r = 2) | |
x1, x2 = x.unbind(dim = -1) | |
x = torch.stack((-x2, x1), dim = -1) | |
return rearrange(x, '... d r -> ... (d r)') | |
class VisionRotaryEmbedding(nn.Module): | |
def __init__( | |
self, | |
dim, | |
pt_seq_len, | |
ft_seq_len=None, | |
custom_freqs = None, | |
freqs_for = 'lang', | |
theta = 10000, | |
max_freq = 10, | |
num_freqs = 1, | |
): | |
super().__init__() | |
if custom_freqs: | |
freqs = custom_freqs | |
elif freqs_for == 'lang': | |
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) | |
elif freqs_for == 'pixel': | |
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi | |
elif freqs_for == 'constant': | |
freqs = torch.ones(num_freqs).float() | |
else: | |
raise ValueError(f'unknown modality {freqs_for}') | |
if ft_seq_len is None: ft_seq_len = pt_seq_len | |
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len | |
freqs_h = torch.einsum('..., f -> ... f', t, freqs) | |
freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2) | |
freqs_w = torch.einsum('..., f -> ... f', t, freqs) | |
freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2) | |
freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1) | |
self.register_buffer("freqs_cos", freqs.cos()) | |
self.register_buffer("freqs_sin", freqs.sin()) | |
print('======== shape of rope freq', self.freqs_cos.shape, '========') | |
def forward(self, t, start_index = 0): | |
rot_dim = self.freqs_cos.shape[-1] | |
end_index = start_index + rot_dim | |
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' | |
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] | |
t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) | |
return torch.cat((t_left, t, t_right), dim = -1) | |
class VisionRotaryEmbeddingFast(nn.Module): | |
def __init__( | |
self, | |
dim, | |
pt_seq_len=16, | |
ft_seq_len=None, | |
custom_freqs = None, | |
freqs_for = 'lang', | |
theta = 10000, | |
max_freq = 10, | |
num_freqs = 1, | |
): | |
super().__init__() | |
if custom_freqs: | |
freqs = custom_freqs | |
elif freqs_for == 'lang': | |
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) | |
elif freqs_for == 'pixel': | |
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi | |
elif freqs_for == 'constant': | |
freqs = torch.ones(num_freqs).float() | |
else: | |
raise ValueError(f'unknown modality {freqs_for}') | |
if ft_seq_len is None: ft_seq_len = pt_seq_len | |
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len | |
freqs = torch.einsum('..., f -> ... f', t, freqs) | |
freqs = repeat(freqs, '... n -> ... (n r)', r = 2) | |
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1) | |
freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) | |
freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) | |
self.register_buffer("freqs_cos", freqs_cos) | |
self.register_buffer("freqs_sin", freqs_sin) | |
print('======== shape of rope freq', self.freqs_cos.shape, '========') | |
# def forward(self, t): return t * self.freqs_cos + rotate_half(t) * self.freqs_sin | |
def forward(self, t): | |
if t.shape[2] != self.freqs_cos.shape[0]: | |
t_len = t.shape[2] | |
output = t * self.freqs_cos[:t_len] + rotate_half(t) * self.freqs_sin[:t_len] | |
else: | |
output = t * self.freqs_cos + rotate_half(t) * self.freqs_sin | |
return output | |