File size: 9,192 Bytes
34d1f8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
import copy
import warnings
import torch
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import (build_attention,
build_feedforward_network)
from mmengine.config import ConfigDict
from mmengine.model import BaseModule, ModuleList
from mmengine.registry import MODELS
@MODELS.register_module()
class TPVFormerLayer(BaseModule):
"""Base `TPVFormerLayer` for vision transformer.
It can be built from `mmcv.ConfigDict` and support more flexible
customization, for example, using any number of `FFN or LN ` and
use different kinds of `attention` by specifying a list of `ConfigDict`
named `attn_cfgs`. It is worth mentioning that it supports `prenorm`
when you specifying `norm` as the first element of `operation_order`.
More details about the `prenorm`: `On Layer Normalization in the
Transformer Architecture <https://arxiv.org/abs/2002.04745>`_ .
Args:
attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
Configs for `self_attention` or `cross_attention` modules,
The order of the configs in the list should be consistent with
corresponding attentions in operation_order.
If it is a dict, all of the attention modules in operation_order
will be built with this config. Default: None.
ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
Configs for FFN, The order of the configs in the list should be
consistent with corresponding ffn in operation_order.
If it is a dict, all of the attention modules in operation_order
will be built with this config.
operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
Support `prenorm` when you specifying first element as `norm`.
Default: None.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
batch_first (bool): Key, Query and Value are shape
of (batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
"""
def __init__(self,
attn_cfgs=None,
ffn_cfgs=dict(
type='FFN',
feedforward_channels=1024,
num_fcs=2,
ffn_drop=0.,
act_cfg=dict(type='ReLU', inplace=True),
),
operation_order=None,
norm_cfg=dict(type='LN'),
init_cfg=None,
batch_first=True,
**kwargs):
deprecated_args = dict(
feedforward_channels='feedforward_channels',
ffn_dropout='ffn_drop',
ffn_num_fcs='num_fcs')
for ori_name, new_name in deprecated_args.items():
if ori_name in kwargs:
warnings.warn(
f'The arguments `{ori_name}` in BaseTransformerLayer '
f'has been deprecated, now you should set `{new_name}` '
f'and other FFN related arguments '
f'to a dict named `ffn_cfgs`. ')
ffn_cfgs[new_name] = kwargs[ori_name]
super().__init__(init_cfg)
self.batch_first = batch_first
num_attn = operation_order.count('self_attn') + operation_order.count(
'cross_attn')
if isinstance(attn_cfgs, dict):
attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)]
else:
assert num_attn == len(attn_cfgs), f'The length ' \
f'of attn_cfg {num_attn} is ' \
f'not consistent with the number of attention' \
f'in operation_order {operation_order}.'
self.num_attn = num_attn
self.operation_order = operation_order
self.norm_cfg = norm_cfg
self.pre_norm = operation_order[0] == 'norm'
self.attentions = ModuleList()
index = 0
for operation_name in operation_order:
if operation_name in ['self_attn', 'cross_attn']:
if 'batch_first' in attn_cfgs[index]:
assert self.batch_first == attn_cfgs[index]['batch_first']
else:
attn_cfgs[index]['batch_first'] = self.batch_first
attention = build_attention(attn_cfgs[index])
# Some custom attentions used as `self_attn`
# or `cross_attn` can have different behavior.
attention.operation_name = operation_name
self.attentions.append(attention)
index += 1
self.embed_dims = self.attentions[0].embed_dims
self.ffns = ModuleList()
num_ffns = operation_order.count('ffn')
if isinstance(ffn_cfgs, dict):
ffn_cfgs = ConfigDict(ffn_cfgs)
if isinstance(ffn_cfgs, dict):
ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)]
assert len(ffn_cfgs) == num_ffns
for ffn_index in range(num_ffns):
if 'embed_dims' not in ffn_cfgs[ffn_index]:
ffn_cfgs[ffn_index]['embed_dims'] = self.embed_dims
else:
assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
self.ffns.append(build_feedforward_network(ffn_cfgs[ffn_index]))
self.norms = ModuleList()
num_norms = operation_order.count('norm')
for _ in range(num_norms):
self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
def forward(self,
query,
key=None,
value=None,
tpv_pos=None,
ref_2d=None,
tpv_h=None,
tpv_w=None,
tpv_z=None,
reference_points_cams=None,
tpv_masks=None,
spatial_shapes=None,
level_start_index=None,
**kwargs):
"""
**kwargs contains some specific arguments of attentions.
Args:
query (Tensor): The input query with shape
[num_queries, bs, embed_dims] if
self.batch_first is False, else
[bs, num_queries embed_dims].
key (Tensor): The key tensor with shape [num_keys, bs,
embed_dims] if self.batch_first is False, else
[bs, num_keys, embed_dims] .
value (Tensor): The value tensor with same shape as `key`.
tpv_pos (Tensor): The positional encoding for self attn.
Returns:
Tensor: forwarded results with shape
[[bs, num_queries, embed_dims] * 3] for 3 tpv planes.
"""
norm_index = 0
attn_index = 0
ffn_index = 0
if self.operation_order[0] == 'cross_attn':
query = torch.cat(query, dim=1)
identity = query
for layer in self.operation_order:
# cross view hybrid-attention
if layer == 'self_attn':
ss = torch.tensor(
[[tpv_h, tpv_w], [tpv_z, tpv_h], [tpv_w, tpv_z]],
device=query[0].device)
lsi = torch.tensor(
[0, tpv_h * tpv_w, tpv_h * tpv_w + tpv_z * tpv_h],
device=query[0].device)
if not isinstance(query, (list, tuple)):
query = torch.split(
query, [tpv_h * tpv_w, tpv_z * tpv_h, tpv_w * tpv_z],
dim=1)
query = self.attentions[attn_index](
query,
identity if self.pre_norm else None,
query_pos=tpv_pos,
reference_points=ref_2d,
spatial_shapes=ss,
level_start_index=lsi,
**kwargs)
attn_index += 1
query = torch.cat(query, dim=1)
identity = query
elif layer == 'norm':
query = self.norms[norm_index](query)
norm_index += 1
# image cross attention
elif layer == 'cross_attn':
query = self.attentions[attn_index](
query,
key,
value,
identity if self.pre_norm else None,
reference_points_cams=reference_points_cams,
tpv_masks=tpv_masks,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
**kwargs)
attn_index += 1
identity = query
elif layer == 'ffn':
query = self.ffns[ffn_index](
query, identity if self.pre_norm else None)
ffn_index += 1
query = torch.split(
query, [tpv_h * tpv_w, tpv_z * tpv_h, tpv_w * tpv_z], dim=1)
return query
|