Spaces:
Running
Running
from typing import Optional | |
from omegaconf import DictConfig | |
import pytorch_lightning as L | |
import torch.utils.data as torchdata | |
from .torch import collate, worker_init_fn | |
def get_dataset(name): | |
if name == "mapillary": | |
from .mapillary.data_module import MapillaryDataModule | |
return MapillaryDataModule | |
elif name == "nuscenes": | |
from .nuscenes.data_module import NuScenesData | |
return NuScenesData | |
elif name == "kitti": | |
from .kitti.data_module import BEVKitti360Data | |
return BEVKitti360Data | |
else: | |
raise NotImplementedError(f"Dataset {name} not implemented.") | |
class GenericDataModule(L.LightningDataModule): | |
def __init__(self, cfg: DictConfig): | |
super().__init__() | |
self.cfg = cfg | |
self.data_module = get_dataset(cfg.name)(cfg) | |
def prepare_data(self) -> None: | |
self.data_module.prepare_data() | |
def setup(self, stage: Optional[str] = None): | |
self.data_module.setup(stage) | |
def dataloader( | |
self, | |
stage: str, | |
shuffle: bool = False, | |
num_workers: int = None, | |
sampler: Optional[torchdata.Sampler] = None, | |
): | |
dataset = self.data_module.dataset(stage) | |
cfg = self.cfg["loading"][stage] | |
num_workers = cfg["num_workers"] if num_workers is None else num_workers | |
loader = torchdata.DataLoader( | |
dataset, | |
batch_size=cfg["batch_size"], | |
num_workers=num_workers, | |
shuffle=shuffle or (stage == "train"), | |
pin_memory=True, | |
persistent_workers=num_workers > 0, | |
worker_init_fn=worker_init_fn, | |
collate_fn=collate, | |
sampler=sampler, | |
) | |
return loader | |
def train_dataloader(self, **kwargs): | |
return self.dataloader("train", **kwargs) | |
def val_dataloader(self, **kwargs): | |
return self.dataloader("val", **kwargs) | |
def test_dataloader(self, **kwargs): | |
return self.dataloader("test", **kwargs) |