geekyrakshit commited on
Commit
f4bfa55
2 Parent(s): cfb00a4 12456dc

Merge pull request #2 from soumik12345/zero-dce

Browse files

Zero-reference Deep Curve Estimation for Low-light Image Enhancement

.gitignore CHANGED
@@ -131,4 +131,6 @@ dmypy.json
131
  # Datasets
132
  datasets/
133
  **.zip
134
- **.h5
 
 
 
131
  # Datasets
132
  datasets/
133
  **.zip
134
+ **.h5
135
+ lol_dataset_**
136
+ wandb**
Dockerfile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pull Base Image
2
+ FROM tensorflow/tensorflow:latest-gpu-jupyter
3
+
4
+ # Set Working Directory
5
+ RUN mkdir /usr/src/enhance-me
6
+ WORKDIR /usr/src/enhance-me
7
+
8
+ # Set Environment Variables
9
+ ENV PYTHONDONTWRITEBYTECODE 1
10
+ ENV PYTHONUNBUFFERED 1
11
+
12
+ RUN pip install --upgrade pip setuptools wheel
13
+ RUN pip install gdown matplotlib streamlit tqdm wandb
14
+
15
+ COPY . /usr/src/enhance-me/
16
+
17
+ CMD ["jupyter", "notebook", "--port=8888", "--no-browser", "--ip=0.0.0.0", "--allow-root"]
README.md CHANGED
@@ -8,4 +8,14 @@ app_file: app.py
8
  pinned: false
9
  ---
10
 
11
- # enhance-me
 
 
 
 
 
 
 
 
 
 
 
8
  pinned: false
9
  ---
10
 
11
+ # Enhance Me
12
+
13
+ A unified platform for image enhancement.
14
+
15
+ ## Usage
16
+
17
+ ### Train using Docker
18
+
19
+ - Build image using `docker build -t enhance-image .`
20
+
21
+ - Run notebook using `docker run -it --gpus all -p 8888:8888 -v $(pwd):/usr/src/enhance-me enhance-image`
app.py CHANGED
@@ -1,23 +1,36 @@
 
1
  from PIL import Image
2
  import streamlit as st
3
  from tensorflow.keras import utils, backend
4
 
5
- from enhance_me.mirnet import MIRNet
6
 
7
 
8
  def get_mirnet_object() -> MIRNet:
9
- mirnet = MIRNet()
10
- mirnet.build_model()
11
  utils.get_file(
12
  "weights_lol_128.h5",
13
  "https://github.com/soumik12345/enhance-me/releases/download/v0.2/weights_lol_128.h5",
14
  cache_dir=".",
15
  cache_subdir="weights",
16
  )
 
 
17
  mirnet.load_weights("./weights/weights_lol_128.h5")
18
  return mirnet
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def main():
22
  st.markdown("# Enhance Me")
23
  st.markdown("Made with :heart: by [geekyRakshit](http://github.com/soumik12345)")
@@ -30,14 +43,39 @@ def main():
30
  if uploaded_file is not None:
31
  original_image = Image.open(uploaded_file)
32
  st.image(original_image, caption="original image")
33
- st.sidebar.info("Loading MIRNet...")
34
- mirnet = get_mirnet_object()
35
- st.sidebar.info("Done!")
36
- st.sidebar.info("Processing Image...")
37
- enhanced_image = mirnet.infer(original_image)
38
- st.sidebar.info("Done!")
39
- st.image(enhanced_image, caption="enhanced image")
40
- backend.clear_session()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
 
43
  if __name__ == "__main__":
 
1
+ import os
2
  from PIL import Image
3
  import streamlit as st
4
  from tensorflow.keras import utils, backend
5
 
6
+ from enhance_me import MIRNet, ZeroDCE
7
 
8
 
9
  def get_mirnet_object() -> MIRNet:
 
 
10
  utils.get_file(
11
  "weights_lol_128.h5",
12
  "https://github.com/soumik12345/enhance-me/releases/download/v0.2/weights_lol_128.h5",
13
  cache_dir=".",
14
  cache_subdir="weights",
15
  )
16
+ mirnet = MIRNet()
17
+ mirnet.build_model()
18
  mirnet.load_weights("./weights/weights_lol_128.h5")
19
  return mirnet
20
 
21
 
