Spaces:
Build error
Build error
from typing import Dict, Optional | |
import torch | |
import numpy as np | |
import pytorch_lightning as pl | |
from yacs.config import CfgNode | |
from ..configs import to_lower | |
from .dataset import Dataset | |
class HAMERDataModule(pl.LightningDataModule): | |
def __init__(self, cfg: CfgNode, dataset_cfg: CfgNode) -> None: | |
""" | |
Initialize LightningDataModule for HAMER training | |
Args: | |
cfg (CfgNode): Config file as a yacs CfgNode containing necessary dataset info. | |
dataset_cfg (CfgNode): Dataset configuration file | |
""" | |
super().__init__() | |
self.cfg = cfg | |
self.dataset_cfg = dataset_cfg | |
self.train_dataset = None | |
self.val_dataset = None | |
self.test_dataset = None | |
self.mocap_dataset = None | |
def setup(self, stage: Optional[str] = None) -> None: | |
""" | |
Load datasets necessary for training | |
Args: | |
cfg (CfgNode): Config file as a yacs CfgNode containing necessary dataset info. | |
""" | |
if self.train_dataset == None: | |
self.train_dataset = MixedWebDataset(self.cfg, self.dataset_cfg, train=True).with_epoch(100_000).shuffle(4000) | |
self.val_dataset = MixedWebDataset(self.cfg, self.dataset_cfg, train=False).shuffle(4000) | |
self.mocap_dataset = MoCapDataset(**to_lower(self.dataset_cfg[self.cfg.DATASETS.MOCAP])) | |
def train_dataloader(self) -> Dict: | |
""" | |
Setup training data loader. | |
Returns: | |
Dict: Dictionary containing image and mocap data dataloaders | |
""" | |
train_dataloader = torch.utils.data.DataLoader(self.train_dataset, self.cfg.TRAIN.BATCH_SIZE, drop_last=True, num_workers=self.cfg.GENERAL.NUM_WORKERS, prefetch_factor=self.cfg.GENERAL.PREFETCH_FACTOR) | |
mocap_dataloader = torch.utils.data.DataLoader(self.mocap_dataset, self.cfg.TRAIN.NUM_TRAIN_SAMPLES * self.cfg.TRAIN.BATCH_SIZE, shuffle=True, drop_last=True, num_workers=1) | |
return {'img': train_dataloader, 'mocap': mocap_dataloader} | |
def val_dataloader(self) -> torch.utils.data.DataLoader: | |
""" | |
Setup val data loader. | |
Returns: | |
torch.utils.data.DataLoader: Validation dataloader | |
""" | |
val_dataloader = torch.utils.data.DataLoader(self.val_dataset, self.cfg.TRAIN.BATCH_SIZE, drop_last=True, num_workers=self.cfg.GENERAL.NUM_WORKERS) | |
return val_dataloader | |