|
import pytorch_lightning as L
|
|
from torch.utils.data import DataLoader, random_split
|
|
import torch
|
|
import time
|
|
|
|
|
|
class ImageDataModule(L.LightningDataModule):
|
|
def __init__(
|
|
self,
|
|
train_dataset,
|
|
val_dataset,
|
|
test_dataset,
|
|
global_batch_size,
|
|
num_workers,
|
|
num_nodes=1,
|
|
num_devices=1,
|
|
val_proportion=0.1,
|
|
):
|
|
super().__init__()
|
|
self._builders = {
|
|
"train": train_dataset,
|
|
"val": val_dataset,
|
|
"test": test_dataset,
|
|
}
|
|
self.num_workers = num_workers
|
|
self.batch_size = global_batch_size // (num_nodes * num_devices)
|
|
print(f"Each GPU will receive {self.batch_size} images")
|
|
self.val_proportion = val_proportion
|
|
|
|
@property
|
|
def num_classes(self):
|
|
if hasattr(self, "train_dataset"):
|
|
return self.train_dataset.num_classes
|
|
else:
|
|
return self._builders["train"]().num_classes
|
|
|
|
def setup(self, stage=None):
|
|
"""Setup the datamodule.
|
|
Args:
|
|
stage (str): stage of the datamodule
|
|
Is be one of "fit" or "test" or None
|
|
"""
|
|
print("Stage", stage)
|
|
start_time = time.time()
|
|
if stage == "fit" or stage is None:
|
|
self.train_dataset = self._builders["train"]()
|
|
self.val_dataset = self._builders["val"]()
|
|
print(f"Train dataset size: {len(self.train_dataset)}")
|
|
print(f"Val dataset size: {len(self.val_dataset)}")
|
|
else:
|
|
self.test_dataset = self._builders["test"]()
|
|
print(f"Test dataset size: {len(self.test_dataset)}")
|
|
end_time = time.time()
|
|
print(f"Setup took {(end_time - start_time):.2f} seconds")
|
|
|
|
def train_dataloader(self):
|
|
return DataLoader(
|
|
self.train_dataset,
|
|
batch_size=self.batch_size,
|
|
shuffle=True,
|
|
pin_memory=False,
|
|
drop_last=True,
|
|
num_workers=self.num_workers,
|
|
collate_fn=self.train_dataset.collate_fn_density,
|
|
)
|
|
|
|
def val_dataloader(self):
|
|
return DataLoader(
|
|
self.val_dataset,
|
|
batch_size=self.batch_size,
|
|
shuffle=False,
|
|
pin_memory=False,
|
|
num_workers=self.num_workers,
|
|
collate_fn=self.val_dataset.collate_fn,
|
|
)
|
|
|
|
def test_dataloader(self):
|
|
return DataLoader(
|
|
self.test_dataset,
|
|
batch_size=self.batch_size,
|
|
shuffle=False,
|
|
pin_memory=False,
|
|
num_workers=self.num_workers,
|
|
collate_fn=self.test_dataset.collate_fn,
|
|
)
|
|
|