geekyrakshit commited on
Commit
192c48a
1 Parent(s): c8d52e7

added download function for lol dataset

Browse files
.gitignore CHANGED
@@ -127,3 +127,6 @@ dmypy.json
127
 
128
  # Pyre type checker
129
  .pyre/
 
 
 
 
127
 
128
  # Pyre type checker
129
  .pyre/
130
+
131
+ # Datasets
132
+ datasets/
enhance_me/commons.py CHANGED
@@ -1,8 +1,11 @@
1
  import os
2
  import wandb
3
- import tensorflow as tf
4
  import matplotlib.pyplot as plt
5
 
 
 
 
6
 
7
  def read_image(image_path):
8
  image = tf.io.read_file(image_path)
@@ -39,5 +42,22 @@ def closest_number(n, m):
39
 
40
  def init_wandb(project_name, experiment_name, wandb_api_key):
41
  if project_name is not None and experiment_name is not None:
42
- os.environ['WANDB_API_KEY'] = wandb_api_key
43
  wandb.init(project=project_name, name=experiment_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import wandb
3
+ from glob import glob
4
  import matplotlib.pyplot as plt
5
 
6
+ import tensorflow as tf
7
+ from tensorflow.keras import utils
8
+
9
 
10
  def read_image(image_path):
11
  image = tf.io.read_file(image_path)
 
42
 
43
  def init_wandb(project_name, experiment_name, wandb_api_key):
44
  if project_name is not None and experiment_name is not None:
45
+ os.environ["WANDB_API_KEY"] = wandb_api_key
46
  wandb.init(project=project_name, name=experiment_name)
47
+
48
+
49
+ def download_lol_dataset():
50
+ utils.get_file(
51
+ "lol_dataset.zip",
52
+ "https://github.com/soumik12345/enhance-me/releases/download/v0.1/lol_dataset.zip",
53
+ cache_dir="./",
54
+ cache_subdir="./datasets",
55
+ extract=True,
56
+ )
57
+ low_images = sorted(glob("./datasets/lol_dataset/our485/low/*"))
58
+ enhanced_images = sorted(glob("./datasets/lol_dataset/our485/high/*"))
59
+ assert len(low_images) == len(enhanced_images)
60
+ test_low_images = sorted(glob("./datasets/lol_dataset/eval15/low/*"))
61
+ test_enhanced_images = sorted(glob("./datasets/lol_dataset/eval15/high/*"))
62
+ assert len(test_low_images) == len(test_enhanced_images)
63
+ return (low_images, enhanced_images), (test_low_images, test_enhanced_images)
enhance_me/mirnet/mirnet.py CHANGED
@@ -12,7 +12,12 @@ from wandb.keras import WandbCallback
12
  from .dataloader import LowLightDataset
13
  from .models import build_mirnet_model
14
  from .losses import CharbonnierLoss
15
- from ..commons import peak_signal_noise_ratio, closest_number, init_wandb
 
 
 
 
 
16
 
17
 
18
  class MIRNet:
@@ -20,12 +25,15 @@ class MIRNet:
20
  self,
21
  experiment_name: str,
22
  image_size: int = 256,
 
23
  apply_random_horizontal_flip: bool = True,
24
  apply_random_vertical_flip: bool = True,
25
  apply_random_rotation: bool = True,
26
  wandb_api_key=None,
27
  ) -> None:
28
  self.experiment_name = experiment_name
 
 
29
  self.data_loader = LowLightDataset(
30
  image_size=image_size,
31
  apply_random_horizontal_flip=apply_random_horizontal_flip,
 
12
  from .dataloader import LowLightDataset
13
  from .models import build_mirnet_model
14
  from .losses import CharbonnierLoss
15
+ from ..commons import (
16
+ peak_signal_noise_ratio,
17
+ closest_number,
18
+ init_wandb,
19
+ download_lol_dataset,
20
+ )
21
 
22
 
23
  class MIRNet:
 
25
  self,
26
  experiment_name: str,
27
  image_size: int = 256,
28
+ dataset_label: str = "lol",
29
  apply_random_horizontal_flip: bool = True,
30
  apply_random_vertical_flip: bool = True,
31
  apply_random_rotation: bool = True,
32
  wandb_api_key=None,
33
  ) -> None:
34
  self.experiment_name = experiment_name
35
+ if dataset_label == "lol":
36
+ download_lol_dataset()
37
  self.data_loader = LowLightDataset(
38
  image_size=image_size,
39
  apply_random_horizontal_flip=apply_random_horizontal_flip,