Spaces:
Running
Running
import pytorch_lightning as L | |
from torch.utils.data import DataLoader, random_split | |
import torch | |
import time | |
import webdataset as wds | |
from torch.utils.data import default_collate | |
import math | |
from PIL import Image | |
class ImageDataModule(L.LightningDataModule): | |
def __init__( | |
self, | |
train_dataset, | |
val_dataset, | |
test_dataset, | |
full_batch_size, | |
num_workers, | |
eval_batch_size=None, | |
num_nodes=1, | |
num_devices=1, | |
val_proportion=0.1, | |
): | |
super().__init__() | |
self._builders = { | |
"train": train_dataset, | |
"val": val_dataset, | |
"test": test_dataset, | |
} | |
self.num_workers = num_workers | |
self.collate_fn = dict_collate_fn() | |
self.full_batch_size = full_batch_size | |
self.train_batch_size = full_batch_size // (num_nodes * num_devices) | |
if eval_batch_size is None: | |
self.eval_batch_size = self.train_batch_size | |
self.full_eval_batch_size = self.full_batch_size | |
else: | |
self.eval_batch_size = eval_batch_size // (num_nodes * num_devices) | |
self.full_eval_batch_size = eval_batch_size | |
print(f"Each GPU will receive {self.train_batch_size} images for training") | |
print(f"Each GPU will receive {self.eval_batch_size} images for evaluation") | |
self.val_proportion = val_proportion | |
self.world_size = num_nodes * num_devices | |
def setup(self, stage=None): | |
"""Setup the datamodule. | |
Args: | |
stage (str): stage of the datamodule | |
Is be one of "fit" or "test" or None | |
""" | |
print("Stage", stage) | |
start_time = time.time() | |
if stage == "fit" or stage is None: | |
self.train_dataset = self._builders["train"]() | |
self.train_dataset, self.num_train_batches = self.get_webdataset_length( | |
self.train_dataset, | |
dict_collate_fn(), | |
self.full_batch_size, | |
self.train_batch_size, | |
) | |
self.val_dataset = self._builders["val"]() | |
self.val_dataset, self.num_val_batches = self.get_webdataset_length( | |
self.val_dataset, | |
dict_collate_fn(), | |
self.full_eval_batch_size, | |
self.eval_batch_size, | |
0, | |
) | |
print(f"Train dataset size: {len(self.train_dataset)}") | |
print(f"Val dataset size: {len(self.val_dataset)}") | |
else: | |
self.test_dataset = self._builders["test"]() | |
self.test_dataset, self.num_test_batches = self.get_webdataset_length( | |
self.test_dataset, | |
dict_collate_fn(), | |
self.full_eval_batch_size, | |
self.eval_batch_size, | |
self.num_workers, | |
) | |
print(f"Test dataset size: {len(self.test_dataset)}") | |
end_time = time.time() | |
print(f"Setup took {(end_time - start_time):.2f} seconds") | |
def train_dataloader(self): | |
return wds.WebLoader( | |
self.train_dataset, | |
batch_size=None, | |
shuffle=False, | |
num_workers=self.num_workers, | |
# persistent_workers=self.num_workers > 1, | |
).with_length(self.num_train_batches) | |
# return DataLoader( | |
# self.train_dataset, | |
# batch_size=self.batch_size, | |
# shuffle=True, | |
# pin_memory=False, | |
# drop_last=True, | |
# num_workers=self.num_workers, | |
# collate_fn=self.train_dataset.collate_fn, | |
# ) | |
def val_dataloader(self): | |
return wds.WebLoader( | |
self.val_dataset, | |
batch_size=None, | |
shuffle=False, | |
num_workers=0, | |
).with_length(self.num_val_batches) | |
def test_dataloader(self): | |
return wds.WebLoader( | |
self.test_dataset, | |
batch_size=None, | |
shuffle=False, | |
num_workers=0, | |
).with_length(self.num_test_batches) | |
def get_webdataset_length( | |
self, dataset, collate_fn, full_batch_size, batch_size, num_workers=0 | |
): | |
dataset = dataset.compose( | |
wds.batched( | |
batch_size, | |
partial=self.world_size > 1, | |
collation_fn=collate_fn, | |
# dict_collate_and_pad(["flan_t5_xl"], max_length=256), | |
) | |
) | |
num_samples = dataset.num_samples | |
if self.world_size > 1: | |
num_batches = math.ceil(num_samples / full_batch_size) | |
num_workers = max(1, num_workers) | |
num_worker_batches = math.ceil(num_batches / num_workers) | |
num_batches = num_worker_batches * num_workers | |
num_samples = num_batches * full_batch_size | |
dataset = dataset.with_epoch(num_worker_batches).with_length( | |
num_worker_batches | |
) | |
else: | |
num_batches = math.ceil(num_samples / batch_size) | |
dataset = dataset.with_epoch(num_batches).with_length(num_batches) | |
return dataset, num_batches | |
def dict_collate_fn(): | |
def dict_collate(batch): | |
output_dict = {} | |
if isinstance(batch[0], dict): | |
for key in batch[0].keys(): | |
output_dict[key] = dict_collate([item[key] for item in batch]) | |
else: | |
# Check if the batch contains PIL images | |
if isinstance(batch[0], Image.Image): | |
output_dict = batch # Return list of PIL images directly | |
else: | |
output_dict = default_collate(batch) | |
return output_dict | |
return dict_collate | |