import os import numpy as np from PIL import Image from typing import List from datetime import datetime from tensorflow import keras from tensorflow.keras import optimizers, models from wandb.keras import WandbCallback from .dataloader import LowLightDataset from .models import build_mirnet_model from .losses import CharbonnierLoss from ..commons import ( peak_signal_noise_ratio, closest_number, init_wandb, download_lol_dataset, ) class MIRNet: def __init__(self, experiment_name=None, wandb_api_key=None) -> None: self.experiment_name = experiment_name if wandb_api_key is not None: init_wandb("mirnet", experiment_name, wandb_api_key) self.using_wandb = True else: self.using_wandb = False def build_datasets( self, image_size: int = 256, dataset_label: str = "lol", apply_random_horizontal_flip: bool = True, apply_random_vertical_flip: bool = True, apply_random_rotation: bool = True, val_split: float = 0.2, batch_size: int = 16, ): if dataset_label == "lol": (self.low_images, self.enhanced_images), ( self.test_low_images, self.test_enhanced_images, ) = download_lol_dataset() self.data_loader = LowLightDataset( image_size=image_size, apply_random_horizontal_flip=apply_random_horizontal_flip, apply_random_vertical_flip=apply_random_vertical_flip, apply_random_rotation=apply_random_rotation, ) (self.train_dataset, self.val_dataset) = self.data_loader.get_datasets( low_light_images=self.low_images, enhanced_images=self.enhanced_images, val_split=val_split, batch_size=batch_size, ) def build_model( self, num_recursive_residual_groups: int = 3, num_multi_scale_residual_blocks: int = 2, channels: int = 64, learning_rate: float = 1e-4, epsilon: float = 1e-3, ): self.model = build_mirnet_model( num_rrg=num_recursive_residual_groups, num_mrb=num_multi_scale_residual_blocks, channels=channels, ) self.model.compile( optimizer=optimizers.Adam(learning_rate=learning_rate), loss=CharbonnierLoss(epsilon=epsilon), metrics=[peak_signal_noise_ratio], ) def load_model( self, filepath, custom_objects=None, compile=True, options=None ) -> None: self.model = models.load_model( filepath=filepath, custom_objects=custom_objects, compile=compile, options=options, ) def save_weights(self, filepath, overwrite=True, save_format=None, options=None): self.model.save_weights( filepath, overwrite=overwrite, save_format=save_format, options=options ) def load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None): self.model.load_weights( filepath, by_name=by_name, skip_mismatch=skip_mismatch, options=options ) def train(self, epochs: int): log_dir = os.path.join( self.experiment_name, "logs", datetime.now().strftime("%Y%m%d-%H%M%S"), ) tensorboard_callback = keras.callbacks.TensorBoard(log_dir, histogram_freq=1) model_checkpoint_callback = keras.callbacks.ModelCheckpoint( os.path.join(self.experiment_name, "weights.h5"), save_best_only=True, save_weights_only=True, ) reduce_lr_callback = keras.callbacks.ReduceLROnPlateau( monitor="val_peak_signal_noise_ratio", factor=0.5, patience=5, verbose=1, min_delta=1e-7, mode="max", ) callbacks = [ tensorboard_callback, model_checkpoint_callback, reduce_lr_callback, ] if self.using_wandb: callbacks += [WandbCallback()] history = self.model.fit( self.train_dataset, validation_data=self.val_dataset, epochs=epochs, callbacks=callbacks, ) return history def infer( self, original_image, image_resize_factor: float = 1.0, resize_output: bool = False, ): width, height = original_image.size target_width, target_height = ( closest_number(width // image_resize_factor, 4), closest_number(height // image_resize_factor, 4), ) original_image = original_image.resize( (target_width, target_height), Image.ANTIALIAS ) image = keras.preprocessing.image.img_to_array(original_image) image = image.astype("float32") / 255.0 image = np.expand_dims(image, axis=0) output = self.model.predict(image) output_image = output[0] * 255.0 output_image = output_image.clip(0, 255) output_image = output_image.reshape( (np.shape(output_image)[0], np.shape(output_image)[1], 3) ) output_image = Image.fromarray(np.uint8(output_image)) original_image = Image.fromarray(np.uint8(original_image)) if resize_output: output_image = output_image.resize((width, height), Image.ANTIALIAS) return output_image def infer_from_file( self, original_image_file: str, image_resize_factor: float = 1.0, resize_output: bool = False, ): original_image = Image.open(original_image_file) return self.infer(original_image, image_resize_factor, resize_output)