MeetMeAt92 commited on
Commit
639184d
·
1 Parent(s): 24a597f

Create model.h5

Browse files
Files changed (1) hide show
  1. model.h5 +331 -3
model.h5 CHANGED
@@ -1,3 +1,331 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7e0bd457aa184c3a8fde411375b292e4f7776aaf7cfc9c29661f577309be451c
3
- size 441222364
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import random
4
+ import numpy as np
5
+ from glob import glob
6
+ from PIL import Image, ImageOps
7
+ import matplotlib.pyplot as plt
8
+
9
+ import tensorflow as tf
10
+ from tensorflow import keras
11
+ from tensorflow.keras import layers
12
+
13
+ from google.colab import drive
14
+ drive.mount('/content/gdrive')
15
+
16
+
17
+ random.seed(10)
18
+
19
+ IMAGE_SIZE = 128
20
+ BATCH_SIZE = 4
21
+ MAX_TRAIN_IMAGES = 300
22
+
23
+
24
+ def read_image(image_path):
25
+ image = tf.io.read_file(image_path)
26
+ image = tf.image.decode_png(image, channels=3)
27
+ image.set_shape([None, None, 3])
28
+ image = tf.cast(image, dtype=tf.float32) / 255.0
29
+
30
+ return image
31
+
32
+
33
+ def random_crop(low_image, enhanced_image):
34
+ low_image_shape = tf.shape(low_image)[:2]
35
+ low_w = tf.random.uniform(
36
+ shape=(), maxval=low_image_shape[1] - IMAGE_SIZE + 1, dtype=tf.int32
37
+ )
38
+ low_h = tf.random.uniform(
39
+ shape=(), maxval=low_image_shape[0] - IMAGE_SIZE + 1, dtype=tf.int32
40
+ )
41
+ enhanced_w = low_w
42
+ enhanced_h = low_h
43
+ low_image_cropped = low_image[
44
+ low_h : low_h + IMAGE_SIZE, low_w : low_w + IMAGE_SIZE
45
+ ]
46
+ enhanced_image_cropped = enhanced_image[
47
+ enhanced_h : enhanced_h + IMAGE_SIZE, enhanced_w : enhanced_w + IMAGE_SIZE
48
+ ]
49
+ return low_image_cropped, enhanced_image_cropped
50
+
51
+
52
+ def load_data(low_light_image_path, enhanced_image_path):
53
+ low_light_image = read_image(low_light_image_path)
54
+ enhanced_image = read_image(enhanced_image_path)
55
+ low_light_image, enhanced_image = random_crop(low_light_image, enhanced_image)
56
+ return low_light_image, enhanced_image
57
+
58
+
59
+ def get_dataset(low_light_images, enhanced_images):
60
+ dataset = tf.data.Dataset.from_tensor_slices((low_light_images, enhanced_images))
61
+
62
+ dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
63
+
64
+ dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
65
+ return dataset
66
+
67
+
68
+ train_low_light_images = sorted(glob("/content/gdrive/MyDrive/dataset/lol_dataset/our485/low/*"))[:MAX_TRAIN_IMAGES]
69
+ train_enhanced_images = sorted(glob("/content/gdrive/MyDrive/dataset/lol_dataset/our485/high/*"))[:MAX_TRAIN_IMAGES]
70
+
71
+ val_low_light_images = sorted(glob("/content/gdrive/MyDrive/dataset/lol_dataset/our485/low/*"))[MAX_TRAIN_IMAGES:]
72
+ val_enhanced_images = sorted(glob("/content/gdrive/MyDrive/dataset/lol_dataset/our485/high/*"))[MAX_TRAIN_IMAGES:]
73
+
74
+ test_low_light_images = sorted(glob("/content/gdrive/MyDrive/dataset/lol_dataset/eval15/low/*"))
75
+ test_enhanced_images = sorted(glob("/content/gdrive/MyDrive/dataset/lol_dataset/eval15/high/*"))
76
+
77
+
78
+ train_dataset = get_dataset(train_low_light_images, train_enhanced_images)
79
+ val_dataset = get_dataset(val_low_light_images, val_enhanced_images)
80
+
81
+
82
+ print("Train Dataset:", train_dataset)
83
+ print("Val Dataset:", val_dataset)
84
+
85
+
86
+ def selective_kernel_feature_fusion(
87
+ multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3
88
+ ):
89
+ channels = list(multi_scale_feature_1.shape)[-1]
90
+ combined_feature = layers.Add()(
91
+ [multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3]
92
+ )
93
+ gap = layers.GlobalAveragePooling2D()(combined_feature)
94
+ channel_wise_statistics = tf.reshape(gap, shape=(-1, 1, 1, channels))
95
+ compact_feature_representation = layers.Conv2D(
96
+ filters=channels // 8, kernel_size=(1, 1), activation="relu"
97
+ )(channel_wise_statistics)
98
+ feature_descriptor_1 = layers.Conv2D(
99
+ channels, kernel_size=(1, 1), activation="softmax"
100
+ )(compact_feature_representation)
101
+ feature_descriptor_2 = layers.Conv2D(
102
+ channels, kernel_size=(1, 1), activation="softmax"
103
+ )(compact_feature_representation)
104
+ feature_descriptor_3 = layers.Conv2D(
105
+ channels, kernel_size=(1, 1), activation="softmax"
106
+ )(compact_feature_representation)
107
+ feature_1 = multi_scale_feature_1 * feature_descriptor_1
108
+ feature_2 = multi_scale_feature_2 * feature_descriptor_2
109
+ feature_3 = multi_scale_feature_3 * feature_descriptor_3
110
+ aggregated_feature = layers.Add()([feature_1, feature_2, feature_3])
111
+ return aggregated_feature
112
+
113
+
114
+
115
+
116
+ def spatial_attention_block(input_tensor):
117
+ average_pooling = tf.reduce_max(input_tensor, axis=-1)
118
+ average_pooling = tf.expand_dims(average_pooling, axis=-1)
119
+ max_pooling = tf.reduce_mean(input_tensor, axis=-1)
120
+ max_pooling = tf.expand_dims(max_pooling, axis=-1)
121
+ concatenated = layers.Concatenate(axis=-1)([average_pooling, max_pooling])
122
+ feature_map = layers.Conv2D(1, kernel_size=(1, 1))(concatenated)
123
+ feature_map = tf.nn.sigmoid(feature_map)
124
+ return input_tensor * feature_map
125
+
126
+
127
+ def channel_attention_block(input_tensor):
128
+ channels = list(input_tensor.shape)[-1]
129
+ average_pooling = layers.GlobalAveragePooling2D()(input_tensor)
130
+ feature_descriptor = tf.reshape(average_pooling, shape=(-1, 1, 1, channels))
131
+ feature_activations = layers.Conv2D(
132
+ filters=channels // 8, kernel_size=(1, 1), activation="relu"
133
+ )(feature_descriptor)
134
+ feature_activations = layers.Conv2D(
135
+ filters=channels, kernel_size=(1, 1), activation="sigmoid"
136
+ )(feature_activations)
137
+ return input_tensor * feature_activations
138
+
139
+
140
+ def dual_attention_unit_block(input_tensor):
141
+ channels = list(input_tensor.shape)[-1]
142
+ feature_map = layers.Conv2D(
143
+ channels, kernel_size=(3, 3), padding="same", activation="relu"
144
+ )(input_tensor)
145
+ feature_map = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(
146
+ feature_map
147
+ )
148
+ channel_attention = channel_attention_block(feature_map)
149
+ spatial_attention = spatial_attention_block(feature_map)
150
+ concatenation = layers.Concatenate(axis=-1)([channel_attention, spatial_attention])
151
+ concatenation = layers.Conv2D(channels, kernel_size=(1, 1))(concatenation)
152
+ return layers.Add()([input_tensor, concatenation])
153
+
154
+
155
+ # Recursive Residual Modules
156
+
157
+
158
+ def down_sampling_module(input_tensor):
159
+ channels = list(input_tensor.shape)[-1]
160
+ main_branch = layers.Conv2D(channels, kernel_size=(1, 1), activation="relu")(
161
+ input_tensor
162
+ )
163
+ main_branch = layers.Conv2D(
164
+ channels, kernel_size=(3, 3), padding="same", activation="relu"
165
+ )(main_branch)
166
+ main_branch = layers.MaxPooling2D()(main_branch)
167
+ main_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(main_branch)
168
+ skip_branch = layers.MaxPooling2D()(input_tensor)
169
+ skip_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(skip_branch)
170
+ return layers.Add()([skip_branch, main_branch])
171
+
172
+
173
+ def up_sampling_module(input_tensor):
174
+ channels = list(input_tensor.shape)[-1]
175
+ main_branch = layers.Conv2D(channels, kernel_size=(1, 1), activation="relu")(
176
+ input_tensor
177
+ )
178
+ main_branch = layers.Conv2D(
179
+ channels, kernel_size=(3, 3), padding="same", activation="relu"
180
+ )(main_branch)
181
+ main_branch = layers.UpSampling2D()(main_branch)
182
+ main_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(main_branch)
183
+ skip_branch = layers.UpSampling2D()(input_tensor)
184
+ skip_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(skip_branch)
185
+ return layers.Add()([skip_branch, main_branch])
186
+
187
+
188
+ # MRB Block
189
+ def multi_scale_residual_block(input_tensor, channels):
190
+ # features
191
+ level1 = input_tensor
192
+ level2 = down_sampling_module(input_tensor)
193
+ level3 = down_sampling_module(level2)
194
+ # DAU
195
+ level1_dau = dual_attention_unit_block(level1)
196
+ level2_dau = dual_attention_unit_block(level2)
197
+ level3_dau = dual_attention_unit_block(level3)
198
+ # SKFF
199
+ level1_skff = selective_kernel_feature_fusion(
200
+ level1_dau,
201
+ up_sampling_module(level2_dau),
202
+ up_sampling_module(up_sampling_module(level3_dau)),
203
+ )
204
+ level2_skff = selective_kernel_feature_fusion(
205
+ down_sampling_module(level1_dau), level2_dau, up_sampling_module(level3_dau)
206
+ )
207
+ level3_skff = selective_kernel_feature_fusion(
208
+ down_sampling_module(down_sampling_module(level1_dau)),
209
+ down_sampling_module(level2_dau),
210
+ level3_dau,
211
+ )
212
+ # DAU 2
213
+ level1_dau_2 = dual_attention_unit_block(level1_skff)
214
+ level2_dau_2 = up_sampling_module((dual_attention_unit_block(level2_skff)))
215
+ level3_dau_2 = up_sampling_module(
216
+ up_sampling_module(dual_attention_unit_block(level3_skff))
217
+ )
218
+ # SKFF 2
219
+ skff_ = selective_kernel_feature_fusion(level1_dau_2, level2_dau_2, level3_dau_2)
220
+ conv = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(skff_)
221
+ return layers.Add()([input_tensor, conv])
222
+
223
+
224
+
225
+
226
+ def recursive_residual_group(input_tensor, num_mrb, channels):
227
+ conv1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor)
228
+ for _ in range(num_mrb):
229
+ conv1 = multi_scale_residual_block(conv1, channels)
230
+ conv2 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(conv1)
231
+ return layers.Add()([conv2, input_tensor])
232
+
233
+
234
+ def mirnet_model(num_rrg, num_mrb, channels):
235
+ input_tensor = keras.Input(shape=[None, None, 3])
236
+ x1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor)
237
+ for _ in range(num_rrg):
238
+ x1 = recursive_residual_group(x1, num_mrb, channels)
239
+ conv = layers.Conv2D(3, kernel_size=(3, 3), padding="same")(x1)
240
+ output_tensor = layers.Add()([input_tensor, conv])
241
+ return keras.Model(input_tensor, output_tensor)
242
+
243
+
244
+ model = mirnet_model(num_rrg=3, num_mrb=2, channels=64)
245
+
246
+
247
+ def charbonnier_loss(y_true, y_pred):
248
+ return tf.reduce_mean(tf.sqrt(tf.square(y_true - y_pred) + tf.square(1e-3)))
249
+
250
+
251
+ def peak_signal_noise_ratio(y_true, y_pred):
252
+ return tf.image.psnr(y_pred, y_true, max_val=255.0)
253
+
254
+
255
+ optimizer = keras.optimizers.Adam(learning_rate=1e-4)
256
+ model.compile(
257
+ optimizer=optimizer, loss=charbonnier_loss, metrics=[peak_signal_noise_ratio]
258
+ )
259
+
260
+ history = model.fit(
261
+ train_dataset,
262
+ validation_data=val_dataset,
263
+ #epochs traning cycles set krna k lia
264
+ epochs=1,
265
+ callbacks=[
266
+ keras.callbacks.ReduceLROnPlateau(
267
+ monitor="val_peak_signal_noise_ratio",
268
+ factor=0.5,
269
+ patience=5,
270
+ verbose=1,
271
+ min_delta=1e-7,
272
+ mode="max",
273
+ )
274
+ ],
275
+ )
276
+
277
+ plt.plot(history.history["loss"], label="train_loss")
278
+ plt.plot(history.history["val_loss"], label="val_loss")
279
+ plt.xlabel("Epochs")
280
+ plt.ylabel("Loss")
281
+ plt.title("Train and Validation Losses Over Epochs", fontsize=14)
282
+ plt.legend()
283
+ plt.grid()
284
+ plt.show()
285
+
286
+
287
+ plt.plot(history.history["peak_signal_noise_ratio"], label="train_psnr")
288
+ plt.plot(history.history["val_peak_signal_noise_ratio"], label="val_psnr")
289
+ plt.xlabel("Epochs")
290
+ plt.ylabel("PSNR")
291
+ plt.title("Train and Validation PSNR Over Epochs", fontsize=14)
292
+ plt.legend()
293
+ plt.grid()
294
+ plt.show()
295
+
296
+
297
+
298
+
299
+ def plot_results(images, titles, figure_size=(12, 12)):
300
+ fig = plt.figure(figsize=figure_size)
301
+ for i in range(len(images)):
302
+ fig.add_subplot(1, len(images), i + 1).set_title(titles[i])
303
+ _ = plt.imshow(images[i])
304
+ plt.axis("off")
305
+ plt.show()
306
+
307
+
308
+ def infer(original_image):
309
+ image = keras.preprocessing.image.img_to_array(original_image)
310
+ image = image.astype("float16") / 255.0
311
+ image = np.expand_dims(image, axis=0)
312
+ output = model.predict(image)
313
+ output_image = output[0] * 255.0
314
+ output_image = output_image.clip(0, 255)
315
+ output_image = output_image.reshape(
316
+ (np.shape(output_image)[0], np.shape(output_image)[1], 3)
317
+ )
318
+ output_image = Image.fromarray(np.uint8(output_image))
319
+ original_image = Image.fromarray(np.uint8(original_image))
320
+ return output_image
321
+
322
+
323
+
324
+ for low_light_image in random.sample(test_low_light_images, 2):
325
+ original_image = Image.open(low_light_image)
326
+ enhanced_image = infer(original_image)
327
+ plot_results(
328
+ [original_image, ImageOps.autocontrast(original_image), enhanced_image],
329
+ ["Original", "PIL Autocontrast", "MIRNet Enhanced"],
330
+ (20, 12),
331
+ )