File size: 1,965 Bytes
c8d52e7
 
192c48a
c8d52e7
2673600
192c48a
 
 
2673600
 
 
 
 
 
 
6fd61b9
 
 
 
c8d52e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192c48a
8ff8f45
192c48a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import wandb
from glob import glob
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras import utils


def read_image(image_path):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_png(image, channels=3)
    image.set_shape([None, None, 3])
    image = tf.cast(image, dtype=tf.float32) / 255.0
    return image


def peak_signal_noise_ratio(y_true, y_pred):
    return tf.image.psnr(y_pred, y_true, max_val=255.0)


def plot_results(images, titles, figure_size=(12, 12)):
    fig = plt.figure(figsize=figure_size)
    for i in range(len(images)):
        fig.add_subplot(1, len(images), i + 1).set_title(titles[i])
        _ = plt.imshow(images[i])
        plt.axis("off")
    plt.show()


def closest_number(n, m):
    q = int(n / m)
    n1 = m * q
    if (n * m) > 0:
        n2 = m * (q + 1)
    else:
        n2 = m * (q - 1)
    if abs(n - n1) < abs(n - n2):
        return n1
    return n2


def init_wandb(project_name, experiment_name, wandb_api_key):
    if project_name is not None and experiment_name is not None:
        os.environ["WANDB_API_KEY"] = wandb_api_key
        wandb.init(project=project_name, name=experiment_name, sync_tensorboard=True)


def download_lol_dataset():
    utils.get_file(
        "lol_dataset.zip",
        "https://github.com/soumik12345/enhance-me/releases/download/v0.1/lol_dataset.zip",
        cache_dir="./",
        cache_subdir="./datasets",
        extract=True,
    )
    low_images = sorted(glob("./datasets/lol_dataset/our485/low/*"))
    enhanced_images = sorted(glob("./datasets/lol_dataset/our485/high/*"))
    assert len(low_images) == len(enhanced_images)
    test_low_images = sorted(glob("./datasets/lol_dataset/eval15/low/*"))
    test_enhanced_images = sorted(glob("./datasets/lol_dataset/eval15/high/*"))
    assert len(test_low_images) == len(test_enhanced_images)
    return (low_images, enhanced_images), (test_low_images, test_enhanced_images)