hw04 / create_image_embeddnigs.py
xguman's picture
v1
875ae89
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
class ImageDataset(Dataset):
def __init__(self, image_dir, processor):
self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
self.processor = processor
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = Image.open(self.image_paths[idx])
return self.processor(images=image, return_tensors="pt")['pixel_values'][0]
def get_clip_embeddings_batch(image_dir, batch_size=32, device='cuda'):
# Load the CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Create dataset and dataloader
dataset = ImageDataset(image_dir, processor)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)
all_embeddings = []
model.eval()
with torch.no_grad():
for batch in dataloader:
batch = batch.to(device)
embeddings = model.get_image_features(pixel_values=batch)
all_embeddings.append(embeddings.cpu().numpy())
return np.concatenate(all_embeddings)
def create_embeddings(image_dir, batch_size, device):
embeddings = get_clip_embeddings_batch(image_dir, batch_size, device)
np.save("image_embeddings.npy", embeddings)
return embeddings