geekyrakshit commited on
Commit
171e1d9
1 Parent(s): d52e07e

updated notebook

Browse files
Files changed (1) hide show
  1. notebooks/enhance_me_train.ipynb +20 -3
notebooks/enhance_me_train.ipynb CHANGED
@@ -22,7 +22,7 @@
22
  },
23
  "outputs": [],
24
  "source": [
25
- "!git clone https://github.com/soumik12345/enhance-me\n",
26
  "!pip install -qqq wandb streamlit"
27
  ]
28
  },
@@ -171,7 +171,7 @@
171
  " enhanced_image = mirnet.infer(original_image)\n",
172
  " ground_truth = Image.open(mirnet.test_enhanced_images[index])\n",
173
  " commons.plot_results(\n",
174
- " [original_image, ground_truth, ground_truth],\n",
175
  " [\"Original Image\", \"Ground Truth\", \"Enhanced Image\"],\n",
176
  " (18, 18),\n",
177
  " )"
@@ -238,7 +238,24 @@
238
  "outputs": [],
239
  "source": [
240
  "zero_dce.compile(learning_rate=learning_rate)\n",
241
- "zero_dce.train(epochs=epochs)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  ]
243
  }
244
  ],
 
22
  },
23
  "outputs": [],
24
  "source": [
25
+ "!git clone https://github.com/soumik12345/enhance-me -b zero-dce\n",
26
  "!pip install -qqq wandb streamlit"
27
  ]
28
  },
 
171
  " enhanced_image = mirnet.infer(original_image)\n",
172
  " ground_truth = Image.open(mirnet.test_enhanced_images[index])\n",
173
  " commons.plot_results(\n",
174
+ " [original_image, ground_truth, enhanced_image],\n",
175
  " [\"Original Image\", \"Ground Truth\", \"Enhanced Image\"],\n",
176
  " (18, 18),\n",
177
  " )"
 
238
  "outputs": [],
239
  "source": [
240
  "zero_dce.compile(learning_rate=learning_rate)\n",
241
+ "history = zero_dce.train(epochs=epochs)\n",
242
+ "zero_dce.save_weights(os.path.join(experiment_name, \"weights.h5\"))"
243
+ ]
244
+ },
245
+ {
246
+ "cell_type": "code",
247
+ "execution_count": null,
248
+ "metadata": {},
249
+ "outputs": [],
250
+ "source": [
251
+ "for index, low_image_file in enumerate(zero_dce.test_low_images):\n",
252
+ " original_image = Image.open(low_image_file)\n",
253
+ " enhanced_image = zero_dce.infer(original_image)\n",
254
+ " commons.plot_results(\n",
255
+ " [original_image, enhanced_image],\n",
256
+ " [\"Original Image\", \"Enhanced Image\"],\n",
257
+ " (18, 18),\n",
258
+ " )"
259
  ]
260
  }
261
  ],