geekyrakshit's picture
added mixed precision
ffad30c
raw
history blame
6 kB
import os
import numpy as np
from PIL import Image
from typing import List
from datetime import datetime
from tensorflow import keras
from tensorflow.keras import optimizers, models, mixed_precision
from wandb.keras import WandbCallback
from .dataloader import LowLightDataset
from .models import build_mirnet_model
from .losses import CharbonnierLoss
from ..commons import (
peak_signal_noise_ratio,
closest_number,
init_wandb,
download_lol_dataset,
)
class MIRNet:
def __init__(self, experiment_name=None, wandb_api_key=None) -> None:
self.experiment_name = experiment_name
if wandb_api_key is not None:
init_wandb("mirnet", experiment_name, wandb_api_key)
self.using_wandb = True
else:
self.using_wandb = False
def build_datasets(
self,
image_size: int = 256,
dataset_label: str = "lol",
apply_random_horizontal_flip: bool = True,
apply_random_vertical_flip: bool = True,
apply_random_rotation: bool = True,
val_split: float = 0.2,
batch_size: int = 16,
):
if dataset_label == "lol":
(self.low_images, self.enhanced_images), (
self.test_low_images,
self.test_enhanced_images,
) = download_lol_dataset()
self.data_loader = LowLightDataset(
image_size=image_size,
apply_random_horizontal_flip=apply_random_horizontal_flip,
apply_random_vertical_flip=apply_random_vertical_flip,
apply_random_rotation=apply_random_rotation,
)
(self.train_dataset, self.val_dataset) = self.data_loader.get_datasets(
low_light_images=self.low_images,
enhanced_images=self.enhanced_images,
val_split=val_split,
batch_size=batch_size,
)
def build_model(
self,
use_mixed_precision: bool = False,
num_recursive_residual_groups: int = 3,
num_multi_scale_residual_blocks: int = 2,
channels: int = 64,
learning_rate: float = 1e-4,
epsilon: float = 1e-3,
):
if use_mixed_precision:
policy = mixed_precision.Policy("mixed_float16")
mixed_precision.set_global_policy(policy)
self.model = build_mirnet_model(
num_rrg=num_recursive_residual_groups,
num_mrb=num_multi_scale_residual_blocks,
channels=channels,
)
self.model.compile(
optimizer=optimizers.Adam(learning_rate=learning_rate),
loss=CharbonnierLoss(epsilon=epsilon),
metrics=[peak_signal_noise_ratio],
)
def load_model(
self, filepath, custom_objects=None, compile=True, options=None
) -> None:
self.model = models.load_model(
filepath=filepath,
custom_objects=custom_objects,
compile=compile,
options=options,
)
def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
self.model.save_weights(
filepath, overwrite=overwrite, save_format=save_format, options=options
)
def load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None):
self.model.load_weights(
filepath, by_name=by_name, skip_mismatch=skip_mismatch, options=options
)
def train(self, epochs: int):
log_dir = os.path.join(
self.experiment_name,
"logs",
datetime.now().strftime("%Y%m%d-%H%M%S"),
)
tensorboard_callback = keras.callbacks.TensorBoard(log_dir, histogram_freq=1)
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
os.path.join(self.experiment_name, "weights.h5"),
save_best_only=True,
save_weights_only=True,
)
reduce_lr_callback = keras.callbacks.ReduceLROnPlateau(
monitor="val_peak_signal_noise_ratio",
factor=0.5,
patience=5,
verbose=1,
min_delta=1e-7,
mode="max",
)
callbacks = [
tensorboard_callback,
model_checkpoint_callback,
reduce_lr_callback,
]
if self.using_wandb:
callbacks += [WandbCallback()]
history = self.model.fit(
self.train_dataset,
validation_data=self.val_dataset,
epochs=epochs,
callbacks=callbacks,
)
return history
def infer(
self,
original_image,
image_resize_factor: float = 1.0,
resize_output: bool = False,
):
width, height = original_image.size
target_width, target_height = (
closest_number(width // image_resize_factor, 4),
closest_number(height // image_resize_factor, 4),
)
original_image = original_image.resize(
(target_width, target_height), Image.ANTIALIAS
)
image = keras.preprocessing.image.img_to_array(original_image)
image = image.astype("float32") / 255.0
image = np.expand_dims(image, axis=0)
output = self.model.predict(image)
output_image = output[0] * 255.0
output_image = output_image.clip(0, 255)
output_image = output_image.reshape(
(np.shape(output_image)[0], np.shape(output_image)[1], 3)
)
output_image = Image.fromarray(np.uint8(output_image))
original_image = Image.fromarray(np.uint8(original_image))
if resize_output:
output_image = output_image.resize((width, height), Image.ANTIALIAS)
return output_image
def infer_from_file(
self,
original_image_file: str,
image_resize_factor: float = 1.0,
resize_output: bool = False,
):
original_image = Image.open(original_image_file)
return self.infer(original_image, image_resize_factor, resize_output)