VictorSanh
commited on
Commit
•
545fbb4
1
Parent(s):
a4ce5f5
working version
Browse files- modeling_siglip.py +15 -20
modeling_siglip.py
CHANGED
@@ -292,32 +292,25 @@ class SiglipVisionEmbeddings(nn.Module):
|
|
292 |
batch_size = pixel_values.size(0)
|
293 |
|
294 |
patch_embeds = self.patch_embedding(pixel_values)
|
295 |
-
|
296 |
|
297 |
-
patches_to_select = patch_attention_mask.view(batch_size, -1)
|
298 |
-
max_num_patches = patches_to_select.sum(dim=-1).max()
|
299 |
-
embeddings = torch.zeros((batch_size, max_num_patches, patch_embeds.size(2)), device=patch_embeds.device, dtype=patch_embeds.dtype)
|
300 |
-
for b_idx, (p_embeds, p_to_select) in enumerate(zip(patch_embeds, patches_to_select)):
|
301 |
-
sub_p_embds = p_embeds[p_to_select]
|
302 |
-
embeddings[b_idx][:len(sub_p_embds)] = sub_p_embds
|
303 |
-
|
304 |
-
boundaries = torch.arange(1/self.num_patches_per_side, 1., 1/self.num_patches_per_side)
|
305 |
max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
|
306 |
max_nb_patches_h, max_nb_patches_w = max_im_h//self.patch_size, max_im_w//self.patch_size
|
307 |
-
|
|
|
308 |
|
309 |
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
|
310 |
-
nb_patches_h = p_attn_mask[0].sum()
|
311 |
-
nb_patches_w = p_attn_mask[
|
312 |
|
313 |
-
fractional_coords_h = torch.arange(0, 1, 1/nb_patches_h)
|
314 |
-
fractional_coords_w = torch.arange(0, 1, 1/nb_patches_w)
|
315 |
|
316 |
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
|
317 |
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
|
318 |
|
319 |
pos_ids = (self.num_patches_per_side * bucket_coords_w[:, None] + bucket_coords_h[None, :]).flatten()
|
320 |
-
position_ids[batch_idx][
|
321 |
|
322 |
position_ids = position_ids.to(self.position_embedding.weight.device)
|
323 |
|
@@ -1099,11 +1092,11 @@ class SiglipVisionTransformer(nn.Module):
|
|
1099 |
)
|
1100 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1101 |
|
|
|
1102 |
if pixel_attention_mask is None:
|
1103 |
-
#
|
1104 |
-
|
1105 |
|
1106 |
-
batch_size = pixel_attention_mask.size(0) # assuming `pixel_attention_mask` is of size bs x h x w
|
1107 |
subgrids = pixel_attention_mask.unfold(dimension=1, size=self.config.patch_size, step=self.config.patch_size).unfold(dimension=2, size=self.config.patch_size, step=self.config.patch_size)
|
1108 |
patch_attention_mask = (subgrids.sum(dim=(-1, -2)) > 0).bool()
|
1109 |
|
@@ -1112,9 +1105,11 @@ class SiglipVisionTransformer(nn.Module):
|
|
1112 |
patch_attention_mask=patch_attention_mask
|
1113 |
)
|
1114 |
|
|
|
|
|
1115 |
encoder_outputs = self.encoder(
|
1116 |
inputs_embeds=hidden_states,
|
1117 |
-
attention_mask=patch_attention_mask
|
1118 |
output_attentions=output_attentions,
|
1119 |
output_hidden_states=output_hidden_states,
|
1120 |
return_dict=return_dict,
|
@@ -1125,7 +1120,7 @@ class SiglipVisionTransformer(nn.Module):
|
|
1125 |
|
1126 |
pooled_output = self.head(
|
1127 |
hidden_state=last_hidden_state,
|
1128 |
-
attention_mask=patch_attention_mask
|
1129 |
)
|
1130 |
|
1131 |
if not return_dict:
|
|
|
292 |
batch_size = pixel_values.size(0)
|
293 |
|
294 |
patch_embeds = self.patch_embedding(pixel_values)
|
295 |
+
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
296 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
|
298 |
max_nb_patches_h, max_nb_patches_w = max_im_h//self.patch_size, max_im_w//self.patch_size
|
299 |
+
boundaries = torch.arange(1/self.num_patches_per_side, 1., 1/self.num_patches_per_side)
|
300 |
+
position_ids = torch.full((batch_size, max_nb_patches_h * max_nb_patches_w,), fill_value=0)
|
301 |
|
302 |
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
|
303 |
+
nb_patches_h = p_attn_mask[:, 0].sum()
|
304 |
+
nb_patches_w = p_attn_mask[0].sum()
|
305 |
|
306 |
+
fractional_coords_h = torch.arange(0, 1-1e-6, 1/nb_patches_h)
|
307 |
+
fractional_coords_w = torch.arange(0, 1-1e-6, 1/nb_patches_w)
|
308 |
|
309 |
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
|
310 |
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
|
311 |
|
312 |
pos_ids = (self.num_patches_per_side * bucket_coords_w[:, None] + bucket_coords_h[None, :]).flatten()
|
313 |
+
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
|
314 |
|
315 |
position_ids = position_ids.to(self.position_embedding.weight.device)
|
316 |
|
|
|
1092 |
)
|
1093 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1094 |
|
1095 |
+
batch_size = pixel_values.size(0)
|
1096 |
if pixel_attention_mask is None:
|
1097 |
+
# assuming `pixel_attention_mask` is of size bs x h x w
|
1098 |
+
pixel_attention_mask = torch.ones(size=(batch_size, pixel_values.size(2), pixel_values.size(3)), dtype=torch.bool, device=pixel_values.device)
|
1099 |
|
|
|
1100 |
subgrids = pixel_attention_mask.unfold(dimension=1, size=self.config.patch_size, step=self.config.patch_size).unfold(dimension=2, size=self.config.patch_size, step=self.config.patch_size)
|
1101 |
patch_attention_mask = (subgrids.sum(dim=(-1, -2)) > 0).bool()
|
1102 |
|
|
|
1105 |
patch_attention_mask=patch_attention_mask
|
1106 |
)
|
1107 |
|
1108 |
+
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
|
1109 |
+
|
1110 |
encoder_outputs = self.encoder(
|
1111 |
inputs_embeds=hidden_states,
|
1112 |
+
attention_mask=_prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) if not self.config._flash_attn_2_enabled else patch_attention_mask,
|
1113 |
output_attentions=output_attentions,
|
1114 |
output_hidden_states=output_hidden_states,
|
1115 |
return_dict=return_dict,
|
|
|
1120 |
|
1121 |
pooled_output = self.head(
|
1122 |
hidden_state=last_hidden_state,
|
1123 |
+
attention_mask=patch_attention_mask,
|
1124 |
)
|
1125 |
|
1126 |
if not return_dict:
|