|
import copy |
|
import warnings |
|
|
|
import torch |
|
from mmcv.cnn import build_norm_layer |
|
from mmcv.cnn.bricks.transformer import (build_attention, |
|
build_feedforward_network) |
|
from mmengine.config import ConfigDict |
|
from mmengine.model import BaseModule, ModuleList |
|
from mmengine.registry import MODELS |
|
|
|
|
|
@MODELS.register_module() |
|
class TPVFormerLayer(BaseModule): |
|
"""Base `TPVFormerLayer` for vision transformer. |
|
|
|
It can be built from `mmcv.ConfigDict` and support more flexible |
|
customization, for example, using any number of `FFN or LN ` and |
|
use different kinds of `attention` by specifying a list of `ConfigDict` |
|
named `attn_cfgs`. It is worth mentioning that it supports `prenorm` |
|
when you specifying `norm` as the first element of `operation_order`. |
|
More details about the `prenorm`: `On Layer Normalization in the |
|
Transformer Architecture <https://arxiv.org/abs/2002.04745>`_ . |
|
Args: |
|
attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )): |
|
Configs for `self_attention` or `cross_attention` modules, |
|
The order of the configs in the list should be consistent with |
|
corresponding attentions in operation_order. |
|
If it is a dict, all of the attention modules in operation_order |
|
will be built with this config. Default: None. |
|
ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )): |
|
Configs for FFN, The order of the configs in the list should be |
|
consistent with corresponding ffn in operation_order. |
|
If it is a dict, all of the attention modules in operation_order |
|
will be built with this config. |
|
operation_order (tuple[str]): The execution order of operation |
|
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). |
|
Support `prenorm` when you specifying first element as `norm`. |
|
Default: None. |
|
norm_cfg (dict): Config dict for normalization layer. |
|
Default: dict(type='LN'). |
|
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. |
|
Default: None. |
|
batch_first (bool): Key, Query and Value are shape |
|
of (batch, n, embed_dim) |
|
or (n, batch, embed_dim). Default to False. |
|
""" |
|
|
|
def __init__(self, |
|
attn_cfgs=None, |
|
ffn_cfgs=dict( |
|
type='FFN', |
|
feedforward_channels=1024, |
|
num_fcs=2, |
|
ffn_drop=0., |
|
act_cfg=dict(type='ReLU', inplace=True), |
|
), |
|
operation_order=None, |
|
norm_cfg=dict(type='LN'), |
|
init_cfg=None, |
|
batch_first=True, |
|
**kwargs): |
|
deprecated_args = dict( |
|
feedforward_channels='feedforward_channels', |
|
ffn_dropout='ffn_drop', |
|
ffn_num_fcs='num_fcs') |
|
for ori_name, new_name in deprecated_args.items(): |
|
if ori_name in kwargs: |
|
warnings.warn( |
|
f'The arguments `{ori_name}` in BaseTransformerLayer ' |
|
f'has been deprecated, now you should set `{new_name}` ' |
|
f'and other FFN related arguments ' |
|
f'to a dict named `ffn_cfgs`. ') |
|
ffn_cfgs[new_name] = kwargs[ori_name] |
|
|
|
super().__init__(init_cfg) |
|
|
|
self.batch_first = batch_first |
|
|
|
num_attn = operation_order.count('self_attn') + operation_order.count( |
|
'cross_attn') |
|
if isinstance(attn_cfgs, dict): |
|
attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)] |
|
else: |
|
assert num_attn == len(attn_cfgs), f'The length ' \ |
|
f'of attn_cfg {num_attn} is ' \ |
|
f'not consistent with the number of attention' \ |
|
f'in operation_order {operation_order}.' |
|
|
|
self.num_attn = num_attn |
|
self.operation_order = operation_order |
|
self.norm_cfg = norm_cfg |
|
self.pre_norm = operation_order[0] == 'norm' |
|
self.attentions = ModuleList() |
|
|
|
index = 0 |
|
for operation_name in operation_order: |
|
if operation_name in ['self_attn', 'cross_attn']: |
|
if 'batch_first' in attn_cfgs[index]: |
|
assert self.batch_first == attn_cfgs[index]['batch_first'] |
|
else: |
|
attn_cfgs[index]['batch_first'] = self.batch_first |
|
attention = build_attention(attn_cfgs[index]) |
|
|
|
|
|
attention.operation_name = operation_name |
|
self.attentions.append(attention) |
|
index += 1 |
|
|
|
self.embed_dims = self.attentions[0].embed_dims |
|
|
|
self.ffns = ModuleList() |
|
num_ffns = operation_order.count('ffn') |
|
if isinstance(ffn_cfgs, dict): |
|
ffn_cfgs = ConfigDict(ffn_cfgs) |
|
if isinstance(ffn_cfgs, dict): |
|
ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)] |
|
assert len(ffn_cfgs) == num_ffns |
|
for ffn_index in range(num_ffns): |
|
if 'embed_dims' not in ffn_cfgs[ffn_index]: |
|
ffn_cfgs[ffn_index]['embed_dims'] = self.embed_dims |
|
else: |
|
assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims |
|
|
|
self.ffns.append(build_feedforward_network(ffn_cfgs[ffn_index])) |
|
|
|
self.norms = ModuleList() |
|
num_norms = operation_order.count('norm') |
|
for _ in range(num_norms): |
|
self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1]) |
|
|
|
def forward(self, |
|
query, |
|
key=None, |
|
value=None, |
|
tpv_pos=None, |
|
ref_2d=None, |
|
tpv_h=None, |
|
tpv_w=None, |
|
tpv_z=None, |
|
reference_points_cams=None, |
|
tpv_masks=None, |
|
spatial_shapes=None, |
|
level_start_index=None, |
|
**kwargs): |
|
""" |
|
**kwargs contains some specific arguments of attentions. |
|
|
|
Args: |
|
query (Tensor): The input query with shape |
|
[num_queries, bs, embed_dims] if |
|
self.batch_first is False, else |
|
[bs, num_queries embed_dims]. |
|
key (Tensor): The key tensor with shape [num_keys, bs, |
|
embed_dims] if self.batch_first is False, else |
|
[bs, num_keys, embed_dims] . |
|
value (Tensor): The value tensor with same shape as `key`. |
|
tpv_pos (Tensor): The positional encoding for self attn. |
|
Returns: |
|
Tensor: forwarded results with shape |
|
[[bs, num_queries, embed_dims] * 3] for 3 tpv planes. |
|
""" |
|
|
|
norm_index = 0 |
|
attn_index = 0 |
|
ffn_index = 0 |
|
if self.operation_order[0] == 'cross_attn': |
|
query = torch.cat(query, dim=1) |
|
identity = query |
|
|
|
for layer in self.operation_order: |
|
|
|
if layer == 'self_attn': |
|
ss = torch.tensor( |
|
[[tpv_h, tpv_w], [tpv_z, tpv_h], [tpv_w, tpv_z]], |
|
device=query[0].device) |
|
lsi = torch.tensor( |
|
[0, tpv_h * tpv_w, tpv_h * tpv_w + tpv_z * tpv_h], |
|
device=query[0].device) |
|
|
|
if not isinstance(query, (list, tuple)): |
|
query = torch.split( |
|
query, [tpv_h * tpv_w, tpv_z * tpv_h, tpv_w * tpv_z], |
|
dim=1) |
|
|
|
query = self.attentions[attn_index]( |
|
query, |
|
identity if self.pre_norm else None, |
|
query_pos=tpv_pos, |
|
reference_points=ref_2d, |
|
spatial_shapes=ss, |
|
level_start_index=lsi, |
|
**kwargs) |
|
attn_index += 1 |
|
query = torch.cat(query, dim=1) |
|
identity = query |
|
|
|
elif layer == 'norm': |
|
query = self.norms[norm_index](query) |
|
norm_index += 1 |
|
|
|
|
|
elif layer == 'cross_attn': |
|
query = self.attentions[attn_index]( |
|
query, |
|
key, |
|
value, |
|
identity if self.pre_norm else None, |
|
reference_points_cams=reference_points_cams, |
|
tpv_masks=tpv_masks, |
|
spatial_shapes=spatial_shapes, |
|
level_start_index=level_start_index, |
|
**kwargs) |
|
attn_index += 1 |
|
identity = query |
|
|
|
elif layer == 'ffn': |
|
query = self.ffns[ffn_index]( |
|
query, identity if self.pre_norm else None) |
|
ffn_index += 1 |
|
query = torch.split( |
|
query, [tpv_h * tpv_w, tpv_z * tpv_h, tpv_w * tpv_z], dim=1) |
|
return query |
|
|