mart9992's picture
m
2cd560a
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_norm_layer, trunc_normal_init
from mmcv.cnn.bricks.transformer import build_dropout
try:
from torch.cuda.amp import autocast
WITH_AUTOCAST = True
except ImportError:
WITH_AUTOCAST = False
def get_grid_index(init_grid_size, map_size, device):
"""For every initial grid, get its index in the feature map.
Note:
[H_init, W_init]: shape of initial grid
[H, W]: shape of feature map
N_init: numbers of initial token
Args:
init_grid_size (list[int] or tuple[int]): initial grid resolution in
format [H_init, W_init].
map_size (list[int] or tuple[int]): feature map resolution in format
[H, W].
device: the device of output
Returns:
idx (torch.LongTensor[B, N_init]): index in flattened feature map.
"""
H_init, W_init = init_grid_size
H, W = map_size
idx = torch.arange(H * W, device=device).reshape(1, 1, H, W)
idx = F.interpolate(idx.float(), [H_init, W_init], mode='nearest').long()
return idx.flatten()
def index_points(points, idx):
"""Sample features following the index.
Note:
B: batch size
N: point number
C: channel number of each point
Ns: sampled point number
Args:
points (torch.Tensor[B, N, C]): input points data
idx (torch.LongTensor[B, S]): sample index
Returns:
new_points (torch.Tensor[B, Ns, C]):, indexed points data
"""
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(
B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
new_points = points[batch_indices, idx, :]
return new_points
def token2map(token_dict):
"""Transform vision tokens to feature map. This function only works when
the resolution of the feature map is not higher than the initial grid
structure.
Note:
B: batch size
C: channel number of each token
[H, W]: shape of feature map
N_init: numbers of initial token
Args:
token_dict (dict): dict for token information.
Returns:
x_out (Tensor[B, C, H, W]): feature map.
"""
x = token_dict['x']
H, W = token_dict['map_size']
H_init, W_init = token_dict['init_grid_size']
idx_token = token_dict['idx_token']
B, N, C = x.shape
N_init = H_init * W_init
device = x.device
if N_init == N and N == H * W:
# for the initial tokens with grid structure, just reshape
return x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
# for each initial grid, get the corresponding index in
# the flattened feature map.
idx_hw = get_grid_index([H_init, W_init], [H, W],
device=device)[None, :].expand(B, -1)
idx_batch = torch.arange(B, device=device)[:, None].expand(B, N_init)
value = x.new_ones(B * N_init)
# choose the way with fewer flops.
if N_init < N * H * W:
# use sparse matrix multiplication
# Flops: B * N_init * (C+2)
idx_hw = idx_hw + idx_batch * H * W
idx_tokens = idx_token + idx_batch * N
coor = torch.stack([idx_hw, idx_tokens], dim=0).reshape(2, B * N_init)
# torch.sparse do not support gradient for
# sparse tensor, so we detach it
value = value.detach().to(torch.float32)
# build a sparse matrix with the shape [B * H * W, B * N]
A = torch.sparse.FloatTensor(coor, value,
torch.Size([B * H * W, B * N]))
# normalize the weight for each row
if WITH_AUTOCAST:
with autocast(enabled=False):
all_weight = A @ x.new_ones(B * N, 1).type(
torch.float32) + 1e-6
else:
all_weight = A @ x.new_ones(B * N, 1).type(torch.float32) + 1e-6
value = value / all_weight[idx_hw.reshape(-1), 0]
# update the matrix with normalize weight
A = torch.sparse.FloatTensor(coor, value,
torch.Size([B * H * W, B * N]))
# sparse matrix multiplication
if WITH_AUTOCAST:
with autocast(enabled=False):
x_out = A @ x.reshape(B * N, C).to(torch.float32) # [B*H*W, C]
else:
x_out = A @ x.reshape(B * N, C).to(torch.float32) # [B*H*W, C]
else:
# use dense matrix multiplication
# Flops: B * N * H * W * (C+2)
coor = torch.stack([idx_batch, idx_hw, idx_token],
dim=0).reshape(3, B * N_init)
# build a matrix with shape [B, H*W, N]
A = torch.sparse.FloatTensor(coor, value, torch.Size([B, H * W,
N])).to_dense()
# normalize the weight
A = A / (A.sum(dim=-1, keepdim=True) + 1e-6)
x_out = A @ x # [B, H*W, C]
x_out = x_out.type(x.dtype)
x_out = x_out.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
return x_out
def map2token(feature_map, token_dict):
"""Transform feature map to vision tokens. This function only works when
the resolution of the feature map is not higher than the initial grid
structure.
Note:
B: batch size
C: channel number
[H, W]: shape of feature map
N_init: numbers of initial token
Args:
feature_map (Tensor[B, C, H, W]): feature map.
token_dict (dict): dict for token information.
Returns:
out (Tensor[B, N, C]): token features.
"""
idx_token = token_dict['idx_token']
N = token_dict['token_num']
H_init, W_init = token_dict['init_grid_size']
N_init = H_init * W_init
B, C, H, W = feature_map.shape
device = feature_map.device
if N_init == N and N == H * W:
# for the initial tokens with grid structure, just reshape
return feature_map.flatten(2).permute(0, 2, 1).contiguous()
idx_hw = get_grid_index([H_init, W_init], [H, W],
device=device)[None, :].expand(B, -1)
idx_batch = torch.arange(B, device=device)[:, None].expand(B, N_init)
value = feature_map.new_ones(B * N_init)
# choose the way with fewer flops.
if N_init < N * H * W:
# use sparse matrix multiplication
# Flops: B * N_init * (C+2)
idx_token = idx_token + idx_batch * N
idx_hw = idx_hw + idx_batch * H * W
indices = torch.stack([idx_token, idx_hw], dim=0).reshape(2, -1)
# sparse mm do not support gradient for sparse matrix
value = value.detach().to(torch.float32)
# build a sparse matrix with shape [B*N, B*H*W]
A = torch.sparse_coo_tensor(indices, value, (B * N, B * H * W))
# normalize the matrix
if WITH_AUTOCAST:
with autocast(enabled=False):
all_weight = A @ torch.ones(
[B * H * W, 1], device=device, dtype=torch.float32) + 1e-6
else:
all_weight = A @ torch.ones(
[B * H * W, 1], device=device, dtype=torch.float32) + 1e-6
value = value / all_weight[idx_token.reshape(-1), 0]
A = torch.sparse_coo_tensor(indices, value, (B * N, B * H * W))
# out: [B*N, C]
if WITH_AUTOCAST:
with autocast(enabled=False):
out = A @ feature_map.permute(0, 2, 3, 1).contiguous().reshape(
B * H * W, C).float()
else:
out = A @ feature_map.permute(0, 2, 3, 1).contiguous().reshape(
B * H * W, C).float()
else:
# use dense matrix multiplication
# Flops: B * N * H * W * (C+2)
indices = torch.stack([idx_batch, idx_token, idx_hw],
dim=0).reshape(3, -1)
value = value.detach() # To reduce the training time, we detach here.
A = torch.sparse_coo_tensor(indices, value, (B, N, H * W)).to_dense()
# normalize the matrix
A = A / (A.sum(dim=-1, keepdim=True) + 1e-6)
out = A @ feature_map.permute(0, 2, 3, 1).reshape(B, H * W,
C).contiguous()
out = out.type(feature_map.dtype)
out = out.reshape(B, N, C)
return out
def token_interp(target_dict, source_dict):
"""Transform token features between different distribution.
Note:
B: batch size
N: token number
C: channel number
Args:
target_dict (dict): dict for target token information
source_dict (dict): dict for source token information.
Returns:
x_out (Tensor[B, N, C]): token features.
"""
x_s = source_dict['x']
idx_token_s = source_dict['idx_token']
idx_token_t = target_dict['idx_token']
T = target_dict['token_num']
B, S, C = x_s.shape
N_init = idx_token_s.shape[1]
weight = target_dict['agg_weight'] if 'agg_weight' in target_dict.keys(
) else None
if weight is None:
weight = x_s.new_ones(B, N_init, 1)
weight = weight.reshape(-1)
# choose the way with fewer flops.
if N_init < T * S:
# use sparse matrix multiplication
# Flops: B * N_init * (C+2)
idx_token_t = idx_token_t + torch.arange(
B, device=x_s.device)[:, None] * T
idx_token_s = idx_token_s + torch.arange(
B, device=x_s.device)[:, None] * S
coor = torch.stack([idx_token_t, idx_token_s],
dim=0).reshape(2, B * N_init)
# torch.sparse does not support grad for sparse matrix
weight = weight.float().detach().to(torch.float32)
# build a matrix with shape [B*T, B*S]
A = torch.sparse.FloatTensor(coor, weight, torch.Size([B * T, B * S]))
# normalize the matrix
if WITH_AUTOCAST:
with autocast(enabled=False):
all_weight = A.type(torch.float32) @ x_s.new_ones(
B * S, 1).type(torch.float32) + 1e-6
else:
all_weight = A.type(torch.float32) @ x_s.new_ones(B * S, 1).type(
torch.float32) + 1e-6
weight = weight / all_weight[idx_token_t.reshape(-1), 0]
A = torch.sparse.FloatTensor(coor, weight, torch.Size([B * T, B * S]))
# sparse matmul
if WITH_AUTOCAST:
with autocast(enabled=False):
x_out = A.type(torch.float32) @ x_s.reshape(B * S, C).type(
torch.float32)
else:
x_out = A.type(torch.float32) @ x_s.reshape(B * S, C).type(
torch.float32)
else:
# use dense matrix multiplication
# Flops: B * T * S * (C+2)
idx_batch = torch.arange(
B, device=x_s.device)[:, None].expand(B, N_init)
coor = torch.stack([idx_batch, idx_token_t, idx_token_s],
dim=0).reshape(3, B * N_init)
weight = weight.detach() # detach to reduce training time
# build a matrix with shape [B, T, S]
A = torch.sparse.FloatTensor(coor, weight, torch.Size([B, T,
S])).to_dense()
# normalize the matrix
A = A / (A.sum(dim=-1, keepdim=True) + 1e-6)
# dense matmul
x_out = A @ x_s
x_out = x_out.reshape(B, T, C).type(x_s.dtype)
return x_out
def cluster_dpc_knn(token_dict, cluster_num, k=5, token_mask=None):
"""Cluster tokens with DPC-KNN algorithm.
Note:
B: batch size
N: token number
C: channel number
Args:
token_dict (dict): dict for token information
cluster_num (int): cluster number
k (int): number of the nearest neighbor used for local density.
token_mask (Tensor[B, N]): mask indicating which token is the
padded empty token. Non-zero value means the token is meaningful,
zero value means the token is an empty token. If set to None, all
tokens are regarded as meaningful.
Return:
idx_cluster (Tensor[B, N]): cluster index of each token.
cluster_num (int): actual cluster number. In this function, it equals
to the input cluster number.
"""
with torch.no_grad():
x = token_dict['x']
B, N, C = x.shape
dist_matrix = torch.cdist(x, x) / (C**0.5)
if token_mask is not None:
token_mask = token_mask > 0
# in order to not affect the local density, the
# distance between empty tokens and any other
# tokens should be the maximal distance.
dist_matrix = \
dist_matrix * token_mask[:, None, :] +\
(dist_matrix.max() + 1) * (~token_mask[:, None, :])
# get local density
dist_nearest, index_nearest = torch.topk(
dist_matrix, k=k, dim=-1, largest=False)
density = (-(dist_nearest**2).mean(dim=-1)).exp()
# add a little noise to ensure no tokens have the same density.
density = density + torch.rand(
density.shape, device=density.device, dtype=density.dtype) * 1e-6
if token_mask is not None:
# the density of empty token should be 0
density = density * token_mask
# get distance indicator
mask = density[:, None, :] > density[:, :, None]
mask = mask.type(x.dtype)
dist_max = dist_matrix.flatten(1).max(dim=-1)[0][:, None, None]
dist, index_parent = (dist_matrix * mask + dist_max *
(1 - mask)).min(dim=-1)
# select clustering center according to score
score = dist * density
_, index_down = torch.topk(score, k=cluster_num, dim=-1)
# assign tokens to the nearest center
dist_matrix = index_points(dist_matrix, index_down)
idx_cluster = dist_matrix.argmin(dim=1)
# make sure cluster center merge to itself
idx_batch = torch.arange(
B, device=x.device)[:, None].expand(B, cluster_num)
idx_tmp = torch.arange(
cluster_num, device=x.device)[None, :].expand(B, cluster_num)
idx_cluster[idx_batch.reshape(-1),
index_down.reshape(-1)] = idx_tmp.reshape(-1)
return idx_cluster, cluster_num
def merge_tokens(token_dict, idx_cluster, cluster_num, token_weight=None):
"""Merge tokens in the same cluster to a single cluster. Implemented by
torch.index_add(). Flops: B*N*(C+2)
Note:
B: batch size
N: token number
C: channel number
Args:
token_dict (dict): dict for input token information
idx_cluster (Tensor[B, N]): cluster index of each token.
cluster_num (int): cluster number
token_weight (Tensor[B, N, 1]): weight for each token.
Return:
out_dict (dict): dict for output token information
"""
x = token_dict['x']
idx_token = token_dict['idx_token']
agg_weight = token_dict['agg_weight']
B, N, C = x.shape
if token_weight is None:
token_weight = x.new_ones(B, N, 1)
idx_batch = torch.arange(B, device=x.device)[:, None]
idx = idx_cluster + idx_batch * cluster_num
all_weight = token_weight.new_zeros(B * cluster_num, 1)
all_weight.index_add_(
dim=0, index=idx.reshape(B * N), source=token_weight.reshape(B * N, 1))
all_weight = all_weight + 1e-6
norm_weight = token_weight / all_weight[idx]
# average token features
x_merged = x.new_zeros(B * cluster_num, C)
source = x * norm_weight
x_merged.index_add_(
dim=0,
index=idx.reshape(B * N),
source=source.reshape(B * N, C).type(x.dtype))
x_merged = x_merged.reshape(B, cluster_num, C)
idx_token_new = index_points(idx_cluster[..., None], idx_token).squeeze(-1)
weight_t = index_points(norm_weight, idx_token)
agg_weight_new = agg_weight * weight_t
agg_weight_new / agg_weight_new.max(dim=1, keepdim=True)[0]
out_dict = {}
out_dict['x'] = x_merged
out_dict['token_num'] = cluster_num
out_dict['map_size'] = token_dict['map_size']
out_dict['init_grid_size'] = token_dict['init_grid_size']
out_dict['idx_token'] = idx_token_new
out_dict['agg_weight'] = agg_weight_new
return out_dict
class MLP(nn.Module):
"""FFN with Depthwise Conv of TCFormer.
Args:
in_features (int): The feature dimension.
hidden_features (int, optional): The hidden dimension of FFNs.
Defaults: The same as in_features.
out_features (int, optional): The output feature dimension.
Defaults: The same as in_features.
act_layer (nn.Module, optional): The activation config for FFNs.
Default: nn.GELU.
drop (float, optional): drop out rate. Default: 0.
"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv = DWConv(hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def init_weights(self):
"""init weights."""
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = self.fc1(x)
x = self.dwconv(x, H, W)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class DWConv(nn.Module):
"""Depthwise Conv for regular grid-based tokens.
Args:
dim (int): The feature dimension.
"""
def __init__(self, dim=768):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x, H, W):
B, N, C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
x = x.flatten(2).transpose(1, 2)
return x
class TCFormerRegularAttention(nn.Module):
"""Spatial Reduction Attention for regular grid-based tokens.
Args:
dim (int): The feature dimension of tokens,
num_heads (int): Parallel attention heads.
qkv_bias (bool): enable bias for qkv if True. Default: False.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
attn_drop (float): A Dropout layer on attn_output_weights.
Default: 0.0.
proj_drop (float): A Dropout layer after attention process.
Default: 0.0.
sr_ratio (int): The ratio of spatial reduction of Spatial Reduction
Attention. Default: 1.
use_sr_conv (bool): If True, use a conv layer for spatial reduction.
If False, use a pooling process for spatial reduction. Defaults:
True.
"""
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
sr_ratio=1,
use_sr_conv=True,
):
super().__init__()
assert dim % num_heads == 0, \
f'dim {dim} should be divided by num_heads {num_heads}.'
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.sr_ratio = sr_ratio
self.use_sr_conv = use_sr_conv
if sr_ratio > 1 and self.use_sr_conv:
self.sr = nn.Conv2d(
dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.num_heads,
C // self.num_heads).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
kv = x.permute(0, 2, 1).reshape(B, C, H, W)
if self.use_sr_conv:
kv = self.sr(kv).reshape(B, C, -1).permute(0, 2,
1).contiguous()
kv = self.norm(kv)
else:
kv = F.avg_pool2d(
kv, kernel_size=self.sr_ratio, stride=self.sr_ratio)
kv = kv.reshape(B, C, -1).permute(0, 2, 1).contiguous()
else:
kv = x
kv = self.kv(kv).reshape(B, -1, 2, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1,
4).contiguous()
k, v = kv[0], kv[1]
attn = (q * self.scale) @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class TCFormerRegularBlock(nn.Module):
"""Transformer block for regular grid-based tokens.
Args:
dim (int): The feature dimension.
num_heads (int): Parallel attention heads.
mlp_ratio (int): The expansion ratio for the FFNs.
qkv_bias (bool): enable bias for qkv if True. Default: False.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
drop (float): Dropout layers after attention process and in FFN.
Default: 0.0.
attn_drop (float): A Dropout layer on attn_output_weights.
Default: 0.0.
drop_path (int, optional): The drop path rate of transformer block.
Default: 0.0
act_layer (nn.Module, optional): The activation config for FFNs.
Default: nn.GELU.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
sr_ratio (int): The ratio of spatial reduction of Spatial Reduction
Attention. Default: 1.
use_sr_conv (bool): If True, use a conv layer for spatial reduction.
If False, use a pooling process for spatial reduction. Defaults:
True.
"""
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_cfg=dict(type='LN'),
sr_ratio=1,
use_sr_conv=True):
super().__init__()
self.norm1 = build_norm_layer(norm_cfg, dim)[1]
self.attn = TCFormerRegularAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
sr_ratio=sr_ratio,
use_sr_conv=use_sr_conv)
self.drop_path = build_dropout(
dict(type='DropPath', drop_prob=drop_path))
self.norm2 = build_norm_layer(norm_cfg, dim)[1]
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = MLP(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x
class TokenConv(nn.Conv2d):
"""Conv layer for dynamic tokens.
A skip link is added between the input and output tokens to reserve detail
tokens.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
groups = kwargs['groups'] if 'groups' in kwargs.keys() else 1
self.skip = nn.Conv1d(
in_channels=kwargs['in_channels'],
out_channels=kwargs['out_channels'],
kernel_size=1,
bias=False,
groups=groups)
def forward(self, token_dict):
x = token_dict['x']
x = self.skip(x.permute(0, 2, 1)).permute(0, 2, 1)
x_map = token2map(token_dict)
x_map = super().forward(x_map)
x = x + map2token(x_map, token_dict)
return x
class TCMLP(nn.Module):
"""FFN with Depthwise Conv for dynamic tokens.
Args:
in_features (int): The feature dimension.
hidden_features (int, optional): The hidden dimension of FFNs.
Defaults: The same as in_features.
out_features (int, optional): The output feature dimension.
Defaults: The same as in_features.
act_layer (nn.Module, optional): The activation config for FFNs.
Default: nn.GELU.
drop (float, optional): drop out rate. Default: 0.
"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv = TokenConv(
in_channels=hidden_features,
out_channels=hidden_features,
kernel_size=3,
padding=1,
stride=1,
bias=True,
groups=hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def init_weights(self):
"""init weights."""
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, token_dict):
token_dict['x'] = self.fc1(token_dict['x'])
x = self.dwconv(token_dict)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class TCFormerDynamicAttention(TCFormerRegularAttention):
"""Spatial Reduction Attention for dynamic tokens."""
def forward(self, q_dict, kv_dict):
"""Attention process for dynamic tokens.
Dynamic tokens are represented by a dict with the following keys:
x (torch.Tensor[B, N, C]): token features.
token_num(int): token number.
map_size(list[int] or tuple[int]): feature map resolution in
format [H, W].
init_grid_size(list[int] or tuple[int]): initial grid resolution
in format [H_init, W_init].
idx_token(torch.LongTensor[B, N_init]): indicates which token
the initial grid belongs to.
agg_weight(torch.LongTensor[B, N_init] or None): weight for
aggregation. Indicates the weight of each token in its
cluster. If set to None, uniform weight is used.
Note:
B: batch size
N: token number
C: channel number
Ns: sampled point number
[H_init, W_init]: shape of initial grid
[H, W]: shape of feature map
N_init: numbers of initial token
Args:
q_dict (dict): dict for query token information
kv_dict (dict): dict for key and value token information
Return:
x (torch.Tensor[B, N, C]): output token features.
"""
q = q_dict['x']
kv = kv_dict['x']
B, Nq, C = q.shape
Nkv = kv.shape[1]
conf_kv = kv_dict['token_score'] if 'token_score' in kv_dict.keys(
) else kv.new_zeros(B, Nkv, 1)
q = self.q(q).reshape(B, Nq, self.num_heads,
C // self.num_heads).permute(0, 2, 1,
3).contiguous()
if self.sr_ratio > 1:
tmp = torch.cat([kv, conf_kv], dim=-1)
tmp_dict = kv_dict.copy()
tmp_dict['x'] = tmp
tmp_dict['map_size'] = q_dict['map_size']
tmp = token2map(tmp_dict)
kv = tmp[:, :C]
conf_kv = tmp[:, C:]
if self.use_sr_conv:
kv = self.sr(kv)
_, _, h, w = kv.shape
kv = kv.reshape(B, C, -1).permute(0, 2, 1).contiguous()
kv = self.norm(kv)
else:
kv = F.avg_pool2d(
kv, kernel_size=self.sr_ratio, stride=self.sr_ratio)
kv = kv.reshape(B, C, -1).permute(0, 2, 1).contiguous()
conf_kv = F.avg_pool2d(
conf_kv, kernel_size=self.sr_ratio, stride=self.sr_ratio)
conf_kv = conf_kv.reshape(B, 1, -1).permute(0, 2, 1).contiguous()
kv = self.kv(kv).reshape(B, -1, 2, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1,
4).contiguous()
k, v = kv[0], kv[1]
attn = (q * self.scale) @ k.transpose(-2, -1)
conf_kv = conf_kv.squeeze(-1)[:, None, None, :]
attn = attn + conf_kv
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, Nq, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
# Transformer block for dynamic tokens
class TCFormerDynamicBlock(TCFormerRegularBlock):
"""Transformer block for dynamic tokens.
Args:
dim (int): The feature dimension.
num_heads (int): Parallel attention heads.
mlp_ratio (int): The expansion ratio for the FFNs.
qkv_bias (bool): enable bias for qkv if True. Default: False.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
drop (float): Dropout layers after attention process and in FFN.
Default: 0.0.
attn_drop (float): A Dropout layer on attn_output_weights.
Default: 0.0.
drop_path (int, optional): The drop path rate of transformer block.
Default: 0.0
act_layer (nn.Module, optional): The activation config for FFNs.
Default: nn.GELU.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
sr_ratio (int): The ratio of spatial reduction of Spatial Reduction
Attention. Default: 1.
use_sr_conv (bool): If True, use a conv layer for spatial reduction.
If False, use a pooling process for spatial reduction. Defaults:
True.
"""
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_cfg=dict(type='LN'),
sr_ratio=1,
use_sr_conv=True):
super(TCFormerRegularBlock, self).__init__()
self.norm1 = build_norm_layer(norm_cfg, dim)[1]
self.attn = TCFormerDynamicAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
sr_ratio=sr_ratio,
use_sr_conv=use_sr_conv)
self.drop_path = build_dropout(
dict(type='DropPath', drop_prob=drop_path))
self.norm2 = build_norm_layer(norm_cfg, dim)[1]
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = TCMLP(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
def forward(self, inputs):
"""Forward function.
Args:
inputs (dict or tuple[dict] or list[dict]): input dynamic
token information. If a single dict is provided, it's
regraded as query and key, value. If a tuple or list
of dict is provided, the first one is regarded as key
and the second one is regarded as key, value.
Return:
q_dict (dict): dict for output token information
"""
if isinstance(inputs, tuple) or isinstance(inputs, list):
q_dict, kv_dict = inputs
else:
q_dict, kv_dict = inputs, None
x = q_dict['x']
# norm1
q_dict['x'] = self.norm1(q_dict['x'])
if kv_dict is None:
kv_dict = q_dict
else:
kv_dict['x'] = self.norm1(kv_dict['x'])
# attn
x = x + self.drop_path(self.attn(q_dict, kv_dict))
# mlp
q_dict['x'] = self.norm2(x)
x = x + self.drop_path(self.mlp(q_dict))
q_dict['x'] = x
return q_dict