Update yolo_world/models/dense_heads/yolo_world_head.py
Browse files
yolo_world/models/dense_heads/yolo_world_head.py
CHANGED
@@ -35,18 +35,32 @@ class ContrastiveHead(BaseModule):
|
|
35 |
"""
|
36 |
def __init__(self,
|
37 |
embed_dims: int,
|
38 |
-
init_cfg: OptConfigType = None
|
|
|
39 |
|
40 |
super().__init__(init_cfg=init_cfg)
|
41 |
|
42 |
self.bias = nn.Parameter(torch.zeros([]))
|
43 |
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
|
|
44 |
|
45 |
def forward(self, x: Tensor, w: Tensor) -> Tensor:
|
46 |
"""Forward function of contrastive learning."""
|
47 |
x = F.normalize(x, dim=1, p=2)
|
48 |
w = F.normalize(w, dim=-1, p=2)
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
x = x * self.logit_scale.exp() + self.bias
|
51 |
return x
|
52 |
|
@@ -62,19 +76,33 @@ class BNContrastiveHead(BaseModule):
|
|
62 |
def __init__(self,
|
63 |
embed_dims: int,
|
64 |
norm_cfg: ConfigDict,
|
65 |
-
init_cfg: OptConfigType = None
|
|
|
66 |
|
67 |
super().__init__(init_cfg=init_cfg)
|
68 |
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
|
69 |
self.bias = nn.Parameter(torch.zeros([]))
|
70 |
# use -1.0 is more stable
|
71 |
self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))
|
|
|
72 |
|
73 |
def forward(self, x: Tensor, w: Tensor) -> Tensor:
|
74 |
"""Forward function of contrastive learning."""
|
75 |
x = self.norm(x)
|
76 |
w = F.normalize(w, dim=-1, p=2)
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
x = x * self.logit_scale.exp() + self.bias
|
79 |
return x
|
80 |
|
@@ -92,9 +120,11 @@ class YOLOWorldHeadModule(YOLOv8HeadModule):
|
|
92 |
*args,
|
93 |
embed_dims: int,
|
94 |
use_bn_head: bool = False,
|
|
|
95 |
**kwargs) -> None:
|
96 |
self.embed_dims = embed_dims
|
97 |
self.use_bn_head = use_bn_head
|
|
|
98 |
super().__init__(*args, **kwargs)
|
99 |
|
100 |
def init_weights(self, prior_prob=0.01):
|
@@ -161,9 +191,9 @@ class YOLOWorldHeadModule(YOLOv8HeadModule):
|
|
161 |
kernel_size=1)))
|
162 |
if self.use_bn_head:
|
163 |
self.cls_contrasts.append(
|
164 |
-
BNContrastiveHead(self.embed_dims, self.norm_cfg))
|
165 |
else:
|
166 |
-
self.cls_contrasts.append(ContrastiveHead(self.embed_dims))
|
167 |
|
168 |
proj = torch.arange(self.reg_max, dtype=torch.float)
|
169 |
self.register_buffer('proj', proj, persistent=False)
|
@@ -252,7 +282,6 @@ class YOLOWorldHead(YOLOv8Head):
|
|
252 |
def forward(self, img_feats: Tuple[Tensor],
|
253 |
txt_feats: Tensor) -> Tuple[List]:
|
254 |
"""Forward features from the upstream network."""
|
255 |
-
self.num_classes = txt_feats.shape[1]
|
256 |
return self.head_module(img_feats, txt_feats)
|
257 |
|
258 |
def predict(self,
|
@@ -593,4 +622,4 @@ class YOLOWorldHead(YOLOv8Head):
|
|
593 |
results.bboxes[:, 1::2].clamp_(0, ori_shape[0])
|
594 |
|
595 |
results_list.append(results)
|
596 |
-
return results_list
|
|
|
35 |
"""
|
36 |
def __init__(self,
|
37 |
embed_dims: int,
|
38 |
+
init_cfg: OptConfigType = None,
|
39 |
+
use_einsum: bool = True) -> None:
|
40 |
|
41 |
super().__init__(init_cfg=init_cfg)
|
42 |
|
43 |
self.bias = nn.Parameter(torch.zeros([]))
|
44 |
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
45 |
+
self.use_einsum = use_einsum
|
46 |
|
47 |
def forward(self, x: Tensor, w: Tensor) -> Tensor:
|
48 |
"""Forward function of contrastive learning."""
|
49 |
x = F.normalize(x, dim=1, p=2)
|
50 |
w = F.normalize(w, dim=-1, p=2)
|
51 |
+
|
52 |
+
if self.use_einsum:
|
53 |
+
x = torch.einsum('bchw,bkc->bkhw', x, w)
|
54 |
+
else:
|
55 |
+
batch, channel, height, width = x.shape
|
56 |
+
_, k, _ = w.shape
|
57 |
+
x = x.permute(0, 2, 3, 1) # bchw->bhwc
|
58 |
+
x = x.reshape(batch, -1, channel) # bhwc->b(hw)c
|
59 |
+
w = w.permute(0, 2, 1) # bkc->bck
|
60 |
+
x = torch.matmul(x, w)
|
61 |
+
x = x.reshape(batch, height, width, k)
|
62 |
+
x = x.permute(0, 3, 1, 2)
|
63 |
+
|
64 |
x = x * self.logit_scale.exp() + self.bias
|
65 |
return x
|
66 |
|
|
|
76 |
def __init__(self,
|
77 |
embed_dims: int,
|
78 |
norm_cfg: ConfigDict,
|
79 |
+
init_cfg: OptConfigType = None,
|
80 |
+
use_einsum: bool = True) -> None:
|
81 |
|
82 |
super().__init__(init_cfg=init_cfg)
|
83 |
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
|
84 |
self.bias = nn.Parameter(torch.zeros([]))
|
85 |
# use -1.0 is more stable
|
86 |
self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))
|
87 |
+
self.use_einsum = use_einsum
|
88 |
|
89 |
def forward(self, x: Tensor, w: Tensor) -> Tensor:
|
90 |
"""Forward function of contrastive learning."""
|
91 |
x = self.norm(x)
|
92 |
w = F.normalize(w, dim=-1, p=2)
|
93 |
+
|
94 |
+
if self.use_einsum:
|
95 |
+
x = torch.einsum('bchw,bkc->bkhw', x, w)
|
96 |
+
else:
|
97 |
+
batch, channel, height, width = x.shape
|
98 |
+
_, k, _ = w.shape
|
99 |
+
x = x.permute(0, 2, 3, 1) # bchw->bhwc
|
100 |
+
x = x.reshape(batch, -1, channel) # bhwc->b(hw)c
|
101 |
+
w = w.permute(0, 2, 1) # bkc->bck
|
102 |
+
x = torch.matmul(x, w)
|
103 |
+
x = x.reshape(batch, height, width, k)
|
104 |
+
x = x.permute(0, 3, 1, 2)
|
105 |
+
|
106 |
x = x * self.logit_scale.exp() + self.bias
|
107 |
return x
|
108 |
|
|
|
120 |
*args,
|
121 |
embed_dims: int,
|
122 |
use_bn_head: bool = False,
|
123 |
+
use_einsum: bool = True,
|
124 |
**kwargs) -> None:
|
125 |
self.embed_dims = embed_dims
|
126 |
self.use_bn_head = use_bn_head
|
127 |
+
self.use_einsum = use_einsum
|
128 |
super().__init__(*args, **kwargs)
|
129 |
|
130 |
def init_weights(self, prior_prob=0.01):
|
|
|
191 |
kernel_size=1)))
|
192 |
if self.use_bn_head:
|
193 |
self.cls_contrasts.append(
|
194 |
+
BNContrastiveHead(self.embed_dims, self.norm_cfg, use_einsum=self.use_einsum))
|
195 |
else:
|
196 |
+
self.cls_contrasts.append(ContrastiveHead(self.embed_dims, use_einsum=self.use_einsum))
|
197 |
|
198 |
proj = torch.arange(self.reg_max, dtype=torch.float)
|
199 |
self.register_buffer('proj', proj, persistent=False)
|
|
|
282 |
def forward(self, img_feats: Tuple[Tensor],
|
283 |
txt_feats: Tensor) -> Tuple[List]:
|
284 |
"""Forward features from the upstream network."""
|
|
|
285 |
return self.head_module(img_feats, txt_feats)
|
286 |
|
287 |
def predict(self,
|
|
|
622 |
results.bboxes[:, 1::2].clamp_(0, ori_shape[0])
|
623 |
|
624 |
results_list.append(results)
|
625 |
+
return results_list
|