import tensorflow as tf from typing import List from ..commons import read_image from ..augmentation import AugmentationFactory class LowLightDataset: def __init__( self, image_size: int = 256, apply_random_horizontal_flip: bool = True, apply_random_vertical_flip: bool = True, apply_random_rotation: bool = True, ) -> None: self.augmentation_factory = AugmentationFactory(image_size=image_size) self.apply_random_horizontal_flip = apply_random_horizontal_flip self.apply_random_vertical_flip = apply_random_vertical_flip self.apply_random_rotation = apply_random_rotation def load_data(self, low_light_image_path, enhanced_image_path): low_light_image = read_image(low_light_image_path) enhanced_image = read_image(enhanced_image_path) low_light_image, enhanced_image = self.augmentation_factory.random_crop( low_light_image, enhanced_image ) return low_light_image, enhanced_image def _get_dataset( self, low_light_images: List[str], enhanced_images: List[str], batch_size: int = 16, is_train: bool = True, ): dataset = tf.data.Dataset.from_tensor_slices( (low_light_images, enhanced_images) ) dataset = dataset.map(self.load_data, num_parallel_calls=tf.data.AUTOTUNE) dataset = dataset.map( self.augmentation_factory.random_crop, num_parallel_calls=tf.data.AUTOTUNE ) if is_train: dataset = ( dataset.map( self.augmentation_factory.random_horizontal_flip, num_parallel_calls=tf.data.AUTOTUNE, ) if self.apply_random_horizontal_flip else dataset ) dataset = ( dataset.map( self.augmentation_factory.random_vertical_flip, num_parallel_calls=tf.data.AUTOTUNE, ) if self.apply_random_vertical_flip else dataset ) dataset = ( dataset.map( self.augmentation_factory.random_rotate, num_parallel_calls=tf.data.AUTOTUNE, ) if self.apply_random_rotation else dataset ) dataset = dataset.batch(batch_size, drop_remainder=True) return dataset def get_datasets( self, low_light_images: List[str], enhanced_images: List[str], val_split: float = 0.2, batch_size: int = 16, ): assert len(low_light_images) == len(enhanced_images) split_index = int(len(low_light_images) * (1 - val_split)) train_low_light_images = low_light_images[:split_index] train_enhanced_images = enhanced_images[:split_index] val_low_light_images = low_light_images[split_index:] val_enhanced_images = enhanced_images[split_index:] print(f"Number of train data points: {len(train_low_light_images)}") print(f"Number of validation data points: {len(val_low_light_images)}") train_dataset = self._get_dataset( train_low_light_images, train_enhanced_images, batch_size, is_train=True ) val_dataset = self._get_dataset( val_low_light_images, val_enhanced_images, batch_size, is_train=False ) return train_dataset, val_dataset