File size: 909 Bytes
02ba63a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from dataset import MyDataset, load_filenames  # dataset.pyに基づく

class DataModule(pl.LightningDataModule):
    def __init__(self, img_dir, batch_size, img_size=112, num_workers=0):
        super().__init__()
        self.img_dir = img_dir
        self.batch_size = batch_size
        self.img_size = img_size
        self.num_workers = num_workers
        self.file_num = 1000 # or 3400

    def setup(self, stage=None):
        filenames = load_filenames(self.img_dir)
        self.train_dataset = MyDataset(filenames[:self.file_num], img_dir=self.img_dir, img_size=self.img_size)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            persistent_workers=True
          )