Update yolo_world/models/layers/yolo_bricks.py
Browse files
yolo_world/models/layers/yolo_bricks.py
CHANGED
@@ -29,7 +29,8 @@ class MaxSigmoidAttnBlock(BaseModule):
|
|
29 |
norm_cfg: ConfigType = dict(type='BN',
|
30 |
momentum=0.03,
|
31 |
eps=0.001),
|
32 |
-
init_cfg: OptMultiConfig = None
|
|
|
33 |
super().__init__(init_cfg=init_cfg)
|
34 |
conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
|
35 |
|
@@ -38,6 +39,7 @@ class MaxSigmoidAttnBlock(BaseModule):
|
|
38 |
'out_channels and embed_channels should be divisible by num_heads.'
|
39 |
self.num_heads = num_heads
|
40 |
self.head_channels = out_channels // num_heads
|
|
|
41 |
|
42 |
self.embed_conv = ConvModule(
|
43 |
in_channels,
|
@@ -71,7 +73,17 @@ class MaxSigmoidAttnBlock(BaseModule):
|
|
71 |
embed = self.embed_conv(x) if self.embed_conv is not None else x
|
72 |
embed = embed.reshape(B, self.num_heads, self.head_channels, H, W)
|
73 |
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
attn_weight = attn_weight.max(dim=-1)[0]
|
76 |
attn_weight = attn_weight / (self.head_channels**0.5)
|
77 |
attn_weight = attn_weight + self.bias[None, :, None, None]
|
@@ -101,7 +113,8 @@ class MaxSigmoidCSPLayerWithTwoConv(CSPLayerWithTwoConv):
|
|
101 |
conv_cfg: OptConfigType = None,
|
102 |
norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001),
|
103 |
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
104 |
-
init_cfg: OptMultiConfig = None
|
|
|
105 |
super().__init__(in_channels=in_channels,
|
106 |
out_channels=out_channels,
|
107 |
expand_ratio=expand_ratio,
|
@@ -126,7 +139,8 @@ class MaxSigmoidCSPLayerWithTwoConv(CSPLayerWithTwoConv):
|
|
126 |
num_heads=num_heads,
|
127 |
with_scale=with_scale,
|
128 |
conv_cfg=conv_cfg,
|
129 |
-
norm_cfg=norm_cfg
|
|
|
130 |
|
131 |
def forward(self, x: Tensor, guide: Tensor) -> Tensor:
|
132 |
"""Forward process."""
|
@@ -298,4 +312,4 @@ class EfficientCSPLayerWithTwoConv(CSPLayerWithTwoConv):
|
|
298 |
x_main = list(x_main.split((self.mid_channels, self.mid_channels), 1))
|
299 |
x_main.extend(blocks(x_main[-1]) for blocks in self.blocks)
|
300 |
x_main.append(self.attn_block(x_main[-1], guide))
|
301 |
-
return self.final_conv(torch.cat(x_main, 1))
|
|
|
29 |
norm_cfg: ConfigType = dict(type='BN',
|
30 |
momentum=0.03,
|
31 |
eps=0.001),
|
32 |
+
init_cfg: OptMultiConfig = None,
|
33 |
+
use_einsum: bool = True) -> None:
|
34 |
super().__init__(init_cfg=init_cfg)
|
35 |
conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
|
36 |
|
|
|
39 |
'out_channels and embed_channels should be divisible by num_heads.'
|
40 |
self.num_heads = num_heads
|
41 |
self.head_channels = out_channels // num_heads
|
42 |
+
self.use_einsum = use_einsum
|
43 |
|
44 |
self.embed_conv = ConvModule(
|
45 |
in_channels,
|
|
|
73 |
embed = self.embed_conv(x) if self.embed_conv is not None else x
|
74 |
embed = embed.reshape(B, self.num_heads, self.head_channels, H, W)
|
75 |
|
76 |
+
if self.use_einsum:
|
77 |
+
attn_weight = torch.einsum('bmchw,bnmc->bmhwn', embed, guide)
|
78 |
+
else:
|
79 |
+
batch, m, channel, height, width = embed.shape
|
80 |
+
_, n, _, _ = guide.shape
|
81 |
+
embed = embed.permute(0, 1, 3, 4, 2)
|
82 |
+
embed = embed.reshape(batch, m, -1, channel)
|
83 |
+
guide = guide.permute(0, 2, 3, 1)
|
84 |
+
attn_weight = torch.matmul(embed, guide)
|
85 |
+
attn_weight = attn_weight.reshape(batch, m, height, width, n)
|
86 |
+
|
87 |
attn_weight = attn_weight.max(dim=-1)[0]
|
88 |
attn_weight = attn_weight / (self.head_channels**0.5)
|
89 |
attn_weight = attn_weight + self.bias[None, :, None, None]
|
|
|
113 |
conv_cfg: OptConfigType = None,
|
114 |
norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001),
|
115 |
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
116 |
+
init_cfg: OptMultiConfig = None,
|
117 |
+
use_einsum: bool = True) -> None:
|
118 |
super().__init__(in_channels=in_channels,
|
119 |
out_channels=out_channels,
|
120 |
expand_ratio=expand_ratio,
|
|
|
139 |
num_heads=num_heads,
|
140 |
with_scale=with_scale,
|
141 |
conv_cfg=conv_cfg,
|
142 |
+
norm_cfg=norm_cfg,
|
143 |
+
use_einsum=use_einsum)
|
144 |
|
145 |
def forward(self, x: Tensor, guide: Tensor) -> Tensor:
|
146 |
"""Forward process."""
|
|
|
312 |
x_main = list(x_main.split((self.mid_channels, self.mid_channels), 1))
|
313 |
x_main.extend(blocks(x_main[-1]) for blocks in self.blocks)
|
314 |
x_main.append(self.attn_block(x_main[-1], guide))
|
315 |
+
return self.final_conv(torch.cat(x_main, 1))
|