22
+ def get_zero_dce_object(model_alias: str) -> ZeroDCE:
23
+ utils.get_file(
24
+ f"{model_alias}.h5",
25
+ f"https://github.com/soumik12345/enhance-me/releases/download/v0.4/{model_alias}.h5",
26
+ cache_dir=".",
27
+ cache_subdir="weights",
28
+ )
29
+ dce = ZeroDCE()
30
+ dce.load_weights(os.path.join("./weights", f"{model_alias}.h5"))
31
+ return dce
32
+
33
+
34
  def main():
35
  st.markdown("# Enhance Me")
36
  st.markdown("Made with :heart: by [geekyRakshit](http://github.com/soumik12345)")
 
43
  if uploaded_file is not None:
44
  original_image = Image.open(uploaded_file)
45
  st.image(original_image, caption="original image")
46
+ model_option = st.sidebar.selectbox(
47
+ "Please select the model:",
48
+ (
49
+ "",
50
+ "MIRNet",
51
+ "Zero-DCE (dce_weights_lol_128)",
52
+ "Zero-DCE (dce_weights_lol_128_resize)",
53
+ "Zero-DCE (dce_weights_lol_256)",
54
+ "Zero-DCE (dce_weights_lol_256_resize)",
55
+ "Zero-DCE (dce_weights_unpaired_128)",
56
+ "Zero-DCE (dce_weights_unpaired_128_resize)",
57
+ "Zero-DCE (dce_weights_unpaired_256)",
58
+ "Zero-DCE (dce_weights_unpaired_256_resize)"
59
+ ),
60
+ )
61
+ if model_option != "":
62
+ if model_option == "MIRNet":
63
+ st.sidebar.info("Loading MIRNet...")
64
+ mirnet = get_mirnet_object()
65
+ st.sidebar.info("Done!")
66
+ st.sidebar.info("Processing Image...")
67
+ enhanced_image = mirnet.infer(original_image)
68
+ st.sidebar.info("Done!")
69
+ st.image(enhanced_image, caption="enhanced image")
70
+ elif "Zero-DCE" in model_option:
71
+ model_alias = model_option[model_option.find("(") + 1: model_option.find(")")]
72
+ st.sidebar.info("Loading Zero-DCE...")
73
+ zero_dce = get_zero_dce_object(model_alias)
74
+ st.sidebar.info("Done!")
75
+ enhanced_image = zero_dce.infer(original_image)
76
+ st.sidebar.info("Done!")
77
+ st.image(enhanced_image, caption="enhanced image")
78
+ backend.clear_session()
79
 
80
 
81
  if __name__ == "__main__":
enhance_me/__init__.py CHANGED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .mirnet import MIRNet
2
+ from .zero_dce import ZeroDCE
enhance_me/augmentation.py CHANGED
@@ -49,3 +49,38 @@ class AugmentationFactory:
49
  return tf.image.rot90(input_image, condition), tf.image.rot90(
50
  enhanced_image, condition
51
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  return tf.image.rot90(input_image, condition), tf.image.rot90(
50
  enhanced_image, condition
51
  )
52
+
53
+
54
+ class UnpairedAugmentationFactory:
55
+ def __init__(self, image_size) -> None:
56
+ self.image_size = image_size
57
+
58
+ def random_crop(self, image):
59
+ image_shape = tf.shape(image)[:2]
60
+ crop_w = tf.random.uniform(
61
+ shape=(), maxval=image_shape[1] - self.image_size + 1, dtype=tf.int32
62
+ )
63
+ crop_h = tf.random.uniform(
64
+ shape=(), maxval=image_shape[0] - self.image_size + 1, dtype=tf.int32
65
+ )
66
+ return image[
67
+ crop_h : crop_h + self.image_size, crop_w : crop_w + self.image_size
68
+ ]
69
+
70
+ def random_horizontal_flip(self, image):
71
+ return tf.cond(
72
+ tf.random.uniform(shape=(), maxval=1) < 0.5,
73
+ lambda: image,
74
+ lambda: tf.image.flip_left_right(image),
75
+ )
76
+
77
+ def random_vertical_flip(self, image):
78
+ return tf.cond(
79
+ tf.random.uniform(shape=(), maxval=1) < 0.5,
80
+ lambda: image,
81
+ lambda: tf.image.flip_up_down(image),
82
+ )
83
+
84
+ def random_rotate(self, image):
85
+ condition = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)
86
+ return tf.image.rot90(image, condition)
enhance_me/commons.py CHANGED
@@ -61,3 +61,18 @@ def download_lol_dataset():
61
  test_enhanced_images = sorted(glob("./datasets/lol_dataset/eval15/high/*"))
