from pathlib import Path import pandas as pd import torch from PIL import Image from torch.utils.data import Dataset from torchvision.transforms import v2 from transformers import AutoImageProcessor, AutoModel class TransformDino(v2.Transform): def __init__(self, model_name="facebook/dinov2-base"): super().__init__() self.processor = AutoImageProcessor.from_pretrained(model_name) self.model = AutoModel.from_pretrained(model_name) def forward(self, batch): model_inputs = self.processor(images=batch["features"], return_tensors="pt") with torch.no_grad(): outputs = self.model(**model_inputs) last_hidden_states = outputs.last_hidden_state # extract the cls token batch["features"] = last_hidden_states[:, 0] return batch class ImageDataset(Dataset): def __init__(self, metadata_path, images_root_path): self.metadata_path = metadata_path self.metadata = pd.read_csv(metadata_path) self.images_root_path = images_root_path def __len__(self): return len(self.metadata) def __getitem__(self, idx): row = self.metadata.iloc[idx] image_path = Path(self.images_root_path) / row.filename img = Image.open(image_path).convert("RGB") img = v2.ToTensor()(img) return {"features": img, "observation_id": row.observation_id}