VictorSanh commited on
Commit
a9d91fb
1 Parent(s): 6601ecb

from pixel_values_attention_masks to patch_attention_mask"

Browse files
Files changed (1) hide show
  1. modeling_siglip.py +12 -6
modeling_siglip.py CHANGED
@@ -1077,7 +1077,7 @@ class SiglipVisionTransformer(nn.Module):
1077
  def forward(
1078
  self,
1079
  pixel_values,
1080
- pixel_attention_mask: Optional[torch.BoolTensor] = None,
1081
  output_attentions: Optional[bool] = None,
1082
  output_hidden_states: Optional[bool] = None,
1083
  return_dict: Optional[bool] = None,
@@ -1093,12 +1093,18 @@ class SiglipVisionTransformer(nn.Module):
1093
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1094
 
1095
  batch_size = pixel_values.size(0)
1096
- if pixel_attention_mask is None:
1097
- # assuming `pixel_attention_mask` is of size bs x h x w
1098
- pixel_attention_mask = torch.ones(size=(batch_size, pixel_values.size(2), pixel_values.size(3)), dtype=torch.bool, device=pixel_values.device)
 
 
 
 
 
 
1099
 
1100
- subgrids = pixel_attention_mask.unfold(dimension=1, size=self.config.patch_size, step=self.config.patch_size).unfold(dimension=2, size=self.config.patch_size, step=self.config.patch_size)
1101
- patch_attention_mask = (subgrids.sum(dim=(-1, -2)) > 0).bool()
1102
 
1103
  hidden_states = self.embeddings(
1104
  pixel_values=pixel_values,
 
1077
  def forward(
1078
  self,
1079
  pixel_values,
1080
+ patch_attention_mask: Optional[torch.BoolTensor] = None,
1081
  output_attentions: Optional[bool] = None,
1082
  output_hidden_states: Optional[bool] = None,
1083
  return_dict: Optional[bool] = None,
 
1093
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1094
 
1095
  batch_size = pixel_values.size(0)
1096
+ if patch_attention_mask is None:
1097
+ patch_attention_mask = torch.ones(
1098
+ size=(batch_size, pixel_values.size(2)//self.config.patch_size, pixel_values.size(3)//self.config.patch_size),
1099
+ dtype=torch.bool,
1100
+ device=pixel_values.device,
1101
+ )
1102
+ # if pixel_attention_mask is None:
1103
+ # # assuming `pixel_attention_mask` is of size bs x h x w
1104
+ # pixel_attention_mask = torch.ones(size=(batch_size, pixel_values.size(2), pixel_values.size(3)), dtype=torch.bool, device=pixel_values.device)
1105
 
1106
+ # subgrids = pixel_attention_mask.unfold(dimension=1, size=self.config.patch_size, step=self.config.patch_size).unfold(dimension=2, size=self.config.patch_size, step=self.config.patch_size)
1107
+ # patch_attention_mask = (subgrids.sum(dim=(-1, -2)) > 0).bool()
1108
 
1109
  hidden_states = self.embeddings(
1110
  pixel_values=pixel_values,