|
import torch.nn as nn |
|
from .net_utils import PosEnSine, softmax_attention, dotproduct_attention, long_range_attention, \ |
|
short_range_attention, patch_attention |
|
|
|
|
|
class OurMultiheadAttention(nn.Module): |
|
def __init__(self, q_feat_dim, k_feat_dim, out_feat_dim, n_head, d_k=None, d_v=None): |
|
super(OurMultiheadAttention, self).__init__() |
|
if d_k is None: |
|
d_k = out_feat_dim // n_head |
|
if d_v is None: |
|
d_v = out_feat_dim // n_head |
|
|
|
self.n_head = n_head |
|
self.d_k = d_k |
|
self.d_v = d_v |
|
|
|
|
|
self.w_qs = nn.Conv2d(q_feat_dim, n_head * d_k, 1, bias=False) |
|
self.w_ks = nn.Conv2d(k_feat_dim, n_head * d_k, 1, bias=False) |
|
self.w_vs = nn.Conv2d(out_feat_dim, n_head * d_v, 1, bias=False) |
|
|
|
|
|
self.fc = nn.Conv2d(n_head * d_v, out_feat_dim, 1, bias=False) |
|
|
|
def forward(self, q, k, v, attn_type='softmax', **kwargs): |
|
|
|
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head |
|
|
|
|
|
|
|
q = self.w_qs(q).view(q.shape[0], n_head, d_k, q.shape[2], q.shape[3]) |
|
k = self.w_ks(k).view(k.shape[0], n_head, d_k, k.shape[2], k.shape[3]) |
|
v = self.w_vs(v).view(v.shape[0], n_head, d_v, v.shape[2], v.shape[3]) |
|
|
|
|
|
if attn_type == 'softmax': |
|
q, attn = softmax_attention(q, k, v) |
|
elif attn_type == 'dotproduct': |
|
q, attn = dotproduct_attention(q, k, v) |
|
elif attn_type == 'patch': |
|
q, attn = patch_attention(q, k, v, P=kwargs['P']) |
|
elif attn_type == 'sparse_long': |
|
q, attn = long_range_attention(q, k, v, P_h=kwargs['ah'], P_w=kwargs['aw']) |
|
elif attn_type == 'sparse_short': |
|
q, attn = short_range_attention(q, k, v, Q_h=kwargs['ah'], Q_w=kwargs['aw']) |
|
else: |
|
raise NotImplementedError(f'Unknown attention type {attn_type}') |
|
|
|
|
|
|
|
q = q.reshape(q.shape[0], -1, q.shape[3], q.shape[4]) |
|
q = self.fc(q) |
|
|
|
return q, attn |
|
|
|
|
|
class TransformerEncoderUnit(nn.Module): |
|
def __init__(self, feat_dim, n_head=8, pos_en_flag=True, attn_type='softmax', P=None): |
|
super(TransformerEncoderUnit, self).__init__() |
|
self.feat_dim = feat_dim |
|
self.attn_type = attn_type |
|
self.pos_en_flag = pos_en_flag |
|
self.P = P |
|
|
|
self.pos_en = PosEnSine(self.feat_dim // 2) |
|
self.attn = OurMultiheadAttention(feat_dim, n_head) |
|
|
|
self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) |
|
self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) |
|
self.activation = nn.ReLU(inplace=True) |
|
|
|
self.norm1 = nn.BatchNorm2d(self.feat_dim) |
|
self.norm2 = nn.BatchNorm2d(self.feat_dim) |
|
|
|
def forward(self, src): |
|
if self.pos_en_flag: |
|
pos_embed = self.pos_en(src) |
|
else: |
|
pos_embed = 0 |
|
|
|
|
|
src2 = self.attn( |
|
q=src + pos_embed, k=src + pos_embed, v=src, attn_type=self.attn_type, P=self.P |
|
)[0] |
|
src = src + src2 |
|
src = self.norm1(src) |
|
|
|
|
|
src2 = self.linear2(self.activation(self.linear1(src))) |
|
src = src + src2 |
|
src = self.norm2(src) |
|
|
|
return src |
|
|
|
|
|
class TransformerEncoderUnitSparse(nn.Module): |
|
def __init__(self, feat_dim, n_head=8, pos_en_flag=True, ahw=None): |
|
super(TransformerEncoderUnitSparse, self).__init__() |
|
self.feat_dim = feat_dim |
|
self.pos_en_flag = pos_en_flag |
|
self.ahw = ahw |
|
|
|
self.pos_en = PosEnSine(self.feat_dim // 2) |
|
self.attn1 = OurMultiheadAttention(feat_dim, n_head) |
|
self.attn2 = OurMultiheadAttention(feat_dim, n_head) |
|
|
|
self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) |
|
self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) |
|
self.activation = nn.ReLU(inplace=True) |
|
|
|
self.norm1 = nn.BatchNorm2d(self.feat_dim) |
|
self.norm2 = nn.BatchNorm2d(self.feat_dim) |
|
|
|
def forward(self, src): |
|
if self.pos_en_flag: |
|
pos_embed = self.pos_en(src) |
|
else: |
|
pos_embed = 0 |
|
|
|
|
|
src2 = self.attn1( |
|
q=src + pos_embed, |
|
k=src + pos_embed, |
|
v=src, |
|
attn_type='sparse_long', |
|
ah=self.ahw[0], |
|
aw=self.ahw[1] |
|
)[0] |
|
src = src + src2 |
|
|
|
|
|
src2 = self.attn2( |
|
q=src + pos_embed, |
|
k=src + pos_embed, |
|
v=src, |
|
attn_type='sparse_short', |
|
ah=self.ahw[2], |
|
aw=self.ahw[3] |
|
)[0] |
|
src = src + src2 |
|
src = self.norm1(src) |
|
|
|
|
|
src2 = self.linear2(self.activation(self.linear1(src))) |
|
src = src + src2 |
|
src = self.norm2(src) |
|
|
|
return src |
|
|
|
|
|
class TransformerDecoderUnit(nn.Module): |
|
def __init__(self, feat_dim, n_head=8, pos_en_flag=True, attn_type='softmax', P=None): |
|
super(TransformerDecoderUnit, self).__init__() |
|
self.feat_dim = feat_dim |
|
self.attn_type = attn_type |
|
self.pos_en_flag = pos_en_flag |
|
self.P = P |
|
|
|
self.pos_en = PosEnSine(self.feat_dim // 2) |
|
self.attn1 = OurMultiheadAttention(feat_dim, n_head) |
|
self.attn2 = OurMultiheadAttention(feat_dim, n_head) |
|
|
|
self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) |
|
self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) |
|
self.activation = nn.ReLU(inplace=True) |
|
|
|
self.norm1 = nn.BatchNorm2d(self.feat_dim) |
|
self.norm2 = nn.BatchNorm2d(self.feat_dim) |
|
self.norm3 = nn.BatchNorm2d(self.feat_dim) |
|
|
|
def forward(self, tgt, src): |
|
if self.pos_en_flag: |
|
src_pos_embed = self.pos_en(src) |
|
tgt_pos_embed = self.pos_en(tgt) |
|
else: |
|
src_pos_embed = 0 |
|
tgt_pos_embed = 0 |
|
|
|
|
|
tgt2 = self.attn1( |
|
q=tgt + tgt_pos_embed, k=tgt + tgt_pos_embed, v=tgt, attn_type=self.attn_type, P=self.P |
|
)[0] |
|
tgt = tgt + tgt2 |
|
tgt = self.norm1(tgt) |
|
|
|
|
|
tgt2 = self.attn2( |
|
q=tgt + tgt_pos_embed, k=src + src_pos_embed, v=src, attn_type=self.attn_type, P=self.P |
|
)[0] |
|
tgt = tgt + tgt2 |
|
tgt = self.norm2(tgt) |
|
|
|
|
|
tgt2 = self.linear2(self.activation(self.linear1(tgt))) |
|
tgt = tgt + tgt2 |
|
tgt = self.norm3(tgt) |
|
|
|
return tgt |
|
|
|
|
|
class TransformerDecoderUnitSparse(nn.Module): |
|
def __init__(self, feat_dim, n_head=8, pos_en_flag=True, ahw=None): |
|
super(TransformerDecoderUnitSparse, self).__init__() |
|
self.feat_dim = feat_dim |
|
self.ahw = ahw |
|
self.pos_en_flag = pos_en_flag |
|
|
|
self.pos_en = PosEnSine(self.feat_dim // 2) |
|
self.attn1_1 = OurMultiheadAttention(feat_dim, n_head) |
|
self.attn1_2 = OurMultiheadAttention(feat_dim, n_head) |
|
|
|
self.attn2_1 = OurMultiheadAttention( |
|
feat_dim, n_head |
|
) |
|
self.attn2_2 = OurMultiheadAttention(feat_dim, n_head) |
|
|
|
self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) |
|
self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) |
|
self.activation = nn.ReLU(inplace=True) |
|
|
|
self.norm1 = nn.BatchNorm2d(self.feat_dim) |
|
self.norm2 = nn.BatchNorm2d(self.feat_dim) |
|
self.norm3 = nn.BatchNorm2d(self.feat_dim) |
|
|
|
def forward(self, tgt, src): |
|
if self.pos_en_flag: |
|
src_pos_embed = self.pos_en(src) |
|
tgt_pos_embed = self.pos_en(tgt) |
|
else: |
|
src_pos_embed = 0 |
|
tgt_pos_embed = 0 |
|
|
|
|
|
tgt2 = self.attn1_1( |
|
q=tgt + tgt_pos_embed, |
|
k=tgt + tgt_pos_embed, |
|
v=tgt, |
|
attn_type='sparse_long', |
|
ah=self.ahw[0], |
|
aw=self.ahw[1] |
|
)[0] |
|
tgt = tgt + tgt2 |
|
|
|
tgt2 = self.attn1_2( |
|
q=tgt + tgt_pos_embed, |
|
k=tgt + tgt_pos_embed, |
|
v=tgt, |
|
attn_type='sparse_short', |
|
ah=self.ahw[2], |
|
aw=self.ahw[3] |
|
)[0] |
|
tgt = tgt + tgt2 |
|
tgt = self.norm1(tgt) |
|
|
|
|
|
src2 = self.attn2_1( |
|
q=src + src_pos_embed, |
|
k=src + src_pos_embed, |
|
v=src, |
|
attn_type='sparse_long', |
|
ah=self.ahw[4], |
|
aw=self.ahw[5] |
|
)[0] |
|
src = src + src2 |
|
|
|
tgt2 = self.attn2_2( |
|
q=tgt + tgt_pos_embed, |
|
k=src + src_pos_embed, |
|
v=src, |
|
attn_type='sparse_short', |
|
ah=self.ahw[6], |
|
aw=self.ahw[7] |
|
)[0] |
|
tgt = tgt + tgt2 |
|
tgt = self.norm2(tgt) |
|
|
|
|
|
tgt2 = self.linear2(self.activation(self.linear1(tgt))) |
|
tgt = tgt + tgt2 |
|
tgt = self.norm3(tgt) |
|
|
|
return tgt |
|
|