Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import clip | |
from PIL import Image | |
from pdb import set_trace as st | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model, preprocess = clip.load("ViT-B/16", device=device) | |
image = preprocess(Image.open("utils.torch_utils/CLIP.png")).unsqueeze(0).to(device) | |
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device) | |
# with torch.no_grad(): | |
# image_features = model.encode_image(image) | |
# text_features = model.encode_text(text) | |
# logits_per_image, logits_per_text = model(image, text) | |
# probs = logits_per_image.softmax(dim=-1).cpu().numpy() | |
with torch.no_grad(): | |
x = image.type(model.dtype) # 1 3 224 224 | |
self = model.visual | |
x = self.conv1(x) # shape = [*, width, grid, grid] | |
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] | |
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] | |
x = x + self.positional_embedding.to(x.dtype) | |
x = self.ln_pre(x) | |
x = x.permute(1, 0, 2) # NLD -> LND | |
x = self.transformer(x) | |
x = x.permute(1, 0, 2) # LND -> NLD , 1, 50, 768 | |
st() | |
pass | |
print("Label probs:", probs) # prints: [[0.9927937 0.00421068 0.00299572]] |