Spaces:
Runtime error
Runtime error
jwyang
commited on
Commit
•
294ea6c
1
Parent(s):
92144aa
average over multiple prompts
Browse files
detectron2/modeling/meta_arch/clip_rcnn.py
CHANGED
@@ -751,8 +751,10 @@ class CLIPFastRCNN(nn.Module):
|
|
751 |
text_features = self.backbone.encode_text(queries)
|
752 |
else:
|
753 |
features = self.backbone(images.tensor)
|
754 |
-
token_embeddings = pre_tokenize([queries])
|
755 |
text_features = self.lang_encoder.encode_text(token_embeddings)
|
|
|
|
|
756 |
|
757 |
if self.backbone_type == "resnet":
|
758 |
head = self.backbone.layer4
|
|
|
751 |
text_features = self.backbone.encode_text(queries)
|
752 |
else:
|
753 |
features = self.backbone(images.tensor)
|
754 |
+
token_embeddings = pre_tokenize([queries]).to(images.tensor.device)[0]
|
755 |
text_features = self.lang_encoder.encode_text(token_embeddings)
|
756 |
+
text_features = text_features.mean(0, keepdim=True)
|
757 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
758 |
|
759 |
if self.backbone_type == "resnet":
|
760 |
head = self.backbone.layer4
|