geekyrakshit commited on
Commit
8ff8f45
1 Parent(s): 841fed0

updated mirnet

Browse files
enhance_me/commons.py CHANGED
@@ -43,7 +43,7 @@ def closest_number(n, m):
43
  def init_wandb(project_name, experiment_name, wandb_api_key):
44
  if project_name is not None and experiment_name is not None:
45
  os.environ["WANDB_API_KEY"] = wandb_api_key
46
- wandb.init(project=project_name, name=experiment_name)
47
 
48
 
49
  def download_lol_dataset():
 
43
  def init_wandb(project_name, experiment_name, wandb_api_key):
44
  if project_name is not None and experiment_name is not None:
45
  os.environ["WANDB_API_KEY"] = wandb_api_key
46
+ wandb.init(project=project_name, name=experiment_name, sync_tensorboard=True)
47
 
48
 
49
  def download_lol_dataset():
enhance_me/mirnet/mirnet.py CHANGED
@@ -21,11 +21,7 @@ from ..commons import (
21
 
22
 
23
  class MIRNet:
24
- def __init__(
25
- self,
26
- experiment_name: str,
27
- wandb_api_key=None,
28
- ) -> None:
29
  self.experiment_name = experiment_name
30
  if wandb_api_key is not None:
31
  init_wandb("mirnet", experiment_name, wandb_api_key)
 
21
 
22
 
23
  class MIRNet:
24
+ def __init__(self, experiment_name: str, wandb_api_key=None) -> None:
 
 
 
 
25
  self.experiment_name = experiment_name
26
  if wandb_api_key is not None:
27
  init_wandb("mirnet", experiment_name, wandb_api_key)
notebooks/enhance_me_train.ipynb CHANGED
@@ -1,284 +1,197 @@
1
  {
2
- "nbformat": 4,
3
- "nbformat_minor": 0,
4
- "metadata": {
5
- "accelerator": "GPU",
 
 
 
 
 
 
 
 
 
 
 
6
  "colab": {
7
- "name": "enhance-me-train.ipynb",
8
- "provenance": [],
9
- "collapsed_sections": [],
10
- "machine_shape": "hm",
11
- "authorship_tag": "ABX9TyN4LuJh6kWhbqxzA5s9sp7k",
12
- "include_colab_link": true
13
- },
14
- "kernelspec": {
15
- "display_name": "Python 3",
16
- "name": "python3"
17
  },
18
- "language_info": {
19
- "name": "python"
20
- }
 
 
 
 
 
21
  },
22
- "cells": [
23
- {
24
- "cell_type": "markdown",
25
- "metadata": {
26
- "id": "view-in-github",
27
- "colab_type": "text"
28
- },
29
- "source": [
30
- "<a href=\"https://colab.research.google.com/github/soumik12345/enhance-me/blob/mirnet/notebooks/enhance_me_train.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
31
- ]
32
- },
33
- {
34
- "cell_type": "code",
35
- "metadata": {
36
- "colab": {
37
- "base_uri": "https://localhost:8080/"
38
- },
39
- "id": "1JryaVhtBHij",
40
- "outputId": "97ee6a4a-2479-4124-e96a-f0a792bdec46"
41
- },
42
- "source": [
43
- "!git clone https://github.com/soumik12345/enhance-me -b mirnet\n",
44
- "!pip install -qqq wandb streamlit"
45
- ],
46
- "execution_count": 1,
47
- "outputs": [
48
- {
49
- "output_type": "stream",
50
- "name": "stdout",
51
- "text": [
52
- "Cloning into 'enhance-me'...\n",
53
- "remote: Enumerating objects: 89, done.\u001b[K\n",
54
- "remote: Counting objects: 100% (89/89), done.\u001b[K\n",
55
- "remote: Compressing objects: 100% (61/61), done.\u001b[K\n",
56
- "remote: Total 89 (delta 43), reused 63 (delta 23), pack-reused 0\u001b[K\n",
57
- "Unpacking objects: 100% (89/89), done.\n",
58
- "\u001b[K |████████████████████████████████| 1.7 MB 8.2 MB/s \n",
59
- "\u001b[K |████████████████████████████████| 9.1 MB 33.4 MB/s \n",
60
- "\u001b[K |████████████████████████████████| 140 kB 74.7 MB/s \n",
61
- "\u001b[K |████████████████████████████████| 97 kB 8.6 MB/s \n",
62
- "\u001b[K |████████████████████████████████| 180 kB 83.6 MB/s \n",
63
- "\u001b[K |████████████████████████████████| 63 kB 2.1 MB/s \n",
64
- "\u001b[K |████████████████████████████████| 4.3 MB 83.4 MB/s \n",
65
- "\u001b[K |████████████████████████████████| 178 kB 68.0 MB/s \n",
66
- "\u001b[K |████████████████████████████████| 76 kB 6.0 MB/s \n",
67
- "\u001b[K |████████████████████████████████| 111 kB 81.8 MB/s \n",
68
- "\u001b[K |████████████████████████████████| 125 kB 86.7 MB/s \n",
69
- "\u001b[K |████████████████████████████████| 791 kB 67.2 MB/s \n",
70
- "\u001b[K |████████████████████████████████| 374 kB 83.4 MB/s \n",
71
- "\u001b[?25h Building wheel for subprocess32 (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
72
- " Building wheel for pympler (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
73
- " Building wheel for blinker (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
74
- " Building wheel for pathtools (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
75
- "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
76
- "jupyter-console 5.2.0 requires prompt-toolkit<2.0.0,>=1.0.0, but you have prompt-toolkit 3.0.23 which is incompatible.\n",
77
- "google-colab 1.0.0 requires ipykernel~=4.10, but you have ipykernel 6.5.1 which is incompatible.\n",
78
- "google-colab 1.0.0 requires ipython~=5.5.0, but you have ipython 7.30.0 which is incompatible.\u001b[0m\n"
79
- ]
80
- }
81
- ]
82
- },
83
- {
84
- "cell_type": "code",
85
- "metadata": {
86
- "id": "G_c4VtXWHR5l"
87
- },
88
- "source": [
89
- "import sys\n",
90
- "sys.path.append(\"./enhance-me\")\n",
91
- "\n",
92
- "from PIL import Image\n",
93
- "from enhance_me import commons\n",
94
- "from enhance_me.mirnet import MIRNet"
95
- ],
96
- "execution_count": 2,
97
- "outputs": []
98
- },
99
- {
100
- "cell_type": "code",
101
- "metadata": {
102
- "id": "ZpBHbYaMIqP_"
103
- },
104
- "source": [
105
- "#@title MIRNet Train Configs\n",
106
- "\n",
107
- "experiment_name = 'lol_dataset_256' #@param {type:\"string\"}\n",
108
- "image_size = 128 #@param {type:\"integer\"}\n",
109
- "dataset_label = 'lol' #@param [\"lol\"]\n",
110
- "apply_random_horizontal_flip = True #@param {type:\"boolean\"}\n",
111
- "apply_random_vertical_flip = True #@param {type:\"boolean\"}\n",
112
- "apply_random_rotation = True #@param {type:\"boolean\"}\n",
113
- "wandb_api_key = '' #@param {type:\"string\"}\n",
114
- "val_split = 0.1 #@param {type:\"slider\", min:0.1, max:1.0, step:0.1}\n",
115
- "batch_size = 4 #@param {type:\"integer\"}\n",
116
- "num_recursive_residual_groups = 3 #@param {type:\"slider\", min:1, max:5, step:1}\n",
117
- "num_multi_scale_residual_blocks = 2 #@param {type:\"slider\", min:1, max:5, step:1}\n",
118
- "learning_rate = 1e-4 #@param {type:\"number\"}\n",
119
- "epsilon = 1e-3 #@param {type:\"number\"}\n",
120
- "epochs = 50 #@param {type:\"slider\", min:10, max:100, step:5}"
121
- ],
122
- "execution_count": 3,
123
- "outputs": []
124
- },
125
- {
126
- "cell_type": "code",
127
- "metadata": {
128
- "colab": {
129
- "base_uri": "https://localhost:8080/",
130
- "height": 52
131
- },
132
- "id": "IVRoedqBIMuH",
133
- "outputId": "53ca5beb-871a-4ec3-b757-173e09a15331"
134
- },
135
- "source": [
136
- "mirnet = MIRNet(\n",
137
- " experiment_name=experiment_name,\n",
138
- " wandb_api_key=None if wandb_api_key == '' else wandb_api_key\n",
139
- ")"
140
- ],
141
- "execution_count": 4,
142
- "outputs": [
143
- {
144
- "output_type": "stream",
145
- "name": "stderr",
146
- "text": [
147
- "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33m19soumik-rakshit96\u001b[0m (use `wandb login --relogin` to force relogin)\n"
148
- ]
149
- },
150
- {
151
- "output_type": "display_data",
152
- "data": {
153
- "text/html": [
154
- "\n",
155
- " Syncing run <strong><a href=\"https://wandb.ai/19soumik-rakshit96/mirnet/runs/3p3rc341\" target=\"_blank\">lol_dataset_256</a></strong> to <a href=\"https://wandb.ai/19soumik-rakshit96/mirnet\" target=\"_blank\">Weights & Biases</a> (<a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">docs</a>).<br/>\n",
156
- "\n",
157
- " "
158
- ],
159
- "text/plain": [
160
- "<IPython.core.display.HTML object>"
161
- ]
162
- },
163
- "metadata": {}
164
- }
165
- ]
166
- },
167
- {
168
- "cell_type": "code",
169
- "metadata": {
170
- "colab": {
171
- "base_uri": "https://localhost:8080/"
172
- },
173
- "id": "O66Iwzx8IsGh",
174
- "outputId": "0b6f1683-65d1-4737-a32f-d36b331d2bc2"
175
- },
176
- "source": [
177
- "mirnet.build_datasets(\n",
178
- " image_size=image_size,\n",
179
- " dataset_label=dataset_label,\n",
180
- " apply_random_horizontal_flip=apply_random_horizontal_flip,\n",
181
- " apply_random_vertical_flip=apply_random_vertical_flip,\n",
182
- " apply_random_rotation=apply_random_rotation,\n",
183
- " val_split=val_split,\n",
184
- " batch_size=batch_size\n",
185
- ")"
186
- ],
187
- "execution_count": 5,
188
- "outputs": [
189
- {
190
- "output_type": "stream",
191
- "name": "stdout",
192
- "text": [
193
- "Downloading data from https://github.com/soumik12345/enhance-me/releases/download/v0.1/lol_dataset.zip\n",
194
- "347176960/347171015 [==============================] - 13s 0us/step\n",
195
- "347185152/347171015 [==============================] - 13s 0us/step\n",
196
- "Number of train data points: 436\n",
197
- "Number of validation data points: 49\n"
198
- ]
199
- }
200
- ]
201
  },
202
- {
203
- "cell_type": "code",
204
- "metadata": {
205
- "id": "tsfKrBCsL_Bb"
206
- },
207
- "source": [
208
- "mirnet.build_model(\n",
209
- " num_recursive_residual_groups=num_recursive_residual_groups,\n",
210
- " num_multi_scale_residual_blocks=num_multi_scale_residual_blocks,\n",
211
- " learning_rate=learning_rate,\n",
212
- " epsilon=epsilon\n",
213
- ")"
214
- ],
215
- "execution_count": 6,
216
- "outputs": []
 
 
217
  },
218
- {
219
- "cell_type": "code",
220
- "metadata": {
221
- "colab": {
222
- "base_uri": "https://localhost:8080/"
223
- },
224
- "id": "y3L9wlpkNziL",
225
- "outputId": "5149f0e7-91f4-450f-c43a-1b6028692bbc"
226
- },
227
- "source": [
228
- "history = mirnet.train(epochs=epochs)"
229
- ],
230
- "execution_count": null,
231
- "outputs": [
232
- {
233
- "output_type": "stream",
234
- "name": "stderr",
235
- "text": [
236
- "/usr/local/lib/python3.7/dist-packages/keras/engine/functional.py:1410: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.\n",
237
- " layer_config = serialize_layer_fn(layer)\n"
238
- ]
239
- },
240
- {
241
- "output_type": "stream",
242
- "name": "stdout",
243
- "text": [
244
- "Epoch 1/50\n",
245
- " 66/218 [========>.....................] - ETA: 2:25 - loss: 0.1721 - peak_signal_noise_ratio: 63.2555"
246
- ]
247
- }
248
- ]
 
 
 
 
 
 
 
249
  },
250
- {
251
- "cell_type": "code",
252
- "metadata": {
253
- "colab": {
254
- "background_save": true
255
- },
256
- "id": "daFKbgBkiyzc"
257
- },
258
- "source": [
259
- "for index, low_image_file in enumerate(mirnet.test_low_images):\n",
260
- " original_image = Image.open(low_image_file)\n",
261
- " enhanced_image = mirnet.infer(original_image)\n",
262
- " ground_truth = Image.open(mirnet.test_enhanced_images[index])\n",
263
- " commons.plot_results(\n",
264
- " [original_image, ground_truth, ground_truth],\n",
265
- " [\"Original Image\", \"Ground Truth\", \"Enhanced Image\"],\n",
266
- " (18, 18)\n",
267
- " )"
268
- ],
269
- "execution_count": null,
270
- "outputs": []
271
  },
272
- {
273
- "cell_type": "code",
274
- "metadata": {
275
- "id": "dO-IbNQHkB3R"
276
- },
277
- "source": [
278
- ""
279
- ],
280
- "execution_count": null,
281
- "outputs": []
282
- }
283
- ]
284
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "colab_type": "text",
7
+ "id": "view-in-github"
8
+ },
9
+ "source": [
10
+ "<a href=\"https://colab.research.google.com/github/soumik12345/enhance-me/blob/mirnet/notebooks/enhance_me_train.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "metadata": {
17
  "colab": {
18
+ "base_uri": "https://localhost:8080/"
 
 
 
 
 
 
 
 
 
19
  },
20
+ "id": "1JryaVhtBHij",
21
+ "outputId": "97ee6a4a-2479-4124-e96a-f0a792bdec46"
22
+ },
23
+ "outputs": [],
24
+ "source": [
25
+ "!git clone https://github.com/soumik12345/enhance-me -b mirnet\n",
26
+ "!pip install -qqq wandb streamlit"
27
+ ]
28
  },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "metadata": {
33
+ "id": "G_c4VtXWHR5l"
34
+ },
35
+ "outputs": [],
36
+ "source": [
37
+ "import sys\n",
38
+ "\n",
39
+ "sys.path.append(\"./enhance-me\")\n",
40
+ "\n",
41
+ "from PIL import Image\n",
42
+ "from enhance_me import commons\n",
43
+ "from enhance_me.mirnet import MIRNet"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": null,
49
+ "metadata": {
50
+ "id": "ZpBHbYaMIqP_"
51
+ },
52
+ "outputs": [],
53
+ "source": [
54
+ "# @title MIRNet Train Configs\n",
55
+ "\n",
56
+ "experiment_name = \"lol_dataset_256\" # @param {type:\"string\"}\n",
57
+ "image_size = 128 # @param {type:\"integer\"}\n",
58
+ "dataset_label = \"lol\" # @param [\"lol\"]\n",
59
+ "apply_random_horizontal_flip = True # @param {type:\"boolean\"}\n",
60
+ "apply_random_vertical_flip = True # @param {type:\"boolean\"}\n",
61
+ "apply_random_rotation = True # @param {type:\"boolean\"}\n",
62
+ "wandb_api_key = \"\" # @param {type:\"string\"}\n",
63
+ "val_split = 0.1 # @param {type:\"slider\", min:0.1, max:1.0, step:0.1}\n",
64
+ "batch_size = 4 # @param {type:\"integer\"}\n",
65
+ "num_recursive_residual_groups = 3 # @param {type:\"slider\", min:1, max:5, step:1}\n",
66
+ "num_multi_scale_residual_blocks = 2 # @param {type:\"slider\", min:1, max:5, step:1}\n",
67
+ "learning_rate = 1e-4 # @param {type:\"number\"}\n",
68
+ "epsilon = 1e-3 # @param {type:\"number\"}\n",
69
+ "epochs = 50 # @param {type:\"slider\", min:10, max:100, step:5}"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": null,
75
+ "metadata": {
76
+ "colab": {
77
+ "base_uri": "https://localhost:8080/",
78
+ "height": 52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  },
80
+ "id": "IVRoedqBIMuH",
81
+ "outputId": "53ca5beb-871a-4ec3-b757-173e09a15331"
82
+ },
83
+ "outputs": [],
84
+ "source": [
85
+ "mirnet = MIRNet(\n",
86
+ " experiment_name=experiment_name,\n",
87
+ " wandb_api_key=None if wandb_api_key == \"\" else wandb_api_key,\n",
88
+ ")"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "metadata": {
95
+ "colab": {
96
+ "base_uri": "https://localhost:8080/"
97
  },
98
+ "id": "O66Iwzx8IsGh",
99
+ "outputId": "0b6f1683-65d1-4737-a32f-d36b331d2bc2"
100
+ },
101
+ "outputs": [],
102
+ "source": [
103
+ "mirnet.build_datasets(\n",
104
+ " image_size=image_size,\n",
105
+ " dataset_label=dataset_label,\n",
106
+ " apply_random_horizontal_flip=apply_random_horizontal_flip,\n",
107
+ " apply_random_vertical_flip=apply_random_vertical_flip,\n",
108
+ " apply_random_rotation=apply_random_rotation,\n",
109
+ " val_split=val_split,\n",
110
+ " batch_size=batch_size,\n",
111
+ ")"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": null,
117
+ "metadata": {
118
+ "id": "tsfKrBCsL_Bb"
119
+ },
120
+ "outputs": [],
121
+ "source": [
122
+ "mirnet.build_model(\n",
123
+ " num_recursive_residual_groups=num_recursive_residual_groups,\n",
124
+ " num_multi_scale_residual_blocks=num_multi_scale_residual_blocks,\n",
125
+ " learning_rate=learning_rate,\n",
126
+ " epsilon=epsilon,\n",
127
+ ")"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": null,
133
+ "metadata": {
134
+ "colab": {
135
+ "base_uri": "https://localhost:8080/"
136
  },
137
+ "id": "y3L9wlpkNziL",
138
+ "outputId": "5149f0e7-91f4-450f-c43a-1b6028692bbc"
139
+ },
140
+ "outputs": [],
141
+ "source": [
142
+ "history = mirnet.train(epochs=epochs)"
143
+ ]
144
+ },
145
+ {
146
+ "cell_type": "code",
147
+ "execution_count": null,
148
+ "metadata": {
149
+ "colab": {
150
+ "background_save": true
 
 
 
 
 
 
 
151
  },
152
+ "id": "daFKbgBkiyzc"
153
+ },
154
+ "outputs": [],
155
+ "source": [
156
+ "for index, low_image_file in enumerate(mirnet.test_low_images):\n",
157
+ " original_image = Image.open(low_image_file)\n",
158
+ " enhanced_image = mirnet.infer(original_image)\n",
159
+ " ground_truth = Image.open(mirnet.test_enhanced_images[index])\n",
160
+ " commons.plot_results(\n",
161
+ " [original_image, ground_truth, ground_truth],\n",
162
+ " [\"Original Image\", \"Ground Truth\", \"Enhanced Image\"],\n",
163
+ " (18, 18),\n",
164
+ " )"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": null,
170
+ "metadata": {
171
+ "id": "dO-IbNQHkB3R"
172
+ },
173
+ "outputs": [],
174
+ "source": []
175
+ }
176
+ ],
177
+ "metadata": {
178
+ "accelerator": "GPU",
179
+ "colab": {
180
+ "authorship_tag": "ABX9TyN4LuJh6kWhbqxzA5s9sp7k",
181
+ "collapsed_sections": [],
182
+ "include_colab_link": true,
183
+ "machine_shape": "hm",
184
+ "name": "enhance-me-train.ipynb",
185
+ "provenance": []
186
+ },
187
+ "kernelspec": {
188
+ "display_name": "Python 3",
189
+ "name": "python3"
190
+ },
191
+ "language_info": {
192
+ "name": "python"
193
+ }
194
+ },
195
+ "nbformat": 4,
196
+ "nbformat_minor": 0
197
+ }