File size: 3,520 Bytes
2673600
 
 
 
 
 
 
 
865788c
 
 
 
 
 
 
2673600
865788c
 
 
2673600
 
 
 
 
 
 
 
 
865788c
2673600
 
 
 
865788c
2673600
 
 
 
 
865788c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2673600
 
865788c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import tensorflow as tf
from typing import List

from ..commons import read_image
from ..augmentation import AugmentationFactory


class LowLightDataset:
    def __init__(
        self,
        image_size: int = 256,
        apply_random_horizontal_flip: bool = True,
        apply_random_vertical_flip: bool = True,
        apply_random_rotation: bool = True,
    ) -> None:
        self.augmentation_factory = AugmentationFactory(image_size=image_size)
        self.apply_random_horizontal_flip = apply_random_horizontal_flip
        self.apply_random_vertical_flip = apply_random_vertical_flip
        self.apply_random_rotation = apply_random_rotation

    def load_data(self, low_light_image_path, enhanced_image_path):
        low_light_image = read_image(low_light_image_path)
        enhanced_image = read_image(enhanced_image_path)
        low_light_image, enhanced_image = self.augmentation_factory.random_crop(
            low_light_image, enhanced_image
        )
        return low_light_image, enhanced_image

    def _get_dataset(
        self,
        low_light_images: List[str],
        enhanced_images: List[str],
        batch_size: int = 16,
        is_train: bool = True,
    ):
        dataset = tf.data.Dataset.from_tensor_slices(
            (low_light_images, enhanced_images)
        )
        dataset = dataset.map(self.load_data, num_parallel_calls=tf.data.AUTOTUNE)
        dataset = dataset.map(
            self.augmentation_factory.random_crop, num_parallel_calls=tf.data.AUTOTUNE
        )
        if is_train:
            dataset = (
                dataset.map(
                    self.augmentation_factory.random_horizontal_flip,
                    num_parallel_calls=tf.data.AUTOTUNE,
                )
                if self.apply_random_horizontal_flip
                else dataset
            )
            dataset = (
                dataset.map(
                    self.augmentation_factory.random_vertical_flip,
                    num_parallel_calls=tf.data.AUTOTUNE,
                )
                if self.apply_random_vertical_flip
                else dataset
            )
            dataset = (
                dataset.map(
                    self.augmentation_factory.random_rotate,
                    num_parallel_calls=tf.data.AUTOTUNE,
                )
                if self.apply_random_rotation
                else dataset
            )
        dataset = dataset.batch(batch_size, drop_remainder=True)
        return dataset

    def get_datasets(
        self,
        low_light_images: List[str],
        enhanced_images: List[str],
        val_split: float = 0.2,
        batch_size: int = 16,
    ):
        assert len(low_light_images) == len(enhanced_images)
        split_index = int(len(low_light_images) * (1 - val_split))
        train_low_light_images = low_light_images[:split_index]
        train_enhanced_images = enhanced_images[:split_index]
        val_low_light_images = low_light_images[split_index:]
        val_enhanced_images = enhanced_images[split_index:]
        print(f"Number of train data points: {len(train_low_light_images)}")
        print(f"Number of validation data points: {len(val_low_light_images)}")
        train_dataset = self._get_dataset(
            train_low_light_images, train_enhanced_images, batch_size, is_train=True
        )
        val_dataset = self._get_dataset(
            val_low_light_images, val_enhanced_images, batch_size, is_train=False
        )
        return train_dataset, val_dataset