Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
Mostly copy-paste from timm library. | |
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py | |
""" | |
from copy import deepcopy | |
from typing import List, Optional, Tuple | |
import math | |
from functools import partial | |
from sympy import flatten | |
import torch | |
import torch.nn as nn | |
from torch import Tensor, pixel_shuffle | |
from einops import rearrange, repeat | |
from einops.layers.torch import Rearrange | |
from torch.nn.modules import GELU | |
import torch.utils.benchmark as benchmark | |
def benchmark_torch_function_in_microseconds(f, *args, **kwargs): | |
t0 = benchmark.Timer( | |
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} | |
) | |
return t0.blocked_autorange().mean * 1e6 | |
# from vit.vision_transformer import Conv3DCrossAttentionBlock | |
from .utils import trunc_normal_ | |
from pdb import set_trace as st | |
# import apex | |
# from apex.normalization import FusedRMSNorm as RMSNorm | |
try: | |
from apex.normalization import FusedRMSNorm as RMSNorm | |
except: | |
from dit.norm import RMSNorm | |
from torch.nn import LayerNorm | |
try: | |
import xformers | |
import xformers.ops | |
from xformers.ops import memory_efficient_attention, unbind, fmha | |
from xformers.ops import MemoryEfficientAttentionFlashAttentionOp, MemoryEfficientAttentionCutlassOp | |
# from xformers.ops import RMSNorm | |
XFORMERS_AVAILABLE = True | |
except ImportError: | |
# logger.warning("xFormers not available") | |
XFORMERS_AVAILABLE = False | |
from packaging import version | |
assert version.parse(torch.__version__) >= version.parse("2.0.0") | |
SDP_IS_AVAILABLE = True | |
# from torch.backends.cuda import SDPBackend, sdp_kernel | |
from torch.nn.attention import sdpa_kernel, SDPBackend | |
class Attention(nn.Module): | |
def __init__(self, | |
dim, | |
num_heads=8, | |
qkv_bias=False, | |
qk_scale=None, | |
attn_drop=0., | |
proj_drop=0., | |
enable_rmsnorm=False, | |
qk_norm=False, | |
no_flash_op=False, | |
enable_rope=False,): | |
super().__init__() | |
self.num_heads = num_heads | |
head_dim = dim // num_heads | |
self.scale = qk_scale or head_dim**-0.5 | |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
# https://github.com/huggingface/pytorch-image-models/blob/5dce71010174ad6599653da4e8ba37fd5f9fa572/timm/models/vision_transformer.py#L79C1-L80C78 | |
self.enable_rope = enable_rope | |
# st() | |
if enable_rope: | |
self.q_norm = RMSNorm(dim, elementwise_affine=True) if qk_norm else nn.Identity() | |
self.k_norm = RMSNorm(dim, elementwise_affine=True) if qk_norm else nn.Identity() | |
else: | |
self.q_norm = RMSNorm(head_dim, elementwise_affine=True) if qk_norm else nn.Identity() | |
self.k_norm = RMSNorm(head_dim, elementwise_affine=True) if qk_norm else nn.Identity() | |
# if qk_norm: | |
# self.q_norm = LayerNorm(dim, eps=1e-5) | |
# self.k_norm = LayerNorm(dim, eps=1e-5) | |
self.qk_norm = qk_norm | |
self.no_flash_op = no_flash_op | |
self.attn_mode = "torch" | |
self.backend = SDPBackend.FLASH_ATTENTION # FA implemented by torch. | |
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): | |
""" | |
Reshape frequency tensor for broadcasting it with another tensor. | |
This function reshapes the frequency tensor to have the same shape as | |
the target tensor 'x' for the purpose of broadcasting the frequency | |
tensor during element-wise operations. | |
Args: | |
freqs_cis (torch.Tensor): Frequency tensor to be reshaped. | |
x (torch.Tensor): Target tensor for broadcasting compatibility. | |
Returns: | |
torch.Tensor: Reshaped frequency tensor. | |
Raises: | |
AssertionError: If the frequency tensor doesn't match the expected | |
shape. | |
AssertionError: If the target tensor 'x' doesn't have the expected | |
number of dimensions. | |
""" | |
ndim = x.ndim | |
assert 0 <= 1 < ndim | |
assert freqs_cis.shape == (x.shape[1], x.shape[-1]) | |
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] | |
return freqs_cis.view(*shape) | |
def apply_rotary_emb( | |
xq: torch.Tensor, | |
xk: torch.Tensor, | |
freqs_cis: torch.Tensor, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Apply rotary embeddings to input tensors using the given frequency | |
tensor. | |
This function applies rotary embeddings to the given query 'xq' and | |
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The | |
input tensors are reshaped as complex numbers, and the frequency tensor | |
is reshaped for broadcasting compatibility. The resulting tensors | |
contain rotary embeddings and are returned as real tensors. | |
Args: | |
xq (torch.Tensor): Query tensor to apply rotary embeddings. | |
xk (torch.Tensor): Key tensor to apply rotary embeddings. | |
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex | |
exponentials. | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor | |
and key tensor with rotary embeddings. | |
""" | |
with torch.cuda.amp.autocast(enabled=False): | |
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
freqs_cis = Attention.reshape_for_broadcast(freqs_cis, xq_) | |
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
return xq_out.type_as(xq), xk_out.type_as(xk) | |
def forward(self, x): | |
# B, N, C = x.shape | |
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, | |
# C // self.num_heads).permute(2, 0, 3, 1, 4) | |
# q, k, v = qkv[0], qkv[1], qkv[2] | |
# attn = (q @ k.transpose(-2, -1)) * self.scale | |
# attn = attn.softmax(dim=-1) | |
# attn = self.attn_drop(attn) | |
# x = (attn @ v).transpose(1, 2).reshape(B, N, C) | |
# return x, attn | |
# https://github.com/Stability-AI/generative-models/blob/863665548f95ff827273948766a3f732ab01bc49/sgm/modules/attention.py#L179 | |
B, L, C = x.shape | |
qkv = self.qkv(x) | |
if self.attn_mode == "torch": | |
qkv = rearrange( | |
qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads | |
).float() | |
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D | |
q, k = self.q_norm(q), self.k_norm(k) | |
with sdpa_kernel([self.backend]): # new signature | |
x = torch.nn.functional.scaled_dot_product_attention(q, k, v) | |
del q, k, v | |
x = rearrange(x, "B H L D -> B L (H D)") | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
class MemEffAttention(Attention): | |
def forward(self, x: Tensor, attn_bias=None, freqs_cis=None) -> Tensor: | |
if not XFORMERS_AVAILABLE: | |
assert attn_bias is None, "xFormers is required for nested tensors usage" | |
return super().forward(x) | |
B, N, C = x.shape | |
qkv = self.qkv(x) | |
dtype = qkv.dtype | |
if self.enable_rope: | |
assert freqs_cis is not None | |
qkv = qkv.reshape(B, N, 3, C) | |
q, k, v = unbind(qkv, 2) | |
q, k = self.q_norm(q), self.k_norm(k) # do q-k norm on the full seq instead. | |
st() | |
q, k = Attention.apply_rotary_emb(q, k, freqs_cis=freqs_cis) | |
q = q.reshape(B, N, self.num_heads, C // self.num_heads) | |
k = k.reshape(B, N, self.num_heads, C // self.num_heads) | |
q, k, v = map( | |
lambda t: t.reshape(b, N, self.num_heads, C // self.num_heads) | |
(q, k, v), | |
) | |
q, k = q.to(dtype), k.to(dtype) | |
else: | |
qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads) | |
q, k, v = unbind(qkv, 2) | |
q, k = self.q_norm(q), self.k_norm(k) | |
# x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) # if not bf16, no flash-attn here. | |
# x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=MemoryEfficientAttentionFlashAttentionOp) # force flash attention | |
if self.no_flash_op: # F-A does not support large batch size? force cutlas? | |
# x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=MemoryEfficientAttentionCutlassOp) # force flash attention | |
if version.parse(xformers.__version__) >= version.parse("0.0.21"): | |
# NOTE: workaround for | |
# https://github.com/facebookresearch/xformers/issues/845 | |
# def attn(max_bs, op): | |
max_bs = 32768 | |
L = q.shape[0] | |
n_batches = math.ceil(L / max_bs) | |
x = list() | |
for i_batch in range(n_batches): | |
batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs) | |
x.append( | |
xformers.ops.memory_efficient_attention( | |
q[batch], | |
k[batch], | |
v[batch], | |
attn_bias=None, | |
# op=MemoryEfficientAttentionFlashAttentionOp, | |
# op=op, | |
op=MemoryEfficientAttentionCutlassOp, | |
) | |
) | |
x = torch.cat(x, 0) | |
# return x | |
# The cutlas implementation runs in 8396.681 microseconds | |
# The Flash implementation runs in 19473.491 microseconds | |
# max_bs = 32768 | |
# math_time = benchmark_torch_function_in_microseconds(attn, max_bs, MemoryEfficientAttentionCutlassOp) | |
# print(f"The cutlas implementation runs in {math_time:.3f} microseconds") | |
# max_bs = 32768 // 2 # works for flash attention | |
# math_time = benchmark_torch_function_in_microseconds(attn, max_bs, MemoryEfficientAttentionFlashAttentionOp) | |
# print(f"The Flash implementation runs in {math_time:.3f} microseconds") | |
# st() | |
# pass | |
else: # will enable flash attention by default. | |
# x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=MemoryEfficientAttentionFlashAttentionOp) # force flash attention | |
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) # force flash attention | |
x = x.reshape([B, N, C]) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
class MemEffCrossAttention(MemEffAttention): | |
# for cross attention, where context serves as k and v | |
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0, proj_drop=0): | |
super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop) | |
del self.qkv | |
self.q = nn.Linear(dim, dim * 1, bias=qkv_bias) | |
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) | |
def forward(self, x: Tensor, context: Tensor, attn_bias=None) -> Tensor: | |
if not XFORMERS_AVAILABLE: | |
assert attn_bias is None, "xFormers is required for nested tensors usage" | |
return super().forward(x) | |
B, N, C = x.shape | |
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) | |
q = self.q(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) | |
kv = self.kv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) | |
k, v = unbind(kv, 2) | |
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) | |
# x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=MemoryEfficientAttentionFlashAttentionOp) | |
x = x.reshape([B, N, C]) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
# https://github.com/IBM/CrossViT/blob/main/models/crossvit.py | |
class CrossAttention(nn.Module): | |
def __init__(self, | |
dim, | |
num_heads=8, | |
qkv_bias=False, | |
qk_scale=None, | |
attn_drop=0., | |
proj_drop=0.): | |
super().__init__() | |
self.num_heads = num_heads | |
head_dim = dim // num_heads | |
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights | |
self.scale = qk_scale or head_dim**-0.5 | |
self.wq = nn.Linear(dim, dim, bias=qkv_bias) | |
self.wk = nn.Linear(dim, dim, bias=qkv_bias) | |
self.wv = nn.Linear(dim, dim, bias=qkv_bias) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
def forward(self, x): | |
B, N, C = x.shape | |
q = self.wq(x[:, | |
0:1, ...]).reshape(B, 1, self.num_heads, | |
C // self.num_heads).permute( | |
0, 2, 1, | |
3) # B1C -> B1H(C/H) -> BH1(C/H) | |
k = self.wk(x).reshape(B, N, | |
self.num_heads, C // self.num_heads).permute( | |
0, 2, 1, 3) # BNC -> BNH(C/H) -> BHN(C/H) | |
v = self.wv(x).reshape(B, N, | |
self.num_heads, C // self.num_heads).permute( | |
0, 2, 1, 3) # BNC -> BNH(C/H) -> BHN(C/H) | |
attn = (q @ k.transpose( | |
-2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N | |
attn = attn.softmax(dim=-1) | |
attn = self.attn_drop(attn) | |
x = (attn @ v).transpose(1, 2).reshape( | |
B, 1, C) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
class Conv3D_Aware_CrossAttention(nn.Module): | |
def __init__(self, | |
dim, | |
num_heads=8, | |
qkv_bias=False, | |
qk_scale=None, | |
attn_drop=0., | |
proj_drop=0.): | |
super().__init__() | |
self.num_heads = num_heads | |
head_dim = dim // num_heads | |
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights | |
self.scale = qk_scale or head_dim**-0.5 | |
self.wq = nn.Linear(dim, dim, bias=qkv_bias) | |
self.wk = nn.Linear(dim, dim, bias=qkv_bias) | |
self.wv = nn.Linear(dim, dim, bias=qkv_bias) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
def forward(self, x): | |
B, group_size, N, C = x.shape # B 3 N C | |
p = int(N**0.5) # patch size | |
assert p**2 == N, 'check input dim, no [cls] needed here' | |
assert group_size == 3, 'designed for triplane here' | |
x = x.reshape(B, group_size, p, p, C) # expand patch token dim | |
# * init qkv | |
# q = torch.empty(B * group_size * N, | |
# 1, | |
# self.num_heads, | |
# C // self.num_heads, | |
# device=x.device).permute(0, 2, 1, 3) | |
# k = torch.empty(B * group_size * N, | |
# 2 * p, | |
# self.num_heads, | |
# C // self.num_heads, | |
# device=x.device).permute(0, 2, 1, 3) | |
# v = torch.empty_like(k) | |
q_x = torch.empty( | |
B * group_size * N, | |
1, | |
# self.num_heads, | |
# C // self.num_heads, | |
C, | |
device=x.device) | |
k_x = torch.empty( | |
B * group_size * N, | |
2 * p, | |
# self.num_heads, | |
# C // self.num_heads, | |
C, | |
device=x.device) | |
v_x = torch.empty_like(k_x) | |
# ! refer to the following plane order | |
# N, M, _ = coordinates.shape | |
# xy_coords = coordinates[..., [0, 1]] | |
# yz_coords = coordinates[..., [1, 2]] | |
# zx_coords = coordinates[..., [2, 0]] | |
# return torch.stack([xy_coords, yz_coords, zx_coords], | |
# dim=1).reshape(N * 3, M, 2) | |
index_i, index_j = torch.meshgrid(torch.arange(0, p), | |
torch.arange(0, p), | |
indexing='ij') # 16*16 | |
index_mesh_grid = torch.stack([index_i, index_j], 0).to( | |
x.device).unsqueeze(0).repeat_interleave(B, | |
0).reshape(B, 2, p, | |
p) # B 2 p p. | |
for i in range(group_size): | |
q_x[B * i * N:B * (i + 1) * N] = x[:, i:i + 1].permute( | |
0, 2, 3, 1, 4).reshape(B * N, 1, C) # B 1 p p C -> B*N, 1, C | |
# TODO, how to batchify gather ops? | |
plane_yz = x[:, (i + 1) % group_size:(i + 1) % group_size + | |
1] # B 1 p p C | |
plane_zx = x[:, (i + 2) % group_size:(i + 2) % group_size + 1] | |
assert plane_yz.shape == plane_zx.shape == ( | |
B, 1, p, p, C), 'check sub plane dimensions' | |
pooling_plane_yz = torch.gather( | |
plane_yz, | |
dim=2, | |
index=index_mesh_grid[:, 0:1].reshape(B, 1, N, 1, 1).expand( | |
-1, -1, -1, p, | |
C)).permute(0, 2, 1, 3, 4) # B 1 256 16 C => B 256 1 16 C | |
pooling_plane_zx = torch.gather( | |
plane_zx, | |
dim=3, | |
index=index_mesh_grid[:, 1:2].reshape(B, 1, 1, N, 1).expand( | |
-1, -1, p, -1, | |
C)).permute(0, 3, 1, 2, 4) # B 1 16 256 C => B 256 1 16 C | |
k_x[B * i * N:B * (i + 1) * | |
N] = v_x[B * i * N:B * (i + 1) * N] = torch.cat( | |
[pooling_plane_yz, pooling_plane_zx], | |
dim=2).reshape(B * N, 2 * p, | |
C) # B 256 2 16 C => (B*256) 2*16 C | |
# q[B * i * N: B * (i+1) * N] = self.wq(q_x).reshape(B*N, 1, self.num_heads, C // self.num_heads).permute( 0, 2, 1, 3) | |
# k[B * i * N: B * (i+1) * N] = self.wk(k_x).reshape(B*N, 2*p, self.num_heads, C // self.num_heads).permute( 0, 2, 1, 3) | |
# v[B * i * N: B * (i+1) * N] = self.wv(v_x).reshape(B*N, 2*p, self.num_heads, C // self.num_heads).permute( 0, 2, 1, 3) | |
q = self.wq(q_x).reshape(B * group_size * N, 1, | |
self.num_heads, C // self.num_heads).permute( | |
0, 2, 1, | |
3) # merge num_heads into Batch dimention | |
k = self.wk(k_x).reshape(B * group_size * N, 2 * p, self.num_heads, | |
C // self.num_heads).permute(0, 2, 1, 3) | |
v = self.wv(v_x).reshape(B * group_size * N, 2 * p, self.num_heads, | |
C // self.num_heads).permute(0, 2, 1, 3) | |
attn = (q @ k.transpose( | |
-2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N, N=2p here | |
attn = attn.softmax(dim=-1) | |
attn = self.attn_drop(attn) | |
x = (attn @ v).transpose(1, 2).reshape( | |
B * 3 * N, 1, | |
C) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
# reshape x back | |
x = x.reshape(B, 3, N, C) | |
return x | |
class xformer_Conv3D_Aware_CrossAttention(nn.Module): | |
# https://github.dev/facebookresearch/dinov2 | |
def __init__(self, | |
dim, | |
num_heads=8, | |
qkv_bias=False, | |
qk_scale=None, | |
attn_drop=0., | |
proj_drop=0.): | |
super().__init__() | |
# https://pytorch.org/blog/accelerated-generative-diffusion-models/ | |
self.num_heads = num_heads | |
self.wq = nn.Linear(dim, dim * 1, bias=qkv_bias) | |
self.w_kv = nn.Linear(dim, dim * 2, bias=qkv_bias) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
self.index_mesh_grid = None | |
def forward(self, x, attn_bias=None): | |
B, group_size, N, C = x.shape # B 3 N C | |
p = int(N**0.5) # patch size | |
assert p**2 == N, 'check input dim, no [cls] needed here' | |
assert group_size == 3, 'designed for triplane here' | |
x = x.reshape(B, group_size, p, p, C) # expand patch token dim | |
q_x = torch.empty(B * group_size * N, 1, C, device=x.device) | |
context = torch.empty(B * group_size * N, 2 * p, C, | |
device=x.device) # k_x=v_x | |
if self.index_mesh_grid is None: # further accelerate | |
index_i, index_j = torch.meshgrid(torch.arange(0, p), | |
torch.arange(0, p), | |
indexing='ij') # 16*16 | |
index_mesh_grid = torch.stack([index_i, index_j], 0).to( | |
x.device).unsqueeze(0).repeat_interleave(B, 0).reshape( | |
B, 2, p, p) # B 2 p p. | |
self.index_mesh_grid = index_mesh_grid[0:1] | |
else: | |
index_mesh_grid = self.index_mesh_grid.clone().repeat_interleave( | |
B, 0) | |
assert index_mesh_grid.shape == ( | |
B, 2, p, p), 'check index_mesh_grid dimension' | |
for i in range(group_size): | |
q_x[B * i * N:B * (i + 1) * N] = x[:, i:i + 1].permute( | |
0, 2, 3, 1, 4).reshape(B * N, 1, C) # B 1 p p C -> B*N, 1, C | |
# TODO, how to batchify gather ops? | |
plane_yz = x[:, (i + 1) % group_size:(i + 1) % group_size + | |
1] # B 1 p p C | |
plane_zx = x[:, (i + 2) % group_size:(i + 2) % group_size + 1] | |
assert plane_yz.shape == plane_zx.shape == ( | |
B, 1, p, p, C), 'check sub plane dimensions' | |
pooling_plane_yz = torch.gather( | |
plane_yz, | |
dim=2, | |
index=index_mesh_grid[:, 0:1].reshape(B, 1, N, 1, 1).expand( | |
-1, -1, -1, p, | |
C)).permute(0, 2, 1, 3, 4) # B 1 256 16 C => B 256 1 16 C | |
pooling_plane_zx = torch.gather( | |
plane_zx, | |
dim=3, | |
index=index_mesh_grid[:, 1:2].reshape(B, 1, 1, N, 1).expand( | |
-1, -1, p, -1, | |
C)).permute(0, 3, 1, 2, 4) # B 1 16 256 C => B 256 1 16 C | |
context[B * i * N:B * (i + 1) * N] = torch.cat( | |
[pooling_plane_yz, pooling_plane_zx], | |
dim=2).reshape(B * N, 2 * p, | |
C) # B 256 2 16 C => (B*256) 2*16 C | |
# B, N, C = x.shape | |
q = self.wq(q_x).reshape(B * group_size * N, 1, self.num_heads, | |
C // self.num_heads) | |
kv = self.w_kv(context).reshape(B * group_size * N, 2 * p, 2, | |
self.num_heads, C // self.num_heads) | |
k, v = unbind(kv, 2) | |
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) | |
# x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=MemoryEfficientAttentionFlashAttentionOp) | |
x = x.transpose(1, 2).reshape([B * 3 * N, 1, C]).reshape(B, 3, N, C) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
class xformer_Conv3D_Aware_CrossAttention_xygrid( | |
xformer_Conv3D_Aware_CrossAttention): | |
"""implementation wise clearer, but yields identical results with xformer_Conv3D_Aware_CrossAttention | |
""" | |
def __init__(self, | |
dim, | |
num_heads=8, | |
qkv_bias=False, | |
qk_scale=None, | |
attn_drop=0.0, | |
proj_drop=0.0): | |
super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, | |
proj_drop) | |
def forward(self, x, attn_bias=None): | |
B, group_size, N, C = x.shape # B 3 N C | |
p = int(N**0.5) # patch size | |
assert p**2 == N, 'check input dim, no [cls] needed here' | |
assert group_size == 3, 'designed for triplane here' | |
x = x.reshape(B, group_size, p, p, C) # expand patch token dim | |
q_x = torch.empty(B * group_size * N, 1, C, device=x.device) | |
context = torch.empty(B * group_size * N, 2 * p, C, | |
device=x.device) # k_x=v_x | |
if self.index_mesh_grid is None: # further accelerate | |
index_u, index_v = torch.meshgrid( | |
torch.arange(0, p), torch.arange(0, p), | |
indexing='xy') # ! switch to 'xy' here to match uv coordinate | |
index_mesh_grid = torch.stack([index_u, index_v], 0).to( | |
x.device).unsqueeze(0).repeat_interleave(B, 0).reshape( | |
B, 2, p, p) # B 2 p p. | |
self.index_mesh_grid = index_mesh_grid[0:1] | |
else: | |
index_mesh_grid = self.index_mesh_grid.clone().repeat_interleave( | |
B, 0) | |
assert index_mesh_grid.shape == ( | |
B, 2, p, p), 'check index_mesh_grid dimension' | |
for i in range(group_size): | |
q_x[B * i * N:B * (i + 1) * N] = x[:, i:i + 1].permute( | |
0, 2, 3, 1, 4).reshape(B * N, 1, C) # B 1 p p C -> B*N, 1, C | |
# TODO, how to batchify gather ops? | |
plane_yz = x[:, (i + 1) % group_size:(i + 1) % group_size + | |
1] # B 1 p p C | |
plane_zx = x[:, (i + 2) % group_size:(i + 2) % group_size + 1] | |
assert plane_yz.shape == plane_zx.shape == ( | |
B, 1, p, p, C), 'check sub plane dimensions' | |
pooling_plane_yz = torch.gather( | |
plane_yz, | |
dim=2, | |
index=index_mesh_grid[:, 1:2].reshape(B, 1, N, 1, 1).expand( | |
-1, -1, -1, p, | |
C)).permute(0, 2, 1, 3, 4) # B 1 256 16 C => B 256 1 16 C | |
pooling_plane_zx = torch.gather( | |
plane_zx, | |
dim=3, | |
index=index_mesh_grid[:, 0:1].reshape(B, 1, 1, N, 1).expand( | |
-1, -1, p, -1, | |
C)).permute(0, 3, 1, 2, 4) # B 1 16 256 C => B 256 1 16 C | |
context[B * i * N:B * (i + 1) * N] = torch.cat( | |
[pooling_plane_yz, pooling_plane_zx], | |
dim=2).reshape(B * N, 2 * p, | |
C) # B 256 2 16 C => (B*256) 2*16 C | |
# B, N, C = x.shape | |
q = self.wq(q_x).reshape(B * group_size * N, 1, self.num_heads, | |
C // self.num_heads) | |
kv = self.w_kv(context).reshape(B * group_size * N, 2 * p, 2, | |
self.num_heads, C // self.num_heads) | |
k, v = unbind(kv, 2) | |
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) | |
# x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=MemoryEfficientAttentionFlashAttentionOp) | |
x = x.transpose(1, 2).reshape([B * 3 * N, 1, C]).reshape(B, 3, N, C) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
class xformer_Conv3D_Aware_CrossAttention_xygrid_withinC( | |
xformer_Conv3D_Aware_CrossAttention_xygrid): | |
def __init__(self, | |
dim, | |
num_heads=8, | |
qkv_bias=False, | |
qk_scale=None, | |
attn_drop=0, | |
proj_drop=0): | |
super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, | |
proj_drop) | |
def forward(self, x, attn_bias=None): | |
# ! split x: B N C into B 3 N C//3 | |
B, N, C = x.shape | |
x = x.reshape(B, N, C // 3, 3).permute(0, 3, 1, | |
2) # B N C 3 -> B 3 N C | |
x_out = super().forward(x, attn_bias) # B 3 N C | |
x_out = x_out.permute(0, 2, 3, 1)# B 3 N C -> B N C 3 | |
x_out = x_out.reshape(*x_out.shape[:2], -1) # B N C 3 -> B N C3 | |
return x_out.contiguous() | |
class self_cross_attn(nn.Module): | |
def __init__(self, dino_attn, cross_attn, *args, **kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
self.dino_attn = dino_attn | |
self.cross_attn = cross_attn | |
def forward(self, x_norm): | |
y = self.dino_attn(x_norm) + x_norm | |
return self.cross_attn(y) # will add x in the original code | |
# class RodinRollOutConv(nn.Module): | |
# """implementation wise clearer, but yields identical results with xformer_Conv3D_Aware_CrossAttention | |
# Use Group Conv | |
# """ | |
# def __init__(self, in_chans, out_chans=None): | |
# super().__init__() | |
# # input: B 3C H W | |
# if out_chans is None: | |
# out_chans = in_chans | |
# self.roll_out_convs = nn.Conv2d(in_chans, | |
# out_chans, | |
# kernel_size=3, | |
# groups=3, | |
# padding=1) | |
# def forward(self, x): | |
# return self.roll_out_convs(x) | |
class RodinRollOutConv3D(nn.Module): | |
"""implementation wise clearer, but yields identical results with xformer_Conv3D_Aware_CrossAttention | |
""" | |
def __init__(self, in_chans, out_chans=None): | |
super().__init__() | |
if out_chans is None: | |
out_chans = in_chans | |
self.out_chans = out_chans // 3 | |
self.roll_out_convs = nn.Conv2d(in_chans, | |
self.out_chans, | |
kernel_size=3, | |
padding=1) | |
def forward(self, x): | |
# todo, reshape before input? | |
B, C3, p, p = x.shape # B 3C H W | |
C = C3 // 3 | |
group_size = C3 // C | |
assert group_size == 3 | |
x = x.reshape(B, 3, C, p, p) | |
roll_out_x = torch.empty(B, group_size * C, p, 3 * p, | |
device=x.device) # B, 3C, H, 3W | |
for i in range(group_size): | |
plane_xy = x[:, i] # B C H W | |
# TODO, simply do the average pooling? | |
plane_yz_pooling = x[:, (i + 1) % group_size].mean( | |
dim=-1, keepdim=True).repeat_interleave( | |
p, dim=-1) # B C H W -> B C H 1 -> B C H W, reduce z dim | |
plane_zx_pooling = x[:, (i + 2) % group_size].mean( | |
dim=-2, keepdim=True).repeat_interleave( | |
p, dim=-2) # B C H W -> B C 1 W -> B C H W, reduce z dim | |
roll_out_x[..., i * p:(i + 1) * p] = torch.cat( | |
[plane_xy, plane_yz_pooling, plane_zx_pooling], | |
1) # fill in the 3W dim | |
x = self.roll_out_convs(roll_out_x) # B C H 3W | |
x = x.reshape(B, self.out_chans, p, 3, p) | |
x = x.permute(0, 3, 1, 2, 4).reshape(B, 3 * self.out_chans, p, | |
p) # B 3C H W | |
return x | |
class RodinRollOutConv3D_GroupConv(nn.Module): | |
"""implementation wise clearer, but yields identical results with xformer_Conv3D_Aware_CrossAttention | |
""" | |
def __init__(self, | |
in_chans, | |
out_chans=None, | |
kernel_size=3, | |
stride=1, | |
padding=1): | |
super().__init__() | |
if out_chans is None: | |
out_chans = in_chans | |
self.roll_out_convs = nn.Conv2d( | |
in_chans * 3, | |
out_chans, | |
kernel_size=kernel_size, | |
groups=3, # B 9C H W | |
stride=stride, | |
padding=padding) | |
# @torch.autocast(device_type='cuda') | |
def forward(self, x): | |
# todo, reshape before input? | |
B, C3, p, p = x.shape # B 3C H W | |
C = C3 // 3 | |
group_size = C3 // C | |
assert group_size == 3 | |
x = x.reshape(B, 3, C, p, p) | |
roll_out_x = torch.empty(B, group_size * C * 3, p, p, | |
device=x.device) # B, 3C, H, 3W | |
for i in range(group_size): | |
plane_xy = x[:, i] # B C H W | |
# # TODO, simply do the average pooling? | |
plane_yz_pooling = x[:, (i + 1) % group_size].mean( | |
dim=-1, keepdim=True).repeat_interleave( | |
p, dim=-1) # B C H W -> B C H 1 -> B C H W, reduce z dim | |
plane_zx_pooling = x[:, (i + 2) % group_size].mean( | |
dim=-2, keepdim=True).repeat_interleave( | |
p, dim=-2) # B C H W -> B C 1 W -> B C H W, reduce z dim | |
roll_out_x[:, i * 3 * C:(i + 1) * 3 * C] = torch.cat( | |
[plane_xy, plane_yz_pooling, plane_zx_pooling], | |
1) # fill in the 3W dim | |
# ! directly cat, avoid intermediate vars | |
# ? why OOM | |
# roll_out_x[:, i * 3 * C:(i + 1) * 3 * C] = torch.cat( | |
# [ | |
# x[:, i], | |
# x[:, (i + 1) % group_size].mean( | |
# dim=-1, keepdim=True).repeat_interleave(p, dim=-1), | |
# x[:, (i + 2) % group_size].mean( | |
# dim=-2, keepdim=True).repeat_interleave( | |
# p, dim=-2 | |
# ) # B C H W -> B C 1 W -> B C H W, reduce z dim | |
# ], | |
# 1) # fill in the 3C dim | |
x = self.roll_out_convs(roll_out_x) # B 3C H W | |
return x | |
class RodinRollOut_GroupConv_noConv3D(nn.Module): | |
"""only roll out and do Conv on individual planes | |
""" | |
def __init__(self, | |
in_chans, | |
out_chans=None, | |
kernel_size=3, | |
stride=1, | |
padding=1): | |
super().__init__() | |
if out_chans is None: | |
out_chans = in_chans | |
self.roll_out_inplane_conv = nn.Conv2d( | |
in_chans, | |
out_chans, | |
kernel_size=kernel_size, | |
groups=3, # B 3C H W | |
stride=stride, | |
padding=padding) | |
def forward(self, x): | |
x = self.roll_out_inplane_conv(x) # B 3C H W | |
return x | |
# class RodinConv3D_SynthesisLayer_withact(nn.Module): | |
# def __init__(self, in_chans, out_chans) -> None: | |
# super().__init__() | |
# self.act = nn.LeakyReLU(inplace=True) | |
# self.conv = nn.Sequential( | |
# RodinRollOutConv3D_GroupConv(in_chans, out_chans), | |
# nn.LeakyReLU(inplace=True), | |
# ) | |
# if in_chans != out_chans: | |
# self.short_cut = RodinRollOutConv3D_GroupConv(in_chans, out_chans) # PSNR 13 first iteration. | |
# else: | |
# self.short_cut = None | |
# def forward(self, feats): | |
# if self.short_cut is not None: | |
# res_feats = self.short_cut(feats) | |
# else: | |
# res_feats = feats | |
# # return res_feats + self.conv(feats) | |
# feats = res_feats + self.conv(feats) | |
# return self.act(feats) # as in resnet, add an act before return | |
class RodinConv3D_SynthesisLayer_mlp_unshuffle_as_residual(nn.Module): | |
def __init__(self, in_chans, out_chans) -> None: | |
super().__init__() | |
self.act = nn.LeakyReLU(inplace=True) | |
self.conv = nn.Sequential( | |
RodinRollOutConv3D_GroupConv(in_chans, out_chans), | |
nn.LeakyReLU(inplace=True), | |
) | |
self.out_chans = out_chans | |
if in_chans != out_chans: | |
# self.short_cut = RodinRollOutConv3D_GroupConv(in_chans, out_chans) # PSNR 13 first iteration. | |
self.short_cut = nn.Linear( # B 3C H W -> B 3C 4H 4W | |
in_chans // 3, # 144 / 3 = 48 | |
out_chans // 3 * 4 * 4, # 32 * 16 | |
bias=True) # decoder to pat | |
# RodinRollOutConv3D_GroupConv(in_chans, out_chans) # PSNR 13 first iteration. | |
else: | |
self.short_cut = None | |
def shortcut_unpatchify_triplane(self, | |
x, | |
p=None, | |
unpatchify_out_chans=None): | |
"""separate triplane version; x shape: B (3*257) 768 | |
""" | |
assert self.short_cut is not None | |
# B, L, C = x.shape | |
B, C3, h, w = x.shape | |
assert h == w | |
L = h * w | |
x = x.reshape(B, C3 // 3, 3, L).permute(0, 2, 3, | |
1) # (B, 3, L // 3, C) | |
x = self.short_cut(x) | |
p = h * 4 | |
x = x.reshape(shape=(B, 3, h, w, p, p, unpatchify_out_chans)) | |
x = torch.einsum('ndhwpqc->ndchpwq', | |
x) # nplanes, C order in the renderer.py | |
x = x.reshape(shape=(B, 3 * self.out_chans, h * p, h * p)) | |
return x | |
def forward(self, feats): | |
if self.short_cut is not None: | |
res_feats = self.shortcut_unpatchify_triplane(feats) | |
else: | |
res_feats = feats | |
# return res_feats + self.conv(feats) | |
feats = res_feats + self.conv(feats) | |
return self.act(feats) # as in resnet, add an act before return | |
# class RodinConv3D_SynthesisLayer(nn.Module): | |
# def __init__(self, in_chans, out_chans) -> None: | |
# super().__init__() | |
# self.act = nn.LeakyReLU(inplace=True) | |
# self.conv = nn.Sequential( | |
# RodinRollOutConv3D_GroupConv(in_chans, out_chans), | |
# nn.LeakyReLU(inplace=True), | |
# ) | |
# if in_chans != out_chans: | |
# self.short_cut = RodinRollOutConv3D_GroupConv(in_chans, out_chans) # PSNR 13 first iteration. | |
# else: | |
# self.short_cut = None | |
# def forward(self, feats): | |
# if self.short_cut is not None: | |
# res_feats = self.short_cut(feats) | |
# else: | |
# res_feats = feats | |
# # return res_feats + self.conv(feats) | |
# feats = res_feats + self.conv(feats) | |
# # return self.act(feats) # as in resnet, add an act before return | |
# return feats # ! old behaviour, no act | |
# previous worked version | |
class RodinConv3D_SynthesisLayer(nn.Module): | |
def __init__(self, in_chans, out_chans) -> None: | |
super().__init__() | |
# x2 SR + 1x1 Conv Residual BLK | |
# self.conv3D = RodinRollOutConv3D(in_chans, out_chans) | |
self.act = nn.LeakyReLU(inplace=True) | |
self.conv = nn.Sequential( | |
RodinRollOutConv3D_GroupConv(in_chans, out_chans), | |
nn.LeakyReLU(inplace=True), | |
) | |
if in_chans != out_chans: | |
self.short_cut = RodinRollOutConv3D_GroupConv(in_chans, out_chans) | |
else: | |
self.short_cut = None | |
def forward(self, feats): | |
feats_out = self.conv(feats) | |
if self.short_cut is not None: | |
# ! failed below | |
feats_out = self.short_cut( | |
feats | |
) + feats_out # ! only difference here, no act() compared with baseline | |
# feats_out = self.act(self.short_cut(feats)) + feats_out # ! only difference here, no act() compared with baseline | |
else: | |
feats_out = feats_out + feats | |
return feats_out | |
class RodinRollOutConv3DSR2X(nn.Module): | |
def __init__(self, in_chans, **kwargs) -> None: | |
super().__init__() | |
self.conv3D = RodinRollOutConv3D_GroupConv(in_chans) | |
# self.conv3D = RodinRollOutConv3D(in_chans) | |
self.act = nn.LeakyReLU(inplace=True) | |
self.input_resolution = 224 | |
def forward(self, x): | |
# x: B 3 112*112 C | |
B, C3, p, p = x.shape # after unpachify triplane | |
C = C3 // 3 | |
group_size = C3 // C | |
assert group_size == 3 | |
# p = int(N**0.5) # patch size | |
# assert p**2 == N, 'check input dim, no [cls] needed here' | |
assert group_size == 3, 'designed for triplane here' | |
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, | |
p) # B 3 C N -> B 3C h W | |
if x.shape[-1] != self.input_resolution: | |
x = torch.nn.functional.interpolate(x, | |
size=(self.input_resolution, | |
self.input_resolution), | |
mode='bilinear', | |
align_corners=False, | |
antialias=True) | |
x = x + self.conv3D(x) | |
return x | |
class RodinRollOutConv3DSR4X_lite(nn.Module): | |
def __init__(self, in_chans, input_resolutiopn=256, **kwargs) -> None: | |
super().__init__() | |
self.conv3D_0 = RodinRollOutConv3D_GroupConv(in_chans) | |
self.conv3D_1 = RodinRollOutConv3D_GroupConv(in_chans) | |
self.act = nn.LeakyReLU(inplace=True) | |
self.input_resolution = input_resolutiopn | |
def forward(self, x): | |
# x: B 3 112*112 C | |
B, C3, p, p = x.shape # after unpachify triplane | |
C = C3 // 3 | |
group_size = C3 // C | |
assert group_size == 3 | |
# p = int(N**0.5) # patch size | |
# assert p**2 == N, 'check input dim, no [cls] needed here' | |
assert group_size == 3, 'designed for triplane here' | |
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, | |
p) # B 3 C N -> B 3C h W | |
if x.shape[-1] != self.input_resolution: | |
x = torch.nn.functional.interpolate(x, | |
size=(self.input_resolution, | |
self.input_resolution), | |
mode='bilinear', | |
align_corners=False, | |
antialias=True) | |
# ! still not convering, not bug here? | |
# x = x + self.conv3D_0(x) | |
# x = x + self.conv3D_1(x) | |
x = x + self.act(self.conv3D_0(x)) | |
x = x + self.act(self.conv3D_1(x)) | |
# TODO: which is better, bilinear + conv or PixelUnshuffle? | |
return x | |
# class RodinConv3D2X_lite_mlp_as_residual(nn.Module): | |
# """lite 4X version, with MLP unshuffle to change the dimention | |
# """ | |
# def __init__(self, in_chans, out_chans, input_resolution=256) -> None: | |
# super().__init__() | |
# self.act = nn.LeakyReLU(inplace=True) | |
# self.conv3D_0 = RodinRollOutConv3D_GroupConv(in_chans, out_chans) | |
# self.conv3D_1 = RodinRollOutConv3D_GroupConv(out_chans, out_chans) | |
# self.act = nn.LeakyReLU(inplace=True) | |
# self.input_resolution = input_resolution | |
# self.out_chans = out_chans | |
# if in_chans != out_chans: # ! only change the dimension | |
# self.short_cut = nn.Linear( # B 3C H W -> B 3C 4H 4W | |
# in_chans//3, # 144 / 3 = 48 | |
# out_chans//3, # 32 * 16 | |
# bias=True) # decoder to pat | |
# else: | |
# self.short_cut = None | |
# def shortcut_unpatchify_triplane(self, x, p=None): | |
# """separate triplane version; x shape: B (3*257) 768 | |
# """ | |
# assert self.short_cut is not None | |
# # B, L, C = x.shape | |
# B, C3, h, w = x.shape | |
# assert h == w | |
# L = h*w | |
# x = x.reshape(B, C3//3, 3, L).permute(0,2,3,1) # (B, 3, L // 3, C_in) | |
# x = self.short_cut(x) # B 3 L//3 C_out | |
# x = x.permute(0,1,3,2) # B 3 C_out L//3 | |
# x = x.reshape(shape=(B, self.out_chans, h, w)) | |
# # directly resize to the target, no unpatchify here since no 3D ViT is included here | |
# if w != self.input_resolution: | |
# x = torch.nn.functional.interpolate(x, # 4X SR | |
# size=(self.input_resolution, | |
# self.input_resolution), | |
# mode='bilinear', | |
# align_corners=False, | |
# antialias=True) | |
# return x | |
# def forward(self, x): | |
# # x: B 3 112*112 C | |
# B, C3, p, p = x.shape # after unpachify triplane | |
# C = C3 // 3 | |
# if self.short_cut is not None: | |
# res_feats = self.shortcut_unpatchify_triplane(x) | |
# else: | |
# res_feats = x | |
# """following forward code copied from lite4x version | |
# """ | |
# x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, | |
# p) # B 3 C N -> B 3C h W | |
# if x.shape[-1] != self.input_resolution: | |
# x = torch.nn.functional.interpolate(x, # 4X SR | |
# size=(self.input_resolution, | |
# self.input_resolution), | |
# mode='bilinear', | |
# align_corners=False, | |
# antialias=True) | |
# x = res_feats + self.act(self.conv3D_0(x)) | |
# x = x + self.act(self.conv3D_1(x)) | |
# return x | |
class RodinConv3D4X_lite_mlp_as_residual(nn.Module): | |
"""lite 4X version, with MLP unshuffle to change the dimention | |
""" | |
def __init__(self, | |
in_chans, | |
out_chans, | |
input_resolution=256, | |
interp_mode='bilinear', | |
bcg_triplane=False) -> None: | |
super().__init__() | |
self.interp_mode = interp_mode | |
self.act = nn.LeakyReLU(inplace=True) | |
self.conv3D_0 = RodinRollOutConv3D_GroupConv(in_chans, out_chans) | |
self.conv3D_1 = RodinRollOutConv3D_GroupConv(out_chans, out_chans) | |
self.bcg_triplane = bcg_triplane | |
if bcg_triplane: | |
self.conv3D_1_bg = RodinRollOutConv3D_GroupConv( | |
out_chans, out_chans) | |
self.act = nn.LeakyReLU(inplace=True) | |
self.input_resolution = input_resolution | |
self.out_chans = out_chans | |
if in_chans != out_chans: # ! only change the dimension | |
self.short_cut = nn.Linear( # B 3C H W -> B 3C 4H 4W | |
in_chans // 3, # 144 / 3 = 48 | |
out_chans // 3, # 32 * 16 | |
bias=True) # decoder to pat | |
else: | |
self.short_cut = None | |
def shortcut_unpatchify_triplane(self, x, p=None): | |
"""separate triplane version; x shape: B (3*257) 768 | |
""" | |
assert self.short_cut is not None | |
B, C3, h, w = x.shape | |
assert h == w | |
L = h * w | |
x = x.reshape(B, C3 // 3, 3, L).permute(0, 2, 3, | |
1) # (B, 3, L // 3, C_in) | |
x = self.short_cut(x) # B 3 L//3 C_out | |
x = x.permute(0, 1, 3, 2) # B 3 C_out L//3 | |
x = x.reshape(shape=(B, self.out_chans, h, w)) | |
# directly resize to the target, no unpatchify here since no 3D ViT is included here | |
if w != self.input_resolution: | |
x = torch.nn.functional.interpolate( | |
x, # 4X SR | |
size=(self.input_resolution, self.input_resolution), | |
mode='bilinear', | |
align_corners=False, | |
antialias=True) | |
return x | |
def interpolate(self, feats): | |
if self.interp_mode == 'bilinear': | |
return torch.nn.functional.interpolate( | |
feats, # 4X SR | |
size=(self.input_resolution, self.input_resolution), | |
mode='bilinear', | |
align_corners=False, | |
antialias=True) | |
else: | |
return torch.nn.functional.interpolate( | |
feats, # 4X SR | |
size=(self.input_resolution, self.input_resolution), | |
mode='nearest', | |
) | |
def forward(self, x): | |
# x: B 3 112*112 C | |
B, C3, p, p = x.shape # after unpachify triplane | |
C = C3 // 3 | |
if self.short_cut is not None: | |
res_feats = self.shortcut_unpatchify_triplane(x) | |
else: | |
res_feats = x | |
if res_feats.shape[-1] != self.input_resolution: | |
res_feats = self.interpolate(res_feats) | |
"""following forward code copied from lite4x version | |
""" | |
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, | |
p) # B 3 C N -> B 3C h W | |
if x.shape[-1] != self.input_resolution: | |
x = self.interpolate(x) | |
x0 = res_feats + self.act(self.conv3D_0(x)) # the base feature | |
x = x0 + self.act(self.conv3D_1(x0)) | |
if self.bcg_triplane: | |
x_bcg = x0 + self.act(self.conv3D_1_bg(x0)) | |
return torch.cat([x, x_bcg], 1) | |
else: | |
return x | |
class RodinConv3D4X_lite_mlp_as_residual_litev2( | |
RodinConv3D4X_lite_mlp_as_residual): | |
def __init__(self, | |
in_chans, | |
out_chans, | |
num_feat=128, | |
input_resolution=256, | |
interp_mode='bilinear', | |
bcg_triplane=False) -> None: | |
super().__init__(in_chans, out_chans, input_resolution, interp_mode, | |
bcg_triplane) | |
self.conv3D_0 = RodinRollOutConv3D_GroupConv(in_chans, in_chans) | |
self.conv_before_upsample = RodinRollOut_GroupConv_noConv3D( | |
in_chans, num_feat * 3) | |
self.conv3D_1 = RodinRollOut_GroupConv_noConv3D( | |
num_feat * 3, num_feat * 3) | |
self.conv_last = RodinRollOut_GroupConv_noConv3D( | |
num_feat * 3, out_chans) | |
self.short_cut = None | |
def forward(self, x): | |
# x: B 3 112*112 C | |
B, C3, p, p = x.shape # after unpachify triplane | |
C = C3 // 3 | |
# if self.short_cut is not None: | |
# res_feats = self.shortcut_unpatchify_triplane(x) | |
# else: | |
# res_feats = x | |
# if res_feats.shape[-1] != self.input_resolution: | |
# res_feats = self.interpolate(res_feats) | |
"""following forward code copied from lite4x version | |
""" | |
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, | |
p) # B 3 C N -> B 3C h W | |
x = x + self.conv3D_0(x) # the base feature | |
x = self.act(self.conv_before_upsample(x)) | |
# if x.shape[-1] != self.input_resolution: | |
x = self.conv_last(self.act(self.conv3D_1(self.interpolate(x)))) | |
return x | |
class RodinConv3D4X_lite_mlp_as_residual_lite( | |
RodinConv3D4X_lite_mlp_as_residual): | |
def __init__(self, | |
in_chans, | |
out_chans, | |
input_resolution=256, | |
interp_mode='bilinear') -> None: | |
super().__init__(in_chans, out_chans, input_resolution, interp_mode) | |
"""replace the first Rodin Conv 3D with ordinary rollout conv to save memory | |
""" | |
self.conv3D_0 = RodinRollOut_GroupConv_noConv3D(in_chans, out_chans) | |
class SR3D(nn.Module): | |
# https://github.com/SeanChenxy/Mimic3D/blob/77d313656df3cd5536d2c4c5766db3a56208eea6/training/networks_stylegan2.py#L629 | |
# roll-out and apply two deconv/pixelUnshuffle layer | |
def __init__(self, *args, **kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
class RodinConv3D4X_lite_mlp_as_residual_improved(nn.Module): | |
def __init__(self, | |
in_chans, | |
num_feat, | |
out_chans, | |
input_resolution=256) -> None: | |
super().__init__() | |
assert in_chans == 4 * out_chans | |
assert num_feat == 2 * out_chans | |
self.input_resolution = input_resolution | |
# refer to https://github.com/JingyunLiang/SwinIR/blob/6545850fbf8df298df73d81f3e8cba638787c8bd/models/network_swinir.py#L750 | |
self.upscale = 4 | |
self.conv_after_body = RodinRollOutConv3D_GroupConv( | |
in_chans, in_chans, 3, 1, 1) | |
self.conv_before_upsample = nn.Sequential( | |
RodinRollOutConv3D_GroupConv(in_chans, num_feat, 3, 1, 1), | |
nn.LeakyReLU(inplace=True)) | |
self.conv_up1 = RodinRollOutConv3D_GroupConv(num_feat, num_feat, 3, 1, | |
1) | |
if self.upscale == 4: | |
self.conv_up2 = RodinRollOutConv3D_GroupConv( | |
num_feat, num_feat, 3, 1, 1) | |
self.conv_hr = RodinRollOutConv3D_GroupConv(num_feat, num_feat, 3, 1, | |
1) | |
self.conv_last = RodinRollOutConv3D_GroupConv(num_feat, out_chans, 3, | |
1, 1) | |
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
def forward(self, x): | |
# x: B 3 112*112 C | |
B, C3, p, p = x.shape # after unpachify triplane | |
C = C3 // 3 | |
"""following forward code copied from lite4x version | |
""" | |
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, | |
p) # B 3 C N -> B 3C h W | |
# ? nearest or bilinear | |
x = self.conv_after_body(x) + x | |
x = self.conv_before_upsample(x) | |
x = self.lrelu( | |
self.conv_up1( | |
torch.nn.functional.interpolate( | |
x, | |
scale_factor=2, | |
mode='nearest', | |
# align_corners=False, | |
# antialias=True | |
))) | |
if self.upscale == 4: | |
x = self.lrelu( | |
self.conv_up2( | |
torch.nn.functional.interpolate( | |
x, | |
scale_factor=2, | |
mode='nearest', | |
# align_corners=False, | |
# antialias=True | |
))) | |
x = self.conv_last(self.lrelu(self.conv_hr(x))) | |
assert x.shape[-1] == self.input_resolution | |
return x | |
class RodinConv3D4X_lite_improved_lint_withresidual(nn.Module): | |
def __init__(self, | |
in_chans, | |
num_feat, | |
out_chans, | |
input_resolution=256) -> None: | |
super().__init__() | |
assert in_chans == 4 * out_chans | |
assert num_feat == 2 * out_chans | |
self.input_resolution = input_resolution | |
# refer to https://github.com/JingyunLiang/SwinIR/blob/6545850fbf8df298df73d81f3e8cba638787c8bd/models/network_swinir.py#L750 | |
self.upscale = 4 | |
self.conv_after_body = RodinRollOutConv3D_GroupConv( | |
in_chans, in_chans, 3, 1, 1) | |
self.conv_before_upsample = nn.Sequential( | |
RodinRollOutConv3D_GroupConv(in_chans, num_feat, 3, 1, 1), | |
nn.LeakyReLU(inplace=True)) | |
self.conv_up1 = RodinRollOutConv3D_GroupConv(num_feat, num_feat, 3, 1, | |
1) | |
if self.upscale == 4: | |
self.conv_up2 = RodinRollOutConv3D_GroupConv( | |
num_feat, num_feat, 3, 1, 1) | |
self.conv_hr = RodinRollOutConv3D_GroupConv(num_feat, num_feat, 3, 1, | |
1) | |
self.conv_last = RodinRollOutConv3D_GroupConv(num_feat, out_chans, 3, | |
1, 1) | |
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
def forward(self, x): | |
# x: B 3 112*112 C | |
B, C3, p, p = x.shape # after unpachify triplane | |
C = C3 // 3 | |
"""following forward code copied from lite4x version | |
""" | |
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, | |
p) # B 3 C N -> B 3C h W | |
# ? nearest or bilinear | |
x = self.conv_after_body(x) + x | |
x = self.conv_before_upsample(x) | |
x = self.lrelu( | |
self.conv_up1( | |
torch.nn.functional.interpolate( | |
x, | |
scale_factor=2, | |
mode='nearest', | |
# align_corners=False, | |
# antialias=True | |
))) | |
if self.upscale == 4: | |
x = self.lrelu( | |
self.conv_up2( | |
torch.nn.functional.interpolate( | |
x, | |
scale_factor=2, | |
mode='nearest', | |
# align_corners=False, | |
# antialias=True | |
))) | |
x = self.conv_last(self.lrelu(self.conv_hr(x) + x)) | |
assert x.shape[-1] == self.input_resolution | |
return x | |
class RodinRollOutConv3DSR_FlexibleChannels(nn.Module): | |
def __init__(self, | |
in_chans, | |
num_out_ch=96, | |
input_resolution=256, | |
**kwargs) -> None: | |
super().__init__() | |
self.block0 = RodinConv3D_SynthesisLayer(in_chans, | |
num_out_ch) # in_chans=48 | |
self.block1 = RodinConv3D_SynthesisLayer(num_out_ch, num_out_ch) | |
self.input_resolution = input_resolution # 64 -> 256 SR | |
def forward(self, x): | |
# x: B 3 112*112 C | |
B, C3, p, p = x.shape # after unpachify triplane | |
C = C3 // 3 | |
# group_size = C3 // C | |
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, | |
p) # B 3 C N -> B 3C h W | |
if x.shape[-1] != self.input_resolution: | |
x = torch.nn.functional.interpolate(x, | |
size=(self.input_resolution, | |
self.input_resolution), | |
mode='bilinear', | |
align_corners=False, | |
antialias=True) | |
x = self.block0(x) | |
x = self.block1(x) | |
return x | |
# previous worked version | |
class RodinRollOutConv3DSR4X(nn.Module): | |
# follow PixelUnshuffleUpsample | |
def __init__(self, in_chans, **kwargs) -> None: | |
super().__init__() | |
# self.block0 = RodinConv3D_SynthesisLayer(in_chans, 96 * 2) # TODO, match the old behaviour now. | |
# self.block1 = RodinConv3D_SynthesisLayer(96 * 2, 96) | |
self.block0 = RodinConv3D_SynthesisLayer(in_chans, 96) | |
self.block1 = RodinConv3D_SynthesisLayer( | |
96, 96) # baseline choice, validate with no LPIPS loss here | |
self.input_resolution = 64 # 64 -> 256 | |
def forward(self, x): | |
# x: B 3 112*112 C | |
B, C3, p, p = x.shape # after unpachify triplane | |
C = C3 // 3 | |
# group_size = C3 // C | |
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, | |
p) # B 3 C N -> B 3C h W | |
if x.shape[-1] != self.input_resolution: | |
x = torch.nn.functional.interpolate(x, | |
size=(self.input_resolution, | |
self.input_resolution), | |
mode='bilinear', | |
align_corners=False, | |
antialias=True) | |
x = self.block0(x) | |
x = self.block1(x) | |
return x | |
class Upsample3D(nn.Module): | |
"""Upsample module. | |
Args: | |
scale (int): Scale factor. Supported scales: 2^n and 3. | |
num_feat (int): Channel number of intermediate features. | |
""" | |
def __init__(self, scale, num_feat): | |
super().__init__() | |
m_convs = [] | |
m_pixelshuffle = [] | |
assert (scale & (scale - 1)) == 0, 'scale = 2^n' | |
self.scale = scale | |
for _ in range(int(math.log(scale, 2))): | |
m_convs.append( | |
RodinRollOutConv3D_GroupConv(num_feat, 4 * num_feat, 3, 1, 1)) | |
m_pixelshuffle.append(nn.PixelShuffle(2)) | |
self.m_convs = nn.ModuleList(m_convs) | |
self.m_pixelshuffle = nn.ModuleList(m_pixelshuffle) | |
# @torch.autocast(device_type='cuda') | |
def forward(self, x): | |
for scale_idx in range(int(math.log(self.scale, 2))): | |
x = self.m_convs[scale_idx](x) # B 3C H W | |
# x = | |
# B, C3, H, W = x.shape | |
x = x.reshape(x.shape[0] * 3, x.shape[1] // 3, *x.shape[2:]) | |
x = self.m_pixelshuffle[scale_idx](x) | |
x = x.reshape(x.shape[0] // 3, x.shape[1] * 3, *x.shape[2:]) | |
return x | |
class RodinConv3DPixelUnshuffleUpsample(nn.Module): | |
def __init__(self, | |
output_dim, | |
num_feat=32 * 6, | |
num_out_ch=32 * 3, | |
sr_ratio=4, | |
*args, | |
**kwargs) -> None: | |
super().__init__() | |
self.conv_after_body = RodinRollOutConv3D_GroupConv( | |
output_dim, output_dim, 3, 1, 1) | |
self.conv_before_upsample = nn.Sequential( | |
RodinRollOutConv3D_GroupConv(output_dim, num_feat, 3, 1, 1), | |
nn.LeakyReLU(inplace=True)) | |
self.upsample = Upsample3D(sr_ratio, num_feat) # 4 time SR | |
self.conv_last = RodinRollOutConv3D_GroupConv(num_feat, num_out_ch, 3, | |
1, 1) | |
# @torch.autocast(device_type='cuda') | |
def forward(self, x, input_skip_connection=True, *args, **kwargs): | |
# x = self.conv_first(x) | |
if input_skip_connection: | |
x = self.conv_after_body(x) + x | |
else: | |
x = self.conv_after_body(x) | |
x = self.conv_before_upsample(x) | |
x = self.upsample(x) | |
x = self.conv_last(x) | |
return x | |
class RodinConv3DPixelUnshuffleUpsample_improvedVersion(nn.Module): | |
def __init__( | |
self, | |
output_dim, | |
num_out_ch=32 * 3, | |
sr_ratio=4, | |
input_resolution=256, | |
) -> None: | |
super().__init__() | |
self.input_resolution = input_resolution | |
# self.conv_first = RodinRollOutConv3D_GroupConv(output_dim, num_out_ch, | |
# 3, 1, 1) | |
self.upsample = Upsample3D(sr_ratio, output_dim) # 4 time SR | |
self.conv_last = RodinRollOutConv3D_GroupConv(output_dim, num_out_ch, | |
3, 1, 1) | |
def forward(self, x, bilinear_upsample=True): | |
B, C3, p, p = x.shape # after unpachify triplane | |
C = C3 // 3 | |
group_size = C3 // C | |
assert group_size == 3, 'designed for triplane here' | |
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, | |
p) # B 3 C N -> B 3C h W | |
if bilinear_upsample and x.shape[-1] != self.input_resolution: | |
x_bilinear_upsample = torch.nn.functional.interpolate( | |
x, | |
size=(self.input_resolution, self.input_resolution), | |
mode='bilinear', | |
align_corners=False, | |
antialias=True) | |
x = self.upsample(x) + x_bilinear_upsample | |
else: | |
# x_bilinear_upsample = x | |
x = self.upsample(x) | |
x = self.conv_last(x) | |
return x | |
class RodinConv3DPixelUnshuffleUpsample_improvedVersion2(nn.Module): | |
"""removed nearest neighbour residual conenctions, add a conv layer residual conenction | |
""" | |
def __init__( | |
self, | |
output_dim, | |
num_out_ch=32 * 3, | |
sr_ratio=4, | |
input_resolution=256, | |
) -> None: | |
super().__init__() | |
self.input_resolution = input_resolution | |
self.conv_after_body = RodinRollOutConv3D_GroupConv( | |
output_dim, num_out_ch, 3, 1, 1) | |
self.upsample = Upsample3D(sr_ratio, output_dim) # 4 time SR | |
self.conv_last = RodinRollOutConv3D_GroupConv(output_dim, num_out_ch, | |
3, 1, 1) | |
def forward(self, x, input_skip_connection=True): | |
B, C3, p, p = x.shape # after unpachify triplane | |
C = C3 // 3 | |
group_size = C3 // C | |
assert group_size == 3, 'designed for triplane here' | |
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, | |
p) # B 3 C N -> B 3C h W | |
if input_skip_connection: | |
x = self.conv_after_body(x) + x | |
else: | |
x = self.conv_after_body(x) | |
x = self.upsample(x) | |
x = self.conv_last(x) | |
return x | |
class CLSCrossAttentionBlock(nn.Module): | |
def __init__(self, | |
dim, | |
num_heads, | |
mlp_ratio=4., | |
qkv_bias=False, | |
qk_scale=None, | |
drop=0., | |
attn_drop=0., | |
drop_path=0., | |
act_layer=nn.GELU, | |
norm_layer=nn.LayerNorm, | |
has_mlp=False): | |
super().__init__() | |
self.norm1 = norm_layer(dim) | |
self.attn = CrossAttention(dim, | |
num_heads=num_heads, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
attn_drop=attn_drop, | |
proj_drop=drop) | |
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | |
self.drop_path = DropPath( | |
drop_path) if drop_path > 0. else nn.Identity() | |
self.has_mlp = has_mlp | |
if has_mlp: | |
self.norm2 = norm_layer(dim) | |
mlp_hidden_dim = int(dim * mlp_ratio) | |
self.mlp = Mlp(in_features=dim, | |
hidden_features=mlp_hidden_dim, | |
act_layer=act_layer, | |
drop=drop) | |
def forward(self, x): | |
x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x))) | |
if self.has_mlp: | |
x = x + self.drop_path(self.mlp(self.norm2(x))) | |
return x | |
class Conv3DCrossAttentionBlock(nn.Module): | |
def __init__(self, | |
dim, | |
num_heads, | |
mlp_ratio=4., | |
qkv_bias=False, | |
qk_scale=None, | |
drop=0., | |
attn_drop=0., | |
drop_path=0., | |
act_layer=nn.GELU, | |
norm_layer=nn.LayerNorm, | |
has_mlp=False): | |
super().__init__() | |
self.norm1 = norm_layer(dim) | |
self.attn = Conv3D_Aware_CrossAttention(dim, | |
num_heads=num_heads, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
attn_drop=attn_drop, | |
proj_drop=drop) | |
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | |
self.drop_path = DropPath( | |
drop_path) if drop_path > 0. else nn.Identity() | |
self.has_mlp = has_mlp | |
if has_mlp: | |
self.norm2 = norm_layer(dim) | |
mlp_hidden_dim = int(dim * mlp_ratio) | |
self.mlp = Mlp(in_features=dim, | |
hidden_features=mlp_hidden_dim, | |
act_layer=act_layer, | |
drop=drop) | |
def forward(self, x): | |
x = x + self.drop_path(self.attn(self.norm1(x))) | |
if self.has_mlp: | |
x = x + self.drop_path(self.mlp(self.norm2(x))) | |
return x | |
class Conv3DCrossAttentionBlockXformerMHA(Conv3DCrossAttentionBlock): | |
def __init__(self, | |
dim, | |
num_heads, | |
mlp_ratio=4, | |
qkv_bias=False, | |
qk_scale=None, | |
drop=0, | |
attn_drop=0, | |
drop_path=0, | |
act_layer=nn.GELU, | |
norm_layer=nn.LayerNorm, | |
has_mlp=False): | |
super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, | |
attn_drop, drop_path, act_layer, norm_layer, has_mlp) | |
# self.attn = xformer_Conv3D_Aware_CrossAttention(dim, | |
self.attn = xformer_Conv3D_Aware_CrossAttention_xygrid( | |
dim, | |
num_heads=num_heads, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
attn_drop=attn_drop, | |
proj_drop=drop) | |
class Conv3DCrossAttentionBlockXformerMHANested( | |
Conv3DCrossAttentionBlockXformerMHA): | |
def __init__(self, | |
dim, | |
num_heads, | |
mlp_ratio=4, | |
qkv_bias=False, | |
qk_scale=None, | |
drop=0., | |
attn_drop=0., | |
drop_path=0., | |
act_layer=nn.GELU, | |
norm_layer=nn.LayerNorm, | |
has_mlp=False): | |
super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, | |
attn_drop, drop_path, act_layer, norm_layer, has_mlp) | |
"""for in-place replaing the internal attn in Dino ViT. | |
""" | |
def forward(self, x): | |
Bx3, N, C = x.shape | |
B, group_size = Bx3 // 3, 3 | |
x = x.reshape(B, group_size, N, C) # in plane vit | |
x = super().forward(x) | |
return x.reshape(B * group_size, N, | |
C) # to match the original attn size | |
class Conv3DCrossAttentionBlockXformerMHANested_withinC( | |
Conv3DCrossAttentionBlockXformerMHANested): | |
def __init__(self, | |
dim, | |
num_heads, | |
mlp_ratio=4, | |
qkv_bias=False, | |
qk_scale=None, | |
drop=0, | |
attn_drop=0, | |
drop_path=0, | |
act_layer=nn.GELU, | |
norm_layer=nn.LayerNorm, | |
has_mlp=False): | |
super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, | |
attn_drop, drop_path, act_layer, norm_layer, has_mlp) | |
self.attn = xformer_Conv3D_Aware_CrossAttention_xygrid_withinC( | |
dim, | |
num_heads=num_heads, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
attn_drop=attn_drop, | |
proj_drop=drop) | |
def forward(self, x): | |
# basic TX attention forward function | |
x = x + self.drop_path(self.attn(self.norm1(x))) | |
if self.has_mlp: | |
x = x + self.drop_path(self.mlp(self.norm2(x))) | |
return x | |
class TriplaneFusionBlock(nn.Module): | |
"""4 ViT blocks + 1 CrossAttentionBlock | |
""" | |
def __init__(self, | |
vit_blks, | |
num_heads, | |
embed_dim, | |
use_fusion_blk=True, | |
cross_attention_blk=CLSCrossAttentionBlock, | |
*args, | |
**kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
self.num_branches = 3 # triplane | |
self.vit_blks = vit_blks | |
if use_fusion_blk: | |
self.fusion = nn.ModuleList() | |
# copied vit settings from https://github.dev/facebookresearch/dinov2 | |
nh = num_heads | |
dim = embed_dim | |
mlp_ratio = 4 # defined for all dino2 model | |
qkv_bias = True | |
norm_layer = partial(nn.LayerNorm, eps=1e-6) | |
drop_path_rate = 0.3 # default setting | |
attn_drop = proj_drop = 0.0 | |
qk_scale = None # TODO, double check | |
for d in range(self.num_branches): | |
self.fusion.append( | |
cross_attention_blk( | |
dim=dim, | |
num_heads=nh, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
# drop=drop, | |
drop=proj_drop, | |
attn_drop=attn_drop, | |
drop_path=drop_path_rate, | |
norm_layer=norm_layer, # type: ignore | |
has_mlp=False)) | |
else: | |
self.fusion = None | |
def forward(self, x): | |
# modified from https://github.com/IBM/CrossViT/blob/main/models/crossvit.py#L132 | |
"""x: B 3 N C, where N = H*W tokens | |
""" | |
# self attention, by merging the triplane channel into B for parallel computation | |
# ! move the below to the front of the first call | |
B, group_size, N, C = x.shape # has [cls] token in N | |
assert group_size == 3, 'triplane' | |
x = x.view(B * group_size, N, C) | |
for blk in self.vit_blks: | |
x = blk(x) # B 3 N C | |
if self.fusion is None: | |
return x.view(B, group_size, N, C) | |
# outs_b = x.view(B, group_size, N, | |
# C).chunk(chunks=3, | |
# dim=1) # 3 * [B, 1, N//3, C] Tensors, for fusion | |
outs_b = x.chunk(chunks=3, | |
dim=0) # 3 * [B, N//3, C] Tensors, for fusion | |
# only take the cls token out | |
proj_cls_token = [x[:, 0:1] for x in outs_b] | |
# cross attention | |
outs = [] | |
for i in range(self.num_branches): | |
tmp = torch.cat( | |
(proj_cls_token[i], outs_b[(i + 1) % self.num_branches][:, 1:, | |
...]), | |
dim=1) | |
tmp = self.fusion[i](tmp) | |
# reverted_proj_cls_token = self.revert_projs[i](tmp[:, 0:1, ...]) | |
reverted_proj_cls_token = tmp[:, 0:1, ...] | |
tmp = torch.cat((reverted_proj_cls_token, outs_b[i][:, 1:, ...]), | |
dim=1) | |
outs.append(tmp) | |
# outs = ? needs to merge back? | |
outs = torch.stack(outs, 1) # B 3 N C | |
return outs | |
class TriplaneFusionBlockv2(nn.Module): | |
"""4 ViT blocks + 1 CrossAttentionBlock | |
""" | |
def __init__(self, | |
vit_blks, | |
num_heads, | |
embed_dim, | |
use_fusion_blk=True, | |
fusion_ca_blk=Conv3DCrossAttentionBlock, | |
*args, | |
**kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
self.num_branches = 3 # triplane | |
self.vit_blks = vit_blks | |
if use_fusion_blk: | |
# self.fusion = nn.ModuleList() | |
# copied vit settings from https://github.dev/facebookresearch/dinov2 | |
nh = num_heads | |
dim = embed_dim | |
mlp_ratio = 4 # defined for all dino2 model | |
qkv_bias = True | |
norm_layer = partial(nn.LayerNorm, eps=1e-6) | |
drop_path_rate = 0.3 # default setting | |
attn_drop = proj_drop = 0.0 | |
qk_scale = None # TODO, double check | |
# for d in range(self.num_branches): | |
self.fusion = fusion_ca_blk( # one fusion is enough | |
dim=dim, | |
num_heads=nh, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
# drop=drop, | |
drop=proj_drop, | |
attn_drop=attn_drop, | |
drop_path=drop_path_rate, | |
norm_layer=norm_layer, # type: ignore | |
has_mlp=False) | |
else: | |
self.fusion = None | |
def forward(self, x): | |
# modified from https://github.com/IBM/CrossViT/blob/main/models/crossvit.py#L132 | |
"""x: B 3 N C, where N = H*W tokens | |
""" | |
# self attention, by merging the triplane channel into B for parallel computation | |
# ! move the below to the front of the first call | |
B, group_size, N, C = x.shape # has [cls] token in N | |
assert group_size == 3, 'triplane' | |
x = x.reshape(B * group_size, N, C) | |
for blk in self.vit_blks: | |
x = blk(x) # B 3 N C | |
if self.fusion is None: | |
return x.reshape(B, group_size, N, C) | |
x = x.reshape(B, group_size, N, C) # .chunk(chunks=3, | |
# dim=1) # 3 * [B, N//3, C] Tensors, for fusion | |
return self.fusion(x) | |
class TriplaneFusionBlockv3(TriplaneFusionBlockv2): | |
def __init__(self, | |
vit_blks, | |
num_heads, | |
embed_dim, | |
use_fusion_blk=True, | |
fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHA, | |
*args, | |
**kwargs) -> None: | |
super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, | |
fusion_ca_blk, *args, **kwargs) | |
class TriplaneFusionBlockv4(TriplaneFusionBlockv3): | |
def __init__(self, | |
vit_blks, | |
num_heads, | |
embed_dim, | |
use_fusion_blk=True, | |
fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHA, | |
*args, | |
**kwargs) -> None: | |
super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, | |
fusion_ca_blk, *args, **kwargs) | |
"""OOM? directly replace the atten here | |
""" | |
assert len(vit_blks) == 2 | |
# del self.vit_blks[1].attn | |
del self.vit_blks[1].attn, self.vit_blks[1].ls1, self.vit_blks[1].norm1 | |
def ffn_residual_func(self, tx_blk, x: Tensor) -> Tensor: | |
return tx_blk.ls2( | |
tx_blk.mlp(tx_blk.norm2(x)) | |
) # https://github.com/facebookresearch/dinov2/blob/c3c2683a13cde94d4d99f523cf4170384b00c34c/dinov2/layers/block.py#L86C1-L87C53 | |
def forward(self, x): | |
"""x: B 3 N C, where N = H*W tokens | |
""" | |
assert self.fusion is not None | |
B, group_size, N, C = x.shape # has [cls] token in N | |
x = x.reshape(B * group_size, N, C) # in plane vit | |
# in plane self attention | |
x = self.vit_blks[0](x) | |
# 3D cross attention blk + ffn | |
x = x + self.fusion(x.reshape(B, group_size, N, C)).reshape( | |
B * group_size, N, C) | |
x = x + self.ffn_residual_func(self.vit_blks[1], x) | |
return x.reshape(B, group_size, N, C) | |
class TriplaneFusionBlockv4_nested(nn.Module): | |
def __init__(self, | |
vit_blks, | |
num_heads, | |
embed_dim, | |
use_fusion_blk=True, | |
fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, | |
*args, | |
**kwargs) -> None: | |
super().__init__() | |
self.num_branches = 3 # triplane | |
self.vit_blks = vit_blks | |
assert use_fusion_blk | |
assert len(vit_blks) == 2 | |
# ! replace vit_blks[1] attn layer with 3D aware attention | |
del self.vit_blks[ | |
1].attn # , self.vit_blks[1].ls1, self.vit_blks[1].norm1 | |
# copied vit settings from https://github.dev/facebookresearch/dinov2 | |
nh = num_heads | |
dim = embed_dim | |
mlp_ratio = 4 # defined for all dino2 model | |
qkv_bias = True | |
norm_layer = partial(nn.LayerNorm, eps=1e-6) | |
drop_path_rate = 0.3 # default setting | |
attn_drop = proj_drop = 0.0 | |
qk_scale = None # TODO, double check | |
self.vit_blks[1].attn = fusion_ca_blk( # one fusion is enough | |
dim=dim, | |
num_heads=nh, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
# drop=drop, | |
drop=proj_drop, | |
attn_drop=attn_drop, | |
drop_path=drop_path_rate, | |
norm_layer=norm_layer, # type: ignore | |
has_mlp=False) | |
def forward(self, x): | |
"""x: B 3 N C, where N = H*W tokens | |
""" | |
# self attention, by merging the triplane channel into B for parallel computation | |
# ! move the below to the front of the first call | |
B, group_size, N, C = x.shape # has [cls] token in N | |
assert group_size == 3, 'triplane' | |
x = x.reshape(B * group_size, N, C) | |
for blk in self.vit_blks: | |
x = blk(x) # B 3 N C | |
# TODO, avoid the reshape overhead? | |
return x.reshape(B, group_size, N, C) | |
class TriplaneFusionBlockv4_nested_init_from_dino(nn.Module): | |
def __init__(self, | |
vit_blks, | |
num_heads, | |
embed_dim, | |
use_fusion_blk=True, | |
fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, | |
init_from_dino=True, | |
*args, | |
**kwargs) -> None: | |
super().__init__() | |
self.num_branches = 3 # triplane | |
self.vit_blks = vit_blks | |
assert use_fusion_blk | |
assert len(vit_blks) == 2 | |
# copied vit settings from https://github.dev/facebookresearch/dinov2 | |
nh = num_heads | |
dim = embed_dim | |
mlp_ratio = 4 # defined for all dino2 model | |
qkv_bias = True | |
norm_layer = partial(nn.LayerNorm, eps=1e-6) | |
drop_path_rate = 0.3 # default setting | |
attn_drop = proj_drop = 0.0 | |
qk_scale = None # TODO, double check | |
attn_3d = fusion_ca_blk( # one fusion is enough | |
dim=dim, | |
num_heads=nh, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
# drop=drop, | |
drop=proj_drop, | |
attn_drop=attn_drop, | |
drop_path=drop_path_rate, | |
norm_layer=norm_layer, # type: ignore | |
has_mlp=False) | |
# ! initialize 3dattn from dino attn | |
if init_from_dino: | |
merged_qkv_linear = self.vit_blks[1].attn.qkv | |
attn_3d.attn.proj.load_state_dict( | |
self.vit_blks[1].attn.proj.state_dict()) | |
# Initialize the Q, K, and V linear layers using the weights of the merged QKV linear layer | |
attn_3d.attn.wq.weight.data = merged_qkv_linear.weight.data[: | |
dim, :] | |
attn_3d.attn.w_kv.weight.data = merged_qkv_linear.weight.data[ | |
dim:, :] | |
# Optionally, you can initialize the biases as well (if your QKV linear layer has biases) | |
if qkv_bias: | |
attn_3d.attn.wq.bias.data = merged_qkv_linear.bias.data[:dim] | |
attn_3d.attn.w_kv.bias.data = merged_qkv_linear.bias.data[dim:] | |
del self.vit_blks[1].attn | |
# ! assign | |
self.vit_blks[1].attn = attn_3d | |
def forward(self, x): | |
"""x: B 3 N C, where N = H*W tokens | |
""" | |
# self attention, by merging the triplane channel into B for parallel computation | |
# ! move the below to the front of the first call | |
B, group_size, N, C = x.shape # has [cls] token in N | |
assert group_size == 3, 'triplane' | |
x = x.reshape(B * group_size, N, C) | |
for blk in self.vit_blks: | |
x = blk(x) # B 3 N C | |
# TODO, avoid the reshape overhead? | |
return x.reshape(B, group_size, N, C) | |
class TriplaneFusionBlockv4_nested_init_from_dino_lite(nn.Module): | |
def __init__(self, | |
vit_blks, | |
num_heads, | |
embed_dim, | |
use_fusion_blk=True, | |
fusion_ca_blk=None, | |
*args, | |
**kwargs) -> None: | |
super().__init__() | |
self.num_branches = 3 # triplane | |
self.vit_blks = vit_blks | |
assert use_fusion_blk | |
assert len(vit_blks) == 2 | |
# copied vit settings from https://github.dev/facebookresearch/dinov2 | |
nh = num_heads | |
dim = embed_dim | |
mlp_ratio = 4 # defined for all dino2 model | |
qkv_bias = True | |
norm_layer = partial(nn.LayerNorm, eps=1e-6) | |
drop_path_rate = 0.3 # default setting | |
attn_drop = proj_drop = 0.0 | |
qk_scale = None # TODO, double check | |
attn_3d = xformer_Conv3D_Aware_CrossAttention_xygrid_withinC( # ! raw 3D attn layer | |
dim, | |
num_heads=num_heads, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
attn_drop=attn_drop, | |
proj_drop=proj_drop) | |
del self.vit_blks[1].attn | |
# ! assign | |
self.vit_blks[1].attn = attn_3d | |
def forward(self, x): | |
"""x: B N C, where N = H*W tokens. Just raw ViT forward pass | |
""" | |
# ! move the below to the front of the first call | |
B, N, C = x.shape # has [cls] token in N | |
for blk in self.vit_blks: | |
x = blk(x) # B N C | |
return x | |
class TriplaneFusionBlockv4_nested_init_from_dino_lite_merge(nn.Module): | |
def __init__(self, | |
vit_blks, | |
num_heads, | |
embed_dim, | |
use_fusion_blk=True, | |
fusion_ca_blk=None, | |
*args, | |
**kwargs) -> None: | |
super().__init__() | |
self.vit_blks = vit_blks | |
assert use_fusion_blk | |
assert len(vit_blks) == 2 | |
# copied vit settings from https://github.dev/facebookresearch/dinov2 | |
nh = num_heads | |
dim = embed_dim | |
qkv_bias = True | |
attn_drop = proj_drop = 0.0 | |
qk_scale = None # TODO, double check | |
if False: # abla | |
for blk in self.vit_blks: | |
attn_3d = xformer_Conv3D_Aware_CrossAttention_xygrid_withinC( # ! raw 3D attn layer | |
dim, | |
num_heads=num_heads, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
attn_drop=attn_drop, | |
proj_drop=proj_drop) | |
blk.attn = self_cross_attn(blk.attn, attn_3d) | |
def forward(self, x): | |
"""x: B N C, where N = H*W tokens. Just raw ViT forward pass | |
""" | |
# ! move the below to the front of the first call | |
B, N, C = x.shape # has [cls] token in N | |
for blk in self.vit_blks: | |
x = blk(x) # B N C | |
return x | |
class TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C(TriplaneFusionBlockv4_nested_init_from_dino_lite_merge): | |
# on roll out + B 3L C | |
def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, init_from_dino=True, *args, **kwargs) -> None: | |
super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, fusion_ca_blk, init_from_dino, *args, **kwargs) | |
def forward(self, x): | |
"""x: B 3 N C, where N = H*W tokens | |
""" | |
# ! move the below to the front of the first call | |
# B, N, C = x.shape # has [cls] token in N | |
B, group_size, N, C = x.shape # has [cls] token in N | |
x = x.reshape(B, group_size*N, C) | |
for blk in self.vit_blks: | |
x = blk(x) # B N C | |
x = x.reshape(B, group_size, N, C) # outer loop tradition | |
return x | |
class TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C_withrollout(TriplaneFusionBlockv4_nested_init_from_dino_lite_merge): | |
# roll out + B 3L C | |
def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, init_from_dino=True, *args, **kwargs) -> None: | |
super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, fusion_ca_blk, init_from_dino, *args, **kwargs) | |
def forward(self, x): | |
"""x: B 3 N C, where N = H*W tokens | |
""" | |
# ! move the below to the front of the first call | |
# B, N, C = x.shape # has [cls] token in N | |
B, group_size, N, C = x.shape # has [cls] token in N | |
x = x.reshape(B*group_size, N, C) | |
x = self.vit_blks[0](x) | |
x = x.reshape(B,group_size*N, C) | |
x = self.vit_blks[1](x) | |
x = x.reshape(B, group_size, N, C) # outer loop tradition | |
return x | |
class TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_add3DAttn(TriplaneFusionBlockv4_nested_init_from_dino): | |
# no roll out + 3D Attention | |
def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, init_from_dino=True, *args, **kwargs) -> None: | |
super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, fusion_ca_blk, init_from_dino, *args, **kwargs) | |
def forward(self, x): | |
"""x: B 3 N C, where N = H*W tokens | |
""" | |
B, group_size, N, C = x.shape # has [cls] token in N | |
x = x.reshape(B, group_size*N, C) | |
x = self.vit_blks[0](x) # B 3 L C | |
# ! move the below to the front of the first call | |
x = x.reshape(B, group_size, N, C).reshape(B*group_size, N, C) | |
x = self.vit_blks[1](x) # has 3D attention | |
return x.reshape(B, group_size, N, C) | |
return x | |
class TriplaneFusionBlockv5_ldm_addCA(nn.Module): | |
def __init__(self, | |
vit_blks, | |
num_heads, | |
embed_dim, | |
use_fusion_blk=True, | |
fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, | |
*args, | |
**kwargs) -> None: | |
super().__init__() | |
self.num_branches = 3 # triplane | |
self.vit_blks = vit_blks | |
assert use_fusion_blk | |
assert len(vit_blks) == 2 | |
# ! rather than replacing, add a 3D attention block after. | |
# del self.vit_blks[ | |
# 1].attn # , self.vit_blks[1].ls1, self.vit_blks[1].norm1 | |
self.norm_for_atten_3d = deepcopy(self.vit_blks[1].norm1) | |
# copied vit settings from https://github.dev/facebookresearch/dinov2 | |
nh = num_heads | |
dim = embed_dim | |
mlp_ratio = 4 # defined for all dino2 model | |
qkv_bias = True | |
norm_layer = partial(nn.LayerNorm, eps=1e-6) | |
drop_path_rate = 0.3 # default setting | |
attn_drop = proj_drop = 0.0 | |
qk_scale = None # TODO, double check | |
self.attn_3d = xformer_Conv3D_Aware_CrossAttention_xygrid( | |
dim, | |
num_heads=num_heads, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
attn_drop=attn_drop, | |
proj_drop=proj_drop) | |
def forward(self, x): | |
"""x: B 3 N C, where N = H*W tokens | |
""" | |
# self attention, by merging the triplane channel into B for parallel computation | |
# ! move the below to the front of the first call | |
B, group_size, N, C = x.shape # has [cls] token in N | |
assert group_size == 3, 'triplane' | |
flatten_token = lambda x: x.reshape(B * group_size, N, C) | |
unflatten_token = lambda x: x.reshape(B, group_size, N, C) | |
x = flatten_token(x) | |
x = self.vit_blks[0](x) | |
x = unflatten_token(x) | |
x = self.attn_3d(self.norm_for_atten_3d(x)) + x | |
x = flatten_token(x) | |
x = self.vit_blks[1](x) | |
return unflatten_token(x) | |
class TriplaneFusionBlockv6_ldm_addCA_Init3DAttnfrom2D( | |
TriplaneFusionBlockv5_ldm_addCA): | |
def __init__(self, | |
vit_blks, | |
num_heads, | |
embed_dim, | |
use_fusion_blk=True, | |
fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, | |
*args, | |
**kwargs) -> None: | |
super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, | |
fusion_ca_blk, *args, **kwargs) | |
def forward(self, x): | |
"""x: B 3 N C, where N = H*W tokens | |
""" | |
# self attention, by merging the triplane channel into B for parallel computation | |
# ! move the below to the front of the first call | |
B, group_size, N, C = x.shape # has [cls] token in N | |
assert group_size == 3, 'triplane' | |
flatten_token = lambda x: x.reshape(B * group_size, N, C) | |
unflatten_token = lambda x: x.reshape(B, group_size, N, C) | |
x = flatten_token(x) | |
x = self.vit_blks[0](x) | |
x = unflatten_token(x) | |
x = self.attn_3d(self.norm_for_atten_3d(x)) + x | |
x = flatten_token(x) | |
x = self.vit_blks[1](x) | |
return unflatten_token(x) | |
class TriplaneFusionBlockv5_ldm_add_dualCA(nn.Module): | |
def __init__(self, | |
vit_blks, | |
num_heads, | |
embed_dim, | |
use_fusion_blk=True, | |
fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, | |
*args, | |
**kwargs) -> None: | |
super().__init__() | |
self.num_branches = 3 # triplane | |
self.vit_blks = vit_blks | |
assert use_fusion_blk | |
assert len(vit_blks) == 2 | |
# ! rather than replacing, add a 3D attention block after. | |
# del self.vit_blks[ | |
# 1].attn # , self.vit_blks[1].ls1, self.vit_blks[1].norm1 | |
self.norm_for_atten_3d_0 = deepcopy(self.vit_blks[0].norm1) | |
self.norm_for_atten_3d_1 = deepcopy(self.vit_blks[1].norm1) | |
# copied vit settings from https://github.dev/facebookresearch/dinov2 | |
nh = num_heads | |
dim = embed_dim | |
mlp_ratio = 4 # defined for all dino2 model | |
qkv_bias = True | |
norm_layer = partial(nn.LayerNorm, eps=1e-6) | |
drop_path_rate = 0.3 # default setting | |
attn_drop = proj_drop = 0.0 | |
qk_scale = None # TODO, double check | |
self.attn_3d_0 = xformer_Conv3D_Aware_CrossAttention_xygrid( | |
dim, | |
num_heads=num_heads, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
attn_drop=attn_drop, | |
proj_drop=proj_drop) | |
self.attn_3d_1 = deepcopy(self.attn_3d_0) | |
def forward(self, x): | |
"""x: B 3 N C, where N = H*W tokens | |
""" | |
# self attention, by merging the triplane channel into B for parallel computation | |
# ! move the below to the front of the first call | |
B, group_size, N, C = x.shape # has [cls] token in N | |
assert group_size == 3, 'triplane' | |
flatten_token = lambda x: x.reshape(B * group_size, N, C) | |
unflatten_token = lambda x: x.reshape(B, group_size, N, C) | |
x = flatten_token(x) | |
x = self.vit_blks[0](x) | |
x = unflatten_token(x) | |
x = self.attn_3d_0(self.norm_for_atten_3d_0(x)) + x | |
x = flatten_token(x) | |
x = self.vit_blks[1](x) | |
x = unflatten_token(x) | |
x = self.attn_3d_1(self.norm_for_atten_3d_1(x)) + x | |
return unflatten_token(x) | |
def drop_path(x, drop_prob: float = 0., training: bool = False): | |
if drop_prob == 0. or not training: | |
return x | |
keep_prob = 1 - drop_prob | |
shape = (x.shape[0], ) + (1, ) * ( | |
x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets | |
random_tensor = keep_prob + torch.rand( | |
shape, dtype=x.dtype, device=x.device) | |
random_tensor.floor_() # binarize | |
output = x.div(keep_prob) * random_tensor | |
return output | |
class DropPath(nn.Module): | |
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |
""" | |
def __init__(self, drop_prob=None): | |
super(DropPath, self).__init__() | |
self.drop_prob = drop_prob | |
def forward(self, x): | |
return drop_path(x, self.drop_prob, self.training) | |
class Mlp(nn.Module): | |
def __init__(self, | |
in_features, | |
hidden_features=None, | |
out_features=None, | |
act_layer=nn.GELU, | |
drop=0.): | |
super().__init__() | |
out_features = out_features or in_features | |
hidden_features = hidden_features or in_features | |
self.fc1 = nn.Linear(in_features, hidden_features) | |
self.act = act_layer() | |
self.fc2 = nn.Linear(hidden_features, out_features) | |
self.drop = nn.Dropout(drop) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.act(x) | |
x = self.drop(x) | |
x = self.fc2(x) | |
x = self.drop(x) | |
return x | |
class Block(nn.Module): | |
def __init__(self, | |
dim, | |
num_heads, | |
mlp_ratio=4., | |
qkv_bias=False, | |
qk_scale=None, | |
drop=0., | |
attn_drop=0., | |
drop_path=0., | |
act_layer=nn.GELU, | |
norm_layer=nn.LayerNorm): | |
super().__init__() | |
self.norm1 = norm_layer(dim) | |
# self.attn = Attention(dim, | |
self.attn = MemEffAttention(dim, | |
num_heads=num_heads, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
attn_drop=attn_drop, | |
proj_drop=drop) | |
self.drop_path = DropPath( | |
drop_path) if drop_path > 0. else nn.Identity() | |
self.norm2 = norm_layer(dim) | |
mlp_hidden_dim = int(dim * mlp_ratio) | |
self.mlp = Mlp(in_features=dim, | |
hidden_features=mlp_hidden_dim, | |
act_layer=act_layer, | |
drop=drop) | |
def forward(self, x, return_attention=False): | |
y, attn = self.attn(self.norm1(x)) | |
if return_attention: | |
return attn | |
x = x + self.drop_path(y) | |
x = x + self.drop_path(self.mlp(self.norm2(x))) | |
return x | |
class PatchEmbed(nn.Module): | |
""" Image to Patch Embedding | |
""" | |
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): | |
super().__init__() | |
num_patches = (img_size // patch_size) * (img_size // patch_size) | |
self.img_size = img_size | |
self.patch_size = patch_size | |
self.num_patches = num_patches | |
self.proj = nn.Conv2d(in_chans, | |
embed_dim, | |
kernel_size=patch_size, | |
stride=patch_size) | |
def forward(self, x): | |
B, C, H, W = x.shape | |
x = self.proj(x).flatten(2).transpose(1, 2) # B, C, L -> B, L, C | |
return x | |
class VisionTransformer(nn.Module): | |
""" Vision Transformer """ | |
def __init__(self, | |
img_size=[224], | |
patch_size=16, | |
in_chans=3, | |
num_classes=0, | |
embed_dim=768, | |
depth=12, | |
num_heads=12, | |
mlp_ratio=4., | |
qkv_bias=False, | |
qk_scale=None, | |
drop_rate=0., | |
attn_drop_rate=0., | |
drop_path_rate=0., | |
norm_layer='nn.LayerNorm', | |
patch_embedding=True, | |
cls_token=True, | |
pixel_unshuffle=False, | |
**kwargs): | |
super().__init__() | |
self.num_features = self.embed_dim = embed_dim | |
self.patch_size = patch_size | |
# if norm_layer == 'nn.LayerNorm': | |
norm_layer = partial(nn.LayerNorm, eps=1e-6) | |
if patch_embedding: | |
self.patch_embed = PatchEmbed(img_size=img_size[0], | |
patch_size=patch_size, | |
in_chans=in_chans, | |
embed_dim=embed_dim) | |
num_patches = self.patch_embed.num_patches | |
self.img_size = self.patch_embed.img_size | |
else: | |
self.patch_embed = None | |
self.img_size = img_size[0] | |
num_patches = (img_size[0] // patch_size) * (img_size[0] // | |
patch_size) | |
if cls_token: | |
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
self.pos_embed = nn.Parameter( | |
torch.zeros(1, num_patches + 1, embed_dim)) | |
else: | |
self.cls_token = None | |
self.pos_embed = nn.Parameter( | |
torch.zeros(1, num_patches, embed_dim)) | |
self.pos_drop = nn.Dropout(p=drop_rate) | |
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) | |
] # stochastic depth decay rule | |
self.blocks = nn.ModuleList([ | |
Block(dim=embed_dim, | |
num_heads=num_heads, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
drop=drop_rate, | |
attn_drop=attn_drop_rate, | |
drop_path=dpr[i], | |
norm_layer=norm_layer) for i in range(depth) | |
]) | |
self.norm = norm_layer(embed_dim) | |
# Classifier head | |
self.head = nn.Linear( | |
embed_dim, num_classes) if num_classes > 0 else nn.Identity() | |
trunc_normal_(self.pos_embed, std=.02) | |
if cls_token: | |
trunc_normal_(self.cls_token, std=.02) | |
self.apply(self._init_weights) | |
# if pixel_unshuffle: | |
# self.decoder_pred = nn.Linear(embed_dim, | |
# patch_size**2 * out_chans, | |
# bias=True) # decoder to patch | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
trunc_normal_(m.weight, std=.02) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
def interpolate_pos_encoding(self, x, w, h): | |
npatch = x.shape[1] - 1 | |
N = self.pos_embed.shape[1] - 1 | |
if npatch == N and w == h: | |
return self.pos_embed | |
patch_pos_embed = self.pos_embed[:, 1:] | |
dim = x.shape[-1] | |
w0 = w // self.patch_size | |
h0 = h // self.patch_size | |
# we add a small number to avoid floating point error in the interpolation | |
# see discussion at https://github.com/facebookresearch/dino/issues/8 | |
w0, h0 = w0 + 0.1, h0 + 0.1 | |
patch_pos_embed = nn.functional.interpolate( | |
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), | |
dim).permute(0, 3, 1, 2), | |
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), | |
mode='bicubic', | |
) | |
assert int(w0) == patch_pos_embed.shape[-2] and int( | |
h0) == patch_pos_embed.shape[-1] | |
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(2, -1, dim) | |
if self.cls_token is not None: | |
class_pos_embed = self.pos_embed[:, 0] | |
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), | |
dim=1) | |
return patch_pos_embed | |
def prepare_tokens(self, x): | |
B, nc, w, h = x.shape | |
x = self.patch_embed(x) # patch linear embedding | |
# add the [CLS] token to the embed patch tokens | |
cls_tokens = self.cls_token.expand(B, -1, -1) | |
x = torch.cat((cls_tokens, x), dim=1) | |
# add positional encoding to each token | |
x = x + self.interpolate_pos_encoding(x, w, h) | |
return self.pos_drop(x) | |
def forward(self, x): | |
x = self.prepare_tokens(x) | |
for blk in self.blocks: | |
x = blk(x) | |
x = self.norm(x) | |
return x[:, 1:] # return spatial feature maps, not the [CLS] token | |
# return x[:, 0] | |
def get_last_selfattention(self, x): | |
x = self.prepare_tokens(x) | |
for i, blk in enumerate(self.blocks): | |
if i < len(self.blocks) - 1: | |
x = blk(x) | |
else: | |
# return attention of the last block | |
return blk(x, return_attention=True) | |
def get_intermediate_layers(self, x, n=1): | |
x = self.prepare_tokens(x) | |
# we return the output tokens from the `n` last blocks | |
output = [] | |
for i, blk in enumerate(self.blocks): | |
x = blk(x) | |
if len(self.blocks) - i <= n: | |
output.append(self.norm(x)) | |
return output | |
def vit_tiny(patch_size=16, **kwargs): | |
model = VisionTransformer(patch_size=patch_size, | |
embed_dim=192, | |
depth=12, | |
num_heads=3, | |
mlp_ratio=4, | |
qkv_bias=True, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
**kwargs) | |
return model | |
def vit_small(patch_size=16, **kwargs): | |
model = VisionTransformer( | |
patch_size=patch_size, | |
embed_dim=384, | |
depth=12, | |
num_heads=6, | |
mlp_ratio=4, | |
qkv_bias=True, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), # type: ignore | |
**kwargs) | |
return model | |
def vit_base(patch_size=16, **kwargs): | |
model = VisionTransformer(patch_size=patch_size, | |
embed_dim=768, | |
depth=12, | |
num_heads=12, | |
mlp_ratio=4, | |
qkv_bias=True, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
**kwargs) | |
return model | |
vits = vit_small | |
vitb = vit_base | |