VictorSanh commited on
Commit
545fbb4
1 Parent(s): a4ce5f5

working version

Browse files
Files changed (1) hide show
  1. modeling_siglip.py +15 -20
modeling_siglip.py CHANGED
@@ -292,32 +292,25 @@ class SiglipVisionEmbeddings(nn.Module):
292
  batch_size = pixel_values.size(0)
293
 
294
  patch_embeds = self.patch_embedding(pixel_values)
295
- patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
296
 
297
- patches_to_select = patch_attention_mask.view(batch_size, -1)
298
- max_num_patches = patches_to_select.sum(dim=-1).max()
299
- embeddings = torch.zeros((batch_size, max_num_patches, patch_embeds.size(2)), device=patch_embeds.device, dtype=patch_embeds.dtype)
300
- for b_idx, (p_embeds, p_to_select) in enumerate(zip(patch_embeds, patches_to_select)):
301
- sub_p_embds = p_embeds[p_to_select]
302
- embeddings[b_idx][:len(sub_p_embds)] = sub_p_embds
303
-
304
- boundaries = torch.arange(1/self.num_patches_per_side, 1., 1/self.num_patches_per_side)
305
  max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
306
  max_nb_patches_h, max_nb_patches_w = max_im_h//self.patch_size, max_im_w//self.patch_size
307
- position_ids = torch.full((batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0)
 
308
 
309
  for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
310
- nb_patches_h = p_attn_mask[0].sum()
311
- nb_patches_w = p_attn_mask[:, 0].sum()
312
 
313
- fractional_coords_h = torch.arange(0, 1, 1/nb_patches_h)
314
- fractional_coords_w = torch.arange(0, 1, 1/nb_patches_w)
315
 
316
  bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
317
  bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
318
 
319
  pos_ids = (self.num_patches_per_side * bucket_coords_w[:, None] + bucket_coords_h[None, :]).flatten()
320
- position_ids[batch_idx][:len(pos_ids)] = pos_ids
321
 
322
  position_ids = position_ids.to(self.position_embedding.weight.device)
323
 
@@ -1099,11 +1092,11 @@ class SiglipVisionTransformer(nn.Module):
1099
  )
1100
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1101
 
 
1102
  if pixel_attention_mask is None:
1103
- #TODO
1104
- pass
1105
 
1106
- batch_size = pixel_attention_mask.size(0) # assuming `pixel_attention_mask` is of size bs x h x w
1107
  subgrids = pixel_attention_mask.unfold(dimension=1, size=self.config.patch_size, step=self.config.patch_size).unfold(dimension=2, size=self.config.patch_size, step=self.config.patch_size)
1108
  patch_attention_mask = (subgrids.sum(dim=(-1, -2)) > 0).bool()
1109
 
@@ -1112,9 +1105,11 @@ class SiglipVisionTransformer(nn.Module):
1112
  patch_attention_mask=patch_attention_mask
1113
  )
1114
 
 
 
1115
  encoder_outputs = self.encoder(
1116
  inputs_embeds=hidden_states,
1117
- attention_mask=patch_attention_mask.view(batch_size, -1),
1118
  output_attentions=output_attentions,
1119
  output_hidden_states=output_hidden_states,
1120
  return_dict=return_dict,
@@ -1125,7 +1120,7 @@ class SiglipVisionTransformer(nn.Module):
1125
 
1126
  pooled_output = self.head(
1127
  hidden_state=last_hidden_state,
1128
- attention_mask=patch_attention_mask.view(batch_size, -1)
1129
  )
1130
 
1131
  if not return_dict:
 
292
  batch_size = pixel_values.size(0)
293
 
294
  patch_embeds = self.patch_embedding(pixel_values)
295
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
296
 
 
 
 
 
 
 
 
 
297
  max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
298
  max_nb_patches_h, max_nb_patches_w = max_im_h//self.patch_size, max_im_w//self.patch_size
299
+ boundaries = torch.arange(1/self.num_patches_per_side, 1., 1/self.num_patches_per_side)
300
+ position_ids = torch.full((batch_size, max_nb_patches_h * max_nb_patches_w,), fill_value=0)
301
 
302
  for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
303
+ nb_patches_h = p_attn_mask[:, 0].sum()
304
+ nb_patches_w = p_attn_mask[0].sum()
305
 
306
+ fractional_coords_h = torch.arange(0, 1-1e-6, 1/nb_patches_h)
307
+ fractional_coords_w = torch.arange(0, 1-1e-6, 1/nb_patches_w)
308
 
309
  bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
310
  bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
311
 
312
  pos_ids = (self.num_patches_per_side * bucket_coords_w[:, None] + bucket_coords_h[None, :]).flatten()
313
+ position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
314
 
315
  position_ids = position_ids.to(self.position_embedding.weight.device)
316
 
 
1092
  )
1093
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1094
 
1095
+ batch_size = pixel_values.size(0)
1096
  if pixel_attention_mask is None:
1097
+ # assuming `pixel_attention_mask` is of size bs x h x w
1098
+ pixel_attention_mask = torch.ones(size=(batch_size, pixel_values.size(2), pixel_values.size(3)), dtype=torch.bool, device=pixel_values.device)
1099
 
 
1100
  subgrids = pixel_attention_mask.unfold(dimension=1, size=self.config.patch_size, step=self.config.patch_size).unfold(dimension=2, size=self.config.patch_size, step=self.config.patch_size)
1101
  patch_attention_mask = (subgrids.sum(dim=(-1, -2)) > 0).bool()
1102
 
 
1105
  patch_attention_mask=patch_attention_mask
1106
  )
1107
 
1108
+ patch_attention_mask = patch_attention_mask.view(batch_size, -1)
1109
+
1110
  encoder_outputs = self.encoder(
1111
  inputs_embeds=hidden_states,
1112
+ attention_mask=_prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) if not self.config._flash_attn_2_enabled else patch_attention_mask,
1113
  output_attentions=output_attentions,
1114
  output_hidden_states=output_hidden_states,
1115
  return_dict=return_dict,
 
1120
 
1121
  pooled_output = self.head(
1122
  hidden_state=last_hidden_state,
1123
+ attention_mask=patch_attention_mask,
1124
  )
1125
 
1126
  if not return_dict: