Spaces:
Runtime error
Runtime error
geekyrakshit
commited on
Commit
•
b40a1f8
1
Parent(s):
91e5f9b
added unpaired low-light dataset
Browse files- enhance_me/commons.py +15 -0
- enhance_me/zero_dce/zero_dce.py +18 -3
- test.py +4 -0
enhance_me/commons.py
CHANGED
@@ -61,3 +61,18 @@ def download_lol_dataset():
|
|
61 |
test_enhanced_images = sorted(glob("./datasets/lol_dataset/eval15/high/*"))
|
62 |
assert len(test_low_images) == len(test_enhanced_images)
|
63 |
return (low_images, enhanced_images), (test_low_images, test_enhanced_images)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
test_enhanced_images = sorted(glob("./datasets/lol_dataset/eval15/high/*"))
|
62 |
assert len(test_low_images) == len(test_enhanced_images)
|
63 |
return (low_images, enhanced_images), (test_low_images, test_enhanced_images)
|
64 |
+
|
65 |
+
|
66 |
+
def download_unpaired_low_light_dataset():
|
67 |
+
utils.get_file(
|
68 |
+
"low_light_dataset.zip",
|
69 |
+
"https://github.com/soumik12345/enhance-me/releases/download/v0.3/low_light_dataset.zip",
|
70 |
+
cache_dir="./",
|
71 |
+
cache_subdir="./datasets",
|
72 |
+
extract=True,
|
73 |
+
)
|
74 |
+
low_images = glob("./datasets/low_light_dataset/*.png")
|
75 |
+
test_low_images = sorted(glob("./datasets/low_light_dataset/eval15/low/*"))
|
76 |
+
test_enhanced_images = sorted(glob("./datasets/low_light_dataset/eval15/high/*"))
|
77 |
+
assert len(test_low_images) == len(test_enhanced_images)
|
78 |
+
return low_images, (test_low_images, test_enhanced_images)
|
enhance_me/zero_dce/zero_dce.py
CHANGED
@@ -16,15 +16,25 @@ from .losses import (
|
|
16 |
illumination_smoothness_loss,
|
17 |
SpatialConsistencyLoss,
|
18 |
)
|
19 |
-
from ..commons import
|
|
|
|
|
|
|
|
|
20 |
|
21 |
|
22 |
class ZeroDCE(Model):
|
23 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
super(ZeroDCE, self).__init__(**kwargs)
|
25 |
self.experiment_name = experiment_name
|
26 |
if use_mixed_precision:
|
27 |
-
policy = mixed_precision.Policy(
|
28 |
mixed_precision.set_global_policy(policy)
|
29 |
if wandb_api_key is not None:
|
30 |
init_wandb("zero-dce", experiment_name, wandb_api_key)
|
@@ -125,6 +135,11 @@ class ZeroDCE(Model):
|
|
125 |
) -> None:
|
126 |
if dataset_label == "lol":
|
127 |
(self.low_images, _), (self.test_low_images, _) = download_lol_dataset()
|
|
|
|
|
|
|
|
|
|
|
128 |
data_loader = UnpairedLowLightDataset(
|
129 |
image_size,
|
130 |
apply_resize,
|
|
|
16 |
illumination_smoothness_loss,
|
17 |
SpatialConsistencyLoss,
|
18 |
)
|
19 |
+
from ..commons import (
|
20 |
+
download_lol_dataset,
|
21 |
+
download_unpaired_low_light_dataset,
|
22 |
+
init_wandb,
|
23 |
+
)
|
24 |
|
25 |
|
26 |
class ZeroDCE(Model):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
experiment_name=None,
|
30 |
+
wandb_api_key=None,
|
31 |
+
use_mixed_precision: bool = False,
|
32 |
+
**kwargs
|
33 |
+
):
|
34 |
super(ZeroDCE, self).__init__(**kwargs)
|
35 |
self.experiment_name = experiment_name
|
36 |
if use_mixed_precision:
|
37 |
+
policy = mixed_precision.Policy("mixed_float16")
|
38 |
mixed_precision.set_global_policy(policy)
|
39 |
if wandb_api_key is not None:
|
40 |
init_wandb("zero-dce", experiment_name, wandb_api_key)
|
|
|
135 |
) -> None:
|
136 |
if dataset_label == "lol":
|
137 |
(self.low_images, _), (self.test_low_images, _) = download_lol_dataset()
|
138 |
+
elif dataset_label == "unpaired":
|
139 |
+
self.low_images, (
|
140 |
+
self.test_low_images,
|
141 |
+
_,
|
142 |
+
) = download_unpaired_low_light_dataset()
|
143 |
data_loader = UnpairedLowLightDataset(
|
144 |
image_size,
|
145 |
apply_resize,
|
test.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enhance_me.commons import download_unpaired_low_light_dataset
|
2 |
+
|
3 |
+
|
4 |
+
download_unpaired_low_light_dataset()
|