VictorSanh
commited on
Commit
•
a9d91fb
1
Parent(s):
6601ecb
from pixel_values_attention_masks to patch_attention_mask"
Browse files- modeling_siglip.py +12 -6
modeling_siglip.py
CHANGED
@@ -1077,7 +1077,7 @@ class SiglipVisionTransformer(nn.Module):
|
|
1077 |
def forward(
|
1078 |
self,
|
1079 |
pixel_values,
|
1080 |
-
|
1081 |
output_attentions: Optional[bool] = None,
|
1082 |
output_hidden_states: Optional[bool] = None,
|
1083 |
return_dict: Optional[bool] = None,
|
@@ -1093,12 +1093,18 @@ class SiglipVisionTransformer(nn.Module):
|
|
1093 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1094 |
|
1095 |
batch_size = pixel_values.size(0)
|
1096 |
-
if
|
1097 |
-
|
1098 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1099 |
|
1100 |
-
subgrids = pixel_attention_mask.unfold(dimension=1, size=self.config.patch_size, step=self.config.patch_size).unfold(dimension=2, size=self.config.patch_size, step=self.config.patch_size)
|
1101 |
-
patch_attention_mask = (subgrids.sum(dim=(-1, -2)) > 0).bool()
|
1102 |
|
1103 |
hidden_states = self.embeddings(
|
1104 |
pixel_values=pixel_values,
|
|
|
1077 |
def forward(
|
1078 |
self,
|
1079 |
pixel_values,
|
1080 |
+
patch_attention_mask: Optional[torch.BoolTensor] = None,
|
1081 |
output_attentions: Optional[bool] = None,
|
1082 |
output_hidden_states: Optional[bool] = None,
|
1083 |
return_dict: Optional[bool] = None,
|
|
|
1093 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1094 |
|
1095 |
batch_size = pixel_values.size(0)
|
1096 |
+
if patch_attention_mask is None:
|
1097 |
+
patch_attention_mask = torch.ones(
|
1098 |
+
size=(batch_size, pixel_values.size(2)//self.config.patch_size, pixel_values.size(3)//self.config.patch_size),
|
1099 |
+
dtype=torch.bool,
|
1100 |
+
device=pixel_values.device,
|
1101 |
+
)
|
1102 |
+
# if pixel_attention_mask is None:
|
1103 |
+
# # assuming `pixel_attention_mask` is of size bs x h x w
|
1104 |
+
# pixel_attention_mask = torch.ones(size=(batch_size, pixel_values.size(2), pixel_values.size(3)), dtype=torch.bool, device=pixel_values.device)
|
1105 |
|
1106 |
+
# subgrids = pixel_attention_mask.unfold(dimension=1, size=self.config.patch_size, step=self.config.patch_size).unfold(dimension=2, size=self.config.patch_size, step=self.config.patch_size)
|
1107 |
+
# patch_attention_mask = (subgrids.sum(dim=(-1, -2)) > 0).bool()
|
1108 |
|
1109 |
hidden_states = self.embeddings(
|
1110 |
pixel_values=pixel_values,
|