File size: 2,619 Bytes
3110ea7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import numpy as np
import pytorch_lightning as pl
import torch

from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset, random_split, Subset
from transformers import SegformerFeatureExtractor, BatchFeature

from typing import Optional


class SegmentationDataset(Dataset):
    """Image Segmentation Dataset"""
    def __init__(self, pixel_values: torch.Tensor, labels: torch.Tensor):
        """
        Dataset for image segmentation.

        Parameters
        ----------
        pixel_values : torch.Tensor
            Tensor of shape (N, H, W) containing the pixel values of the images.
        labels : torch.Tensor
            Tensor of shape (H, W) containing the labels of the images.
        """
        self.pixel_values = pixel_values
        self.labels = labels
        assert pixel_values.shape[0] == labels.shape[0]
        self.length = pixel_values.shape[0]
        print(f"Created dataset with {self.length} samples")
    

    def __len__(self):
        return self.length
    

    def __getitem__(self, index):
        image = self.pixel_values[index]
        label = self.labels[index]

        encoded_inputs = BatchFeature({"pixel_values": image, "labels": label})

        return encoded_inputs


class SidewalkSegmentationDataLoader(pl.LightningDataModule):
    def __init__(
        self, hub_dir: str, batch_size: int, split: Optional[str] = None,
    ):
        super().__init__()
        self.hub_dir = hub_dir
        self.batch_size = batch_size
        self.tokenizer = SegformerFeatureExtractor(reduce_labels=True)
        self.dataset = load_dataset(self.hub_dir, split=split)
        self.len = len(self.dataset)


    def tokenize_data(self, *args, **kwargs):
        return self.tokenizer(*args, **kwargs)

    
    def setup(self, stage: str = None):
        encoded_dataset = self.tokenize_data(
            images=self.dataset["pixel_values"], segmentation_maps=self.dataset["label"], return_tensors="pt"
        )
        dataset = SegmentationDataset(encoded_dataset["pixel_values"], encoded_dataset["labels"])

        indices = np.arange(self.len)
        train_indices, val_indices = random_split(indices, [int(self.len * 0.8), int(self.len * 0.2)])

        self.train_dataset = Subset(dataset, train_indices)
        self.val_dataset = Subset(dataset, val_indices)


    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=12)

    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=12)