geekyrakshit commited on
Commit
c8d52e7
1 Parent(s): 6fd61b9

added mirnet class for training and inference

Browse files
enhance_me/commons.py CHANGED
@@ -1,4 +1,7 @@
 
 
1
  import tensorflow as tf
 
2
 
3
 
4
  def read_image(image_path):
@@ -11,3 +14,30 @@ def read_image(image_path):
11
 
12
  def peak_signal_noise_ratio(y_true, y_pred):
13
  return tf.image.psnr(y_pred, y_true, max_val=255.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import wandb
3
  import tensorflow as tf
4
+ import matplotlib.pyplot as plt
5
 
6
 
7
  def read_image(image_path):
 
14
 
15
  def peak_signal_noise_ratio(y_true, y_pred):
16
  return tf.image.psnr(y_pred, y_true, max_val=255.0)
17
+
18
+
19
+ def plot_results(images, titles, figure_size=(12, 12)):
20
+ fig = plt.figure(figsize=figure_size)
21
+ for i in range(len(images)):
22
+ fig.add_subplot(1, len(images), i + 1).set_title(titles[i])
23
+ _ = plt.imshow(images[i])
24
+ plt.axis("off")
25
+ plt.show()
26
+
27
+
28
+ def closest_number(n, m):
29
+ q = int(n / m)
30
+ n1 = m * q
31
+ if (n * m) > 0:
32
+ n2 = m * (q + 1)
33
+ else:
34
+ n2 = m * (q - 1)
35
+ if abs(n - n1) < abs(n - n2):
36
+ return n1
37
+ return n2
38
+
39
+
40
+ def init_wandb(project_name, experiment_name, wandb_api_key):
41
+ if project_name is not None and experiment_name is not None:
42
+ os.environ['WANDB_API_KEY'] = wandb_api_key
43
+ wandb.init(project=project_name, name=experiment_name)
enhance_me/mirnet/__init__.py CHANGED
@@ -0,0 +1 @@
 
 
1
+ from .mirnet import MIRNet
enhance_me/mirnet/mirnet.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from PIL import Image
4
+ from typing import List
5
+ from datetime import datetime
6
+
7
+ from tensorflow import keras
8
+ from tensorflow.keras import optimizers
9
+
10
+ from wandb.keras import WandbCallback
11
+
12
+ from .dataloader import LowLightDataset
13
+ from .models import build_mirnet_model
14
+ from .losses import CharbonnierLoss
15
+ from ..commons import peak_signal_noise_ratio, closest_number, init_wandb
16
+
17
+
18
+ class MIRNet:
19
+ def __init__(
20
+ self,
21
+ experiment_name: str,
22
+ image_size: int = 256,
23
+ apply_random_horizontal_flip: bool = True,
24
+ apply_random_vertical_flip: bool = True,
25
+ apply_random_rotation: bool = True,
26
+ wandb_api_key=None,
27
+ ) -> None:
28
+ self.experiment_name = experiment_name
29
+ self.data_loader = LowLightDataset(
30
+ image_size=image_size,
31
+ apply_random_horizontal_flip=apply_random_horizontal_flip,
32
+ apply_random_vertical_flip=apply_random_vertical_flip,
33
+ apply_random_rotation=apply_random_rotation,
34
+ )
35
+ if wandb_api_key is not None:
36
+ init_wandb("mirnet", experiment_name, wandb_api_key)
37
+ self.using_wandb = True
38
+ else:
39
+ self.using_wandb = False
40
+
41
+ def build_datasets(
42
+ self,
43
+ low_light_images: List[str],
44
+ enhanced_images: List[str],
45
+ val_split: float = 0.2,
46
+ batch_size: int = 16,
47
+ ):
48
+ (self.train_dataset, self.val_dataset) = self.data_loader.get_datasets(
49
+ low_light_images=low_light_images,
50
+ enhanced_images=enhanced_images,
51
+ val_split=val_split,
52
+ batch_size=batch_size,
53
+ )
54
+
55
+ def build_model(
56
+ self,
57
+ num_recursive_residual_groups: int = 3,
58
+ num_multi_scale_residual_blocks: int = 2,
59
+ channels: int = 64,
60
+ learning_rate: float = 1e-4,
61
+ epsilon: float = 1e-3,
62
+ ):
63
+ self.model = build_mirnet_model(
64
+ num_rrg=num_recursive_residual_groups,
65
+ num_mrb=num_multi_scale_residual_blocks,
66
+ channels=channels,
67
+ )
68
+ self.model.compile(
69
+ optimizer=optimizers.Adam(learning_rate=learning_rate),
70
+ loss=CharbonnierLoss(epsilon=epsilon),
71
+ metrics=[peak_signal_noise_ratio],
72
+ )
73
+
74
+ def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
75
+ self.model.save_weights(
76
+ filepath, overwrite=overwrite, save_format=save_format, options=options
77
+ )
78
+
79
+ def load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None):
80
+ self.model.load_weights(
81
+ filepath, by_name=by_name, skip_mismatch=skip_mismatch, options=options
82
+ )
83
+
84
+ def train(self, epochs: int):
85
+ log_dir = os.path.join(
86
+ self.experiment_name,
87
+ "logs",
88
+ datetime.datetime.now().strftime("%Y%m%d-%H%M%S"),
89
+ )
90
+ tensorboard_callback = keras.callbacks.TensorBoard(log_dir, histogram_freq=1)
91
+ model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
92
+ os.path.join(self.experiment_name, "weights.h5"),
93
+ save_best_only=True,
94
+ save_weights_only=True,
95
+ )
96
+ reduce_lr_callback = keras.callbacks.ReduceLROnPlateau(
97
+ monitor="val_peak_signal_noise_ratio",
98
+ factor=0.5,
99
+ patience=5,
100
+ verbose=1,
101
+ min_delta=1e-7,
102
+ mode="max",
103
+ )
104
+ callbacks = [
105
+ tensorboard_callback,
106
+ model_checkpoint_callback,
107
+ reduce_lr_callback,
108
+ ]
109
+ if self.using_wandb:
110
+ callbacks += [WandbCallback()]
111
+ history = self.model.fit(
112
+ self.train_dataset,
113
+ validation_data=self.val_dataset,
114
+ epochs=epochs,
115
+ callbacks=callbacks,
116
+ )
117
+ return history
118
+
119
+ def infer(
120
+ self,
121
+ original_image,
122
+ image_resize_factor: float = 1.0,
123
+ resize_output: bool = False,
124
+ ):
125
+ width, height = original_image.size
126
+ target_width, target_height = (
127
+ closest_number(width // image_resize_factor, 4),
128
+ closest_number(height // image_resize_factor, 4),
129
+ )
130
+ original_image = original_image.resize(
131
+ (target_width, target_height), Image.ANTIALIAS
132
+ )
133
+ image = keras.preprocessing.image.img_to_array(original_image)
134
+ image = image.astype("float32") / 255.0
135
+ image = np.expand_dims(image, axis=0)
136
+ output = self.model.predict(image)
137
+ output_image = output[0] * 255.0
138
+ output_image = output_image.clip(0, 255)
139
+ output_image = output_image.reshape(
140
+ (np.shape(output_image)[0], np.shape(output_image)[1], 3)
141
+ )
142
+ output_image = Image.fromarray(np.uint8(output_image))
143
+ original_image = Image.fromarray(np.uint8(original_image))
144
+ if resize_output:
145
+ output_image = output_image.resize((width, height), Image.ANTIALIAS)
146
+ return output_image
147
+
148
+ def infer_from_file(
149
+ self,
150
+ original_image_file: str,
151
+ image_resize_factor: float = 1.0,
152
+ resize_output: bool = False,
153
+ ):
154
+ original_image = Image.open(original_image_file)
155
+ return self.infer(original_image, image_resize_factor, resize_output)
enhance_me/mirnet/models/mirnet_model.py CHANGED
@@ -1,10 +1,9 @@
1
- import tensorflow as tf
2
  from tensorflow.keras import layers, Input, Model
3
 
4
  from .recursive_residual_blocks import recursive_residual_group
5
 
6
 
7
- def mirnet_model(num_rrg, num_mrb, channels):
8
  input_tensor = Input(shape=[None, None, 3])
9
  x1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor)
10
  for _ in range(num_rrg):
 
 
1
  from tensorflow.keras import layers, Input, Model
2
 
3
  from .recursive_residual_blocks import recursive_residual_group
4
 
5
 
6
+ def build_mirnet_model(num_rrg, num_mrb, channels):
7
  input_tensor = Input(shape=[None, None, 3])
8
  x1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor)
9
  for _ in range(num_rrg):