import pytorch_lightning as pl | |
from torch.utils.data import DataLoader | |
from torchvision.transforms import v2 | |
from .data import ImageDataset, TransformDino | |
class InferenceDataModel(pl.LightningDataModule): | |
def __init__( | |
self, | |
metadata_path, | |
images_root_path, | |
batch_size=32, | |
): | |
super().__init__() | |
self.metadata_path = metadata_path | |
self.images_root_path = images_root_path | |
self.batch_size = batch_size | |
def setup(self, stage=None): | |
self.dataloader = DataLoader( | |
ImageDataset(self.metadata_path, self.images_root_path), | |
batch_size=self.batch_size, | |
shuffle=False, | |
num_workers=4, | |
) | |
def predict_dataloader(self): | |
transform = v2.Compose([TransformDino("facebook/dinov2-base")]) | |
for batch in self.dataloader: | |
batch = transform(batch) | |
yield batch | |