62
  assert len(test_low_images) == len(test_enhanced_images)
63
  return (low_images, enhanced_images), (test_low_images, test_enhanced_images)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  test_enhanced_images = sorted(glob("./datasets/lol_dataset/eval15/high/*"))
62
  assert len(test_low_images) == len(test_enhanced_images)
63
  return (low_images, enhanced_images), (test_low_images, test_enhanced_images)
64
+
65
+
66
+ def download_unpaired_low_light_dataset():
67
+ utils.get_file(
68
+ "low_light_dataset.zip",
69
+ "https://github.com/soumik12345/enhance-me/releases/download/v0.3/low_light_dataset.zip",
70
+ cache_dir="./",
71
+ cache_subdir="./datasets",
72
+ extract=True,
73
+ )
74
+ low_images = glob("./datasets/low_light_dataset/*.png")
75
+ test_low_images = sorted(glob("./datasets/low_light_dataset/eval15/low/*"))
76
+ test_enhanced_images = sorted(glob("./datasets/low_light_dataset/eval15/high/*"))
77
+ assert len(test_low_images) == len(test_enhanced_images)
78
+ return low_images, (test_low_images, test_enhanced_images)
enhance_me/zero_dce/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .zero_dce import ZeroDCE
enhance_me/zero_dce/dataloader.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from typing import List
3
+
4
+ from ..commons import read_image
5
+ from ..augmentation import UnpairedAugmentationFactory
6
+
7
+
8
+ class UnpairedLowLightDataset:
9
+ def __init__(
10
+ self,
11
+ image_size: int = 256,
12
+ apply_resize: bool = False,
13
+ apply_random_horizontal_flip: bool = True,
14
+ apply_random_vertical_flip: bool = True,
15
+ apply_random_rotation: bool = True,
16
+ ) -> None:
17
+ self.augmentation_factory = UnpairedAugmentationFactory(image_size=image_size)
18
+ self.image_size = image_size
19
+ self.apply_resize = apply_resize
20
+ self.apply_random_horizontal_flip = apply_random_horizontal_flip
21
+ self.apply_random_vertical_flip = apply_random_vertical_flip
22
+ self.apply_random_rotation = apply_random_rotation
23
+
24
+ def _resize(self, image):
25
+ return tf.image.resize(image, (self.image_size, self.image_size))
26
+
27
+ def _get_dataset(self, images: List[str], batch_size: int, is_train: bool):
28
+ dataset = tf.data.Dataset.from_tensor_slices((images))
29
+ dataset = dataset.map(read_image, num_parallel_calls=tf.data.AUTOTUNE)
30
+ dataset = (
31
+ dataset.map(
32
+ self.augmentation_factory.random_crop,
33
+ num_parallel_calls=tf.data.AUTOTUNE,
34
+ )
35
+ if not self.apply_resize
36
+ else dataset.map(self._resize, num_parallel_calls=tf.data.AUTOTUNE)
37
+ )
38
+ if is_train:
39
+ dataset = (
40
+ dataset.map(
41
+ self.augmentation_factory.random_horizontal_flip,
42
+ num_parallel_calls=tf.data.AUTOTUNE,
43
+ )
44
+ if self.apply_random_horizontal_flip
45
+ else dataset
46
+ )
47
+ dataset = (
48
+ dataset.map(
49
+ self.augmentation_factory.random_vertical_flip,
50
+ num_parallel_calls=tf.data.AUTOTUNE,
51
+ )
52
+ if self.apply_random_vertical_flip
53
+ else dataset
54
+ )
55
+ dataset = (
56
+ dataset.map(
57
+ self.augmentation_factory.random_rotate,
58
+ num_parallel_calls=tf.data.AUTOTUNE,
59
+ )
60
+ if self.apply_random_rotation
61
+ else dataset
62
+ )
63
+ dataset = dataset.batch(batch_size, drop_remainder=True)
64
+ return dataset
65
+
66
+ def get_datasets(
67
+ self,
68
+ images: List[str],
69
+ val_split: float = 0.2,
70
+ batch_size: int = 16,
71
+ ):
72
+ split_index = int(len(images) * (1 - val_split))
73
+ train_images = images[:split_index]
74
+ val_images = images[split_index:]
75
+ print(f"Number of train data points: {len(train_images)}")
76
+ print(f"Number of validation data points: {len(val_images)}")
77
+ train_dataset = self._get_dataset(train_images, batch_size, is_train=True)
78
+ val_dataset = self._get_dataset(val_images, batch_size, is_train=False)
79
+ return train_dataset, val_dataset
enhance_me/zero_dce/dce_net.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import layers, Input, Model
3
+
4
+
5
+ def build_dce_net() -> Model:
6
+ input_image = Input(shape=[None, None, 3])
7
+ conv1 = layers.Conv2D(
8
+ 32, (3, 3), strides=(1, 1), activation="relu", padding="same"
9
+ )(input_image)
10
+ conv2 = layers.Conv2D(
11
+ 32, (3, 3), strides=(1, 1), activation="relu", padding="same"
12
+ )(conv1)
13
+ conv3 = layers.Conv2D(
14
+ 32, (3, 3), strides=(1, 1), activation="relu", padding="same"
15
+ )(conv2)
16
+ conv4 = layers.Conv2D(
17
+ 32, (3, 3), strides=(1, 1), activation="relu", padding="same"
18
+ )(conv3)
19
+ int_con1 = layers.Concatenate(axis=-1)([conv4, conv3])
20
+ conv5 = layers.Conv2D(
21
+ 32, (3, 3), strides=(1, 1), activation="relu", padding="same"
22
+ )(int_con1)
23
+ int_con2 = layers.Concatenate(axis=-1)([conv5, conv2])
24
+ conv6 = layers.Conv2D(
25
+ 32, (3, 3), strides=(1, 1), activation="relu", padding="same"
26
+ )(int_con2)
27
+ int_con3 = layers.Concatenate(axis=-1)([conv6, conv1])
28
+ x_r = layers.Conv2D(24, (3, 3), strides=(1, 1), activation="tanh", padding="same")(
29
+ int_con3
30
+ )
31
+ return Model(inputs=input_image, outputs=x_r)
enhance_me/zero_dce/losses/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ from .spatial_constancy import SpatialConsistencyLoss
4
+
5
+
6
+ def color_constancy_loss(x):
7
+ mean_rgb = tf.reduce_mean(x, axis=(1, 2), keepdims=True)
8
+ mean_r, mean_g, mean_b = (
9
+ mean_rgb[:, :, :, 0],
10
+ mean_rgb[:, :, :, 1],
11
+ mean_rgb[:, :, :, 2],
12
+ )
13
+ diff_rg = tf.square(mean_r - mean_g)
14
+ diff_rb = tf.square(mean_r - mean_b)
15
+ diff_gb = tf.square(mean_b - mean_g)
16
+ return tf.sqrt(tf.square(diff_rg) + tf.square(diff_rb) + tf.square(diff_gb))
17
+
18
+
19
+ def exposure_loss(x, mean_val=0.6):
20
+ x = tf.reduce_mean(x, axis=3, keepdims=True)
21
+ mean = tf.nn.avg_pool2d(x, ksize=16, strides=16, padding="VALID")
22
+ return tf.reduce_mean(tf.square(mean - mean_val))
23
+
24
+
25
+ def illumination_smoothness_loss(x):
26
+ batch_size = tf.shape(x)[0]
27
+ h_x = tf.shape(x)[1]
28
+ w_x = tf.shape(x)[2]
29
+ count_h = (tf.shape(x)[2] - 1) * tf.shape(x)[3]
30
+ count_w = tf.shape(x)[2] * (tf.shape(x)[3] - 1)
31
+ h_tv = tf.reduce_sum(tf.square((x[:, 1:, :, :] - x[:, : h_x - 1, :, :])))
32
+ w_tv = tf.reduce_sum(tf.square((x[:, :, 1:, :] - x[:, :, : w_x - 1, :])))
33
+ batch_size = tf.cast(batch_size, dtype=tf.float32)
34
+ count_h = tf.cast(count_h, dtype=tf.float32)
35
+ count_w = tf.cast(count_w, dtype=tf.float32)
36
+ return 2 * (h_tv / count_h + w_tv / count_w) / batch_size
enhance_me/zero_dce/losses/spatial_constancy.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import losses
3
+
4
+
5
+ class SpatialConsistencyLoss(losses.Loss):
6
+ def __init__(self, **kwargs):
7
+ super(SpatialConsistencyLoss, self).__init__(reduction="none")
8
+
9
+ self.left_kernel = tf.constant(
10
+ [[[[0, 0, 0]], [[-1, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32
11
+ )
12
+ self.right_kernel = tf.constant(
13
+ [[[[0, 0, 0]], [[0, 1, -1]], [[0, 0, 0]]]], dtype=tf.float32
14
+ )
15
+ self.up_kernel = tf.constant(
16
+ [[[[0, -1, 0]], [[0, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32
17
+ )
18
+ self.down_kernel = tf.constant(
19
+ [[[[0, 0, 0]], [[0, 1, 0]], [[0, -1, 0]]]], dtype=tf.float32
20
+ )
21
+
22
+ def call(self, y_true, y_pred):
23
+
24
+ original_mean = tf.reduce_mean(y_true, 3, keepdims=True)
25
+ enhanced_mean = tf.reduce_mean(y_pred, 3, keepdims=True)
26
+ original_pool = tf.nn.avg_pool2d(
27
+ original_mean, ksize=4, strides=4, padding="VALID"
28
+ )
29
+ enhanced_pool = tf.nn.avg_pool2d(
30
+ enhanced_mean, ksize=4, strides=4, padding="VALID"
31
+ )
32
+
33
+ d_original_left = tf.nn.conv2d(
34
+ original_pool, self.left_kernel, strides=[1, 1, 1, 1], padding="SAME"
35
+ )
36
+ d_original_right = tf.nn.conv2d(
37
+ original_pool, self.right_kernel, strides=[1, 1, 1, 1], padding="SAME"
38
+ )
39
+ d_original_up = tf.nn.conv2d(
40
+ original_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME"
41
+ )
42
+ d_original_down = tf.nn.conv2d(
43
+ original_pool, self.down_kernel, strides=[1, 1, 1, 1], padding="SAME"
44
+ )
45
+
46
+ d_enhanced_left = tf.nn.conv2d(
47
+ enhanced_pool, self.left_kernel, strides=[1, 1, 1, 1], padding="SAME"
48
+ )
49
+ d_enhanced_right = tf.nn.conv2d(
50
+ enhanced_pool, self.right_kernel, strides=[1, 1, 1, 1], padding="SAME"
51
+ )
52
+ d_enhanced_up = tf.nn.conv2d(
53
+ enhanced_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME"
54
+ )
55
+ d_enhanced_down = tf.nn.conv2d(
56
+ enhanced_pool, self.down_kernel, strides=[1, 1, 1, 1], padding="SAME"
57
+ )
58
+
59
+ d_left = tf.square(d_original_left - d_enhanced_left)
60
+ d_right = tf.square(d_original_right - d_enhanced_right)
61
+ d_up = tf.square(d_original_up - d_enhanced_up)
62
+ d_down = tf.square(d_original_down - d_enhanced_down)
63
+ return d_left + d_right + d_up + d_down
enhance_me/zero_dce/zero_dce.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from PIL import Image
4
+ from datetime import datetime
5
+
6
+ import tensorflow as tf
7
+ from tensorflow import keras
8
+ from tensorflow.keras import optimizers, mixed_precision, Model
9
+ from wandb.keras import WandbCallback
10
+
11
+ from .dce_net import build_dce_net
12
+ from .dataloader import UnpairedLowLightDataset
13
+ from .losses import (
14
+ color_constancy_loss,
15
+ exposure_loss,
16
+ illumination_smoothness_loss,
17
+ SpatialConsistencyLoss,
18
+ )
19
+ from ..commons import (
20
+ download_lol_dataset,
21
+ download_unpaired_low_light_dataset,
22
+ init_wandb,
23
+ )
24
+
25
+
26
+ class ZeroDCE(Model):
27
+ def __init__(
28
+ self,
29
+ experiment_name=None,
30
+ wandb_api_key=None,
31
+ use_mixed_precision: bool = False,
32
+ **kwargs
33
+ ):
34
+ super(ZeroDCE, self).__init__(**kwargs)
35
+ self.experiment_name = experiment_name
36
+ if use_mixed_precision:
37
+ policy = mixed_precision.Policy("mixed_float16")
38
+ mixed_precision.set_global_policy(policy)
39
+ if wandb_api_key is not None:
40
+ init_wandb("zero-dce", experiment_name, wandb_api_key)
41
+ self.using_wandb = True
42
+ else:
43
+ self.using_wandb = False
44
+ self.dce_model = build_dce_net()
45
+
46
+ def compile(self, learning_rate, **kwargs):
47
+ super(ZeroDCE, self).compile(**kwargs)
48
+ self.optimizer = optimizers.Adam(learning_rate=learning_rate)
49
+ self.spatial_constancy_loss = SpatialConsistencyLoss(reduction="none")
50
+
51
+ def get_enhanced_image(self, data, output):
52
+ r1 = output[:, :, :, :3]
53
+ r2 = output[:, :, :, 3:6]
54
+ r3 = output[:, :, :, 6:9]
55
+ r4 = output[:, :, :, 9:12]
56
+ r5 = output[:, :, :, 12:15]
57
+ r6 = output[:, :, :, 15:18]
58
+ r7 = output[:, :, :, 18:21]
59
+ r8 = output[:, :, :, 21:24]
60
+ x = data + r1 * (tf.square(data) - data)
61
+ x = x + r2 * (tf.square(x) - x)
62
+ x = x + r3 * (tf.square(x) - x)
63
+ enhanced_image = x + r4 * (tf.square(x) - x)
64
+ x = enhanced_image + r5 * (tf.square(enhanced_image) - enhanced_image)
65
+ x = x + r6 * (tf.square(x) - x)
66
+ x = x + r7 * (tf.square(x) - x)
67
+ enhanced_image = x + r8 * (tf.square(x) - x)
68
+ return enhanced_image
69
+
70
+ def call(self, data):
71
+ dce_net_output = self.dce_model(data)
72
+ return self.get_enhanced_image(data, dce_net_output)
73
+
74
+ def compute_losses(self, data, output):
75
+ enhanced_image = self.get_enhanced_image(data, output)
76
+ loss_illumination = 200 * illumination_smoothness_loss(output)
77
+ loss_spatial_constancy = tf.reduce_mean(
78
+ self.spatial_constancy_loss(enhanced_image, data)
79
+ )
80
+ loss_color_constancy = 5 * tf.reduce_mean(color_constancy_loss(enhanced_image))
81
+ loss_exposure = 10 * tf.reduce_mean(exposure_loss(enhanced_image))
82
+ total_loss = (
83
+ loss_illumination
84
+ + loss_spatial_constancy
85
+ + loss_color_constancy
86
+ + loss_exposure
87
+ )
88
+ return {
89
+ "total_loss": total_loss,
90
+ "illumination_smoothness_loss": loss_illumination,
91
+ "spatial_constancy_loss": loss_spatial_constancy,
92
+ "color_constancy_loss": loss_color_constancy,
93
+ "exposure_loss": loss_exposure,
94
+ }
95
+
96
+ def train_step(self, data):
97
+ with tf.GradientTape() as tape:
98
+ output = self.dce_model(data)
99
+ losses = self.compute_losses(data, output)
100
+ gradients = tape.gradient(
101
+ losses["total_loss"], self.dce_model.trainable_weights
102
+ )
103
+ self.optimizer.apply_gradients(zip(gradients, self.dce_model.trainable_weights))
104
+ return losses
105
+
106
+ def test_step(self, data):
107
+ output = self.dce_model(data)
108
+ return self.compute_losses(data, output)
109
+
110
+ def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
111
+ """While saving the weights, we simply save the weights of the DCE-Net"""
112
+ self.dce_model.save_weights(
113
+ filepath, overwrite=overwrite, save_format=save_format, options=options
114
+ )
115
+
116
+ def load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None):
117
+ """While loading the weights, we simply load the weights of the DCE-Net"""
118
+ self.dce_model.load_weights(
119
+ filepath=filepath,
120
+ by_name=by_name,
121
+ skip_mismatch=skip_mismatch,
122
+ options=options,
123
+ )
124
+
125
+ def build_datasets(
126
+ self,
127
+ image_size: int = 256,
128
+ dataset_label: str = "lol",
129
+ apply_resize: bool = False,
130
+ apply_random_horizontal_flip: bool = True,
131
+ apply_random_vertical_flip: bool = True,
132
+ apply_random_rotation: bool = True,
133
+ val_split: float = 0.2,
134
+ batch_size: int = 16,
135
+ ) -> None:
136
+ if dataset_label == "lol":
137
+ (self.low_images, _), (self.test_low_images, _) = download_lol_dataset()
138
+ elif dataset_label == "unpaired":
139
+ self.low_images, (
140
+ self.test_low_images,
141
+ _,
142
+ ) = download_unpaired_low_light_dataset()
143
+ data_loader = UnpairedLowLightDataset(
144
+ image_size,
145
+ apply_resize,
146
+ apply_random_horizontal_flip,
147
+ apply_random_vertical_flip,
148
+ apply_random_rotation,
149
+ )
150
+ self.train_dataset, self.val_dataset = data_loader.get_datasets(
151
+ self.low_images, val_split, batch_size
152
+ )
153
+
154
+ def train(self, epochs: int):
155
+ log_dir = os.path.join(
156
+ self.experiment_name,
157
+ "logs",
158
+ datetime.now().strftime("%Y%m%d-%H%M%S"),
159
+ )
160
+ tensorboard_callback = keras.callbacks.TensorBoard(log_dir, histogram_freq=1)
161
+ callbacks = [tensorboard_callback]
162
+ if self.using_wandb:
163
+ callbacks += [WandbCallback()]
164
+ history = self.fit(
165
+ self.train_dataset,
166
+ validation_data=self.val_dataset,
167
+ epochs=epochs,
168
+ callbacks=callbacks,
169
+ )
170
+ return history
171
+
172
+ def infer(self, original_image):
173
+ image = keras.preprocessing.image.img_to_array(original_image)
174
+ image = image.astype("float32") / 255.0
175
+ image = np.expand_dims(image, axis=0)
176
+ output_image = self.call(image)
177
+ output_image = tf.cast((output_image[0, :, :, :] * 255), dtype=np.uint8)
178
+ output_image = Image.fromarray(output_image.numpy())
179
+ return output_image
180
+
181
+ def infer_from_file(self, original_image_file: str):
182
+ original_image = Image.open(original_image_file)
183
+ return self.infer(original_image)
notebooks/enhance_me_train.ipynb CHANGED
@@ -37,11 +37,12 @@
37
  "import os\n",
38
  "import sys\n",
39
  "\n",
40
- "sys.path.append(\"./enhance-me\")\n",
41
  "\n",
42
  "from PIL import Image\n",
43
  "from enhance_me import commons\n",
44
- "from enhance_me.mirnet import MIRNet"
 
45
  ]
46
  },
47
  {
@@ -170,7 +171,7 @@
170
  " enhanced_image = mirnet.infer(original_image)\n",
171
  " ground_truth = Image.open(mirnet.test_enhanced_images[index])\n",
172
  " commons.plot_results(\n",
173
- " [original_image, ground_truth, ground_truth],\n",
174
  " [\"Original Image\", \"Ground Truth\", \"Enhanced Image\"],\n",
175
  " (18, 18),\n",
176
  " )"
@@ -183,6 +184,92 @@
183
  "id": "dO-IbNQHkB3R"
184
  },
185
  "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  "source": []
187
  }
188
  ],
@@ -197,13 +284,23 @@
197
  "provenance": []
198
  },
199
  "kernelspec": {
200
- "display_name": "Python 3",
 
201
  "name": "python3"
202
  },
203
  "language_info": {
204
- "name": "python"
 
 
 
 
 
 
 
 
 
205
  }
206
  },
207
  "nbformat": 4,
208
- "nbformat_minor": 0
209
  }
 
