Spaces:
Runtime error
Runtime error
import os | |
import numpy as np | |
from PIL import Image | |
from typing import List | |
from datetime import datetime | |
from tensorflow import keras | |
from tensorflow.keras import optimizers, models, mixed_precision | |
from wandb.keras import WandbCallback | |
from .dataloader import LowLightDataset | |
from .models import build_mirnet_model | |
from .losses import CharbonnierLoss | |
from ..commons import ( | |
peak_signal_noise_ratio, | |
closest_number, | |
init_wandb, | |
download_lol_dataset, | |
) | |
class MIRNet: | |
def __init__(self, experiment_name=None, wandb_api_key=None) -> None: | |
self.experiment_name = experiment_name | |
if wandb_api_key is not None: | |
init_wandb("mirnet", experiment_name, wandb_api_key) | |
self.using_wandb = True | |
else: | |
self.using_wandb = False | |
def build_datasets( | |
self, | |
image_size: int = 256, | |
dataset_label: str = "lol", | |
apply_random_horizontal_flip: bool = True, | |
apply_random_vertical_flip: bool = True, | |
apply_random_rotation: bool = True, | |
val_split: float = 0.2, | |
batch_size: int = 16, | |
): | |
if dataset_label == "lol": | |
(self.low_images, self.enhanced_images), ( | |
self.test_low_images, | |
self.test_enhanced_images, | |
) = download_lol_dataset() | |
self.data_loader = LowLightDataset( | |
image_size=image_size, | |
apply_random_horizontal_flip=apply_random_horizontal_flip, | |
apply_random_vertical_flip=apply_random_vertical_flip, | |
apply_random_rotation=apply_random_rotation, | |
) | |
(self.train_dataset, self.val_dataset) = self.data_loader.get_datasets( | |
low_light_images=self.low_images, | |
enhanced_images=self.enhanced_images, | |
val_split=val_split, | |
batch_size=batch_size, | |
) | |
def build_model( | |
self, | |
use_mixed_precision: bool = False, | |
num_recursive_residual_groups: int = 3, | |
num_multi_scale_residual_blocks: int = 2, | |
channels: int = 64, | |
learning_rate: float = 1e-4, | |
epsilon: float = 1e-3, | |
): | |
if use_mixed_precision: | |
policy = mixed_precision.Policy("mixed_float16") | |
mixed_precision.set_global_policy(policy) | |
self.model = build_mirnet_model( | |
num_rrg=num_recursive_residual_groups, | |
num_mrb=num_multi_scale_residual_blocks, | |
channels=channels, | |
) | |
self.model.compile( | |
optimizer=optimizers.Adam(learning_rate=learning_rate), | |
loss=CharbonnierLoss(epsilon=epsilon), | |
metrics=[peak_signal_noise_ratio], | |
) | |
def load_model( | |
self, filepath, custom_objects=None, compile=True, options=None | |
) -> None: | |
self.model = models.load_model( | |
filepath=filepath, | |
custom_objects=custom_objects, | |
compile=compile, | |
options=options, | |
) | |
def save_weights(self, filepath, overwrite=True, save_format=None, options=None): | |
self.model.save_weights( | |
filepath, overwrite=overwrite, save_format=save_format, options=options | |
) | |
def load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None): | |
self.model.load_weights( | |
filepath, by_name=by_name, skip_mismatch=skip_mismatch, options=options | |
) | |
def train(self, epochs: int): | |
log_dir = os.path.join( | |
self.experiment_name, | |
"logs", | |
datetime.now().strftime("%Y%m%d-%H%M%S"), | |
) | |
tensorboard_callback = keras.callbacks.TensorBoard(log_dir, histogram_freq=1) | |
model_checkpoint_callback = keras.callbacks.ModelCheckpoint( | |
os.path.join(self.experiment_name, "weights.h5"), | |
save_best_only=True, | |
save_weights_only=True, | |
) | |
reduce_lr_callback = keras.callbacks.ReduceLROnPlateau( | |
monitor="val_peak_signal_noise_ratio", | |
factor=0.5, | |
patience=5, | |
verbose=1, | |
min_delta=1e-7, | |
mode="max", | |
) | |
callbacks = [ | |
tensorboard_callback, | |
model_checkpoint_callback, | |
reduce_lr_callback, | |
] | |
if self.using_wandb: | |
callbacks += [WandbCallback()] | |
history = self.model.fit( | |
self.train_dataset, | |
validation_data=self.val_dataset, | |
epochs=epochs, | |
callbacks=callbacks, | |
) | |
return history | |
def infer( | |
self, | |
original_image, | |
image_resize_factor: float = 1.0, | |
resize_output: bool = False, | |
): | |
width, height = original_image.size | |
target_width, target_height = ( | |
closest_number(width // image_resize_factor, 4), | |
closest_number(height // image_resize_factor, 4), | |
) | |
original_image = original_image.resize( | |
(target_width, target_height), Image.ANTIALIAS | |
) | |
image = keras.preprocessing.image.img_to_array(original_image) | |
image = image.astype("float32") / 255.0 | |
image = np.expand_dims(image, axis=0) | |
output = self.model.predict(image) | |
output_image = output[0] * 255.0 | |
output_image = output_image.clip(0, 255) | |
output_image = output_image.reshape( | |
(np.shape(output_image)[0], np.shape(output_image)[1], 3) | |
) | |
output_image = Image.fromarray(np.uint8(output_image)) | |
original_image = Image.fromarray(np.uint8(original_image)) | |
if resize_output: | |
output_image = output_image.resize((width, height), Image.ANTIALIAS) | |
return output_image | |
def infer_from_file( | |
self, | |
original_image_file: str, | |
image_resize_factor: float = 1.0, | |
resize_output: bool = False, | |
): | |
original_image = Image.open(original_image_file) | |
return self.infer(original_image, image_resize_factor, resize_output) | |