|
import torch |
|
from resampler import Resampler |
|
from transformers import CLIPVisionModel |
|
|
|
BATCH_SIZE = 2 |
|
OUTPUT_DIM = 1280 |
|
NUM_QUERIES = 8 |
|
NUM_LATENTS_MEAN_POOLED = 4 |
|
APPLY_POS_EMB = True |
|
IMAGE_ENCODER_NAME_OR_PATH = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" |
|
|
|
|
|
def main(): |
|
image_encoder = CLIPVisionModel.from_pretrained(IMAGE_ENCODER_NAME_OR_PATH) |
|
embedding_dim = image_encoder.config.hidden_size |
|
print(f"image_encoder hidden size: ", embedding_dim) |
|
|
|
image_proj_model = Resampler( |
|
dim=1024, |
|
depth=2, |
|
dim_head=64, |
|
heads=16, |
|
num_queries=NUM_QUERIES, |
|
embedding_dim=embedding_dim, |
|
output_dim=OUTPUT_DIM, |
|
ff_mult=2, |
|
max_seq_len=257, |
|
apply_pos_emb=APPLY_POS_EMB, |
|
num_latents_mean_pooled=NUM_LATENTS_MEAN_POOLED, |
|
) |
|
|
|
dummy_images = torch.randn(BATCH_SIZE, 3, 224, 224) |
|
with torch.no_grad(): |
|
image_embeds = image_encoder(dummy_images, output_hidden_states=True).hidden_states[-2] |
|
print("image_embds shape: ", image_embeds.shape) |
|
|
|
with torch.no_grad(): |
|
ip_tokens = image_proj_model(image_embeds) |
|
print("ip_tokens shape:", ip_tokens.shape) |
|
assert ip_tokens.shape == (BATCH_SIZE, NUM_QUERIES + NUM_LATENTS_MEAN_POOLED, OUTPUT_DIM) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|