meow2018 commited on
Commit
ae673af
·
verified ·
1 Parent(s): a0053b8

Update yolo_world/models/layers/yolo_bricks.py

Browse files
yolo_world/models/layers/yolo_bricks.py CHANGED
@@ -29,7 +29,8 @@ class MaxSigmoidAttnBlock(BaseModule):
29
  norm_cfg: ConfigType = dict(type='BN',
30
  momentum=0.03,
31
  eps=0.001),
32
- init_cfg: OptMultiConfig = None) -> None:
 
33
  super().__init__(init_cfg=init_cfg)
34
  conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
35
 
@@ -38,6 +39,7 @@ class MaxSigmoidAttnBlock(BaseModule):
38
  'out_channels and embed_channels should be divisible by num_heads.'
39
  self.num_heads = num_heads
40
  self.head_channels = out_channels // num_heads
 
41
 
42
  self.embed_conv = ConvModule(
43
  in_channels,
@@ -71,7 +73,17 @@ class MaxSigmoidAttnBlock(BaseModule):
71
  embed = self.embed_conv(x) if self.embed_conv is not None else x
72
  embed = embed.reshape(B, self.num_heads, self.head_channels, H, W)
73
 
74
- attn_weight = torch.einsum('bmchw,bnmc->bmhwn', embed, guide)
 
 
 
 
 
 
 
 
 
 
75
  attn_weight = attn_weight.max(dim=-1)[0]
76
  attn_weight = attn_weight / (self.head_channels**0.5)
77
  attn_weight = attn_weight + self.bias[None, :, None, None]
@@ -101,7 +113,8 @@ class MaxSigmoidCSPLayerWithTwoConv(CSPLayerWithTwoConv):
101
  conv_cfg: OptConfigType = None,
102
  norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001),
103
  act_cfg: ConfigType = dict(type='SiLU', inplace=True),
104
- init_cfg: OptMultiConfig = None) -> None:
 
105
  super().__init__(in_channels=in_channels,
106
  out_channels=out_channels,
107
  expand_ratio=expand_ratio,
@@ -126,7 +139,8 @@ class MaxSigmoidCSPLayerWithTwoConv(CSPLayerWithTwoConv):
126
  num_heads=num_heads,
127
  with_scale=with_scale,
128
  conv_cfg=conv_cfg,
129
- norm_cfg=norm_cfg)
 
130
 
131
  def forward(self, x: Tensor, guide: Tensor) -> Tensor:
132
  """Forward process."""
@@ -298,4 +312,4 @@ class EfficientCSPLayerWithTwoConv(CSPLayerWithTwoConv):
298
  x_main = list(x_main.split((self.mid_channels, self.mid_channels), 1))
299
  x_main.extend(blocks(x_main[-1]) for blocks in self.blocks)
300
  x_main.append(self.attn_block(x_main[-1], guide))
301
- return self.final_conv(torch.cat(x_main, 1))
 
29
  norm_cfg: ConfigType = dict(type='BN',
30
  momentum=0.03,
31
  eps=0.001),
32
+ init_cfg: OptMultiConfig = None,
33
+ use_einsum: bool = True) -> None:
34
  super().__init__(init_cfg=init_cfg)
35
  conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
36
 
 
39
  'out_channels and embed_channels should be divisible by num_heads.'
40
  self.num_heads = num_heads
41
  self.head_channels = out_channels // num_heads
42
+ self.use_einsum = use_einsum
43
 
44
  self.embed_conv = ConvModule(
45
  in_channels,
 
73
  embed = self.embed_conv(x) if self.embed_conv is not None else x
74
  embed = embed.reshape(B, self.num_heads, self.head_channels, H, W)
75
 
76
+ if self.use_einsum:
77
+ attn_weight = torch.einsum('bmchw,bnmc->bmhwn', embed, guide)
78
+ else:
79
+ batch, m, channel, height, width = embed.shape
80
+ _, n, _, _ = guide.shape
81
+ embed = embed.permute(0, 1, 3, 4, 2)
82
+ embed = embed.reshape(batch, m, -1, channel)
83
+ guide = guide.permute(0, 2, 3, 1)
84
+ attn_weight = torch.matmul(embed, guide)
85
+ attn_weight = attn_weight.reshape(batch, m, height, width, n)
86
+
87
  attn_weight = attn_weight.max(dim=-1)[0]
88
  attn_weight = attn_weight / (self.head_channels**0.5)
89
  attn_weight = attn_weight + self.bias[None, :, None, None]
 
113
  conv_cfg: OptConfigType = None,
114
  norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001),
115
  act_cfg: ConfigType = dict(type='SiLU', inplace=True),
116
+ init_cfg: OptMultiConfig = None,
117
+ use_einsum: bool = True) -> None:
118
  super().__init__(in_channels=in_channels,
119
  out_channels=out_channels,
120
  expand_ratio=expand_ratio,
 
139
  num_heads=num_heads,
140
  with_scale=with_scale,
141
  conv_cfg=conv_cfg,
142
+ norm_cfg=norm_cfg,
143
+ use_einsum=use_einsum)
144
 
145
  def forward(self, x: Tensor, guide: Tensor) -> Tensor:
146
  """Forward process."""
 
312
  x_main = list(x_main.split((self.mid_channels, self.mid_channels), 1))
313
  x_main.extend(blocks(x_main[-1]) for blocks in self.blocks)
314
  x_main.append(self.attn_block(x_main[-1], guide))
315
+ return self.final_conv(torch.cat(x_main, 1))