navit style ratio preserving image treatment

#2
by VictorSanh - opened
No description provided.

Dummy test:

import torch
from modeling_siglip import SiglipVisionModel

DEVICE = torch.device("cuda:0")
PATCH_SIZE = 14

pixel_values = torch.randn(2, 3, 28, 42, dtype=torch.bfloat16, device=DEVICE)
pixel_attention_mask = [
    [
        [1] * 14 + [1] * 14  + [1] * 14,
        [1] * 14 + [1] * 14  + [1] * 14,
        [1] * 14 + [1] * 14  + [1] * 14,
        [1] * 14 + [1] * 14  + [1] * 14,
        [1] * 14 + [1] * 14  + [1] * 14,
        [1] * 14 + [1] * 14  + [1] * 14,
        [1] * 14 + [1] * 14  + [1] * 14,
        [1] * 14 + [1] * 14  + [1] * 14,
        [1] * 14 + [1] * 14  + [1] * 14,
        [1] * 14 + [1] * 14  + [1] * 14,
        [1] * 14 + [1] * 14  + [1] * 14,
        [1] * 14 + [1] * 14  + [1] * 14,
        [1] * 14 + [1] * 14  + [1] * 14,
        [1] * 14 + [1] * 14  + [1] * 14,

        [0] * 14 + [0] * 14  + [0] * 14,
        [0] * 14 + [0] * 14  + [0] * 14,
        [0] * 14 + [0] * 14  + [0] * 14,
        [0] * 14 + [0] * 14  + [0] * 14,
        [0] * 14 + [0] * 14  + [0] * 14,
        [0] * 14 + [0] * 14  + [0] * 14,
        [0] * 14 + [0] * 14  + [0] * 14,
        [0] * 14 + [0] * 14  + [0] * 14,
        [0] * 14 + [0] * 14  + [0] * 14,
        [0] * 14 + [0] * 14  + [0] * 14,
        [0] * 14 + [0] * 14  + [0] * 14,
        [0] * 14 + [0] * 14  + [0] * 14,
        [0] * 14 + [0] * 14  + [0] * 14,
        [0] * 14 + [0] * 14  + [0] * 14,
    ],
    [
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,

        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
    ],
]
pixel_attention_mask = torch.tensor(pixel_attention_mask, dtype=torch.bool, device=DEVICE)
patches_subgrid = pixel_attention_mask.unfold(
    dimension=1, size=PATCH_SIZE, step=PATCH_SIZE
).unfold(dimension=2, size=PATCH_SIZE, step=PATCH_SIZE)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()

model = SiglipVisionModel.from_pretrained("LOCAL_PATH/siglip-so400m-14-384-flash-attn2/", _flash_attn_2_enabled=True)
model.train()
model.vision_model.to(DEVICE, dtype=torch.bfloat16)

output = model.vision_model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)

Deactivate 3 checks inside modeling_siglip.py for debugging. will re-add them later

looks good to me!

Publish this branch
This branch is in draft mode, publish it to be able to merge.

Sign up or log in to comment