Spaces:
Running
on
Zero
Running
on
Zero
import pytorch_lightning as pl | |
from torch.utils.data import DataLoader | |
from dataset import MyDataset, load_filenames # dataset.pyに基づく | |
class DataModule(pl.LightningDataModule): | |
def __init__(self, img_dir, batch_size, img_size=112, num_workers=0): | |
super().__init__() | |
self.img_dir = img_dir | |
self.batch_size = batch_size | |
self.img_size = img_size | |
self.num_workers = num_workers | |
self.file_num = 1000 # or 3400 | |
def setup(self, stage=None): | |
filenames = load_filenames(self.img_dir) | |
self.train_dataset = MyDataset(filenames[:self.file_num], img_dir=self.img_dir, img_size=self.img_size) | |
def train_dataloader(self): | |
return DataLoader( | |
self.train_dataset, | |
batch_size=self.batch_size, | |
shuffle=True, | |
num_workers=self.num_workers, | |
persistent_workers=True | |
) | |