Spaces:
Runtime error
Runtime error
geekyrakshit
commited on
Merge pull request #2 from soumik12345/zero-dce
Browse filesZero-reference Deep Curve Estimation for Low-light Image Enhancement
- .gitignore +3 -1
- Dockerfile +17 -0
- README.md +11 -1
- app.py +49 -11
- enhance_me/__init__.py +2 -0
- enhance_me/augmentation.py +35 -0
- enhance_me/commons.py +15 -0
- enhance_me/zero_dce/__init__.py +1 -0
- enhance_me/zero_dce/dataloader.py +79 -0
- enhance_me/zero_dce/dce_net.py +31 -0
- enhance_me/zero_dce/losses/__init__.py +36 -0
- enhance_me/zero_dce/losses/spatial_constancy.py +63 -0
- enhance_me/zero_dce/zero_dce.py +183 -0
- notebooks/enhance_me_train.ipynb +103 -6
- test.py +4 -0
.gitignore
CHANGED
@@ -131,4 +131,6 @@ dmypy.json
|
|
131 |
# Datasets
|
132 |
datasets/
|
133 |
**.zip
|
134 |
-
**.h5
|
|
|
|
|
|
131 |
# Datasets
|
132 |
datasets/
|
133 |
**.zip
|
134 |
+
**.h5
|
135 |
+
lol_dataset_**
|
136 |
+
wandb**
|
Dockerfile
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Pull Base Image
|
2 |
+
FROM tensorflow/tensorflow:latest-gpu-jupyter
|
3 |
+
|
4 |
+
# Set Working Directory
|
5 |
+
RUN mkdir /usr/src/enhance-me
|
6 |
+
WORKDIR /usr/src/enhance-me
|
7 |
+
|
8 |
+
# Set Environment Variables
|
9 |
+
ENV PYTHONDONTWRITEBYTECODE 1
|
10 |
+
ENV PYTHONUNBUFFERED 1
|
11 |
+
|
12 |
+
RUN pip install --upgrade pip setuptools wheel
|
13 |
+
RUN pip install gdown matplotlib streamlit tqdm wandb
|
14 |
+
|
15 |
+
COPY . /usr/src/enhance-me/
|
16 |
+
|
17 |
+
CMD ["jupyter", "notebook", "--port=8888", "--no-browser", "--ip=0.0.0.0", "--allow-root"]
|
README.md
CHANGED
@@ -8,4 +8,14 @@ app_file: app.py
|
|
8 |
pinned: false
|
9 |
---
|
10 |
|
11 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
pinned: false
|
9 |
---
|
10 |
|
11 |
+
# Enhance Me
|
12 |
+
|
13 |
+
A unified platform for image enhancement.
|
14 |
+
|
15 |
+
## Usage
|
16 |
+
|
17 |
+
### Train using Docker
|
18 |
+
|
19 |
+
- Build image using `docker build -t enhance-image .`
|
20 |
+
|
21 |
+
- Run notebook using `docker run -it --gpus all -p 8888:8888 -v $(pwd):/usr/src/enhance-me enhance-image`
|
app.py
CHANGED
@@ -1,23 +1,36 @@
|
|
|
|
1 |
from PIL import Image
|
2 |
import streamlit as st
|
3 |
from tensorflow.keras import utils, backend
|
4 |
|
5 |
-
from enhance_me
|
6 |
|
7 |
|
8 |
def get_mirnet_object() -> MIRNet:
|
9 |
-
mirnet = MIRNet()
|
10 |
-
mirnet.build_model()
|
11 |
utils.get_file(
|
12 |
"weights_lol_128.h5",
|
13 |
"https://github.com/soumik12345/enhance-me/releases/download/v0.2/weights_lol_128.h5",
|
14 |
cache_dir=".",
|
15 |
cache_subdir="weights",
|
16 |
)
|
|
|
|
|
17 |
mirnet.load_weights("./weights/weights_lol_128.h5")
|
18 |
return mirnet
|
19 |
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
def main():
|
22 |
st.markdown("# Enhance Me")
|
23 |
st.markdown("Made with :heart: by [geekyRakshit](http://github.com/soumik12345)")
|
@@ -30,14 +43,39 @@ def main():
|
|
30 |
if uploaded_file is not None:
|
31 |
original_image = Image.open(uploaded_file)
|
32 |
st.image(original_image, caption="original image")
|
33 |
-
st.sidebar.
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
|
43 |
if __name__ == "__main__":
|
|
|
1 |
+
import os
|
2 |
from PIL import Image
|
3 |
import streamlit as st
|
4 |
from tensorflow.keras import utils, backend
|
5 |
|
6 |
+
from enhance_me import MIRNet, ZeroDCE
|
7 |
|
8 |
|
9 |
def get_mirnet_object() -> MIRNet:
|
|
|
|
|
10 |
utils.get_file(
|
11 |
"weights_lol_128.h5",
|
12 |
"https://github.com/soumik12345/enhance-me/releases/download/v0.2/weights_lol_128.h5",
|
13 |
cache_dir=".",
|
14 |
cache_subdir="weights",
|
15 |
)
|
16 |
+
mirnet = MIRNet()
|
17 |
+
mirnet.build_model()
|
18 |
mirnet.load_weights("./weights/weights_lol_128.h5")
|
19 |
return mirnet
|
20 |
|
21 |
|
22 |
+
def get_zero_dce_object(model_alias: str) -> ZeroDCE:
|
23 |
+
utils.get_file(
|
24 |
+
f"{model_alias}.h5",
|
25 |
+
f"https://github.com/soumik12345/enhance-me/releases/download/v0.4/{model_alias}.h5",
|
26 |
+
cache_dir=".",
|
27 |
+
cache_subdir="weights",
|
28 |
+
)
|
29 |
+
dce = ZeroDCE()
|
30 |
+
dce.load_weights(os.path.join("./weights", f"{model_alias}.h5"))
|
31 |
+
return dce
|
32 |
+
|
33 |
+
|
34 |
def main():
|
35 |
st.markdown("# Enhance Me")
|
36 |
st.markdown("Made with :heart: by [geekyRakshit](http://github.com/soumik12345)")
|
|
|
43 |
if uploaded_file is not None:
|
44 |
original_image = Image.open(uploaded_file)
|
45 |
st.image(original_image, caption="original image")
|
46 |
+
model_option = st.sidebar.selectbox(
|
47 |
+
"Please select the model:",
|
48 |
+
(
|
49 |
+
"",
|
50 |
+
"MIRNet",
|
51 |
+
"Zero-DCE (dce_weights_lol_128)",
|
52 |
+
"Zero-DCE (dce_weights_lol_128_resize)",
|
53 |
+
"Zero-DCE (dce_weights_lol_256)",
|
54 |
+
"Zero-DCE (dce_weights_lol_256_resize)",
|
55 |
+
"Zero-DCE (dce_weights_unpaired_128)",
|
56 |
+
"Zero-DCE (dce_weights_unpaired_128_resize)",
|
57 |
+
"Zero-DCE (dce_weights_unpaired_256)",
|
58 |
+
"Zero-DCE (dce_weights_unpaired_256_resize)"
|
59 |
+
),
|
60 |
+
)
|
61 |
+
if model_option != "":
|
62 |
+
if model_option == "MIRNet":
|
63 |
+
st.sidebar.info("Loading MIRNet...")
|
64 |
+
mirnet = get_mirnet_object()
|
65 |
+
st.sidebar.info("Done!")
|
66 |
+
st.sidebar.info("Processing Image...")
|
67 |
+
enhanced_image = mirnet.infer(original_image)
|
68 |
+
st.sidebar.info("Done!")
|
69 |
+
st.image(enhanced_image, caption="enhanced image")
|
70 |
+
elif "Zero-DCE" in model_option:
|
71 |
+
model_alias = model_option[model_option.find("(") + 1: model_option.find(")")]
|
72 |
+
st.sidebar.info("Loading Zero-DCE...")
|
73 |
+
zero_dce = get_zero_dce_object(model_alias)
|
74 |
+
st.sidebar.info("Done!")
|
75 |
+
enhanced_image = zero_dce.infer(original_image)
|
76 |
+
st.sidebar.info("Done!")
|
77 |
+
st.image(enhanced_image, caption="enhanced image")
|
78 |
+
backend.clear_session()
|
79 |
|
80 |
|
81 |
if __name__ == "__main__":
|
enhance_me/__init__.py
CHANGED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .mirnet import MIRNet
|
2 |
+
from .zero_dce import ZeroDCE
|
enhance_me/augmentation.py
CHANGED
@@ -49,3 +49,38 @@ class AugmentationFactory:
|
|
49 |
return tf.image.rot90(input_image, condition), tf.image.rot90(
|
50 |
enhanced_image, condition
|
51 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
return tf.image.rot90(input_image, condition), tf.image.rot90(
|
50 |
enhanced_image, condition
|
51 |
)
|
52 |
+
|
53 |
+
|
54 |
+
class UnpairedAugmentationFactory:
|
55 |
+
def __init__(self, image_size) -> None:
|
56 |
+
self.image_size = image_size
|
57 |
+
|
58 |
+
def random_crop(self, image):
|
59 |
+
image_shape = tf.shape(image)[:2]
|
60 |
+
crop_w = tf.random.uniform(
|
61 |
+
shape=(), maxval=image_shape[1] - self.image_size + 1, dtype=tf.int32
|
62 |
+
)
|
63 |
+
crop_h = tf.random.uniform(
|
64 |
+
shape=(), maxval=image_shape[0] - self.image_size + 1, dtype=tf.int32
|
65 |
+
)
|
66 |
+
return image[
|
67 |
+
crop_h : crop_h + self.image_size, crop_w : crop_w + self.image_size
|
68 |
+
]
|
69 |
+
|
70 |
+
def random_horizontal_flip(self, image):
|
71 |
+
return tf.cond(
|
72 |
+
tf.random.uniform(shape=(), maxval=1) < 0.5,
|
73 |
+
lambda: image,
|
74 |
+
lambda: tf.image.flip_left_right(image),
|
75 |
+
)
|
76 |
+
|
77 |
+
def random_vertical_flip(self, image):
|
78 |
+
return tf.cond(
|
79 |
+
tf.random.uniform(shape=(), maxval=1) < 0.5,
|
80 |
+
lambda: image,
|
81 |
+
lambda: tf.image.flip_up_down(image),
|
82 |
+
)
|
83 |
+
|
84 |
+
def random_rotate(self, image):
|
85 |
+
condition = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)
|
86 |
+
return tf.image.rot90(image, condition)
|
enhance_me/commons.py
CHANGED
@@ -61,3 +61,18 @@ def download_lol_dataset():
|
|
61 |
test_enhanced_images = sorted(glob("./datasets/lol_dataset/eval15/high/*"))
|
62 |
assert len(test_low_images) == len(test_enhanced_images)
|
63 |
return (low_images, enhanced_images), (test_low_images, test_enhanced_images)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
test_enhanced_images = sorted(glob("./datasets/lol_dataset/eval15/high/*"))
|
62 |
assert len(test_low_images) == len(test_enhanced_images)
|
63 |
return (low_images, enhanced_images), (test_low_images, test_enhanced_images)
|
64 |
+
|
65 |
+
|
66 |
+
def download_unpaired_low_light_dataset():
|
67 |
+
utils.get_file(
|
68 |
+
"low_light_dataset.zip",
|
69 |
+
"https://github.com/soumik12345/enhance-me/releases/download/v0.3/low_light_dataset.zip",
|
70 |
+
cache_dir="./",
|
71 |
+
cache_subdir="./datasets",
|
72 |
+
extract=True,
|
73 |
+
)
|
74 |
+
low_images = glob("./datasets/low_light_dataset/*.png")
|
75 |
+
test_low_images = sorted(glob("./datasets/low_light_dataset/eval15/low/*"))
|
76 |
+
test_enhanced_images = sorted(glob("./datasets/low_light_dataset/eval15/high/*"))
|
77 |
+
assert len(test_low_images) == len(test_enhanced_images)
|
78 |
+
return low_images, (test_low_images, test_enhanced_images)
|
enhance_me/zero_dce/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .zero_dce import ZeroDCE
|
enhance_me/zero_dce/dataloader.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from ..commons import read_image
|
5 |
+
from ..augmentation import UnpairedAugmentationFactory
|
6 |
+
|
7 |
+
|
8 |
+
class UnpairedLowLightDataset:
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
image_size: int = 256,
|
12 |
+
apply_resize: bool = False,
|
13 |
+
apply_random_horizontal_flip: bool = True,
|
14 |
+
apply_random_vertical_flip: bool = True,
|
15 |
+
apply_random_rotation: bool = True,
|
16 |
+
) -> None:
|
17 |
+
self.augmentation_factory = UnpairedAugmentationFactory(image_size=image_size)
|
18 |
+
self.image_size = image_size
|
19 |
+
self.apply_resize = apply_resize
|
20 |
+
self.apply_random_horizontal_flip = apply_random_horizontal_flip
|
21 |
+
self.apply_random_vertical_flip = apply_random_vertical_flip
|
22 |
+
self.apply_random_rotation = apply_random_rotation
|
23 |
+
|
24 |
+
def _resize(self, image):
|
25 |
+
return tf.image.resize(image, (self.image_size, self.image_size))
|
26 |
+
|
27 |
+
def _get_dataset(self, images: List[str], batch_size: int, is_train: bool):
|
28 |
+
dataset = tf.data.Dataset.from_tensor_slices((images))
|
29 |
+
dataset = dataset.map(read_image, num_parallel_calls=tf.data.AUTOTUNE)
|
30 |
+
dataset = (
|
31 |
+
dataset.map(
|
32 |
+
self.augmentation_factory.random_crop,
|
33 |
+
num_parallel_calls=tf.data.AUTOTUNE,
|
34 |
+
)
|
35 |
+
if not self.apply_resize
|
36 |
+
else dataset.map(self._resize, num_parallel_calls=tf.data.AUTOTUNE)
|
37 |
+
)
|
38 |
+
if is_train:
|
39 |
+
dataset = (
|
40 |
+
dataset.map(
|
41 |
+
self.augmentation_factory.random_horizontal_flip,
|
42 |
+
num_parallel_calls=tf.data.AUTOTUNE,
|
43 |
+
)
|
44 |
+
if self.apply_random_horizontal_flip
|
45 |
+
else dataset
|
46 |
+
)
|
47 |
+
dataset = (
|
48 |
+
dataset.map(
|
49 |
+
self.augmentation_factory.random_vertical_flip,
|
50 |
+
num_parallel_calls=tf.data.AUTOTUNE,
|
51 |
+
)
|
52 |
+
if self.apply_random_vertical_flip
|
53 |
+
else dataset
|
54 |
+
)
|
55 |
+
dataset = (
|
56 |
+
dataset.map(
|
57 |
+
self.augmentation_factory.random_rotate,
|
58 |
+
num_parallel_calls=tf.data.AUTOTUNE,
|
59 |
+
)
|
60 |
+
if self.apply_random_rotation
|
61 |
+
else dataset
|
62 |
+
)
|
63 |
+
dataset = dataset.batch(batch_size, drop_remainder=True)
|
64 |
+
return dataset
|
65 |
+
|
66 |
+
def get_datasets(
|
67 |
+
self,
|
68 |
+
images: List[str],
|
69 |
+
val_split: float = 0.2,
|
70 |
+
batch_size: int = 16,
|
71 |
+
):
|
72 |
+
split_index = int(len(images) * (1 - val_split))
|
73 |
+
train_images = images[:split_index]
|
74 |
+
val_images = images[split_index:]
|
75 |
+
print(f"Number of train data points: {len(train_images)}")
|
76 |
+
print(f"Number of validation data points: {len(val_images)}")
|
77 |
+
train_dataset = self._get_dataset(train_images, batch_size, is_train=True)
|
78 |
+
val_dataset = self._get_dataset(val_images, batch_size, is_train=False)
|
79 |
+
return train_dataset, val_dataset
|
enhance_me/zero_dce/dce_net.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow.keras import layers, Input, Model
|
3 |
+
|
4 |
+
|
5 |
+
def build_dce_net() -> Model:
|
6 |
+
input_image = Input(shape=[None, None, 3])
|
7 |
+
conv1 = layers.Conv2D(
|
8 |
+
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
|
9 |
+
)(input_image)
|
10 |
+
conv2 = layers.Conv2D(
|
11 |
+
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
|
12 |
+
)(conv1)
|
13 |
+
conv3 = layers.Conv2D(
|
14 |
+
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
|
15 |
+
)(conv2)
|
16 |
+
conv4 = layers.Conv2D(
|
17 |
+
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
|
18 |
+
)(conv3)
|
19 |
+
int_con1 = layers.Concatenate(axis=-1)([conv4, conv3])
|
20 |
+
conv5 = layers.Conv2D(
|
21 |
+
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
|
22 |
+
)(int_con1)
|
23 |
+
int_con2 = layers.Concatenate(axis=-1)([conv5, conv2])
|
24 |
+
conv6 = layers.Conv2D(
|
25 |
+
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
|
26 |
+
)(int_con2)
|
27 |
+
int_con3 = layers.Concatenate(axis=-1)([conv6, conv1])
|
28 |
+
x_r = layers.Conv2D(24, (3, 3), strides=(1, 1), activation="tanh", padding="same")(
|
29 |
+
int_con3
|
30 |
+
)
|
31 |
+
return Model(inputs=input_image, outputs=x_r)
|
enhance_me/zero_dce/losses/__init__.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
from .spatial_constancy import SpatialConsistencyLoss
|
4 |
+
|
5 |
+
|
6 |
+
def color_constancy_loss(x):
|
7 |
+
mean_rgb = tf.reduce_mean(x, axis=(1, 2), keepdims=True)
|
8 |
+
mean_r, mean_g, mean_b = (
|
9 |
+
mean_rgb[:, :, :, 0],
|
10 |
+
mean_rgb[:, :, :, 1],
|
11 |
+
mean_rgb[:, :, :, 2],
|
12 |
+
)
|
13 |
+
diff_rg = tf.square(mean_r - mean_g)
|
14 |
+
diff_rb = tf.square(mean_r - mean_b)
|
15 |
+
diff_gb = tf.square(mean_b - mean_g)
|
16 |
+
return tf.sqrt(tf.square(diff_rg) + tf.square(diff_rb) + tf.square(diff_gb))
|
17 |
+
|
18 |
+
|
19 |
+
def exposure_loss(x, mean_val=0.6):
|
20 |
+
x = tf.reduce_mean(x, axis=3, keepdims=True)
|
21 |
+
mean = tf.nn.avg_pool2d(x, ksize=16, strides=16, padding="VALID")
|
22 |
+
return tf.reduce_mean(tf.square(mean - mean_val))
|
23 |
+
|
24 |
+
|
25 |
+
def illumination_smoothness_loss(x):
|
26 |
+
batch_size = tf.shape(x)[0]
|
27 |
+
h_x = tf.shape(x)[1]
|
28 |
+
w_x = tf.shape(x)[2]
|
29 |
+
count_h = (tf.shape(x)[2] - 1) * tf.shape(x)[3]
|
30 |
+
count_w = tf.shape(x)[2] * (tf.shape(x)[3] - 1)
|
31 |
+
h_tv = tf.reduce_sum(tf.square((x[:, 1:, :, :] - x[:, : h_x - 1, :, :])))
|
32 |
+
w_tv = tf.reduce_sum(tf.square((x[:, :, 1:, :] - x[:, :, : w_x - 1, :])))
|
33 |
+
batch_size = tf.cast(batch_size, dtype=tf.float32)
|
34 |
+
count_h = tf.cast(count_h, dtype=tf.float32)
|
35 |
+
count_w = tf.cast(count_w, dtype=tf.float32)
|
36 |
+
return 2 * (h_tv / count_h + w_tv / count_w) / batch_size
|
enhance_me/zero_dce/losses/spatial_constancy.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow.keras import losses
|
3 |
+
|
4 |
+
|
5 |
+
class SpatialConsistencyLoss(losses.Loss):
|
6 |
+
def __init__(self, **kwargs):
|
7 |
+
super(SpatialConsistencyLoss, self).__init__(reduction="none")
|
8 |
+
|
9 |
+
self.left_kernel = tf.constant(
|
10 |
+
[[[[0, 0, 0]], [[-1, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32
|
11 |
+
)
|
12 |
+
self.right_kernel = tf.constant(
|
13 |
+
[[[[0, 0, 0]], [[0, 1, -1]], [[0, 0, 0]]]], dtype=tf.float32
|
14 |
+
)
|
15 |
+
self.up_kernel = tf.constant(
|
16 |
+
[[[[0, -1, 0]], [[0, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32
|
17 |
+
)
|
18 |
+
self.down_kernel = tf.constant(
|
19 |
+
[[[[0, 0, 0]], [[0, 1, 0]], [[0, -1, 0]]]], dtype=tf.float32
|
20 |
+
)
|
21 |
+
|
22 |
+
def call(self, y_true, y_pred):
|
23 |
+
|
24 |
+
original_mean = tf.reduce_mean(y_true, 3, keepdims=True)
|
25 |
+
enhanced_mean = tf.reduce_mean(y_pred, 3, keepdims=True)
|
26 |
+
original_pool = tf.nn.avg_pool2d(
|
27 |
+
original_mean, ksize=4, strides=4, padding="VALID"
|
28 |
+
)
|
29 |
+
enhanced_pool = tf.nn.avg_pool2d(
|
30 |
+
enhanced_mean, ksize=4, strides=4, padding="VALID"
|
31 |
+
)
|
32 |
+
|
33 |
+
d_original_left = tf.nn.conv2d(
|
34 |
+
original_pool, self.left_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
35 |
+
)
|
36 |
+
d_original_right = tf.nn.conv2d(
|
37 |
+
original_pool, self.right_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
38 |
+
)
|
39 |
+
d_original_up = tf.nn.conv2d(
|
40 |
+
original_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
41 |
+
)
|
42 |
+
d_original_down = tf.nn.conv2d(
|
43 |
+
original_pool, self.down_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
44 |
+
)
|
45 |
+
|
46 |
+
d_enhanced_left = tf.nn.conv2d(
|
47 |
+
enhanced_pool, self.left_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
48 |
+
)
|
49 |
+
d_enhanced_right = tf.nn.conv2d(
|
50 |
+
enhanced_pool, self.right_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
51 |
+
)
|
52 |
+
d_enhanced_up = tf.nn.conv2d(
|
53 |
+
enhanced_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
54 |
+
)
|
55 |
+
d_enhanced_down = tf.nn.conv2d(
|
56 |
+
enhanced_pool, self.down_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
57 |
+
)
|
58 |
+
|
59 |
+
d_left = tf.square(d_original_left - d_enhanced_left)
|
60 |
+
d_right = tf.square(d_original_right - d_enhanced_right)
|
61 |
+
d_up = tf.square(d_original_up - d_enhanced_up)
|
62 |
+
d_down = tf.square(d_original_down - d_enhanced_down)
|
63 |
+
return d_left + d_right + d_up + d_down
|
enhance_me/zero_dce/zero_dce.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
from datetime import datetime
|
5 |
+
|
6 |
+
import tensorflow as tf
|
7 |
+
from tensorflow import keras
|
8 |
+
from tensorflow.keras import optimizers, mixed_precision, Model
|
9 |
+
from wandb.keras import WandbCallback
|
10 |
+
|
11 |
+
from .dce_net import build_dce_net
|
12 |
+
from .dataloader import UnpairedLowLightDataset
|
13 |
+
from .losses import (
|
14 |
+
color_constancy_loss,
|
15 |
+
exposure_loss,
|
16 |
+
illumination_smoothness_loss,
|
17 |
+
SpatialConsistencyLoss,
|
18 |
+
)
|
19 |
+
from ..commons import (
|
20 |
+
download_lol_dataset,
|
21 |
+
download_unpaired_low_light_dataset,
|
22 |
+
init_wandb,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
class ZeroDCE(Model):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
experiment_name=None,
|
30 |
+
wandb_api_key=None,
|
31 |
+
use_mixed_precision: bool = False,
|
32 |
+
**kwargs
|
33 |
+
):
|
34 |
+
super(ZeroDCE, self).__init__(**kwargs)
|
35 |
+
self.experiment_name = experiment_name
|
36 |
+
if use_mixed_precision:
|
37 |
+
policy = mixed_precision.Policy("mixed_float16")
|
38 |
+
mixed_precision.set_global_policy(policy)
|
39 |
+
if wandb_api_key is not None:
|
40 |
+
init_wandb("zero-dce", experiment_name, wandb_api_key)
|
41 |
+
self.using_wandb = True
|
42 |
+
else:
|
43 |
+
self.using_wandb = False
|
44 |
+
self.dce_model = build_dce_net()
|
45 |
+
|
46 |
+
def compile(self, learning_rate, **kwargs):
|
47 |
+
super(ZeroDCE, self).compile(**kwargs)
|
48 |
+
self.optimizer = optimizers.Adam(learning_rate=learning_rate)
|
49 |
+
self.spatial_constancy_loss = SpatialConsistencyLoss(reduction="none")
|
50 |
+
|
51 |
+
def get_enhanced_image(self, data, output):
|
52 |
+
r1 = output[:, :, :, :3]
|
53 |
+
r2 = output[:, :, :, 3:6]
|
54 |
+
r3 = output[:, :, :, 6:9]
|
55 |
+
r4 = output[:, :, :, 9:12]
|
56 |
+
r5 = output[:, :, :, 12:15]
|
57 |
+
r6 = output[:, :, :, 15:18]
|
58 |
+
r7 = output[:, :, :, 18:21]
|
59 |
+
r8 = output[:, :, :, 21:24]
|
60 |
+
x = data + r1 * (tf.square(data) - data)
|
61 |
+
x = x + r2 * (tf.square(x) - x)
|
62 |
+
x = x + r3 * (tf.square(x) - x)
|
63 |
+
enhanced_image = x + r4 * (tf.square(x) - x)
|
64 |
+
x = enhanced_image + r5 * (tf.square(enhanced_image) - enhanced_image)
|
65 |
+
x = x + r6 * (tf.square(x) - x)
|
66 |
+
x = x + r7 * (tf.square(x) - x)
|
67 |
+
enhanced_image = x + r8 * (tf.square(x) - x)
|
68 |
+
return enhanced_image
|
69 |
+
|
70 |
+
def call(self, data):
|
71 |
+
dce_net_output = self.dce_model(data)
|
72 |
+
return self.get_enhanced_image(data, dce_net_output)
|
73 |
+
|
74 |
+
def compute_losses(self, data, output):
|
75 |
+
enhanced_image = self.get_enhanced_image(data, output)
|
76 |
+
loss_illumination = 200 * illumination_smoothness_loss(output)
|
77 |
+
loss_spatial_constancy = tf.reduce_mean(
|
78 |
+
self.spatial_constancy_loss(enhanced_image, data)
|
79 |
+
)
|
80 |
+
loss_color_constancy = 5 * tf.reduce_mean(color_constancy_loss(enhanced_image))
|
81 |
+
loss_exposure = 10 * tf.reduce_mean(exposure_loss(enhanced_image))
|
82 |
+
total_loss = (
|
83 |
+
loss_illumination
|
84 |
+
+ loss_spatial_constancy
|
85 |
+
+ loss_color_constancy
|
86 |
+
+ loss_exposure
|
87 |
+
)
|
88 |
+
return {
|
89 |
+
"total_loss": total_loss,
|
90 |
+
"illumination_smoothness_loss": loss_illumination,
|
91 |
+
"spatial_constancy_loss": loss_spatial_constancy,
|
92 |
+
"color_constancy_loss": loss_color_constancy,
|
93 |
+
"exposure_loss": loss_exposure,
|
94 |
+
}
|
95 |
+
|
96 |
+
def train_step(self, data):
|
97 |
+
with tf.GradientTape() as tape:
|
98 |
+
output = self.dce_model(data)
|
99 |
+
losses = self.compute_losses(data, output)
|
100 |
+
gradients = tape.gradient(
|
101 |
+
losses["total_loss"], self.dce_model.trainable_weights
|
102 |
+
)
|
103 |
+
self.optimizer.apply_gradients(zip(gradients, self.dce_model.trainable_weights))
|
104 |
+
return losses
|
105 |
+
|
106 |
+
def test_step(self, data):
|
107 |
+
output = self.dce_model(data)
|
108 |
+
return self.compute_losses(data, output)
|
109 |
+
|
110 |
+
def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
|
111 |
+
"""While saving the weights, we simply save the weights of the DCE-Net"""
|
112 |
+
self.dce_model.save_weights(
|
113 |
+
filepath, overwrite=overwrite, save_format=save_format, options=options
|
114 |
+
)
|
115 |
+
|
116 |
+
def load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None):
|
117 |
+
"""While loading the weights, we simply load the weights of the DCE-Net"""
|
118 |
+
self.dce_model.load_weights(
|
119 |
+
filepath=filepath,
|
120 |
+
by_name=by_name,
|
121 |
+
skip_mismatch=skip_mismatch,
|
122 |
+
options=options,
|
123 |
+
)
|
124 |
+
|
125 |
+
def build_datasets(
|
126 |
+
self,
|
127 |
+
image_size: int = 256,
|
128 |
+
dataset_label: str = "lol",
|
129 |
+
apply_resize: bool = False,
|
130 |
+
apply_random_horizontal_flip: bool = True,
|
131 |
+
apply_random_vertical_flip: bool = True,
|
132 |
+
apply_random_rotation: bool = True,
|
133 |
+
val_split: float = 0.2,
|
134 |
+
batch_size: int = 16,
|
135 |
+
) -> None:
|
136 |
+
if dataset_label == "lol":
|
137 |
+
(self.low_images, _), (self.test_low_images, _) = download_lol_dataset()
|
138 |
+
elif dataset_label == "unpaired":
|
139 |
+
self.low_images, (
|
140 |
+
self.test_low_images,
|
141 |
+
_,
|
142 |
+
) = download_unpaired_low_light_dataset()
|
143 |
+
data_loader = UnpairedLowLightDataset(
|
144 |
+
image_size,
|
145 |
+
apply_resize,
|
146 |
+
apply_random_horizontal_flip,
|
147 |
+
apply_random_vertical_flip,
|
148 |
+
apply_random_rotation,
|
149 |
+
)
|
150 |
+
self.train_dataset, self.val_dataset = data_loader.get_datasets(
|
151 |
+
self.low_images, val_split, batch_size
|
152 |
+
)
|
153 |
+
|
154 |
+
def train(self, epochs: int):
|
155 |
+
log_dir = os.path.join(
|
156 |
+
self.experiment_name,
|
157 |
+
"logs",
|
158 |
+
datetime.now().strftime("%Y%m%d-%H%M%S"),
|
159 |
+
)
|
160 |
+
tensorboard_callback = keras.callbacks.TensorBoard(log_dir, histogram_freq=1)
|
161 |
+
callbacks = [tensorboard_callback]
|
162 |
+
if self.using_wandb:
|
163 |
+
callbacks += [WandbCallback()]
|
164 |
+
history = self.fit(
|
165 |
+
self.train_dataset,
|
166 |
+
validation_data=self.val_dataset,
|
167 |
+
epochs=epochs,
|
168 |
+
callbacks=callbacks,
|
169 |
+
)
|
170 |
+
return history
|
171 |
+
|
172 |
+
def infer(self, original_image):
|
173 |
+
image = keras.preprocessing.image.img_to_array(original_image)
|
174 |
+
image = image.astype("float32") / 255.0
|
175 |
+
image = np.expand_dims(image, axis=0)
|
176 |
+
output_image = self.call(image)
|
177 |
+
output_image = tf.cast((output_image[0, :, :, :] * 255), dtype=np.uint8)
|
178 |
+
output_image = Image.fromarray(output_image.numpy())
|
179 |
+
return output_image
|
180 |
+
|
181 |
+
def infer_from_file(self, original_image_file: str):
|
182 |
+
original_image = Image.open(original_image_file)
|
183 |
+
return self.infer(original_image)
|
notebooks/enhance_me_train.ipynb
CHANGED
@@ -37,11 +37,12 @@
|
|
37 |
"import os\n",
|
38 |
"import sys\n",
|
39 |
"\n",
|
40 |
-
"sys.path.append(\"
|
41 |
"\n",
|
42 |
"from PIL import Image\n",
|
43 |
"from enhance_me import commons\n",
|
44 |
-
"from enhance_me.mirnet import MIRNet"
|
|
|
45 |
]
|
46 |
},
|
47 |
{
|
@@ -170,7 +171,7 @@
|
|
170 |
" enhanced_image = mirnet.infer(original_image)\n",
|
171 |
" ground_truth = Image.open(mirnet.test_enhanced_images[index])\n",
|
172 |
" commons.plot_results(\n",
|
173 |
-
" [original_image, ground_truth,
|
174 |
" [\"Original Image\", \"Ground Truth\", \"Enhanced Image\"],\n",
|
175 |
" (18, 18),\n",
|
176 |
" )"
|
@@ -183,6 +184,92 @@
|
|
183 |
"id": "dO-IbNQHkB3R"
|
184 |
},
|
185 |
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
"source": []
|
187 |
}
|
188 |
],
|
@@ -197,13 +284,23 @@
|
|
197 |
"provenance": []
|
198 |
},
|
199 |
"kernelspec": {
|
200 |
-
"display_name": "Python 3",
|
|
|
201 |
"name": "python3"
|
202 |
},
|
203 |
"language_info": {
|
204 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
}
|
206 |
},
|
207 |
"nbformat": 4,
|
208 |
-
"nbformat_minor":
|
209 |
}
|
|
|
37 |
"import os\n",
|
38 |
"import sys\n",
|
39 |
"\n",
|
40 |
+
"sys.path.append(\"..\")\n",
|
41 |
"\n",
|
42 |
"from PIL import Image\n",
|
43 |
"from enhance_me import commons\n",
|
44 |
+
"from enhance_me.mirnet import MIRNet\n",
|
45 |
+
"from enhance_me.zero_dce import ZeroDCE"
|
46 |
]
|
47 |
},
|
48 |
{
|
|
|
171 |
" enhanced_image = mirnet.infer(original_image)\n",
|
172 |
" ground_truth = Image.open(mirnet.test_enhanced_images[index])\n",
|
173 |
" commons.plot_results(\n",
|
174 |
+
" [original_image, ground_truth, enhanced_image],\n",
|
175 |
" [\"Original Image\", \"Ground Truth\", \"Enhanced Image\"],\n",
|
176 |
" (18, 18),\n",
|
177 |
" )"
|
|
|
184 |
"id": "dO-IbNQHkB3R"
|
185 |
},
|
186 |
"outputs": [],
|
187 |
+
"source": [
|
188 |
+
"# @title Zero-DCE Train Configs\n",
|
189 |
+
"\n",
|
190 |
+
"experiment_name = \"unpaired_low_light_256_resize\" # @param {type:\"string\"}\n",
|
191 |
+
"image_size = 256 # @param {type:\"integer\"}\n",
|
192 |
+
"dataset_label = \"unpaired\" # @param [\"lol\", \"unpaired\"]\n",
|
193 |
+
"use_mixed_precision = False # @param {type:\"boolean\"}\n",
|
194 |
+
"apply_resize = True # @param {type:\"boolean\"}\n",
|
195 |
+
"apply_random_horizontal_flip = True # @param {type:\"boolean\"}\n",
|
196 |
+
"apply_random_vertical_flip = True # @param {type:\"boolean\"}\n",
|
197 |
+
"apply_random_rotation = True # @param {type:\"boolean\"}\n",
|
198 |
+
"wandb_api_key = \"\" # @param {type:\"string\"}\n",
|
199 |
+
"val_split = 0.1 # @param {type:\"slider\", min:0.1, max:1.0, step:0.1}\n",
|
200 |
+
"batch_size = 16 # @param {type:\"integer\"}\n",
|
201 |
+
"learning_rate = 1e-4 # @param {type:\"number\"}\n",
|
202 |
+
"epsilon = 1e-3 # @param {type:\"number\"}\n",
|
203 |
+
"epochs = 100 # @param {type:\"slider\", min:10, max:100, step:5}"
|
204 |
+
]
|
205 |
+
},
|
206 |
+
{
|
207 |
+
"cell_type": "code",
|
208 |
+
"execution_count": null,
|
209 |
+
"metadata": {},
|
210 |
+
"outputs": [],
|
211 |
+
"source": [
|
212 |
+
"zero_dce = ZeroDCE(\n",
|
213 |
+
" experiment_name=experiment_name,\n",
|
214 |
+
" wandb_api_key=None if wandb_api_key == \"\" else wandb_api_key,\n",
|
215 |
+
" use_mixed_precision=use_mixed_precision\n",
|
216 |
+
")"
|
217 |
+
]
|
218 |
+
},
|
219 |
+
{
|
220 |
+
"cell_type": "code",
|
221 |
+
"execution_count": null,
|
222 |
+
"metadata": {},
|
223 |
+
"outputs": [],
|
224 |
+
"source": [
|
225 |
+
"zero_dce.build_datasets(\n",
|
226 |
+
" image_size=image_size,\n",
|
227 |
+
" dataset_label=dataset_label,\n",
|
228 |
+
" apply_resize=apply_resize,\n",
|
229 |
+
" apply_random_horizontal_flip=apply_random_horizontal_flip,\n",
|
230 |
+
" apply_random_vertical_flip=apply_random_vertical_flip,\n",
|
231 |
+
" apply_random_rotation=apply_random_rotation,\n",
|
232 |
+
" val_split=val_split,\n",
|
233 |
+
" batch_size=batch_size\n",
|
234 |
+
")"
|
235 |
+
]
|
236 |
+
},
|
237 |
+
{
|
238 |
+
"cell_type": "code",
|
239 |
+
"execution_count": null,
|
240 |
+
"metadata": {
|
241 |
+
"scrolled": false
|
242 |
+
},
|
243 |
+
"outputs": [],
|
244 |
+
"source": [
|
245 |
+
"zero_dce.compile(learning_rate=learning_rate)\n",
|
246 |
+
"history = zero_dce.train(epochs=epochs)\n",
|
247 |
+
"zero_dce.save_weights(os.path.join(experiment_name, \"weights.h5\"))"
|
248 |
+
]
|
249 |
+
},
|
250 |
+
{
|
251 |
+
"cell_type": "code",
|
252 |
+
"execution_count": null,
|
253 |
+
"metadata": {
|
254 |
+
"scrolled": false
|
255 |
+
},
|
256 |
+
"outputs": [],
|
257 |
+
"source": [
|
258 |
+
"for index, low_image_file in enumerate(zero_dce.test_low_images):\n",
|
259 |
+
" original_image = Image.open(low_image_file)\n",
|
260 |
+
" enhanced_image = zero_dce.infer(original_image)\n",
|
261 |
+
" commons.plot_results(\n",
|
262 |
+
" [original_image, enhanced_image],\n",
|
263 |
+
" [\"Original Image\", \"Enhanced Image\"],\n",
|
264 |
+
" (18, 18),\n",
|
265 |
+
" )"
|
266 |
+
]
|
267 |
+
},
|
268 |
+
{
|
269 |
+
"cell_type": "code",
|
270 |
+
"execution_count": null,
|
271 |
+
"metadata": {},
|
272 |
+
"outputs": [],
|
273 |
"source": []
|
274 |
}
|
275 |
],
|
|
|
284 |
"provenance": []
|
285 |
},
|
286 |
"kernelspec": {
|
287 |
+
"display_name": "Python 3 (ipykernel)",
|
288 |
+
"language": "python",
|
289 |
"name": "python3"
|
290 |
},
|
291 |
"language_info": {
|
292 |
+
"codemirror_mode": {
|
293 |
+
"name": "ipython",
|
294 |
+
"version": 3
|
295 |
+
},
|
296 |
+
"file_extension": ".py",
|
297 |
+
"mimetype": "text/x-python",
|
298 |
+
"name": "python",
|
299 |
+
"nbconvert_exporter": "python",
|
300 |
+
"pygments_lexer": "ipython3",
|
301 |
+
"version": "3.8.10"
|
302 |
}
|
303 |
},
|
304 |
"nbformat": 4,
|
305 |
+
"nbformat_minor": 1
|
306 |
}
|
test.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enhance_me.commons import download_unpaired_low_light_dataset
|
2 |
+
|
3 |
+
|
4 |
+
download_unpaired_low_light_dataset()
|