TripletGeoEncoder-demo / datamodule.py
yeq6x's picture
init
02ba63a
raw
history blame contribute delete
909 Bytes
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from dataset import MyDataset, load_filenames # dataset.pyに基づく
class DataModule(pl.LightningDataModule):
def __init__(self, img_dir, batch_size, img_size=112, num_workers=0):
super().__init__()
self.img_dir = img_dir
self.batch_size = batch_size
self.img_size = img_size
self.num_workers = num_workers
self.file_num = 1000 # or 3400
def setup(self, stage=None):
filenames = load_filenames(self.img_dir)
self.train_dataset = MyDataset(filenames[:self.file_num], img_dir=self.img_dir, img_size=self.img_size)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
persistent_workers=True
)