from typing import Dict, Optional import torch import numpy as np import pytorch_lightning as pl from yacs.config import CfgNode from ..configs import to_lower from .dataset import Dataset class HAMERDataModule(pl.LightningDataModule): def __init__(self, cfg: CfgNode, dataset_cfg: CfgNode) -> None: """ Initialize LightningDataModule for HAMER training Args: cfg (CfgNode): Config file as a yacs CfgNode containing necessary dataset info. dataset_cfg (CfgNode): Dataset configuration file """ super().__init__() self.cfg = cfg self.dataset_cfg = dataset_cfg self.train_dataset = None self.val_dataset = None self.test_dataset = None self.mocap_dataset = None def setup(self, stage: Optional[str] = None) -> None: """ Load datasets necessary for training Args: cfg (CfgNode): Config file as a yacs CfgNode containing necessary dataset info. """ if self.train_dataset == None: self.train_dataset = MixedWebDataset(self.cfg, self.dataset_cfg, train=True).with_epoch(100_000).shuffle(4000) self.val_dataset = MixedWebDataset(self.cfg, self.dataset_cfg, train=False).shuffle(4000) self.mocap_dataset = MoCapDataset(**to_lower(self.dataset_cfg[self.cfg.DATASETS.MOCAP])) def train_dataloader(self) -> Dict: """ Setup training data loader. Returns: Dict: Dictionary containing image and mocap data dataloaders """ train_dataloader = torch.utils.data.DataLoader(self.train_dataset, self.cfg.TRAIN.BATCH_SIZE, drop_last=True, num_workers=self.cfg.GENERAL.NUM_WORKERS, prefetch_factor=self.cfg.GENERAL.PREFETCH_FACTOR) mocap_dataloader = torch.utils.data.DataLoader(self.mocap_dataset, self.cfg.TRAIN.NUM_TRAIN_SAMPLES * self.cfg.TRAIN.BATCH_SIZE, shuffle=True, drop_last=True, num_workers=1) return {'img': train_dataloader, 'mocap': mocap_dataloader} def val_dataloader(self) -> torch.utils.data.DataLoader: """ Setup val data loader. Returns: torch.utils.data.DataLoader: Validation dataloader """ val_dataloader = torch.utils.data.DataLoader(self.val_dataset, self.cfg.TRAIN.BATCH_SIZE, drop_last=True, num_workers=self.cfg.GENERAL.NUM_WORKERS) return val_dataloader