meow2018 commited on
Commit
a0053b8
·
verified ·
1 Parent(s): 1d5dd06

Update yolo_world/models/dense_heads/yolo_world_head.py

Browse files
yolo_world/models/dense_heads/yolo_world_head.py CHANGED
@@ -35,18 +35,32 @@ class ContrastiveHead(BaseModule):
35
  """
36
  def __init__(self,
37
  embed_dims: int,
38
- init_cfg: OptConfigType = None) -> None:
 
39
 
40
  super().__init__(init_cfg=init_cfg)
41
 
42
  self.bias = nn.Parameter(torch.zeros([]))
43
  self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
 
44
 
45
  def forward(self, x: Tensor, w: Tensor) -> Tensor:
46
  """Forward function of contrastive learning."""
47
  x = F.normalize(x, dim=1, p=2)
48
  w = F.normalize(w, dim=-1, p=2)
49
- x = torch.einsum('bchw,bkc->bkhw', x, w)
 
 
 
 
 
 
 
 
 
 
 
 
50
  x = x * self.logit_scale.exp() + self.bias
51
  return x
52
 
@@ -62,19 +76,33 @@ class BNContrastiveHead(BaseModule):
62
  def __init__(self,
63
  embed_dims: int,
64
  norm_cfg: ConfigDict,
65
- init_cfg: OptConfigType = None) -> None:
 
66
 
67
  super().__init__(init_cfg=init_cfg)
68
  self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
69
  self.bias = nn.Parameter(torch.zeros([]))
70
  # use -1.0 is more stable
71
  self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))
 
72
 
73
  def forward(self, x: Tensor, w: Tensor) -> Tensor:
74
  """Forward function of contrastive learning."""
75
  x = self.norm(x)
76
  w = F.normalize(w, dim=-1, p=2)
77
- x = torch.einsum('bchw,bkc->bkhw', x, w)
 
 
 
 
 
 
 
 
 
 
 
 
78
  x = x * self.logit_scale.exp() + self.bias
79
  return x
80
 
@@ -92,9 +120,11 @@ class YOLOWorldHeadModule(YOLOv8HeadModule):
92
  *args,
93
  embed_dims: int,
94
  use_bn_head: bool = False,
 
95
  **kwargs) -> None:
96
  self.embed_dims = embed_dims
97
  self.use_bn_head = use_bn_head
 
98
  super().__init__(*args, **kwargs)
99
 
100
  def init_weights(self, prior_prob=0.01):
@@ -161,9 +191,9 @@ class YOLOWorldHeadModule(YOLOv8HeadModule):
161
  kernel_size=1)))
162
  if self.use_bn_head:
163
  self.cls_contrasts.append(
164
- BNContrastiveHead(self.embed_dims, self.norm_cfg))
165
  else:
166
- self.cls_contrasts.append(ContrastiveHead(self.embed_dims))
167
 
168
  proj = torch.arange(self.reg_max, dtype=torch.float)
169
  self.register_buffer('proj', proj, persistent=False)
@@ -252,7 +282,6 @@ class YOLOWorldHead(YOLOv8Head):
252
  def forward(self, img_feats: Tuple[Tensor],
253
  txt_feats: Tensor) -> Tuple[List]:
254
  """Forward features from the upstream network."""
255
- self.num_classes = txt_feats.shape[1]
256
  return self.head_module(img_feats, txt_feats)
257
 
258
  def predict(self,
@@ -593,4 +622,4 @@ class YOLOWorldHead(YOLOv8Head):
593
  results.bboxes[:, 1::2].clamp_(0, ori_shape[0])
594
 
595
  results_list.append(results)
596
- return results_list
 
35
  """
36
  def __init__(self,
37
  embed_dims: int,
38
+ init_cfg: OptConfigType = None,
39
+ use_einsum: bool = True) -> None:
40
 
41
  super().__init__(init_cfg=init_cfg)
42
 
43
  self.bias = nn.Parameter(torch.zeros([]))
44
  self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
45
+ self.use_einsum = use_einsum
46
 
47
  def forward(self, x: Tensor, w: Tensor) -> Tensor:
48
  """Forward function of contrastive learning."""
49
  x = F.normalize(x, dim=1, p=2)
50
  w = F.normalize(w, dim=-1, p=2)
51
+
52
+ if self.use_einsum:
53
+ x = torch.einsum('bchw,bkc->bkhw', x, w)
54
+ else:
55
+ batch, channel, height, width = x.shape
56
+ _, k, _ = w.shape
57
+ x = x.permute(0, 2, 3, 1) # bchw->bhwc
58
+ x = x.reshape(batch, -1, channel) # bhwc->b(hw)c
59
+ w = w.permute(0, 2, 1) # bkc->bck
60
+ x = torch.matmul(x, w)
61
+ x = x.reshape(batch, height, width, k)
62
+ x = x.permute(0, 3, 1, 2)
63
+
64
  x = x * self.logit_scale.exp() + self.bias
65
  return x
66
 
 
76
  def __init__(self,
77
  embed_dims: int,
78
  norm_cfg: ConfigDict,
79
+ init_cfg: OptConfigType = None,
80
+ use_einsum: bool = True) -> None:
81
 
82
  super().__init__(init_cfg=init_cfg)
83
  self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
84
  self.bias = nn.Parameter(torch.zeros([]))
85
  # use -1.0 is more stable
86
  self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))
87
+ self.use_einsum = use_einsum
88
 
89
  def forward(self, x: Tensor, w: Tensor) -> Tensor:
90
  """Forward function of contrastive learning."""
91
  x = self.norm(x)
92
  w = F.normalize(w, dim=-1, p=2)
93
+
94
+ if self.use_einsum:
95
+ x = torch.einsum('bchw,bkc->bkhw', x, w)
96
+ else:
97
+ batch, channel, height, width = x.shape
98
+ _, k, _ = w.shape
99
+ x = x.permute(0, 2, 3, 1) # bchw->bhwc
100
+ x = x.reshape(batch, -1, channel) # bhwc->b(hw)c
101
+ w = w.permute(0, 2, 1) # bkc->bck
102
+ x = torch.matmul(x, w)
103
+ x = x.reshape(batch, height, width, k)
104
+ x = x.permute(0, 3, 1, 2)
105
+
106
  x = x * self.logit_scale.exp() + self.bias
107
  return x
108
 
 
120
  *args,
121
  embed_dims: int,
122
  use_bn_head: bool = False,
123
+ use_einsum: bool = True,
124
  **kwargs) -> None:
125
  self.embed_dims = embed_dims
126
  self.use_bn_head = use_bn_head
127
+ self.use_einsum = use_einsum
128
  super().__init__(*args, **kwargs)
129
 
130
  def init_weights(self, prior_prob=0.01):
 
191
  kernel_size=1)))
192
  if self.use_bn_head:
193
  self.cls_contrasts.append(
194
+ BNContrastiveHead(self.embed_dims, self.norm_cfg, use_einsum=self.use_einsum))
195
  else:
196
+ self.cls_contrasts.append(ContrastiveHead(self.embed_dims, use_einsum=self.use_einsum))
197
 
198
  proj = torch.arange(self.reg_max, dtype=torch.float)
199
  self.register_buffer('proj', proj, persistent=False)
 
282
  def forward(self, img_feats: Tuple[Tensor],
283
  txt_feats: Tensor) -> Tuple[List]:
284
  """Forward features from the upstream network."""
 
285
  return self.head_module(img_feats, txt_feats)
286
 
287
  def predict(self,
 
622
  results.bboxes[:, 1::2].clamp_(0, ori_shape[0])
623
 
624
  results_list.append(results)
625
+ return results_list