dsgt-snakeclef / evaluate /data_lightning.py
Anthony Miyaguchi
Remove lightning dependency from submission
a0583df
raw
history blame contribute delete
925 Bytes
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchvision.transforms import v2
from .data import ImageDataset, TransformDino
class InferenceDataModel(pl.LightningDataModule):
def __init__(
self,
metadata_path,
images_root_path,
batch_size=32,
):
super().__init__()
self.metadata_path = metadata_path
self.images_root_path = images_root_path
self.batch_size = batch_size
def setup(self, stage=None):
self.dataloader = DataLoader(
ImageDataset(self.metadata_path, self.images_root_path),
batch_size=self.batch_size,
shuffle=False,
num_workers=4,
)
def predict_dataloader(self):
transform = v2.Compose([TransformDino("facebook/dinov2-base")])
for batch in self.dataloader:
batch = transform(batch)
yield batch