|
import torch |
|
from PIL import Image |
|
from transformers import CLIPProcessor, CLIPModel |
|
from torch.utils.data import Dataset, DataLoader |
|
import os |
|
import numpy as np |
|
import warnings |
|
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
class ImageDataset(Dataset): |
|
def __init__(self, image_dir, processor): |
|
self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))] |
|
self.processor = processor |
|
|
|
def __len__(self): |
|
return len(self.image_paths) |
|
|
|
def __getitem__(self, idx): |
|
image = Image.open(self.image_paths[idx]) |
|
return self.processor(images=image, return_tensors="pt")['pixel_values'][0] |
|
|
|
def get_clip_embeddings_batch(image_dir, batch_size=32, device='cuda'): |
|
|
|
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) |
|
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
|
|
dataset = ImageDataset(image_dir, processor) |
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4) |
|
|
|
all_embeddings = [] |
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
for batch in dataloader: |
|
batch = batch.to(device) |
|
embeddings = model.get_image_features(pixel_values=batch) |
|
all_embeddings.append(embeddings.cpu().numpy()) |
|
|
|
return np.concatenate(all_embeddings) |
|
|
|
def create_embeddings(image_dir, batch_size, device): |
|
embeddings = get_clip_embeddings_batch(image_dir, batch_size, device) |
|
np.save("image_embeddings.npy", embeddings) |
|
return embeddings |