import pytorch_lightning as pl from torch.utils.data import DataLoader from dataset import MyDataset, load_filenames # dataset.pyに基づく class DataModule(pl.LightningDataModule): def __init__(self, img_dir, batch_size, img_size=112, num_workers=0): super().__init__() self.img_dir = img_dir self.batch_size = batch_size self.img_size = img_size self.num_workers = num_workers self.file_num = 1000 # or 3400 def setup(self, stage=None): filenames = load_filenames(self.img_dir) self.train_dataset = MyDataset(filenames[:self.file_num], img_dir=self.img_dir, img_size=self.img_size) def train_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, persistent_workers=True )