geekyrakshit commited on
Commit
841fed0
1 Parent(s): 2dd1081

Created using Colaboratory

Browse files
Files changed (1) hide show
  1. notebooks/enhance_me_train.ipynb +161 -61
notebooks/enhance_me_train.ipynb CHANGED
@@ -1,10 +1,30 @@
1
  {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  "cells": [
3
  {
4
  "cell_type": "markdown",
5
  "metadata": {
6
- "colab_type": "text",
7
- "id": "view-in-github"
8
  },
9
  "source": [
10
  "<a href=\"https://colab.research.google.com/github/soumik12345/enhance-me/blob/mirnet/notebooks/enhance_me_train.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
@@ -12,28 +32,59 @@
12
  },
13
  {
14
  "cell_type": "code",
15
- "execution_count": null,
16
  "metadata": {
17
  "colab": {
18
- "base_uri": "https://localhost:8080/",
19
- "height": 1000
20
  },
21
  "id": "1JryaVhtBHij",
22
- "outputId": "4fac7fb6-787c-4a1b-f6ef-12ec48024619"
23
  },
24
- "outputs": [],
25
  "source": [
26
  "!git clone https://github.com/soumik12345/enhance-me -b mirnet\n",
27
- "!pip install wandb streamlit"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  ]
29
  },
30
  {
31
  "cell_type": "code",
32
- "execution_count": null,
33
  "metadata": {
34
  "id": "G_c4VtXWHR5l"
35
  },
36
- "outputs": [],
37
  "source": [
38
  "import sys\n",
39
  "sys.path.append(\"./enhance-me\")\n",
@@ -41,19 +92,19 @@
41
  "from PIL import Image\n",
42
  "from enhance_me import commons\n",
43
  "from enhance_me.mirnet import MIRNet"
44
- ]
 
 
45
  },
46
  {
47
  "cell_type": "code",
48
- "execution_count": null,
49
  "metadata": {
50
  "id": "ZpBHbYaMIqP_"
51
  },
52
- "outputs": [],
53
  "source": [
54
  "#@title MIRNet Train Configs\n",
55
  "\n",
56
- "experiment_name = 'lol_dataset_128' #@param {type:\"string\"}\n",
57
  "image_size = 128 #@param {type:\"integer\"}\n",
58
  "dataset_label = 'lol' #@param [\"lol\"]\n",
59
  "apply_random_horizontal_flip = True #@param {type:\"boolean\"}\n",
@@ -67,41 +118,92 @@
67
  "learning_rate = 1e-4 #@param {type:\"number\"}\n",
68
  "epsilon = 1e-3 #@param {type:\"number\"}\n",
69
  "epochs = 50 #@param {type:\"slider\", min:10, max:100, step:5}"
70
- ]
 
 
71
  },
72
  {
73
  "cell_type": "code",
74
- "execution_count": null,
75
  "metadata": {
76
  "colab": {
77
  "base_uri": "https://localhost:8080/",
78
- "height": 124
79
  },
80
  "id": "IVRoedqBIMuH",
81
- "outputId": "388a806f-f41f-420c-9c03-01024decb2d3"
82
  },
83
- "outputs": [],
84
  "source": [
85
  "mirnet = MIRNet(\n",
86
  " experiment_name=experiment_name,\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  " image_size=image_size,\n",
88
  " dataset_label=dataset_label,\n",
89
- " val_split=val_split,\n",
90
- " batch_size=batch_size,\n",
91
  " apply_random_horizontal_flip=apply_random_horizontal_flip,\n",
92
  " apply_random_vertical_flip=apply_random_vertical_flip,\n",
93
  " apply_random_rotation=apply_random_rotation,\n",
94
- " wandb_api_key=None if wandb_api_key == '' else wandb_api_key\n",
 
95
  ")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  ]
97
  },
98
  {
99
  "cell_type": "code",
100
- "execution_count": null,
101
  "metadata": {
102
  "id": "tsfKrBCsL_Bb"
103
  },
104
- "outputs": [],
105
  "source": [
106
  "mirnet.build_model(\n",
107
  " num_recursive_residual_groups=num_recursive_residual_groups,\n",
@@ -109,36 +211,50 @@
109
  " learning_rate=learning_rate,\n",
110
  " epsilon=epsilon\n",
111
  ")"
112
- ]
 
 
113
  },
114
  {
115
  "cell_type": "code",
116
- "execution_count": null,
117
  "metadata": {
118
  "colab": {
119
  "base_uri": "https://localhost:8080/"
120
  },
121
  "id": "y3L9wlpkNziL",
122
- "outputId": "65e7ba4d-1607-4c14-d5d7-e55c4641ad0a"
123
  },
124
- "outputs": [],
125
  "source": [
126
  "history = mirnet.train(epochs=epochs)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  ]
128
  },
129
  {
130
  "cell_type": "code",
131
- "execution_count": null,
132
  "metadata": {
133
  "colab": {
134
- "background_save": true,
135
- "base_uri": "https://localhost:8080/",
136
- "height": 1000
137
  },
138
- "id": "daFKbgBkiyzc",
139
- "outputId": "38c3fc7a-8cef-4332-8efe-35103c75f1a3"
140
  },
141
- "outputs": [],
142
  "source": [
143
  "for index, low_image_file in enumerate(mirnet.test_low_images):\n",
144
  " original_image = Image.open(low_image_file)\n",
@@ -149,36 +265,20 @@
149
  " [\"Original Image\", \"Ground Truth\", \"Enhanced Image\"],\n",
150
  " (18, 18)\n",
151
  " )"
152
- ]
 
 
153
  },
154
  {
155
  "cell_type": "code",
156
- "execution_count": null,
157
  "metadata": {
158
  "id": "dO-IbNQHkB3R"
159
  },
160
- "outputs": [],
161
- "source": []
162
- }
163
- ],
164
- "metadata": {
165
- "accelerator": "GPU",
166
- "colab": {
167
- "authorship_tag": "ABX9TyMwNbyaCs348ucM56hcLJop",
168
- "collapsed_sections": [],
169
- "include_colab_link": true,
170
- "machine_shape": "hm",
171
- "name": "enhance-me-train.ipynb",
172
- "provenance": []
173
- },
174
- "kernelspec": {
175
- "display_name": "Python 3",
176
- "name": "python3"
177
- },
178
- "language_info": {
179
- "name": "python"
180
  }
181
- },
182
- "nbformat": 4,
183
- "nbformat_minor": 0
184
- }
 
1
  {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "accelerator": "GPU",
6
+ "colab": {
7
+ "name": "enhance-me-train.ipynb",
8
+ "provenance": [],
9
+ "collapsed_sections": [],
10
+ "machine_shape": "hm",
11
+ "authorship_tag": "ABX9TyN4LuJh6kWhbqxzA5s9sp7k",
12
+ "include_colab_link": true
13
+ },
14
+ "kernelspec": {
15
+ "display_name": "Python 3",
16
+ "name": "python3"
17
+ },
18
+ "language_info": {
19
+ "name": "python"
20
+ }
21
+ },
22
  "cells": [
23
  {
24
  "cell_type": "markdown",
25
  "metadata": {
26
+ "id": "view-in-github",
27
+ "colab_type": "text"
28
  },
29
  "source": [
30
  "<a href=\"https://colab.research.google.com/github/soumik12345/enhance-me/blob/mirnet/notebooks/enhance_me_train.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
 
32
  },
33
  {
34
  "cell_type": "code",
 
35
  "metadata": {
36
  "colab": {
37
+ "base_uri": "https://localhost:8080/"
 
38
  },
39
  "id": "1JryaVhtBHij",
40
+ "outputId": "97ee6a4a-2479-4124-e96a-f0a792bdec46"
41
  },
 
42
  "source": [
43
  "!git clone https://github.com/soumik12345/enhance-me -b mirnet\n",
44
+ "!pip install -qqq wandb streamlit"
45
+ ],
46
+ "execution_count": 1,
47
+ "outputs": [
48
+ {
49
+ "output_type": "stream",
50
+ "name": "stdout",
51
+ "text": [
52
+ "Cloning into 'enhance-me'...\n",
53
+ "remote: Enumerating objects: 89, done.\u001b[K\n",
54
+ "remote: Counting objects: 100% (89/89), done.\u001b[K\n",
55
+ "remote: Compressing objects: 100% (61/61), done.\u001b[K\n",
56
+ "remote: Total 89 (delta 43), reused 63 (delta 23), pack-reused 0\u001b[K\n",
57
+ "Unpacking objects: 100% (89/89), done.\n",
58
+ "\u001b[K |████████████████████████████████| 1.7 MB 8.2 MB/s \n",
59
+ "\u001b[K |████████████████████████████████| 9.1 MB 33.4 MB/s \n",
60
+ "\u001b[K |████████████████████████████████| 140 kB 74.7 MB/s \n",
61
+ "\u001b[K |████████████████████████████████| 97 kB 8.6 MB/s \n",
62
+ "\u001b[K |████████████████████████████████| 180 kB 83.6 MB/s \n",
63
+ "\u001b[K |████████████████████████████████| 63 kB 2.1 MB/s \n",
64
+ "\u001b[K |████████████████████████████████| 4.3 MB 83.4 MB/s \n",
65
+ "\u001b[K |████████████████████████████████| 178 kB 68.0 MB/s \n",
66
+ "\u001b[K |████████████████████████████████| 76 kB 6.0 MB/s \n",
67
+ "\u001b[K |████████████████████████████████| 111 kB 81.8 MB/s \n",
68
+ "\u001b[K |████████████████████████████████| 125 kB 86.7 MB/s \n",
69
+ "\u001b[K |████████████████████████████████| 791 kB 67.2 MB/s \n",
70
+ "\u001b[K |████████████████████████████████| 374 kB 83.4 MB/s \n",
71
+ "\u001b[?25h Building wheel for subprocess32 (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
72
+ " Building wheel for pympler (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
73
+ " Building wheel for blinker (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
74
+ " Building wheel for pathtools (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
75
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
76
+ "jupyter-console 5.2.0 requires prompt-toolkit<2.0.0,>=1.0.0, but you have prompt-toolkit 3.0.23 which is incompatible.\n",
77
+ "google-colab 1.0.0 requires ipykernel~=4.10, but you have ipykernel 6.5.1 which is incompatible.\n",
78
+ "google-colab 1.0.0 requires ipython~=5.5.0, but you have ipython 7.30.0 which is incompatible.\u001b[0m\n"
79
+ ]
80
+ }
81
  ]
82
  },
83
  {
84
  "cell_type": "code",
 
85
  "metadata": {
86
  "id": "G_c4VtXWHR5l"
87
  },
 
88
  "source": [
89
  "import sys\n",
90
  "sys.path.append(\"./enhance-me\")\n",
 
92
  "from PIL import Image\n",
93
  "from enhance_me import commons\n",
94
  "from enhance_me.mirnet import MIRNet"
95
+ ],
96
+ "execution_count": 2,
97
+ "outputs": []
98
  },
99
  {
100
  "cell_type": "code",
 
101
  "metadata": {
102
  "id": "ZpBHbYaMIqP_"
103
  },
 
104
  "source": [
105
  "#@title MIRNet Train Configs\n",
106
  "\n",
107
+ "experiment_name = 'lol_dataset_256' #@param {type:\"string\"}\n",
108
  "image_size = 128 #@param {type:\"integer\"}\n",
109
  "dataset_label = 'lol' #@param [\"lol\"]\n",
110
  "apply_random_horizontal_flip = True #@param {type:\"boolean\"}\n",
 
118
  "learning_rate = 1e-4 #@param {type:\"number\"}\n",
119
  "epsilon = 1e-3 #@param {type:\"number\"}\n",
120
  "epochs = 50 #@param {type:\"slider\", min:10, max:100, step:5}"
121
+ ],
122
+ "execution_count": 3,
123
+ "outputs": []
124
  },
125
  {
126
  "cell_type": "code",
 
127
  "metadata": {
128
  "colab": {
129
  "base_uri": "https://localhost:8080/",
130
+ "height": 52
131
  },
132
  "id": "IVRoedqBIMuH",
133
+ "outputId": "53ca5beb-871a-4ec3-b757-173e09a15331"
134
  },
 
135
  "source": [
136
  "mirnet = MIRNet(\n",
137
  " experiment_name=experiment_name,\n",
138
+ " wandb_api_key=None if wandb_api_key == '' else wandb_api_key\n",
139
+ ")"
140
+ ],
141
+ "execution_count": 4,
142
+ "outputs": [
143
+ {
144
+ "output_type": "stream",
145
+ "name": "stderr",
146
+ "text": [
147
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33m19soumik-rakshit96\u001b[0m (use `wandb login --relogin` to force relogin)\n"
148
+ ]
149
+ },
150
+ {
151
+ "output_type": "display_data",
152
+ "data": {
153
+ "text/html": [
154
+ "\n",
155
+ " Syncing run <strong><a href=\"https://wandb.ai/19soumik-rakshit96/mirnet/runs/3p3rc341\" target=\"_blank\">lol_dataset_256</a></strong> to <a href=\"https://wandb.ai/19soumik-rakshit96/mirnet\" target=\"_blank\">Weights & Biases</a> (<a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">docs</a>).<br/>\n",
156
+ "\n",
157
+ " "
158
+ ],
159
+ "text/plain": [
160
+ "<IPython.core.display.HTML object>"
161
+ ]
162
+ },
163
+ "metadata": {}
164
+ }
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "metadata": {
170
+ "colab": {
171
+ "base_uri": "https://localhost:8080/"
172
+ },
173
+ "id": "O66Iwzx8IsGh",
174
+ "outputId": "0b6f1683-65d1-4737-a32f-d36b331d2bc2"
175
+ },
176
+ "source": [
177
+ "mirnet.build_datasets(\n",
178
  " image_size=image_size,\n",
179
  " dataset_label=dataset_label,\n",
 
 
180
  " apply_random_horizontal_flip=apply_random_horizontal_flip,\n",
181
  " apply_random_vertical_flip=apply_random_vertical_flip,\n",
182
  " apply_random_rotation=apply_random_rotation,\n",
183
+ " val_split=val_split,\n",
184
+ " batch_size=batch_size\n",
185
  ")"
186
+ ],
187
+ "execution_count": 5,
188
+ "outputs": [
189
+ {
190
+ "output_type": "stream",
191
+ "name": "stdout",
192
+ "text": [
193
+ "Downloading data from https://github.com/soumik12345/enhance-me/releases/download/v0.1/lol_dataset.zip\n",
194
+ "347176960/347171015 [==============================] - 13s 0us/step\n",
195
+ "347185152/347171015 [==============================] - 13s 0us/step\n",
196
+ "Number of train data points: 436\n",
197
+ "Number of validation data points: 49\n"
198
+ ]
199
+ }
200
  ]
201
  },
202
  {
203
  "cell_type": "code",
 
204
  "metadata": {
205
  "id": "tsfKrBCsL_Bb"
206
  },
 
207
  "source": [
208
  "mirnet.build_model(\n",
209
  " num_recursive_residual_groups=num_recursive_residual_groups,\n",
 
211
  " learning_rate=learning_rate,\n",
212
  " epsilon=epsilon\n",
213
  ")"
214
+ ],
215
+ "execution_count": 6,
216
+ "outputs": []
217
  },
218
  {
219
  "cell_type": "code",
 
220
  "metadata": {
221
  "colab": {
222
  "base_uri": "https://localhost:8080/"
223
  },
224
  "id": "y3L9wlpkNziL",
225
+ "outputId": "5149f0e7-91f4-450f-c43a-1b6028692bbc"
226
  },
 
227
  "source": [
228
  "history = mirnet.train(epochs=epochs)"
229
+ ],
230
+ "execution_count": null,
231
+ "outputs": [
232
+ {
233
+ "output_type": "stream",
234
+ "name": "stderr",
235
+ "text": [
236
+ "/usr/local/lib/python3.7/dist-packages/keras/engine/functional.py:1410: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.\n",
237
+ " layer_config = serialize_layer_fn(layer)\n"
238
+ ]
239
+ },
240
+ {
241
+ "output_type": "stream",
242
+ "name": "stdout",
243
+ "text": [
244
+ "Epoch 1/50\n",
245
+ " 66/218 [========>.....................] - ETA: 2:25 - loss: 0.1721 - peak_signal_noise_ratio: 63.2555"
246
+ ]
247
+ }
248
  ]
249
  },
250
  {
251
  "cell_type": "code",
 
252
  "metadata": {
253
  "colab": {
254
+ "background_save": true
 
 
255
  },
256
+ "id": "daFKbgBkiyzc"
 
257
  },
 
258
  "source": [
259
  "for index, low_image_file in enumerate(mirnet.test_low_images):\n",
260
  " original_image = Image.open(low_image_file)\n",
 
265
  " [\"Original Image\", \"Ground Truth\", \"Enhanced Image\"],\n",
266
  " (18, 18)\n",
267
  " )"
268
+ ],
269
+ "execution_count": null,
270
+ "outputs": []
271
  },
272
  {
273
  "cell_type": "code",
 
274
  "metadata": {
275
  "id": "dO-IbNQHkB3R"
276
  },
277
+ "source": [
278
+ ""
279
+ ],
280
+ "execution_count": null,
281
+ "outputs": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  }
283
+ ]
284
+ }