segformer-sidewalk / dataloader.py
chainyo's picture
create dataset and dataloader
3110ea7
raw
history blame
2.62 kB
import numpy as np
import pytorch_lightning as pl
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset, random_split, Subset
from transformers import SegformerFeatureExtractor, BatchFeature
from typing import Optional
class SegmentationDataset(Dataset):
"""Image Segmentation Dataset"""
def __init__(self, pixel_values: torch.Tensor, labels: torch.Tensor):
"""
Dataset for image segmentation.
Parameters
----------
pixel_values : torch.Tensor
Tensor of shape (N, H, W) containing the pixel values of the images.
labels : torch.Tensor
Tensor of shape (H, W) containing the labels of the images.
"""
self.pixel_values = pixel_values
self.labels = labels
assert pixel_values.shape[0] == labels.shape[0]
self.length = pixel_values.shape[0]
print(f"Created dataset with {self.length} samples")
def __len__(self):
return self.length
def __getitem__(self, index):
image = self.pixel_values[index]
label = self.labels[index]
encoded_inputs = BatchFeature({"pixel_values": image, "labels": label})
return encoded_inputs
class SidewalkSegmentationDataLoader(pl.LightningDataModule):
def __init__(
self, hub_dir: str, batch_size: int, split: Optional[str] = None,
):
super().__init__()
self.hub_dir = hub_dir
self.batch_size = batch_size
self.tokenizer = SegformerFeatureExtractor(reduce_labels=True)
self.dataset = load_dataset(self.hub_dir, split=split)
self.len = len(self.dataset)
def tokenize_data(self, *args, **kwargs):
return self.tokenizer(*args, **kwargs)
def setup(self, stage: str = None):
encoded_dataset = self.tokenize_data(
images=self.dataset["pixel_values"], segmentation_maps=self.dataset["label"], return_tensors="pt"
)
dataset = SegmentationDataset(encoded_dataset["pixel_values"], encoded_dataset["labels"])
indices = np.arange(self.len)
train_indices, val_indices = random_split(indices, [int(self.len * 0.8), int(self.len * 0.2)])
self.train_dataset = Subset(dataset, train_indices)
self.val_dataset = Subset(dataset, val_indices)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=12)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=12)