GaussianAnything-AIGC3D / vit /vision_transformer.py
yslan's picture
init
7f51798
raw
history blame
106 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Mostly copy-paste from timm library.
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
from copy import deepcopy
from typing import List, Optional, Tuple
import math
from functools import partial
from sympy import flatten
import torch
import torch.nn as nn
from torch import Tensor, pixel_shuffle
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch.nn.modules import GELU
import torch.utils.benchmark as benchmark
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6
# from vit.vision_transformer import Conv3DCrossAttentionBlock
from .utils import trunc_normal_
from pdb import set_trace as st
# import apex
# from apex.normalization import FusedRMSNorm as RMSNorm
try:
from apex.normalization import FusedRMSNorm as RMSNorm
except:
from dit.norm import RMSNorm
from torch.nn import LayerNorm
try:
import xformers
import xformers.ops
from xformers.ops import memory_efficient_attention, unbind, fmha
from xformers.ops import MemoryEfficientAttentionFlashAttentionOp, MemoryEfficientAttentionCutlassOp
# from xformers.ops import RMSNorm
XFORMERS_AVAILABLE = True
except ImportError:
# logger.warning("xFormers not available")
XFORMERS_AVAILABLE = False
from packaging import version
assert version.parse(torch.__version__) >= version.parse("2.0.0")
SDP_IS_AVAILABLE = True
# from torch.backends.cuda import SDPBackend, sdp_kernel
from torch.nn.attention import sdpa_kernel, SDPBackend
class Attention(nn.Module):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
enable_rmsnorm=False,
qk_norm=False,
no_flash_op=False,
enable_rope=False,):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
# https://github.com/huggingface/pytorch-image-models/blob/5dce71010174ad6599653da4e8ba37fd5f9fa572/timm/models/vision_transformer.py#L79C1-L80C78
self.enable_rope = enable_rope
# st()
if enable_rope:
self.q_norm = RMSNorm(dim, elementwise_affine=True) if qk_norm else nn.Identity()
self.k_norm = RMSNorm(dim, elementwise_affine=True) if qk_norm else nn.Identity()
else:
self.q_norm = RMSNorm(head_dim, elementwise_affine=True) if qk_norm else nn.Identity()
self.k_norm = RMSNorm(head_dim, elementwise_affine=True) if qk_norm else nn.Identity()
# if qk_norm:
# self.q_norm = LayerNorm(dim, eps=1e-5)
# self.k_norm = LayerNorm(dim, eps=1e-5)
self.qk_norm = qk_norm
self.no_flash_op = no_flash_op
self.attn_mode = "torch"
self.backend = SDPBackend.FLASH_ATTENTION # FA implemented by torch.
@staticmethod
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
"""
Reshape frequency tensor for broadcasting it with another tensor.
This function reshapes the frequency tensor to have the same shape as
the target tensor 'x' for the purpose of broadcasting the frequency
tensor during element-wise operations.
Args:
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
x (torch.Tensor): Target tensor for broadcasting compatibility.
Returns:
torch.Tensor: Reshaped frequency tensor.
Raises:
AssertionError: If the frequency tensor doesn't match the expected
shape.
AssertionError: If the target tensor 'x' doesn't have the expected
number of dimensions.
"""
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
@staticmethod
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency
tensor.
This function applies rotary embeddings to the given query 'xq' and
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors
contain rotary embeddings and are returned as real tensors.
Args:
xq (torch.Tensor): Query tensor to apply rotary embeddings.
xk (torch.Tensor): Key tensor to apply rotary embeddings.
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
exponentials.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
and key tensor with rotary embeddings.
"""
with torch.cuda.amp.autocast(enabled=False):
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = Attention.reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def forward(self, x):
# B, N, C = x.shape
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
# C // self.num_heads).permute(2, 0, 3, 1, 4)
# q, k, v = qkv[0], qkv[1], qkv[2]
# attn = (q @ k.transpose(-2, -1)) * self.scale
# attn = attn.softmax(dim=-1)
# attn = self.attn_drop(attn)
# x = (attn @ v).transpose(1, 2).reshape(B, N, C)
# return x, attn
# https://github.com/Stability-AI/generative-models/blob/863665548f95ff827273948766a3f732ab01bc49/sgm/modules/attention.py#L179
B, L, C = x.shape
qkv = self.qkv(x)
if self.attn_mode == "torch":
qkv = rearrange(
qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
).float()
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
q, k = self.q_norm(q), self.k_norm(k)
with sdpa_kernel([self.backend]): # new signature
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
del q, k, v
x = rearrange(x, "B H L D -> B L (H D)")
x = self.proj(x)
x = self.proj_drop(x)
return x
class MemEffAttention(Attention):
def forward(self, x: Tensor, attn_bias=None, freqs_cis=None) -> Tensor:
if not XFORMERS_AVAILABLE:
assert attn_bias is None, "xFormers is required for nested tensors usage"
return super().forward(x)
B, N, C = x.shape
qkv = self.qkv(x)
dtype = qkv.dtype
if self.enable_rope:
assert freqs_cis is not None
qkv = qkv.reshape(B, N, 3, C)
q, k, v = unbind(qkv, 2)
q, k = self.q_norm(q), self.k_norm(k) # do q-k norm on the full seq instead.
st()
q, k = Attention.apply_rotary_emb(q, k, freqs_cis=freqs_cis)
q = q.reshape(B, N, self.num_heads, C // self.num_heads)
k = k.reshape(B, N, self.num_heads, C // self.num_heads)
q, k, v = map(
lambda t: t.reshape(b, N, self.num_heads, C // self.num_heads)
(q, k, v),
)
q, k = q.to(dtype), k.to(dtype)
else:
qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
q, k, v = unbind(qkv, 2)
q, k = self.q_norm(q), self.k_norm(k)
# x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) # if not bf16, no flash-attn here.
# x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=MemoryEfficientAttentionFlashAttentionOp) # force flash attention
if self.no_flash_op: # F-A does not support large batch size? force cutlas?
# x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=MemoryEfficientAttentionCutlassOp) # force flash attention
if version.parse(xformers.__version__) >= version.parse("0.0.21"):
# NOTE: workaround for
# https://github.com/facebookresearch/xformers/issues/845
# def attn(max_bs, op):
max_bs = 32768
L = q.shape[0]
n_batches = math.ceil(L / max_bs)
x = list()
for i_batch in range(n_batches):
batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs)
x.append(
xformers.ops.memory_efficient_attention(
q[batch],
k[batch],
v[batch],
attn_bias=None,
# op=MemoryEfficientAttentionFlashAttentionOp,
# op=op,
op=MemoryEfficientAttentionCutlassOp,
)
)
x = torch.cat(x, 0)
# return x
# The cutlas implementation runs in 8396.681 microseconds
# The Flash implementation runs in 19473.491 microseconds
# max_bs = 32768
# math_time = benchmark_torch_function_in_microseconds(attn, max_bs, MemoryEfficientAttentionCutlassOp)
# print(f"The cutlas implementation runs in {math_time:.3f} microseconds")
# max_bs = 32768 // 2 # works for flash attention
# math_time = benchmark_torch_function_in_microseconds(attn, max_bs, MemoryEfficientAttentionFlashAttentionOp)
# print(f"The Flash implementation runs in {math_time:.3f} microseconds")
# st()
# pass
else: # will enable flash attention by default.
# x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=MemoryEfficientAttentionFlashAttentionOp) # force flash attention
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) # force flash attention
x = x.reshape([B, N, C])
x = self.proj(x)
x = self.proj_drop(x)
return x
class MemEffCrossAttention(MemEffAttention):
# for cross attention, where context serves as k and v
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0, proj_drop=0):
super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop)
del self.qkv
self.q = nn.Linear(dim, dim * 1, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
def forward(self, x: Tensor, context: Tensor, attn_bias=None) -> Tensor:
if not XFORMERS_AVAILABLE:
assert attn_bias is None, "xFormers is required for nested tensors usage"
return super().forward(x)
B, N, C = x.shape
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
q = self.q(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
kv = self.kv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
k, v = unbind(kv, 2)
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
# x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=MemoryEfficientAttentionFlashAttentionOp)
x = x.reshape([B, N, C])
x = self.proj(x)
x = self.proj_drop(x)
return x
# https://github.com/IBM/CrossViT/blob/main/models/crossvit.py
class CrossAttention(nn.Module):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim**-0.5
self.wq = nn.Linear(dim, dim, bias=qkv_bias)
self.wk = nn.Linear(dim, dim, bias=qkv_bias)
self.wv = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
q = self.wq(x[:,
0:1, ...]).reshape(B, 1, self.num_heads,
C // self.num_heads).permute(
0, 2, 1,
3) # B1C -> B1H(C/H) -> BH1(C/H)
k = self.wk(x).reshape(B, N,
self.num_heads, C // self.num_heads).permute(
0, 2, 1, 3) # BNC -> BNH(C/H) -> BHN(C/H)
v = self.wv(x).reshape(B, N,
self.num_heads, C // self.num_heads).permute(
0, 2, 1, 3) # BNC -> BNH(C/H) -> BHN(C/H)
attn = (q @ k.transpose(
-2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(
B, 1, C) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C
x = self.proj(x)
x = self.proj_drop(x)
return x
class Conv3D_Aware_CrossAttention(nn.Module):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim**-0.5
self.wq = nn.Linear(dim, dim, bias=qkv_bias)
self.wk = nn.Linear(dim, dim, bias=qkv_bias)
self.wv = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, group_size, N, C = x.shape # B 3 N C
p = int(N**0.5) # patch size
assert p**2 == N, 'check input dim, no [cls] needed here'
assert group_size == 3, 'designed for triplane here'
x = x.reshape(B, group_size, p, p, C) # expand patch token dim
# * init qkv
# q = torch.empty(B * group_size * N,
# 1,
# self.num_heads,
# C // self.num_heads,
# device=x.device).permute(0, 2, 1, 3)
# k = torch.empty(B * group_size * N,
# 2 * p,
# self.num_heads,
# C // self.num_heads,
# device=x.device).permute(0, 2, 1, 3)
# v = torch.empty_like(k)
q_x = torch.empty(
B * group_size * N,
1,
# self.num_heads,
# C // self.num_heads,
C,
device=x.device)
k_x = torch.empty(
B * group_size * N,
2 * p,
# self.num_heads,
# C // self.num_heads,
C,
device=x.device)
v_x = torch.empty_like(k_x)
# ! refer to the following plane order
# N, M, _ = coordinates.shape
# xy_coords = coordinates[..., [0, 1]]
# yz_coords = coordinates[..., [1, 2]]
# zx_coords = coordinates[..., [2, 0]]
# return torch.stack([xy_coords, yz_coords, zx_coords],
# dim=1).reshape(N * 3, M, 2)
index_i, index_j = torch.meshgrid(torch.arange(0, p),
torch.arange(0, p),
indexing='ij') # 16*16
index_mesh_grid = torch.stack([index_i, index_j], 0).to(
x.device).unsqueeze(0).repeat_interleave(B,
0).reshape(B, 2, p,
p) # B 2 p p.
for i in range(group_size):
q_x[B * i * N:B * (i + 1) * N] = x[:, i:i + 1].permute(
0, 2, 3, 1, 4).reshape(B * N, 1, C) # B 1 p p C -> B*N, 1, C
# TODO, how to batchify gather ops?
plane_yz = x[:, (i + 1) % group_size:(i + 1) % group_size +
1] # B 1 p p C
plane_zx = x[:, (i + 2) % group_size:(i + 2) % group_size + 1]
assert plane_yz.shape == plane_zx.shape == (
B, 1, p, p, C), 'check sub plane dimensions'
pooling_plane_yz = torch.gather(
plane_yz,
dim=2,
index=index_mesh_grid[:, 0:1].reshape(B, 1, N, 1, 1).expand(
-1, -1, -1, p,
C)).permute(0, 2, 1, 3, 4) # B 1 256 16 C => B 256 1 16 C
pooling_plane_zx = torch.gather(
plane_zx,
dim=3,
index=index_mesh_grid[:, 1:2].reshape(B, 1, 1, N, 1).expand(
-1, -1, p, -1,
C)).permute(0, 3, 1, 2, 4) # B 1 16 256 C => B 256 1 16 C
k_x[B * i * N:B * (i + 1) *
N] = v_x[B * i * N:B * (i + 1) * N] = torch.cat(
[pooling_plane_yz, pooling_plane_zx],
dim=2).reshape(B * N, 2 * p,
C) # B 256 2 16 C => (B*256) 2*16 C
# q[B * i * N: B * (i+1) * N] = self.wq(q_x).reshape(B*N, 1, self.num_heads, C // self.num_heads).permute( 0, 2, 1, 3)
# k[B * i * N: B * (i+1) * N] = self.wk(k_x).reshape(B*N, 2*p, self.num_heads, C // self.num_heads).permute( 0, 2, 1, 3)
# v[B * i * N: B * (i+1) * N] = self.wv(v_x).reshape(B*N, 2*p, self.num_heads, C // self.num_heads).permute( 0, 2, 1, 3)
q = self.wq(q_x).reshape(B * group_size * N, 1,
self.num_heads, C // self.num_heads).permute(
0, 2, 1,
3) # merge num_heads into Batch dimention
k = self.wk(k_x).reshape(B * group_size * N, 2 * p, self.num_heads,
C // self.num_heads).permute(0, 2, 1, 3)
v = self.wv(v_x).reshape(B * group_size * N, 2 * p, self.num_heads,
C // self.num_heads).permute(0, 2, 1, 3)
attn = (q @ k.transpose(
-2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N, N=2p here
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(
B * 3 * N, 1,
C) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C
x = self.proj(x)
x = self.proj_drop(x)
# reshape x back
x = x.reshape(B, 3, N, C)
return x
class xformer_Conv3D_Aware_CrossAttention(nn.Module):
# https://github.dev/facebookresearch/dinov2
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.):
super().__init__()
# https://pytorch.org/blog/accelerated-generative-diffusion-models/
self.num_heads = num_heads
self.wq = nn.Linear(dim, dim * 1, bias=qkv_bias)
self.w_kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.index_mesh_grid = None
def forward(self, x, attn_bias=None):
B, group_size, N, C = x.shape # B 3 N C
p = int(N**0.5) # patch size
assert p**2 == N, 'check input dim, no [cls] needed here'
assert group_size == 3, 'designed for triplane here'
x = x.reshape(B, group_size, p, p, C) # expand patch token dim
q_x = torch.empty(B * group_size * N, 1, C, device=x.device)
context = torch.empty(B * group_size * N, 2 * p, C,
device=x.device) # k_x=v_x
if self.index_mesh_grid is None: # further accelerate
index_i, index_j = torch.meshgrid(torch.arange(0, p),
torch.arange(0, p),
indexing='ij') # 16*16
index_mesh_grid = torch.stack([index_i, index_j], 0).to(
x.device).unsqueeze(0).repeat_interleave(B, 0).reshape(
B, 2, p, p) # B 2 p p.
self.index_mesh_grid = index_mesh_grid[0:1]
else:
index_mesh_grid = self.index_mesh_grid.clone().repeat_interleave(
B, 0)
assert index_mesh_grid.shape == (
B, 2, p, p), 'check index_mesh_grid dimension'
for i in range(group_size):
q_x[B * i * N:B * (i + 1) * N] = x[:, i:i + 1].permute(
0, 2, 3, 1, 4).reshape(B * N, 1, C) # B 1 p p C -> B*N, 1, C
# TODO, how to batchify gather ops?
plane_yz = x[:, (i + 1) % group_size:(i + 1) % group_size +
1] # B 1 p p C
plane_zx = x[:, (i + 2) % group_size:(i + 2) % group_size + 1]
assert plane_yz.shape == plane_zx.shape == (
B, 1, p, p, C), 'check sub plane dimensions'
pooling_plane_yz = torch.gather(
plane_yz,
dim=2,
index=index_mesh_grid[:, 0:1].reshape(B, 1, N, 1, 1).expand(
-1, -1, -1, p,
C)).permute(0, 2, 1, 3, 4) # B 1 256 16 C => B 256 1 16 C
pooling_plane_zx = torch.gather(
plane_zx,
dim=3,
index=index_mesh_grid[:, 1:2].reshape(B, 1, 1, N, 1).expand(
-1, -1, p, -1,
C)).permute(0, 3, 1, 2, 4) # B 1 16 256 C => B 256 1 16 C
context[B * i * N:B * (i + 1) * N] = torch.cat(
[pooling_plane_yz, pooling_plane_zx],
dim=2).reshape(B * N, 2 * p,
C) # B 256 2 16 C => (B*256) 2*16 C
# B, N, C = x.shape
q = self.wq(q_x).reshape(B * group_size * N, 1, self.num_heads,
C // self.num_heads)
kv = self.w_kv(context).reshape(B * group_size * N, 2 * p, 2,
self.num_heads, C // self.num_heads)
k, v = unbind(kv, 2)
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
# x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=MemoryEfficientAttentionFlashAttentionOp)
x = x.transpose(1, 2).reshape([B * 3 * N, 1, C]).reshape(B, 3, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class xformer_Conv3D_Aware_CrossAttention_xygrid(
xformer_Conv3D_Aware_CrossAttention):
"""implementation wise clearer, but yields identical results with xformer_Conv3D_Aware_CrossAttention
"""
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0):
super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop,
proj_drop)
def forward(self, x, attn_bias=None):
B, group_size, N, C = x.shape # B 3 N C
p = int(N**0.5) # patch size
assert p**2 == N, 'check input dim, no [cls] needed here'
assert group_size == 3, 'designed for triplane here'
x = x.reshape(B, group_size, p, p, C) # expand patch token dim
q_x = torch.empty(B * group_size * N, 1, C, device=x.device)
context = torch.empty(B * group_size * N, 2 * p, C,
device=x.device) # k_x=v_x
if self.index_mesh_grid is None: # further accelerate
index_u, index_v = torch.meshgrid(
torch.arange(0, p), torch.arange(0, p),
indexing='xy') # ! switch to 'xy' here to match uv coordinate
index_mesh_grid = torch.stack([index_u, index_v], 0).to(
x.device).unsqueeze(0).repeat_interleave(B, 0).reshape(
B, 2, p, p) # B 2 p p.
self.index_mesh_grid = index_mesh_grid[0:1]
else:
index_mesh_grid = self.index_mesh_grid.clone().repeat_interleave(
B, 0)
assert index_mesh_grid.shape == (
B, 2, p, p), 'check index_mesh_grid dimension'
for i in range(group_size):
q_x[B * i * N:B * (i + 1) * N] = x[:, i:i + 1].permute(
0, 2, 3, 1, 4).reshape(B * N, 1, C) # B 1 p p C -> B*N, 1, C
# TODO, how to batchify gather ops?
plane_yz = x[:, (i + 1) % group_size:(i + 1) % group_size +
1] # B 1 p p C
plane_zx = x[:, (i + 2) % group_size:(i + 2) % group_size + 1]
assert plane_yz.shape == plane_zx.shape == (
B, 1, p, p, C), 'check sub plane dimensions'
pooling_plane_yz = torch.gather(
plane_yz,
dim=2,
index=index_mesh_grid[:, 1:2].reshape(B, 1, N, 1, 1).expand(
-1, -1, -1, p,
C)).permute(0, 2, 1, 3, 4) # B 1 256 16 C => B 256 1 16 C
pooling_plane_zx = torch.gather(
plane_zx,
dim=3,
index=index_mesh_grid[:, 0:1].reshape(B, 1, 1, N, 1).expand(
-1, -1, p, -1,
C)).permute(0, 3, 1, 2, 4) # B 1 16 256 C => B 256 1 16 C
context[B * i * N:B * (i + 1) * N] = torch.cat(
[pooling_plane_yz, pooling_plane_zx],
dim=2).reshape(B * N, 2 * p,
C) # B 256 2 16 C => (B*256) 2*16 C
# B, N, C = x.shape
q = self.wq(q_x).reshape(B * group_size * N, 1, self.num_heads,
C // self.num_heads)
kv = self.w_kv(context).reshape(B * group_size * N, 2 * p, 2,
self.num_heads, C // self.num_heads)
k, v = unbind(kv, 2)
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
# x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=MemoryEfficientAttentionFlashAttentionOp)
x = x.transpose(1, 2).reshape([B * 3 * N, 1, C]).reshape(B, 3, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class xformer_Conv3D_Aware_CrossAttention_xygrid_withinC(
xformer_Conv3D_Aware_CrossAttention_xygrid):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0,
proj_drop=0):
super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop,
proj_drop)
def forward(self, x, attn_bias=None):
# ! split x: B N C into B 3 N C//3
B, N, C = x.shape
x = x.reshape(B, N, C // 3, 3).permute(0, 3, 1,
2) # B N C 3 -> B 3 N C
x_out = super().forward(x, attn_bias) # B 3 N C
x_out = x_out.permute(0, 2, 3, 1)# B 3 N C -> B N C 3
x_out = x_out.reshape(*x_out.shape[:2], -1) # B N C 3 -> B N C3
return x_out.contiguous()
class self_cross_attn(nn.Module):
def __init__(self, dino_attn, cross_attn, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.dino_attn = dino_attn
self.cross_attn = cross_attn
def forward(self, x_norm):
y = self.dino_attn(x_norm) + x_norm
return self.cross_attn(y) # will add x in the original code
# class RodinRollOutConv(nn.Module):
# """implementation wise clearer, but yields identical results with xformer_Conv3D_Aware_CrossAttention
# Use Group Conv
# """
# def __init__(self, in_chans, out_chans=None):
# super().__init__()
# # input: B 3C H W
# if out_chans is None:
# out_chans = in_chans
# self.roll_out_convs = nn.Conv2d(in_chans,
# out_chans,
# kernel_size=3,
# groups=3,
# padding=1)
# def forward(self, x):
# return self.roll_out_convs(x)
class RodinRollOutConv3D(nn.Module):
"""implementation wise clearer, but yields identical results with xformer_Conv3D_Aware_CrossAttention
"""
def __init__(self, in_chans, out_chans=None):
super().__init__()
if out_chans is None:
out_chans = in_chans
self.out_chans = out_chans // 3
self.roll_out_convs = nn.Conv2d(in_chans,
self.out_chans,
kernel_size=3,
padding=1)
def forward(self, x):
# todo, reshape before input?
B, C3, p, p = x.shape # B 3C H W
C = C3 // 3
group_size = C3 // C
assert group_size == 3
x = x.reshape(B, 3, C, p, p)
roll_out_x = torch.empty(B, group_size * C, p, 3 * p,
device=x.device) # B, 3C, H, 3W
for i in range(group_size):
plane_xy = x[:, i] # B C H W
# TODO, simply do the average pooling?
plane_yz_pooling = x[:, (i + 1) % group_size].mean(
dim=-1, keepdim=True).repeat_interleave(
p, dim=-1) # B C H W -> B C H 1 -> B C H W, reduce z dim
plane_zx_pooling = x[:, (i + 2) % group_size].mean(
dim=-2, keepdim=True).repeat_interleave(
p, dim=-2) # B C H W -> B C 1 W -> B C H W, reduce z dim
roll_out_x[..., i * p:(i + 1) * p] = torch.cat(
[plane_xy, plane_yz_pooling, plane_zx_pooling],
1) # fill in the 3W dim
x = self.roll_out_convs(roll_out_x) # B C H 3W
x = x.reshape(B, self.out_chans, p, 3, p)
x = x.permute(0, 3, 1, 2, 4).reshape(B, 3 * self.out_chans, p,
p) # B 3C H W
return x
class RodinRollOutConv3D_GroupConv(nn.Module):
"""implementation wise clearer, but yields identical results with xformer_Conv3D_Aware_CrossAttention
"""
def __init__(self,
in_chans,
out_chans=None,
kernel_size=3,
stride=1,
padding=1):
super().__init__()
if out_chans is None:
out_chans = in_chans
self.roll_out_convs = nn.Conv2d(
in_chans * 3,
out_chans,
kernel_size=kernel_size,
groups=3, # B 9C H W
stride=stride,
padding=padding)
# @torch.autocast(device_type='cuda')
def forward(self, x):
# todo, reshape before input?
B, C3, p, p = x.shape # B 3C H W
C = C3 // 3
group_size = C3 // C
assert group_size == 3
x = x.reshape(B, 3, C, p, p)
roll_out_x = torch.empty(B, group_size * C * 3, p, p,
device=x.device) # B, 3C, H, 3W
for i in range(group_size):
plane_xy = x[:, i] # B C H W
# # TODO, simply do the average pooling?
plane_yz_pooling = x[:, (i + 1) % group_size].mean(
dim=-1, keepdim=True).repeat_interleave(
p, dim=-1) # B C H W -> B C H 1 -> B C H W, reduce z dim
plane_zx_pooling = x[:, (i + 2) % group_size].mean(
dim=-2, keepdim=True).repeat_interleave(
p, dim=-2) # B C H W -> B C 1 W -> B C H W, reduce z dim
roll_out_x[:, i * 3 * C:(i + 1) * 3 * C] = torch.cat(
[plane_xy, plane_yz_pooling, plane_zx_pooling],
1) # fill in the 3W dim
# ! directly cat, avoid intermediate vars
# ? why OOM
# roll_out_x[:, i * 3 * C:(i + 1) * 3 * C] = torch.cat(
# [
# x[:, i],
# x[:, (i + 1) % group_size].mean(
# dim=-1, keepdim=True).repeat_interleave(p, dim=-1),
# x[:, (i + 2) % group_size].mean(
# dim=-2, keepdim=True).repeat_interleave(
# p, dim=-2
# ) # B C H W -> B C 1 W -> B C H W, reduce z dim
# ],
# 1) # fill in the 3C dim
x = self.roll_out_convs(roll_out_x) # B 3C H W
return x
class RodinRollOut_GroupConv_noConv3D(nn.Module):
"""only roll out and do Conv on individual planes
"""
def __init__(self,
in_chans,
out_chans=None,
kernel_size=3,
stride=1,
padding=1):
super().__init__()
if out_chans is None:
out_chans = in_chans
self.roll_out_inplane_conv = nn.Conv2d(
in_chans,
out_chans,
kernel_size=kernel_size,
groups=3, # B 3C H W
stride=stride,
padding=padding)
def forward(self, x):
x = self.roll_out_inplane_conv(x) # B 3C H W
return x
# class RodinConv3D_SynthesisLayer_withact(nn.Module):
# def __init__(self, in_chans, out_chans) -> None:
# super().__init__()
# self.act = nn.LeakyReLU(inplace=True)
# self.conv = nn.Sequential(
# RodinRollOutConv3D_GroupConv(in_chans, out_chans),
# nn.LeakyReLU(inplace=True),
# )
# if in_chans != out_chans:
# self.short_cut = RodinRollOutConv3D_GroupConv(in_chans, out_chans) # PSNR 13 first iteration.
# else:
# self.short_cut = None
# def forward(self, feats):
# if self.short_cut is not None:
# res_feats = self.short_cut(feats)
# else:
# res_feats = feats
# # return res_feats + self.conv(feats)
# feats = res_feats + self.conv(feats)
# return self.act(feats) # as in resnet, add an act before return
class RodinConv3D_SynthesisLayer_mlp_unshuffle_as_residual(nn.Module):
def __init__(self, in_chans, out_chans) -> None:
super().__init__()
self.act = nn.LeakyReLU(inplace=True)
self.conv = nn.Sequential(
RodinRollOutConv3D_GroupConv(in_chans, out_chans),
nn.LeakyReLU(inplace=True),
)
self.out_chans = out_chans
if in_chans != out_chans:
# self.short_cut = RodinRollOutConv3D_GroupConv(in_chans, out_chans) # PSNR 13 first iteration.
self.short_cut = nn.Linear( # B 3C H W -> B 3C 4H 4W
in_chans // 3, # 144 / 3 = 48
out_chans // 3 * 4 * 4, # 32 * 16
bias=True) # decoder to pat
# RodinRollOutConv3D_GroupConv(in_chans, out_chans) # PSNR 13 first iteration.
else:
self.short_cut = None
def shortcut_unpatchify_triplane(self,
x,
p=None,
unpatchify_out_chans=None):
"""separate triplane version; x shape: B (3*257) 768
"""
assert self.short_cut is not None
# B, L, C = x.shape
B, C3, h, w = x.shape
assert h == w
L = h * w
x = x.reshape(B, C3 // 3, 3, L).permute(0, 2, 3,
1) # (B, 3, L // 3, C)
x = self.short_cut(x)
p = h * 4
x = x.reshape(shape=(B, 3, h, w, p, p, unpatchify_out_chans))
x = torch.einsum('ndhwpqc->ndchpwq',
x) # nplanes, C order in the renderer.py
x = x.reshape(shape=(B, 3 * self.out_chans, h * p, h * p))
return x
def forward(self, feats):
if self.short_cut is not None:
res_feats = self.shortcut_unpatchify_triplane(feats)
else:
res_feats = feats
# return res_feats + self.conv(feats)
feats = res_feats + self.conv(feats)
return self.act(feats) # as in resnet, add an act before return
# class RodinConv3D_SynthesisLayer(nn.Module):
# def __init__(self, in_chans, out_chans) -> None:
# super().__init__()
# self.act = nn.LeakyReLU(inplace=True)
# self.conv = nn.Sequential(
# RodinRollOutConv3D_GroupConv(in_chans, out_chans),
# nn.LeakyReLU(inplace=True),
# )
# if in_chans != out_chans:
# self.short_cut = RodinRollOutConv3D_GroupConv(in_chans, out_chans) # PSNR 13 first iteration.
# else:
# self.short_cut = None
# def forward(self, feats):
# if self.short_cut is not None:
# res_feats = self.short_cut(feats)
# else:
# res_feats = feats
# # return res_feats + self.conv(feats)
# feats = res_feats + self.conv(feats)
# # return self.act(feats) # as in resnet, add an act before return
# return feats # ! old behaviour, no act
# previous worked version
class RodinConv3D_SynthesisLayer(nn.Module):
def __init__(self, in_chans, out_chans) -> None:
super().__init__()
# x2 SR + 1x1 Conv Residual BLK
# self.conv3D = RodinRollOutConv3D(in_chans, out_chans)
self.act = nn.LeakyReLU(inplace=True)
self.conv = nn.Sequential(
RodinRollOutConv3D_GroupConv(in_chans, out_chans),
nn.LeakyReLU(inplace=True),
)
if in_chans != out_chans:
self.short_cut = RodinRollOutConv3D_GroupConv(in_chans, out_chans)
else:
self.short_cut = None
def forward(self, feats):
feats_out = self.conv(feats)
if self.short_cut is not None:
# ! failed below
feats_out = self.short_cut(
feats
) + feats_out # ! only difference here, no act() compared with baseline
# feats_out = self.act(self.short_cut(feats)) + feats_out # ! only difference here, no act() compared with baseline
else:
feats_out = feats_out + feats
return feats_out
class RodinRollOutConv3DSR2X(nn.Module):
def __init__(self, in_chans, **kwargs) -> None:
super().__init__()
self.conv3D = RodinRollOutConv3D_GroupConv(in_chans)
# self.conv3D = RodinRollOutConv3D(in_chans)
self.act = nn.LeakyReLU(inplace=True)
self.input_resolution = 224
def forward(self, x):
# x: B 3 112*112 C
B, C3, p, p = x.shape # after unpachify triplane
C = C3 // 3
group_size = C3 // C
assert group_size == 3
# p = int(N**0.5) # patch size
# assert p**2 == N, 'check input dim, no [cls] needed here'
assert group_size == 3, 'designed for triplane here'
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p,
p) # B 3 C N -> B 3C h W
if x.shape[-1] != self.input_resolution:
x = torch.nn.functional.interpolate(x,
size=(self.input_resolution,
self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=True)
x = x + self.conv3D(x)
return x
class RodinRollOutConv3DSR4X_lite(nn.Module):
def __init__(self, in_chans, input_resolutiopn=256, **kwargs) -> None:
super().__init__()
self.conv3D_0 = RodinRollOutConv3D_GroupConv(in_chans)
self.conv3D_1 = RodinRollOutConv3D_GroupConv(in_chans)
self.act = nn.LeakyReLU(inplace=True)
self.input_resolution = input_resolutiopn
def forward(self, x):
# x: B 3 112*112 C
B, C3, p, p = x.shape # after unpachify triplane
C = C3 // 3
group_size = C3 // C
assert group_size == 3
# p = int(N**0.5) # patch size
# assert p**2 == N, 'check input dim, no [cls] needed here'
assert group_size == 3, 'designed for triplane here'
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p,
p) # B 3 C N -> B 3C h W
if x.shape[-1] != self.input_resolution:
x = torch.nn.functional.interpolate(x,
size=(self.input_resolution,
self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=True)
# ! still not convering, not bug here?
# x = x + self.conv3D_0(x)
# x = x + self.conv3D_1(x)
x = x + self.act(self.conv3D_0(x))
x = x + self.act(self.conv3D_1(x))
# TODO: which is better, bilinear + conv or PixelUnshuffle?
return x
# class RodinConv3D2X_lite_mlp_as_residual(nn.Module):
# """lite 4X version, with MLP unshuffle to change the dimention
# """
# def __init__(self, in_chans, out_chans, input_resolution=256) -> None:
# super().__init__()
# self.act = nn.LeakyReLU(inplace=True)
# self.conv3D_0 = RodinRollOutConv3D_GroupConv(in_chans, out_chans)
# self.conv3D_1 = RodinRollOutConv3D_GroupConv(out_chans, out_chans)
# self.act = nn.LeakyReLU(inplace=True)
# self.input_resolution = input_resolution
# self.out_chans = out_chans
# if in_chans != out_chans: # ! only change the dimension
# self.short_cut = nn.Linear( # B 3C H W -> B 3C 4H 4W
# in_chans//3, # 144 / 3 = 48
# out_chans//3, # 32 * 16
# bias=True) # decoder to pat
# else:
# self.short_cut = None
# def shortcut_unpatchify_triplane(self, x, p=None):
# """separate triplane version; x shape: B (3*257) 768
# """
# assert self.short_cut is not None
# # B, L, C = x.shape
# B, C3, h, w = x.shape
# assert h == w
# L = h*w
# x = x.reshape(B, C3//3, 3, L).permute(0,2,3,1) # (B, 3, L // 3, C_in)
# x = self.short_cut(x) # B 3 L//3 C_out
# x = x.permute(0,1,3,2) # B 3 C_out L//3
# x = x.reshape(shape=(B, self.out_chans, h, w))
# # directly resize to the target, no unpatchify here since no 3D ViT is included here
# if w != self.input_resolution:
# x = torch.nn.functional.interpolate(x, # 4X SR
# size=(self.input_resolution,
# self.input_resolution),
# mode='bilinear',
# align_corners=False,
# antialias=True)
# return x
# def forward(self, x):
# # x: B 3 112*112 C
# B, C3, p, p = x.shape # after unpachify triplane
# C = C3 // 3
# if self.short_cut is not None:
# res_feats = self.shortcut_unpatchify_triplane(x)
# else:
# res_feats = x
# """following forward code copied from lite4x version
# """
# x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p,
# p) # B 3 C N -> B 3C h W
# if x.shape[-1] != self.input_resolution:
# x = torch.nn.functional.interpolate(x, # 4X SR
# size=(self.input_resolution,
# self.input_resolution),
# mode='bilinear',
# align_corners=False,
# antialias=True)
# x = res_feats + self.act(self.conv3D_0(x))
# x = x + self.act(self.conv3D_1(x))
# return x
class RodinConv3D4X_lite_mlp_as_residual(nn.Module):
"""lite 4X version, with MLP unshuffle to change the dimention
"""
def __init__(self,
in_chans,
out_chans,
input_resolution=256,
interp_mode='bilinear',
bcg_triplane=False) -> None:
super().__init__()
self.interp_mode = interp_mode
self.act = nn.LeakyReLU(inplace=True)
self.conv3D_0 = RodinRollOutConv3D_GroupConv(in_chans, out_chans)
self.conv3D_1 = RodinRollOutConv3D_GroupConv(out_chans, out_chans)
self.bcg_triplane = bcg_triplane
if bcg_triplane:
self.conv3D_1_bg = RodinRollOutConv3D_GroupConv(
out_chans, out_chans)
self.act = nn.LeakyReLU(inplace=True)
self.input_resolution = input_resolution
self.out_chans = out_chans
if in_chans != out_chans: # ! only change the dimension
self.short_cut = nn.Linear( # B 3C H W -> B 3C 4H 4W
in_chans // 3, # 144 / 3 = 48
out_chans // 3, # 32 * 16
bias=True) # decoder to pat
else:
self.short_cut = None
def shortcut_unpatchify_triplane(self, x, p=None):
"""separate triplane version; x shape: B (3*257) 768
"""
assert self.short_cut is not None
B, C3, h, w = x.shape
assert h == w
L = h * w
x = x.reshape(B, C3 // 3, 3, L).permute(0, 2, 3,
1) # (B, 3, L // 3, C_in)
x = self.short_cut(x) # B 3 L//3 C_out
x = x.permute(0, 1, 3, 2) # B 3 C_out L//3
x = x.reshape(shape=(B, self.out_chans, h, w))
# directly resize to the target, no unpatchify here since no 3D ViT is included here
if w != self.input_resolution:
x = torch.nn.functional.interpolate(
x, # 4X SR
size=(self.input_resolution, self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=True)
return x
def interpolate(self, feats):
if self.interp_mode == 'bilinear':
return torch.nn.functional.interpolate(
feats, # 4X SR
size=(self.input_resolution, self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=True)
else:
return torch.nn.functional.interpolate(
feats, # 4X SR
size=(self.input_resolution, self.input_resolution),
mode='nearest',
)
def forward(self, x):
# x: B 3 112*112 C
B, C3, p, p = x.shape # after unpachify triplane
C = C3 // 3
if self.short_cut is not None:
res_feats = self.shortcut_unpatchify_triplane(x)
else:
res_feats = x
if res_feats.shape[-1] != self.input_resolution:
res_feats = self.interpolate(res_feats)
"""following forward code copied from lite4x version
"""
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p,
p) # B 3 C N -> B 3C h W
if x.shape[-1] != self.input_resolution:
x = self.interpolate(x)
x0 = res_feats + self.act(self.conv3D_0(x)) # the base feature
x = x0 + self.act(self.conv3D_1(x0))
if self.bcg_triplane:
x_bcg = x0 + self.act(self.conv3D_1_bg(x0))
return torch.cat([x, x_bcg], 1)
else:
return x
class RodinConv3D4X_lite_mlp_as_residual_litev2(
RodinConv3D4X_lite_mlp_as_residual):
def __init__(self,
in_chans,
out_chans,
num_feat=128,
input_resolution=256,
interp_mode='bilinear',
bcg_triplane=False) -> None:
super().__init__(in_chans, out_chans, input_resolution, interp_mode,
bcg_triplane)
self.conv3D_0 = RodinRollOutConv3D_GroupConv(in_chans, in_chans)
self.conv_before_upsample = RodinRollOut_GroupConv_noConv3D(
in_chans, num_feat * 3)
self.conv3D_1 = RodinRollOut_GroupConv_noConv3D(
num_feat * 3, num_feat * 3)
self.conv_last = RodinRollOut_GroupConv_noConv3D(
num_feat * 3, out_chans)
self.short_cut = None
def forward(self, x):
# x: B 3 112*112 C
B, C3, p, p = x.shape # after unpachify triplane
C = C3 // 3
# if self.short_cut is not None:
# res_feats = self.shortcut_unpatchify_triplane(x)
# else:
# res_feats = x
# if res_feats.shape[-1] != self.input_resolution:
# res_feats = self.interpolate(res_feats)
"""following forward code copied from lite4x version
"""
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p,
p) # B 3 C N -> B 3C h W
x = x + self.conv3D_0(x) # the base feature
x = self.act(self.conv_before_upsample(x))
# if x.shape[-1] != self.input_resolution:
x = self.conv_last(self.act(self.conv3D_1(self.interpolate(x))))
return x
class RodinConv3D4X_lite_mlp_as_residual_lite(
RodinConv3D4X_lite_mlp_as_residual):
def __init__(self,
in_chans,
out_chans,
input_resolution=256,
interp_mode='bilinear') -> None:
super().__init__(in_chans, out_chans, input_resolution, interp_mode)
"""replace the first Rodin Conv 3D with ordinary rollout conv to save memory
"""
self.conv3D_0 = RodinRollOut_GroupConv_noConv3D(in_chans, out_chans)
class SR3D(nn.Module):
# https://github.com/SeanChenxy/Mimic3D/blob/77d313656df3cd5536d2c4c5766db3a56208eea6/training/networks_stylegan2.py#L629
# roll-out and apply two deconv/pixelUnshuffle layer
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
class RodinConv3D4X_lite_mlp_as_residual_improved(nn.Module):
def __init__(self,
in_chans,
num_feat,
out_chans,
input_resolution=256) -> None:
super().__init__()
assert in_chans == 4 * out_chans
assert num_feat == 2 * out_chans
self.input_resolution = input_resolution
# refer to https://github.com/JingyunLiang/SwinIR/blob/6545850fbf8df298df73d81f3e8cba638787c8bd/models/network_swinir.py#L750
self.upscale = 4
self.conv_after_body = RodinRollOutConv3D_GroupConv(
in_chans, in_chans, 3, 1, 1)
self.conv_before_upsample = nn.Sequential(
RodinRollOutConv3D_GroupConv(in_chans, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.conv_up1 = RodinRollOutConv3D_GroupConv(num_feat, num_feat, 3, 1,
1)
if self.upscale == 4:
self.conv_up2 = RodinRollOutConv3D_GroupConv(
num_feat, num_feat, 3, 1, 1)
self.conv_hr = RodinRollOutConv3D_GroupConv(num_feat, num_feat, 3, 1,
1)
self.conv_last = RodinRollOutConv3D_GroupConv(num_feat, out_chans, 3,
1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
# x: B 3 112*112 C
B, C3, p, p = x.shape # after unpachify triplane
C = C3 // 3
"""following forward code copied from lite4x version
"""
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p,
p) # B 3 C N -> B 3C h W
# ? nearest or bilinear
x = self.conv_after_body(x) + x
x = self.conv_before_upsample(x)
x = self.lrelu(
self.conv_up1(
torch.nn.functional.interpolate(
x,
scale_factor=2,
mode='nearest',
# align_corners=False,
# antialias=True
)))
if self.upscale == 4:
x = self.lrelu(
self.conv_up2(
torch.nn.functional.interpolate(
x,
scale_factor=2,
mode='nearest',
# align_corners=False,
# antialias=True
)))
x = self.conv_last(self.lrelu(self.conv_hr(x)))
assert x.shape[-1] == self.input_resolution
return x
class RodinConv3D4X_lite_improved_lint_withresidual(nn.Module):
def __init__(self,
in_chans,
num_feat,
out_chans,
input_resolution=256) -> None:
super().__init__()
assert in_chans == 4 * out_chans
assert num_feat == 2 * out_chans
self.input_resolution = input_resolution
# refer to https://github.com/JingyunLiang/SwinIR/blob/6545850fbf8df298df73d81f3e8cba638787c8bd/models/network_swinir.py#L750
self.upscale = 4
self.conv_after_body = RodinRollOutConv3D_GroupConv(
in_chans, in_chans, 3, 1, 1)
self.conv_before_upsample = nn.Sequential(
RodinRollOutConv3D_GroupConv(in_chans, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.conv_up1 = RodinRollOutConv3D_GroupConv(num_feat, num_feat, 3, 1,
1)
if self.upscale == 4:
self.conv_up2 = RodinRollOutConv3D_GroupConv(
num_feat, num_feat, 3, 1, 1)
self.conv_hr = RodinRollOutConv3D_GroupConv(num_feat, num_feat, 3, 1,
1)
self.conv_last = RodinRollOutConv3D_GroupConv(num_feat, out_chans, 3,
1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
# x: B 3 112*112 C
B, C3, p, p = x.shape # after unpachify triplane
C = C3 // 3
"""following forward code copied from lite4x version
"""
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p,
p) # B 3 C N -> B 3C h W
# ? nearest or bilinear
x = self.conv_after_body(x) + x
x = self.conv_before_upsample(x)
x = self.lrelu(
self.conv_up1(
torch.nn.functional.interpolate(
x,
scale_factor=2,
mode='nearest',
# align_corners=False,
# antialias=True
)))
if self.upscale == 4:
x = self.lrelu(
self.conv_up2(
torch.nn.functional.interpolate(
x,
scale_factor=2,
mode='nearest',
# align_corners=False,
# antialias=True
)))
x = self.conv_last(self.lrelu(self.conv_hr(x) + x))
assert x.shape[-1] == self.input_resolution
return x
class RodinRollOutConv3DSR_FlexibleChannels(nn.Module):
def __init__(self,
in_chans,
num_out_ch=96,
input_resolution=256,
**kwargs) -> None:
super().__init__()
self.block0 = RodinConv3D_SynthesisLayer(in_chans,
num_out_ch) # in_chans=48
self.block1 = RodinConv3D_SynthesisLayer(num_out_ch, num_out_ch)
self.input_resolution = input_resolution # 64 -> 256 SR
def forward(self, x):
# x: B 3 112*112 C
B, C3, p, p = x.shape # after unpachify triplane
C = C3 // 3
# group_size = C3 // C
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p,
p) # B 3 C N -> B 3C h W
if x.shape[-1] != self.input_resolution:
x = torch.nn.functional.interpolate(x,
size=(self.input_resolution,
self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=True)
x = self.block0(x)
x = self.block1(x)
return x
# previous worked version
class RodinRollOutConv3DSR4X(nn.Module):
# follow PixelUnshuffleUpsample
def __init__(self, in_chans, **kwargs) -> None:
super().__init__()
# self.block0 = RodinConv3D_SynthesisLayer(in_chans, 96 * 2) # TODO, match the old behaviour now.
# self.block1 = RodinConv3D_SynthesisLayer(96 * 2, 96)
self.block0 = RodinConv3D_SynthesisLayer(in_chans, 96)
self.block1 = RodinConv3D_SynthesisLayer(
96, 96) # baseline choice, validate with no LPIPS loss here
self.input_resolution = 64 # 64 -> 256
def forward(self, x):
# x: B 3 112*112 C
B, C3, p, p = x.shape # after unpachify triplane
C = C3 // 3
# group_size = C3 // C
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p,
p) # B 3 C N -> B 3C h W
if x.shape[-1] != self.input_resolution:
x = torch.nn.functional.interpolate(x,
size=(self.input_resolution,
self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=True)
x = self.block0(x)
x = self.block1(x)
return x
class Upsample3D(nn.Module):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat):
super().__init__()
m_convs = []
m_pixelshuffle = []
assert (scale & (scale - 1)) == 0, 'scale = 2^n'
self.scale = scale
for _ in range(int(math.log(scale, 2))):
m_convs.append(
RodinRollOutConv3D_GroupConv(num_feat, 4 * num_feat, 3, 1, 1))
m_pixelshuffle.append(nn.PixelShuffle(2))
self.m_convs = nn.ModuleList(m_convs)
self.m_pixelshuffle = nn.ModuleList(m_pixelshuffle)
# @torch.autocast(device_type='cuda')
def forward(self, x):
for scale_idx in range(int(math.log(self.scale, 2))):
x = self.m_convs[scale_idx](x) # B 3C H W
# x =
# B, C3, H, W = x.shape
x = x.reshape(x.shape[0] * 3, x.shape[1] // 3, *x.shape[2:])
x = self.m_pixelshuffle[scale_idx](x)
x = x.reshape(x.shape[0] // 3, x.shape[1] * 3, *x.shape[2:])
return x
class RodinConv3DPixelUnshuffleUpsample(nn.Module):
def __init__(self,
output_dim,
num_feat=32 * 6,
num_out_ch=32 * 3,
sr_ratio=4,
*args,
**kwargs) -> None:
super().__init__()
self.conv_after_body = RodinRollOutConv3D_GroupConv(
output_dim, output_dim, 3, 1, 1)
self.conv_before_upsample = nn.Sequential(
RodinRollOutConv3D_GroupConv(output_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.upsample = Upsample3D(sr_ratio, num_feat) # 4 time SR
self.conv_last = RodinRollOutConv3D_GroupConv(num_feat, num_out_ch, 3,
1, 1)
# @torch.autocast(device_type='cuda')
def forward(self, x, input_skip_connection=True, *args, **kwargs):
# x = self.conv_first(x)
if input_skip_connection:
x = self.conv_after_body(x) + x
else:
x = self.conv_after_body(x)
x = self.conv_before_upsample(x)
x = self.upsample(x)
x = self.conv_last(x)
return x
class RodinConv3DPixelUnshuffleUpsample_improvedVersion(nn.Module):
def __init__(
self,
output_dim,
num_out_ch=32 * 3,
sr_ratio=4,
input_resolution=256,
) -> None:
super().__init__()
self.input_resolution = input_resolution
# self.conv_first = RodinRollOutConv3D_GroupConv(output_dim, num_out_ch,
# 3, 1, 1)
self.upsample = Upsample3D(sr_ratio, output_dim) # 4 time SR
self.conv_last = RodinRollOutConv3D_GroupConv(output_dim, num_out_ch,
3, 1, 1)
def forward(self, x, bilinear_upsample=True):
B, C3, p, p = x.shape # after unpachify triplane
C = C3 // 3
group_size = C3 // C
assert group_size == 3, 'designed for triplane here'
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p,
p) # B 3 C N -> B 3C h W
if bilinear_upsample and x.shape[-1] != self.input_resolution:
x_bilinear_upsample = torch.nn.functional.interpolate(
x,
size=(self.input_resolution, self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=True)
x = self.upsample(x) + x_bilinear_upsample
else:
# x_bilinear_upsample = x
x = self.upsample(x)
x = self.conv_last(x)
return x
class RodinConv3DPixelUnshuffleUpsample_improvedVersion2(nn.Module):
"""removed nearest neighbour residual conenctions, add a conv layer residual conenction
"""
def __init__(
self,
output_dim,
num_out_ch=32 * 3,
sr_ratio=4,
input_resolution=256,
) -> None:
super().__init__()
self.input_resolution = input_resolution
self.conv_after_body = RodinRollOutConv3D_GroupConv(
output_dim, num_out_ch, 3, 1, 1)
self.upsample = Upsample3D(sr_ratio, output_dim) # 4 time SR
self.conv_last = RodinRollOutConv3D_GroupConv(output_dim, num_out_ch,
3, 1, 1)
def forward(self, x, input_skip_connection=True):
B, C3, p, p = x.shape # after unpachify triplane
C = C3 // 3
group_size = C3 // C
assert group_size == 3, 'designed for triplane here'
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p,
p) # B 3 C N -> B 3C h W
if input_skip_connection:
x = self.conv_after_body(x) + x
else:
x = self.conv_after_body(x)
x = self.upsample(x)
x = self.conv_last(x)
return x
class CLSCrossAttentionBlock(nn.Module):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
has_mlp=False):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = CrossAttention(dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.has_mlp = has_mlp
if has_mlp:
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
def forward(self, x):
x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x)))
if self.has_mlp:
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class Conv3DCrossAttentionBlock(nn.Module):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
has_mlp=False):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Conv3D_Aware_CrossAttention(dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.has_mlp = has_mlp
if has_mlp:
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
if self.has_mlp:
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class Conv3DCrossAttentionBlockXformerMHA(Conv3DCrossAttentionBlock):
def __init__(self,
dim,
num_heads,
mlp_ratio=4,
qkv_bias=False,
qk_scale=None,
drop=0,
attn_drop=0,
drop_path=0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
has_mlp=False):
super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop,
attn_drop, drop_path, act_layer, norm_layer, has_mlp)
# self.attn = xformer_Conv3D_Aware_CrossAttention(dim,
self.attn = xformer_Conv3D_Aware_CrossAttention_xygrid(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop)
class Conv3DCrossAttentionBlockXformerMHANested(
Conv3DCrossAttentionBlockXformerMHA):
def __init__(self,
dim,
num_heads,
mlp_ratio=4,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
has_mlp=False):
super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop,
attn_drop, drop_path, act_layer, norm_layer, has_mlp)
"""for in-place replaing the internal attn in Dino ViT.
"""
def forward(self, x):
Bx3, N, C = x.shape
B, group_size = Bx3 // 3, 3
x = x.reshape(B, group_size, N, C) # in plane vit
x = super().forward(x)
return x.reshape(B * group_size, N,
C) # to match the original attn size
class Conv3DCrossAttentionBlockXformerMHANested_withinC(
Conv3DCrossAttentionBlockXformerMHANested):
def __init__(self,
dim,
num_heads,
mlp_ratio=4,
qkv_bias=False,
qk_scale=None,
drop=0,
attn_drop=0,
drop_path=0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
has_mlp=False):
super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop,
attn_drop, drop_path, act_layer, norm_layer, has_mlp)
self.attn = xformer_Conv3D_Aware_CrossAttention_xygrid_withinC(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop)
def forward(self, x):
# basic TX attention forward function
x = x + self.drop_path(self.attn(self.norm1(x)))
if self.has_mlp:
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class TriplaneFusionBlock(nn.Module):
"""4 ViT blocks + 1 CrossAttentionBlock
"""
def __init__(self,
vit_blks,
num_heads,
embed_dim,
use_fusion_blk=True,
cross_attention_blk=CLSCrossAttentionBlock,
*args,
**kwargs) -> None:
super().__init__(*args, **kwargs)
self.num_branches = 3 # triplane
self.vit_blks = vit_blks
if use_fusion_blk:
self.fusion = nn.ModuleList()
# copied vit settings from https://github.dev/facebookresearch/dinov2
nh = num_heads
dim = embed_dim
mlp_ratio = 4 # defined for all dino2 model
qkv_bias = True
norm_layer = partial(nn.LayerNorm, eps=1e-6)
drop_path_rate = 0.3 # default setting
attn_drop = proj_drop = 0.0
qk_scale = None # TODO, double check
for d in range(self.num_branches):
self.fusion.append(
cross_attention_blk(
dim=dim,
num_heads=nh,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
# drop=drop,
drop=proj_drop,
attn_drop=attn_drop,
drop_path=drop_path_rate,
norm_layer=norm_layer, # type: ignore
has_mlp=False))
else:
self.fusion = None
def forward(self, x):
# modified from https://github.com/IBM/CrossViT/blob/main/models/crossvit.py#L132
"""x: B 3 N C, where N = H*W tokens
"""
# self attention, by merging the triplane channel into B for parallel computation
# ! move the below to the front of the first call
B, group_size, N, C = x.shape # has [cls] token in N
assert group_size == 3, 'triplane'
x = x.view(B * group_size, N, C)
for blk in self.vit_blks:
x = blk(x) # B 3 N C
if self.fusion is None:
return x.view(B, group_size, N, C)
# outs_b = x.view(B, group_size, N,
# C).chunk(chunks=3,
# dim=1) # 3 * [B, 1, N//3, C] Tensors, for fusion
outs_b = x.chunk(chunks=3,
dim=0) # 3 * [B, N//3, C] Tensors, for fusion
# only take the cls token out
proj_cls_token = [x[:, 0:1] for x in outs_b]
# cross attention
outs = []
for i in range(self.num_branches):
tmp = torch.cat(
(proj_cls_token[i], outs_b[(i + 1) % self.num_branches][:, 1:,
...]),
dim=1)
tmp = self.fusion[i](tmp)
# reverted_proj_cls_token = self.revert_projs[i](tmp[:, 0:1, ...])
reverted_proj_cls_token = tmp[:, 0:1, ...]
tmp = torch.cat((reverted_proj_cls_token, outs_b[i][:, 1:, ...]),
dim=1)
outs.append(tmp)
# outs = ? needs to merge back?
outs = torch.stack(outs, 1) # B 3 N C
return outs
class TriplaneFusionBlockv2(nn.Module):
"""4 ViT blocks + 1 CrossAttentionBlock
"""
def __init__(self,
vit_blks,
num_heads,
embed_dim,
use_fusion_blk=True,
fusion_ca_blk=Conv3DCrossAttentionBlock,
*args,
**kwargs) -> None:
super().__init__(*args, **kwargs)
self.num_branches = 3 # triplane
self.vit_blks = vit_blks
if use_fusion_blk:
# self.fusion = nn.ModuleList()
# copied vit settings from https://github.dev/facebookresearch/dinov2
nh = num_heads
dim = embed_dim
mlp_ratio = 4 # defined for all dino2 model
qkv_bias = True
norm_layer = partial(nn.LayerNorm, eps=1e-6)
drop_path_rate = 0.3 # default setting
attn_drop = proj_drop = 0.0
qk_scale = None # TODO, double check
# for d in range(self.num_branches):
self.fusion = fusion_ca_blk( # one fusion is enough
dim=dim,
num_heads=nh,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
# drop=drop,
drop=proj_drop,
attn_drop=attn_drop,
drop_path=drop_path_rate,
norm_layer=norm_layer, # type: ignore
has_mlp=False)
else:
self.fusion = None
def forward(self, x):
# modified from https://github.com/IBM/CrossViT/blob/main/models/crossvit.py#L132
"""x: B 3 N C, where N = H*W tokens
"""
# self attention, by merging the triplane channel into B for parallel computation
# ! move the below to the front of the first call
B, group_size, N, C = x.shape # has [cls] token in N
assert group_size == 3, 'triplane'
x = x.reshape(B * group_size, N, C)
for blk in self.vit_blks:
x = blk(x) # B 3 N C
if self.fusion is None:
return x.reshape(B, group_size, N, C)
x = x.reshape(B, group_size, N, C) # .chunk(chunks=3,
# dim=1) # 3 * [B, N//3, C] Tensors, for fusion
return self.fusion(x)
class TriplaneFusionBlockv3(TriplaneFusionBlockv2):
def __init__(self,
vit_blks,
num_heads,
embed_dim,
use_fusion_blk=True,
fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHA,
*args,
**kwargs) -> None:
super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk,
fusion_ca_blk, *args, **kwargs)
class TriplaneFusionBlockv4(TriplaneFusionBlockv3):
def __init__(self,
vit_blks,
num_heads,
embed_dim,
use_fusion_blk=True,
fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHA,
*args,
**kwargs) -> None:
super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk,
fusion_ca_blk, *args, **kwargs)
"""OOM? directly replace the atten here
"""
assert len(vit_blks) == 2
# del self.vit_blks[1].attn
del self.vit_blks[1].attn, self.vit_blks[1].ls1, self.vit_blks[1].norm1
def ffn_residual_func(self, tx_blk, x: Tensor) -> Tensor:
return tx_blk.ls2(
tx_blk.mlp(tx_blk.norm2(x))
) # https://github.com/facebookresearch/dinov2/blob/c3c2683a13cde94d4d99f523cf4170384b00c34c/dinov2/layers/block.py#L86C1-L87C53
def forward(self, x):
"""x: B 3 N C, where N = H*W tokens
"""
assert self.fusion is not None
B, group_size, N, C = x.shape # has [cls] token in N
x = x.reshape(B * group_size, N, C) # in plane vit
# in plane self attention
x = self.vit_blks[0](x)
# 3D cross attention blk + ffn
x = x + self.fusion(x.reshape(B, group_size, N, C)).reshape(
B * group_size, N, C)
x = x + self.ffn_residual_func(self.vit_blks[1], x)
return x.reshape(B, group_size, N, C)
class TriplaneFusionBlockv4_nested(nn.Module):
def __init__(self,
vit_blks,
num_heads,
embed_dim,
use_fusion_blk=True,
fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested,
*args,
**kwargs) -> None:
super().__init__()
self.num_branches = 3 # triplane
self.vit_blks = vit_blks
assert use_fusion_blk
assert len(vit_blks) == 2
# ! replace vit_blks[1] attn layer with 3D aware attention
del self.vit_blks[
1].attn # , self.vit_blks[1].ls1, self.vit_blks[1].norm1
# copied vit settings from https://github.dev/facebookresearch/dinov2
nh = num_heads
dim = embed_dim
mlp_ratio = 4 # defined for all dino2 model
qkv_bias = True
norm_layer = partial(nn.LayerNorm, eps=1e-6)
drop_path_rate = 0.3 # default setting
attn_drop = proj_drop = 0.0
qk_scale = None # TODO, double check
self.vit_blks[1].attn = fusion_ca_blk( # one fusion is enough
dim=dim,
num_heads=nh,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
# drop=drop,
drop=proj_drop,
attn_drop=attn_drop,
drop_path=drop_path_rate,
norm_layer=norm_layer, # type: ignore
has_mlp=False)
def forward(self, x):
"""x: B 3 N C, where N = H*W tokens
"""
# self attention, by merging the triplane channel into B for parallel computation
# ! move the below to the front of the first call
B, group_size, N, C = x.shape # has [cls] token in N
assert group_size == 3, 'triplane'
x = x.reshape(B * group_size, N, C)
for blk in self.vit_blks:
x = blk(x) # B 3 N C
# TODO, avoid the reshape overhead?
return x.reshape(B, group_size, N, C)
class TriplaneFusionBlockv4_nested_init_from_dino(nn.Module):
def __init__(self,
vit_blks,
num_heads,
embed_dim,
use_fusion_blk=True,
fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested,
init_from_dino=True,
*args,
**kwargs) -> None:
super().__init__()
self.num_branches = 3 # triplane
self.vit_blks = vit_blks
assert use_fusion_blk
assert len(vit_blks) == 2
# copied vit settings from https://github.dev/facebookresearch/dinov2
nh = num_heads
dim = embed_dim
mlp_ratio = 4 # defined for all dino2 model
qkv_bias = True
norm_layer = partial(nn.LayerNorm, eps=1e-6)
drop_path_rate = 0.3 # default setting
attn_drop = proj_drop = 0.0
qk_scale = None # TODO, double check
attn_3d = fusion_ca_blk( # one fusion is enough
dim=dim,
num_heads=nh,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
# drop=drop,
drop=proj_drop,
attn_drop=attn_drop,
drop_path=drop_path_rate,
norm_layer=norm_layer, # type: ignore
has_mlp=False)
# ! initialize 3dattn from dino attn
if init_from_dino:
merged_qkv_linear = self.vit_blks[1].attn.qkv
attn_3d.attn.proj.load_state_dict(
self.vit_blks[1].attn.proj.state_dict())
# Initialize the Q, K, and V linear layers using the weights of the merged QKV linear layer
attn_3d.attn.wq.weight.data = merged_qkv_linear.weight.data[:
dim, :]
attn_3d.attn.w_kv.weight.data = merged_qkv_linear.weight.data[
dim:, :]
# Optionally, you can initialize the biases as well (if your QKV linear layer has biases)
if qkv_bias:
attn_3d.attn.wq.bias.data = merged_qkv_linear.bias.data[:dim]
attn_3d.attn.w_kv.bias.data = merged_qkv_linear.bias.data[dim:]
del self.vit_blks[1].attn
# ! assign
self.vit_blks[1].attn = attn_3d
def forward(self, x):
"""x: B 3 N C, where N = H*W tokens
"""
# self attention, by merging the triplane channel into B for parallel computation
# ! move the below to the front of the first call
B, group_size, N, C = x.shape # has [cls] token in N
assert group_size == 3, 'triplane'
x = x.reshape(B * group_size, N, C)
for blk in self.vit_blks:
x = blk(x) # B 3 N C
# TODO, avoid the reshape overhead?
return x.reshape(B, group_size, N, C)
class TriplaneFusionBlockv4_nested_init_from_dino_lite(nn.Module):
def __init__(self,
vit_blks,
num_heads,
embed_dim,
use_fusion_blk=True,
fusion_ca_blk=None,
*args,
**kwargs) -> None:
super().__init__()
self.num_branches = 3 # triplane
self.vit_blks = vit_blks
assert use_fusion_blk
assert len(vit_blks) == 2
# copied vit settings from https://github.dev/facebookresearch/dinov2
nh = num_heads
dim = embed_dim
mlp_ratio = 4 # defined for all dino2 model
qkv_bias = True
norm_layer = partial(nn.LayerNorm, eps=1e-6)
drop_path_rate = 0.3 # default setting
attn_drop = proj_drop = 0.0
qk_scale = None # TODO, double check
attn_3d = xformer_Conv3D_Aware_CrossAttention_xygrid_withinC( # ! raw 3D attn layer
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=proj_drop)
del self.vit_blks[1].attn
# ! assign
self.vit_blks[1].attn = attn_3d
def forward(self, x):
"""x: B N C, where N = H*W tokens. Just raw ViT forward pass
"""
# ! move the below to the front of the first call
B, N, C = x.shape # has [cls] token in N
for blk in self.vit_blks:
x = blk(x) # B N C
return x
class TriplaneFusionBlockv4_nested_init_from_dino_lite_merge(nn.Module):
def __init__(self,
vit_blks,
num_heads,
embed_dim,
use_fusion_blk=True,
fusion_ca_blk=None,
*args,
**kwargs) -> None:
super().__init__()
self.vit_blks = vit_blks
assert use_fusion_blk
assert len(vit_blks) == 2
# copied vit settings from https://github.dev/facebookresearch/dinov2
nh = num_heads
dim = embed_dim
qkv_bias = True
attn_drop = proj_drop = 0.0
qk_scale = None # TODO, double check
if False: # abla
for blk in self.vit_blks:
attn_3d = xformer_Conv3D_Aware_CrossAttention_xygrid_withinC( # ! raw 3D attn layer
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=proj_drop)
blk.attn = self_cross_attn(blk.attn, attn_3d)
def forward(self, x):
"""x: B N C, where N = H*W tokens. Just raw ViT forward pass
"""
# ! move the below to the front of the first call
B, N, C = x.shape # has [cls] token in N
for blk in self.vit_blks:
x = blk(x) # B N C
return x
class TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C(TriplaneFusionBlockv4_nested_init_from_dino_lite_merge):
# on roll out + B 3L C
def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, init_from_dino=True, *args, **kwargs) -> None:
super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, fusion_ca_blk, init_from_dino, *args, **kwargs)
def forward(self, x):
"""x: B 3 N C, where N = H*W tokens
"""
# ! move the below to the front of the first call
# B, N, C = x.shape # has [cls] token in N
B, group_size, N, C = x.shape # has [cls] token in N
x = x.reshape(B, group_size*N, C)
for blk in self.vit_blks:
x = blk(x) # B N C
x = x.reshape(B, group_size, N, C) # outer loop tradition
return x
class TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C_withrollout(TriplaneFusionBlockv4_nested_init_from_dino_lite_merge):
# roll out + B 3L C
def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, init_from_dino=True, *args, **kwargs) -> None:
super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, fusion_ca_blk, init_from_dino, *args, **kwargs)
def forward(self, x):
"""x: B 3 N C, where N = H*W tokens
"""
# ! move the below to the front of the first call
# B, N, C = x.shape # has [cls] token in N
B, group_size, N, C = x.shape # has [cls] token in N
x = x.reshape(B*group_size, N, C)
x = self.vit_blks[0](x)
x = x.reshape(B,group_size*N, C)
x = self.vit_blks[1](x)
x = x.reshape(B, group_size, N, C) # outer loop tradition
return x
class TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_add3DAttn(TriplaneFusionBlockv4_nested_init_from_dino):
# no roll out + 3D Attention
def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, init_from_dino=True, *args, **kwargs) -> None:
super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, fusion_ca_blk, init_from_dino, *args, **kwargs)
def forward(self, x):
"""x: B 3 N C, where N = H*W tokens
"""
B, group_size, N, C = x.shape # has [cls] token in N
x = x.reshape(B, group_size*N, C)
x = self.vit_blks[0](x) # B 3 L C
# ! move the below to the front of the first call
x = x.reshape(B, group_size, N, C).reshape(B*group_size, N, C)
x = self.vit_blks[1](x) # has 3D attention
return x.reshape(B, group_size, N, C)
return x
class TriplaneFusionBlockv5_ldm_addCA(nn.Module):
def __init__(self,
vit_blks,
num_heads,
embed_dim,
use_fusion_blk=True,
fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested,
*args,
**kwargs) -> None:
super().__init__()
self.num_branches = 3 # triplane
self.vit_blks = vit_blks
assert use_fusion_blk
assert len(vit_blks) == 2
# ! rather than replacing, add a 3D attention block after.
# del self.vit_blks[
# 1].attn # , self.vit_blks[1].ls1, self.vit_blks[1].norm1
self.norm_for_atten_3d = deepcopy(self.vit_blks[1].norm1)
# copied vit settings from https://github.dev/facebookresearch/dinov2
nh = num_heads
dim = embed_dim
mlp_ratio = 4 # defined for all dino2 model
qkv_bias = True
norm_layer = partial(nn.LayerNorm, eps=1e-6)
drop_path_rate = 0.3 # default setting
attn_drop = proj_drop = 0.0
qk_scale = None # TODO, double check
self.attn_3d = xformer_Conv3D_Aware_CrossAttention_xygrid(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=proj_drop)
def forward(self, x):
"""x: B 3 N C, where N = H*W tokens
"""
# self attention, by merging the triplane channel into B for parallel computation
# ! move the below to the front of the first call
B, group_size, N, C = x.shape # has [cls] token in N
assert group_size == 3, 'triplane'
flatten_token = lambda x: x.reshape(B * group_size, N, C)
unflatten_token = lambda x: x.reshape(B, group_size, N, C)
x = flatten_token(x)
x = self.vit_blks[0](x)
x = unflatten_token(x)
x = self.attn_3d(self.norm_for_atten_3d(x)) + x
x = flatten_token(x)
x = self.vit_blks[1](x)
return unflatten_token(x)
class TriplaneFusionBlockv6_ldm_addCA_Init3DAttnfrom2D(
TriplaneFusionBlockv5_ldm_addCA):
def __init__(self,
vit_blks,
num_heads,
embed_dim,
use_fusion_blk=True,
fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested,
*args,
**kwargs) -> None:
super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk,
fusion_ca_blk, *args, **kwargs)
def forward(self, x):
"""x: B 3 N C, where N = H*W tokens
"""
# self attention, by merging the triplane channel into B for parallel computation
# ! move the below to the front of the first call
B, group_size, N, C = x.shape # has [cls] token in N
assert group_size == 3, 'triplane'
flatten_token = lambda x: x.reshape(B * group_size, N, C)
unflatten_token = lambda x: x.reshape(B, group_size, N, C)
x = flatten_token(x)
x = self.vit_blks[0](x)
x = unflatten_token(x)
x = self.attn_3d(self.norm_for_atten_3d(x)) + x
x = flatten_token(x)
x = self.vit_blks[1](x)
return unflatten_token(x)
class TriplaneFusionBlockv5_ldm_add_dualCA(nn.Module):
def __init__(self,
vit_blks,
num_heads,
embed_dim,
use_fusion_blk=True,
fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested,
*args,
**kwargs) -> None:
super().__init__()
self.num_branches = 3 # triplane
self.vit_blks = vit_blks
assert use_fusion_blk
assert len(vit_blks) == 2
# ! rather than replacing, add a 3D attention block after.
# del self.vit_blks[
# 1].attn # , self.vit_blks[1].ls1, self.vit_blks[1].norm1
self.norm_for_atten_3d_0 = deepcopy(self.vit_blks[0].norm1)
self.norm_for_atten_3d_1 = deepcopy(self.vit_blks[1].norm1)
# copied vit settings from https://github.dev/facebookresearch/dinov2
nh = num_heads
dim = embed_dim
mlp_ratio = 4 # defined for all dino2 model
qkv_bias = True
norm_layer = partial(nn.LayerNorm, eps=1e-6)
drop_path_rate = 0.3 # default setting
attn_drop = proj_drop = 0.0
qk_scale = None # TODO, double check
self.attn_3d_0 = xformer_Conv3D_Aware_CrossAttention_xygrid(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=proj_drop)
self.attn_3d_1 = deepcopy(self.attn_3d_0)
def forward(self, x):
"""x: B 3 N C, where N = H*W tokens
"""
# self attention, by merging the triplane channel into B for parallel computation
# ! move the below to the front of the first call
B, group_size, N, C = x.shape # has [cls] token in N
assert group_size == 3, 'triplane'
flatten_token = lambda x: x.reshape(B * group_size, N, C)
unflatten_token = lambda x: x.reshape(B, group_size, N, C)
x = flatten_token(x)
x = self.vit_blks[0](x)
x = unflatten_token(x)
x = self.attn_3d_0(self.norm_for_atten_3d_0(x)) + x
x = flatten_token(x)
x = self.vit_blks[1](x)
x = unflatten_token(x)
x = self.attn_3d_1(self.norm_for_atten_3d_1(x)) + x
return unflatten_token(x)
def drop_path(x, drop_prob: float = 0., training: bool = False):
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0], ) + (1, ) * (
x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(
shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class Mlp(nn.Module):
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Block(nn.Module):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
# self.attn = Attention(dim,
self.attn = MemEffAttention(dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
def forward(self, x, return_attention=False):
y, attn = self.attn(self.norm1(x))
if return_attention:
return attn
x = x + self.drop_path(y)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
num_patches = (img_size // patch_size) * (img_size // patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans,
embed_dim,
kernel_size=patch_size,
stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2) # B, C, L -> B, L, C
return x
class VisionTransformer(nn.Module):
""" Vision Transformer """
def __init__(self,
img_size=[224],
patch_size=16,
in_chans=3,
num_classes=0,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer='nn.LayerNorm',
patch_embedding=True,
cls_token=True,
pixel_unshuffle=False,
**kwargs):
super().__init__()
self.num_features = self.embed_dim = embed_dim
self.patch_size = patch_size
# if norm_layer == 'nn.LayerNorm':
norm_layer = partial(nn.LayerNorm, eps=1e-6)
if patch_embedding:
self.patch_embed = PatchEmbed(img_size=img_size[0],
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.img_size = self.patch_embed.img_size
else:
self.patch_embed = None
self.img_size = img_size[0]
num_patches = (img_size[0] // patch_size) * (img_size[0] //
patch_size)
if cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dim))
else:
self.cls_token = None
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer) for i in range(depth)
])
self.norm = norm_layer(embed_dim)
# Classifier head
self.head = nn.Linear(
embed_dim, num_classes) if num_classes > 0 else nn.Identity()
trunc_normal_(self.pos_embed, std=.02)
if cls_token:
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
# if pixel_unshuffle:
# self.decoder_pred = nn.Linear(embed_dim,
# patch_size**2 * out_chans,
# bias=True) # decoder to patch
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def interpolate_pos_encoding(self, x, w, h):
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1
if npatch == N and w == h:
return self.pos_embed
patch_pos_embed = self.pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + 0.1, h0 + 0.1
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)),
dim).permute(0, 3, 1, 2),
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
mode='bicubic',
)
assert int(w0) == patch_pos_embed.shape[-2] and int(
h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(2, -1, dim)
if self.cls_token is not None:
class_pos_embed = self.pos_embed[:, 0]
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed),
dim=1)
return patch_pos_embed
def prepare_tokens(self, x):
B, nc, w, h = x.shape
x = self.patch_embed(x) # patch linear embedding
# add the [CLS] token to the embed patch tokens
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# add positional encoding to each token
x = x + self.interpolate_pos_encoding(x, w, h)
return self.pos_drop(x)
def forward(self, x):
x = self.prepare_tokens(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x[:, 1:] # return spatial feature maps, not the [CLS] token
# return x[:, 0]
def get_last_selfattention(self, x):
x = self.prepare_tokens(x)
for i, blk in enumerate(self.blocks):
if i < len(self.blocks) - 1:
x = blk(x)
else:
# return attention of the last block
return blk(x, return_attention=True)
def get_intermediate_layers(self, x, n=1):
x = self.prepare_tokens(x)
# we return the output tokens from the `n` last blocks
output = []
for i, blk in enumerate(self.blocks):
x = blk(x)
if len(self.blocks) - i <= n:
output.append(self.norm(x))
return output
def vit_tiny(patch_size=16, **kwargs):
model = VisionTransformer(patch_size=patch_size,
embed_dim=192,
depth=12,
num_heads=3,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs)
return model
def vit_small(patch_size=16, **kwargs):
model = VisionTransformer(
patch_size=patch_size,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), # type: ignore
**kwargs)
return model
def vit_base(patch_size=16, **kwargs):
model = VisionTransformer(patch_size=patch_size,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs)
return model
vits = vit_small
vitb = vit_base