|
import os |
|
import torch |
|
from torch.utils.data import Dataset |
|
|
|
class FlatTileDataset(Dataset): |
|
def __init__(self, data_dir): |
|
super().__init__() |
|
self.data_dir = data_dir |
|
|
|
self.files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if os.path.isfile(os.path.join(data_dir, f))] |
|
|
|
def __len__(self): |
|
|
|
return len(self.files) |
|
|
|
def __getitem__(self, idx): |
|
|
|
file_path = self.files[idx] |
|
|
|
data = torch.load(file_path) |
|
|
|
tile_data = torch.from_numpy(data['tile_data'][0]) |
|
file_data = data['file_data'] |
|
|
|
return tile_data, file_data |
|
|