37
  "import os\n",
38
  "import sys\n",
39
  "\n",
40
+ "sys.path.append(\"..\")\n",
41
  "\n",
42
  "from PIL import Image\n",
43
  "from enhance_me import commons\n",
44
+ "from enhance_me.mirnet import MIRNet\n",
45
+ "from enhance_me.zero_dce import ZeroDCE"
46
  ]
47
  },
48
  {
 
171
  " enhanced_image = mirnet.infer(original_image)\n",
172
  " ground_truth = Image.open(mirnet.test_enhanced_images[index])\n",
173
  " commons.plot_results(\n",
174
+ " [original_image, ground_truth, enhanced_image],\n",
175
  " [\"Original Image\", \"Ground Truth\", \"Enhanced Image\"],\n",
176
  " (18, 18),\n",
177
  " )"
 
184
  "id": "dO-IbNQHkB3R"
185
  },
186
  "outputs": [],
187
+ "source": [
188
+ "# @title Zero-DCE Train Configs\n",
189
+ "\n",
190
+ "experiment_name = \"unpaired_low_light_256_resize\" # @param {type:\"string\"}\n",
191
+ "image_size = 256 # @param {type:\"integer\"}\n",
192
+ "dataset_label = \"unpaired\" # @param [\"lol\", \"unpaired\"]\n",
193
+ "use_mixed_precision = False # @param {type:\"boolean\"}\n",
194
+ "apply_resize = True # @param {type:\"boolean\"}\n",
195
+ "apply_random_horizontal_flip = True # @param {type:\"boolean\"}\n",
196
+ "apply_random_vertical_flip = True # @param {type:\"boolean\"}\n",
197
+ "apply_random_rotation = True # @param {type:\"boolean\"}\n",
198
+ "wandb_api_key = \"\" # @param {type:\"string\"}\n",
199
+ "val_split = 0.1 # @param {type:\"slider\", min:0.1, max:1.0, step:0.1}\n",
200
+ "batch_size = 16 # @param {type:\"integer\"}\n",
201
+ "learning_rate = 1e-4 # @param {type:\"number\"}\n",
202
+ "epsilon = 1e-3 # @param {type:\"number\"}\n",
203
+ "epochs = 100 # @param {type:\"slider\", min:10, max:100, step:5}"
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "code",
208
+ "execution_count": null,
209
+ "metadata": {},
210
+ "outputs": [],
211
+ "source": [
212
+ "zero_dce = ZeroDCE(\n",
213
+ " experiment_name=experiment_name,\n",
214
+ " wandb_api_key=None if wandb_api_key == \"\" else wandb_api_key,\n",
215
+ " use_mixed_precision=use_mixed_precision\n",
216
+ ")"
217
+ ]
218
+ },
219
+ {
220
+ "cell_type": "code",
221
+ "execution_count": null,
222
+ "metadata": {},
223
+ "outputs": [],
224
+ "source": [
225
+ "zero_dce.build_datasets(\n",
226
+ " image_size=image_size,\n",
227
+ " dataset_label=dataset_label,\n",
228
+ " apply_resize=apply_resize,\n",
229
+ " apply_random_horizontal_flip=apply_random_horizontal_flip,\n",
230
+ " apply_random_vertical_flip=apply_random_vertical_flip,\n",
231
+ " apply_random_rotation=apply_random_rotation,\n",
232
+ " val_split=val_split,\n",
233
+ " batch_size=batch_size\n",
234
+ ")"
235
+ ]
236
+ },
237
+ {
238
+ "cell_type": "code",
239
+ "execution_count": null,
240
+ "metadata": {
241
+ "scrolled": false
242
+ },
243
+ "outputs": [],
244
+ "source": [
245
+ "zero_dce.compile(learning_rate=learning_rate)\n",
246
+ "history = zero_dce.train(epochs=epochs)\n",
247
+ "zero_dce.save_weights(os.path.join(experiment_name, \"weights.h5\"))"
248
+ ]
249
+ },
250
+ {
251
+ "cell_type": "code",
252
+ "execution_count": null,
253
+ "metadata": {
254
+ "scrolled": false
255
+ },
256
+ "outputs": [],
257
+ "source": [
258
+ "for index, low_image_file in enumerate(zero_dce.test_low_images):\n",
259
+ " original_image = Image.open(low_image_file)\n",
260
+ " enhanced_image = zero_dce.infer(original_image)\n",
261
+ " commons.plot_results(\n",
262
+ " [original_image, enhanced_image],\n",
263
+ " [\"Original Image\", \"Enhanced Image\"],\n",
264
+ " (18, 18),\n",
265
+ " )"
266
+ ]
267
+ },
268
+ {
269
+ "cell_type": "code",
270
+ "execution_count": null,
271
+ "metadata": {},
272
+ "outputs": [],
273
  "source": []
274
  }
275
  ],
 
284
  "provenance": []
285
  },
286
  "kernelspec": {
287
+ "display_name": "Python 3 (ipykernel)",
288
+ "language": "python",
289
  "name": "python3"
290
  },
291
  "language_info": {
292
+ "codemirror_mode": {
293
+ "name": "ipython",
294
+ "version": 3
295
+ },
296
+ "file_extension": ".py",
297
+ "mimetype": "text/x-python",
298
+ "name": "python",
299
+ "nbconvert_exporter": "python",
300
+ "pygments_lexer": "ipython3",
301
+ "version": "3.8.10"
302
  }
303
  },
304
  "nbformat": 4,
305
+ "nbformat_minor": 1
306
  }
test.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from enhance_me.commons import download_unpaired_low_light_dataset
2
+
3
+
4
+ download_unpaired_low_light_dataset()