diabolic6045 commited on
Commit
be1a4f6
1 Parent(s): cf444e6

Upload Copy_of_image_classification_using_cnn.ipynb

Browse files
Copy_of_image_classification_using_cnn.ipynb ADDED
@@ -0,0 +1,1099 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "source": [
6
+ "## Download dataset and connect your Google drive\n",
7
+ "for that you need to get kaggle.json file for [here](https://www.kaggle.com/settings/account) where you will see API section under which you will have option to ```\"Create New Token\"``` ,which will download a ```kaggle.json``` file, upload that file it working dir."
8
+ ],
9
+ "metadata": {
10
+ "id": "TZ9ndnKCKnyQ"
11
+ }
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "metadata": {
17
+ "id": "_YzUO8UV7HgP"
18
+ },
19
+ "outputs": [],
20
+ "source": [
21
+ "!pip install -q kaggle\n",
22
+ "!mkdir -p ~/.kaggle\n",
23
+ "!cp kaggle.json ~/.kaggle/\n",
24
+ "!chmod 600 ~/.kaggle/kaggle.json\n",
25
+ "!kaggle datasets download -d divaxshah/cities-all\n",
26
+ "!unzip /content/cities-all.zip"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "metadata": {
33
+ "id": "9y543mK87Gif"
34
+ },
35
+ "outputs": [],
36
+ "source": [
37
+ "!pip install split-folders tensorflow[torch] seaborn numpy matplotlib"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "source": [
43
+ "from google.colab import drive\n",
44
+ "drive.mount('/content/drive')"
45
+ ],
46
+ "metadata": {
47
+ "id": "CBecy2sZBMGY"
48
+ },
49
+ "execution_count": null,
50
+ "outputs": []
51
+ },
52
+ {
53
+ "cell_type": "markdown",
54
+ "metadata": {
55
+ "id": "tvrwZoW_7Gih"
56
+ },
57
+ "source": [
58
+ "### **Importing of Necessary Libraries**"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": null,
64
+ "metadata": {
65
+ "id": "x4OJYagW7Gii"
66
+ },
67
+ "outputs": [],
68
+ "source": [
69
+ "import matplotlib.pyplot as plt\n",
70
+ "import seaborn as sns\n",
71
+ "import numpy as np\n",
72
+ "import pandas as pd\n",
73
+ "import random\n",
74
+ "import cv2\n",
75
+ "import os\n",
76
+ "import PIL\n",
77
+ "import pathlib\n",
78
+ "import splitfolders\n",
79
+ "\n",
80
+ "import tensorflow as tf\n",
81
+ "from tensorflow import keras\n",
82
+ "from tensorflow.keras import layers\n",
83
+ "from tensorflow.keras.models import Sequential\n",
84
+ "from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau\n",
85
+ "from keras.preprocessing.image import ImageDataGenerator\n",
86
+ "from keras.applications.vgg16 import VGG16"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "markdown",
91
+ "metadata": {
92
+ "id": "jk3HfsKQ7Gij"
93
+ },
94
+ "source": [
95
+ "### **Dataset Loading and Splitting**\n",
96
+ "Split-folders library was used to split the dataset into three parts: Training set(70%), Validation set(15%), and Test set(15%)."
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": null,
102
+ "metadata": {
103
+ "id": "KAc3avxf7Gij"
104
+ },
105
+ "outputs": [],
106
+ "source": [
107
+ "base_ds = '/content/Citeisall'\n",
108
+ "base_ds = pathlib.Path(base_ds)"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": null,
114
+ "metadata": {
115
+ "id": "fDH51KsE7Gik",
116
+ "colab": {
117
+ "base_uri": "https://localhost:8080/"
118
+ },
119
+ "outputId": "88652468-e628-44fe-91f6-f093a05649b8"
120
+ },
121
+ "outputs": [
122
+ {
123
+ "output_type": "stream",
124
+ "name": "stderr",
125
+ "text": [
126
+ "Copying files: 12500 files [00:13, 928.79 files/s]\n"
127
+ ]
128
+ }
129
+ ],
130
+ "source": [
131
+ "splitfolders.ratio(base_ds, output='/content/imgs', seed=123, ratio=(.7,.15,.15), group_prefix=None)"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": null,
137
+ "metadata": {
138
+ "id": "WFsRGuU07Gik"
139
+ },
140
+ "outputs": [],
141
+ "source": [
142
+ "Ahmedabad = [fn for fn in os.listdir(f'{base_ds}/Ahmedabad') if fn.endswith('.jpg')]\n",
143
+ "Delhi = [fn for fn in os.listdir(f'{base_ds}/Delhi') if fn.endswith('.jpg')]\n",
144
+ "Kerala = [fn for fn in os.listdir(f'{base_ds}/Kerala') if fn.endswith('.jpg')]\n",
145
+ "Kolkata = [fn for fn in os.listdir(f'{base_ds}/Kolkata') if fn.endswith('.jpg')]\n",
146
+ "Mumbai = [fn for fn in os.listdir(f'{base_ds}/Mumabi') ]\n",
147
+ "city = [Ahmedabad, Delhi, Kerala, Kolkata, Mumbai]\n",
148
+ "city_classes = []\n",
149
+ "for i in os.listdir('imgs/train'):\n",
150
+ " city_classes+=[i]\n",
151
+ "city_classes.sort()"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "markdown",
156
+ "metadata": {
157
+ "id": "zIOajVox7Gik"
158
+ },
159
+ "source": [
160
+ "### **Dataset Exploration**\n",
161
+ "It can be seen here the total number of images in the dataset, the number of classes, and how well the images from each variety is distributed"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "execution_count": null,
167
+ "metadata": {
168
+ "id": "RKugur_t7Gil"
169
+ },
170
+ "outputs": [],
171
+ "source": [
172
+ "image_count = len(list(base_ds.glob('*/*.jpg')))\n",
173
+ "print(f'Total images: {image_count}')\n",
174
+ "print(f'Total number of classes: {len(city_classes)}')\n",
175
+ "count = 0\n",
176
+ "city_count = []\n",
177
+ "for x in city_classes:\n",
178
+ " print(f'Total {x} images: {len(city[count])}')\n",
179
+ " city_count.append(len(city[count]))\n",
180
+ " count += 1\n",
181
+ "\n",
182
+ "sns.set_style('darkgrid')\n",
183
+ "sns.barplot(x=city_classes, y=city_count, palette=\"Blues_d\")\n",
184
+ "plt.show()"
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "markdown",
189
+ "metadata": {
190
+ "id": "K-i3dnII7Gil"
191
+ },
192
+ "source": [
193
+ "### Sample Images\n",
194
+ "Each image from the dataset has a dimension of 250 by 250 and a color type of RGB"
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "code",
199
+ "execution_count": null,
200
+ "metadata": {
201
+ "id": "rBjHPfUL7Gil"
202
+ },
203
+ "outputs": [],
204
+ "source": [
205
+ "sample_img = cv2.imread('/content/imgs/test/Ahmedabad/Ahmedabad-Test (1).jpg')\n",
206
+ "plt.imshow(sample_img)\n",
207
+ "print(f'Image dimensions: {sample_img.shape}')"
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "code",
212
+ "execution_count": null,
213
+ "metadata": {
214
+ "id": "Yy6zeno07Gim"
215
+ },
216
+ "outputs": [],
217
+ "source": [
218
+ "def load_random_img(dir, label):\n",
219
+ " plt.figure(figsize=(10,10))\n",
220
+ " i=0\n",
221
+ " for label in city_classes:\n",
222
+ " i+=1\n",
223
+ " plt.subplot(1, 5, i)\n",
224
+ " file = random.choice(os.listdir(f'{dir}/{label}'))\n",
225
+ " image_path = os.path.join(f'{dir}/{label}', file)\n",
226
+ " img=cv2.imread(image_path)\n",
227
+ " plt.title(label)\n",
228
+ " plt.imshow(img)\n",
229
+ " plt.grid(None)\n",
230
+ " plt.axis('off')"
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "code",
235
+ "execution_count": null,
236
+ "metadata": {
237
+ "id": "Mejs17hg7Gim"
238
+ },
239
+ "outputs": [],
240
+ "source": [
241
+ "for i in range(3):\n",
242
+ " load_random_img(base_ds, city_classes)"
243
+ ]
244
+ },
245
+ {
246
+ "cell_type": "code",
247
+ "execution_count": null,
248
+ "metadata": {
249
+ "id": "4x9dpaIM7Gim"
250
+ },
251
+ "outputs": [],
252
+ "source": [
253
+ "batch_size = 128\n",
254
+ "img_height, img_width = 175, 175\n",
255
+ "input_shape = (img_height, img_width, 3)"
256
+ ]
257
+ },
258
+ {
259
+ "cell_type": "markdown",
260
+ "metadata": {
261
+ "id": "lEiaPL5a7Gim"
262
+ },
263
+ "source": [
264
+ "### **Data Pre-processing**"
265
+ ]
266
+ },
267
+ {
268
+ "cell_type": "code",
269
+ "execution_count": null,
270
+ "metadata": {
271
+ "id": "eER-2oNF7Gin"
272
+ },
273
+ "outputs": [],
274
+ "source": [
275
+ "datagen = ImageDataGenerator(rescale=1./255)"
276
+ ]
277
+ },
278
+ {
279
+ "cell_type": "code",
280
+ "execution_count": null,
281
+ "metadata": {
282
+ "id": "lWozskns7Gin"
283
+ },
284
+ "outputs": [],
285
+ "source": [
286
+ "train_ds = datagen.flow_from_directory(\n",
287
+ " 'imgs/train',\n",
288
+ " target_size = (img_height, img_width),\n",
289
+ " batch_size = batch_size,\n",
290
+ " subset = \"training\",\n",
291
+ " class_mode='categorical')\n",
292
+ "\n",
293
+ "val_ds = datagen.flow_from_directory(\n",
294
+ " 'imgs/val',\n",
295
+ " target_size = (img_height, img_width),\n",
296
+ " batch_size = batch_size,\n",
297
+ " class_mode='categorical',\n",
298
+ " shuffle=False)\n",
299
+ "\n",
300
+ "test_ds = datagen.flow_from_directory(\n",
301
+ " 'imgs/test',\n",
302
+ " target_size = (img_height, img_width),\n",
303
+ " batch_size = batch_size,\n",
304
+ " class_mode='categorical',\n",
305
+ " shuffle=False)"
306
+ ]
307
+ },
308
+ {
309
+ "cell_type": "code",
310
+ "execution_count": null,
311
+ "metadata": {
312
+ "id": "F4Q9lfcU7Gin"
313
+ },
314
+ "outputs": [],
315
+ "source": [
316
+ "def plot_train_history(history):\n",
317
+ " plt.figure(figsize=(15,5))\n",
318
+ " plt.subplot(1,2,1)\n",
319
+ " plt.plot(history.history['accuracy'])\n",
320
+ " plt.plot(history.history['val_accuracy'])\n",
321
+ " plt.title('Model accuracy')\n",
322
+ " plt.ylabel('accuracy')\n",
323
+ " plt.xlabel('epoch')\n",
324
+ " plt.legend(['train', 'validation'], loc='upper left')\n",
325
+ "\n",
326
+ " plt.subplot(1,2,2)\n",
327
+ " plt.plot(history.history['loss'])\n",
328
+ " plt.plot(history.history['val_loss'])\n",
329
+ " plt.title('Model loss')\n",
330
+ " plt.ylabel('loss')\n",
331
+ " plt.xlabel('epoch')\n",
332
+ " plt.legend(['train', 'validation'], loc='upper left')\n",
333
+ " plt.show()"
334
+ ]
335
+ },
336
+ {
337
+ "cell_type": "markdown",
338
+ "metadata": {
339
+ "id": "7AmkjAds7Gin"
340
+ },
341
+ "source": [
342
+ "## **Vanilla CNN Model**"
343
+ ]
344
+ },
345
+ {
346
+ "cell_type": "code",
347
+ "execution_count": null,
348
+ "metadata": {
349
+ "id": "mrzY1vc17Gin"
350
+ },
351
+ "outputs": [],
352
+ "source": [
353
+ "model_vanilla = tf.keras.Sequential([\n",
354
+ " tf.keras.layers.Conv2D(32,(3,3), activation='relu', input_shape=input_shape),\n",
355
+ " tf.keras.layers.BatchNormalization(),\n",
356
+ " tf.keras.layers.Conv2D(32,(3,3),activation='relu',padding='same'),\n",
357
+ " tf.keras.layers.BatchNormalization(axis = 3),\n",
358
+ " tf.keras.layers.MaxPooling2D(pool_size=(2,2),padding='same'),\n",
359
+ " tf.keras.layers.Dropout(0.3),\n",
360
+ "\n",
361
+ " tf.keras.layers.Conv2D(64,(3,3),activation='relu',padding='same'),\n",
362
+ " tf.keras.layers.BatchNormalization(),\n",
363
+ " tf.keras.layers.Conv2D(64,(3,3),activation='relu',padding='same'),\n",
364
+ " tf.keras.layers.BatchNormalization(axis = 3),\n",
365
+ " tf.keras.layers.MaxPooling2D(pool_size=(2,2),padding='same'),\n",
366
+ " tf.keras.layers.Dropout(0.3),\n",
367
+ "\n",
368
+ " tf.keras.layers.Conv2D(128,(3,3),activation='relu',padding='same'),\n",
369
+ " tf.keras.layers.BatchNormalization(),\n",
370
+ " tf.keras.layers.Conv2D(128,(3,3),activation='relu',padding='same'),\n",
371
+ " tf.keras.layers.BatchNormalization(axis = 3),\n",
372
+ " tf.keras.layers.MaxPooling2D(pool_size=(2,2),padding='same'),\n",
373
+ " tf.keras.layers.Dropout(0.5),\n",
374
+ "\n",
375
+ " tf.keras.layers.Flatten(),\n",
376
+ " tf.keras.layers.Dense(512, activation='relu'),\n",
377
+ " tf.keras.layers.BatchNormalization(),\n",
378
+ " tf.keras.layers.Dropout(0.5),\n",
379
+ " tf.keras.layers.Dense(128, activation='relu'),\n",
380
+ " tf.keras.layers.Dropout(0.25),\n",
381
+ " tf.keras.layers.Dense(5, activation='softmax')\n",
382
+ "])"
383
+ ]
384
+ },
385
+ {
386
+ "cell_type": "code",
387
+ "execution_count": null,
388
+ "metadata": {
389
+ "id": "U8KLpKJU7Gin"
390
+ },
391
+ "outputs": [],
392
+ "source": [
393
+ "model_vanilla.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])\n",
394
+ "model_vanilla.summary()"
395
+ ]
396
+ },
397
+ {
398
+ "cell_type": "markdown",
399
+ "metadata": {
400
+ "id": "uUBYpTCi7Gin"
401
+ },
402
+ "source": [
403
+ "## **Callbacks**"
404
+ ]
405
+ },
406
+ {
407
+ "cell_type": "code",
408
+ "execution_count": null,
409
+ "metadata": {
410
+ "id": "5FEvu8Lv7Gin"
411
+ },
412
+ "outputs": [],
413
+ "source": [
414
+ "models_dir = \"saved_models\"\n",
415
+ "if not os.path.exists(models_dir):\n",
416
+ " os.makedirs(models_dir)\n",
417
+ "\n",
418
+ "checkpointer = ModelCheckpoint(filepath='saved_models/model_vanilla.hdf5',\n",
419
+ " monitor='val_accuracy', mode='max',\n",
420
+ " verbose=1, save_best_only=True)\n",
421
+ "early_stopping = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=3)\n",
422
+ "reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,patience=2, min_lr=0.001)\n",
423
+ "callbacks=[early_stopping, reduce_lr, checkpointer]"
424
+ ]
425
+ },
426
+ {
427
+ "cell_type": "code",
428
+ "execution_count": null,
429
+ "metadata": {
430
+ "id": "KcIjH0kt7Gin"
431
+ },
432
+ "outputs": [],
433
+ "source": [
434
+ "history1 = model_vanilla.fit(train_ds, epochs = 40, validation_data = val_ds, callbacks=callbacks)"
435
+ ]
436
+ },
437
+ {
438
+ "cell_type": "code",
439
+ "execution_count": null,
440
+ "metadata": {
441
+ "id": "ZkstY19-7Gio"
442
+ },
443
+ "outputs": [],
444
+ "source": [
445
+ "model_vanilla.save(\"model1\")\n",
446
+ "model_vanilla.load_weights('saved_models/model_vanilla.hdf5')\n",
447
+ "plot_train_history(history1)"
448
+ ]
449
+ },
450
+ {
451
+ "cell_type": "markdown",
452
+ "metadata": {
453
+ "id": "4AVhC3Gu7Gio"
454
+ },
455
+ "source": [
456
+ "## **Model Evaluation of Vanilla CNN**"
457
+ ]
458
+ },
459
+ {
460
+ "cell_type": "code",
461
+ "execution_count": null,
462
+ "metadata": {
463
+ "id": "jzx2cCef7Gio"
464
+ },
465
+ "outputs": [],
466
+ "source": [
467
+ "score1 = model_vanilla.evaluate(test_ds, verbose=1)"
468
+ ]
469
+ },
470
+ {
471
+ "cell_type": "code",
472
+ "execution_count": null,
473
+ "metadata": {
474
+ "id": "pSYggqLK7Gio"
475
+ },
476
+ "outputs": [],
477
+ "source": [
478
+ "from sklearn.metrics import classification_report, confusion_matrix\n",
479
+ "\n",
480
+ "Y_pred = model_vanilla.predict(test_ds)"
481
+ ]
482
+ },
483
+ {
484
+ "cell_type": "code",
485
+ "execution_count": null,
486
+ "metadata": {
487
+ "id": "zd8-1pAC7Gio"
488
+ },
489
+ "outputs": [],
490
+ "source": [
491
+ "y_pred = np.argmax(Y_pred, axis=1)\n",
492
+ "confusion_mtx = confusion_matrix(y_pred, test_ds.classes)\n",
493
+ "f,ax = plt.subplots(figsize=(12, 12))\n",
494
+ "sns.heatmap(confusion_mtx, annot=True,\n",
495
+ " linewidths=0.01,\n",
496
+ " linecolor=\"white\",\n",
497
+ " fmt= '.1f',ax=ax,)\n",
498
+ "sns.color_palette(\"rocket\", as_cmap=True)\n",
499
+ "\n",
500
+ "plt.xlabel(\"Predicted Label\")\n",
501
+ "plt.ylabel(\"True Label\")\n",
502
+ "ax.xaxis.set_ticklabels(test_ds.class_indices)\n",
503
+ "ax.yaxis.set_ticklabels(city_classes)\n",
504
+ "plt.title(\"Confusion Matrix\")\n",
505
+ "plt.show()"
506
+ ]
507
+ },
508
+ {
509
+ "cell_type": "code",
510
+ "execution_count": null,
511
+ "metadata": {
512
+ "id": "YE6j5ex97Gio"
513
+ },
514
+ "outputs": [],
515
+ "source": [
516
+ "report1 = classification_report(test_ds.classes, y_pred, target_names=city_classes, output_dict=True)\n",
517
+ "df1 = pd.DataFrame(report1).transpose()\n",
518
+ "df1"
519
+ ]
520
+ },
521
+ {
522
+ "cell_type": "markdown",
523
+ "metadata": {
524
+ "id": "SS8QOBdy7Gio"
525
+ },
526
+ "source": [
527
+ "## **Transfer Learning**"
528
+ ]
529
+ },
530
+ {
531
+ "cell_type": "code",
532
+ "execution_count": null,
533
+ "metadata": {
534
+ "id": "0sbFu6VZ7Gio"
535
+ },
536
+ "outputs": [],
537
+ "source": [
538
+ "vgg16 = VGG16(weights=\"imagenet\", include_top=False, input_shape=input_shape)\n",
539
+ "vgg16.trainable = False\n",
540
+ "inputs = tf.keras.Input(input_shape)\n",
541
+ "x = vgg16(inputs, training=False)\n",
542
+ "x = tf.keras.layers.GlobalAveragePooling2D()(x)\n",
543
+ "x = tf.keras.layers.Dense(1024, activation='relu')(x)\n",
544
+ "x = tf.keras.layers.Dense(5, activation='softmax')(x)\n",
545
+ "model_vgg16 = tf.keras.Model(inputs, x)"
546
+ ]
547
+ },
548
+ {
549
+ "cell_type": "code",
550
+ "execution_count": null,
551
+ "metadata": {
552
+ "id": "IrQTnBef7Gip"
553
+ },
554
+ "outputs": [],
555
+ "source": [
556
+ "model_vgg16.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])\n",
557
+ "model_vgg16.summary()"
558
+ ]
559
+ },
560
+ {
561
+ "cell_type": "code",
562
+ "execution_count": null,
563
+ "metadata": {
564
+ "id": "nMaw_y4y7Gip"
565
+ },
566
+ "outputs": [],
567
+ "source": [
568
+ "checkpointer = ModelCheckpoint(filepath='saved_models/model_vgg16.hdf5',\n",
569
+ " monitor='val_accuracy', mode='max',\n",
570
+ " verbose=1, save_best_only=True)\n",
571
+ "early_stopping = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=3)\n",
572
+ "reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,patience=2, min_lr=0.001)\n",
573
+ "callbacks=[early_stopping, reduce_lr, checkpointer]"
574
+ ]
575
+ },
576
+ {
577
+ "cell_type": "code",
578
+ "execution_count": null,
579
+ "metadata": {
580
+ "id": "XscbiiWE7Gip"
581
+ },
582
+ "outputs": [],
583
+ "source": [
584
+ "history2 = model_vgg16.fit(train_ds, epochs = 40, validation_data = val_ds, callbacks=callbacks)"
585
+ ]
586
+ },
587
+ {
588
+ "cell_type": "code",
589
+ "execution_count": null,
590
+ "metadata": {
591
+ "id": "toiVgYCR7Gip"
592
+ },
593
+ "outputs": [],
594
+ "source": [
595
+ "model_vgg16.load_weights('saved_models/model_vgg16.hdf5')"
596
+ ]
597
+ },
598
+ {
599
+ "cell_type": "code",
600
+ "execution_count": null,
601
+ "metadata": {
602
+ "id": "46M_EVvE7Gip"
603
+ },
604
+ "outputs": [],
605
+ "source": [
606
+ "plot_train_history(history2)"
607
+ ]
608
+ },
609
+ {
610
+ "cell_type": "code",
611
+ "execution_count": null,
612
+ "metadata": {
613
+ "id": "PzhAvIEE7Gip"
614
+ },
615
+ "outputs": [],
616
+ "source": [
617
+ "score2 = model_vgg16.evaluate(test_ds, verbose=1)\n",
618
+ "print(f'Model 1 Vanilla Loss: {score1[0]}, Accuracy: {score1[1]*100}')\n",
619
+ "print(f'Model 2 VGG16 Loss: {score2[0]}, Accuracy: {score2[1]*100}')\n",
620
+ "model_vgg16.save(\"model2\")"
621
+ ]
622
+ },
623
+ {
624
+ "cell_type": "markdown",
625
+ "metadata": {
626
+ "id": "e1Jmu8wQ7Giq"
627
+ },
628
+ "source": [
629
+ "## **Fine Tuning**"
630
+ ]
631
+ },
632
+ {
633
+ "cell_type": "code",
634
+ "execution_count": null,
635
+ "metadata": {
636
+ "id": "IiEIjnKp7Giv"
637
+ },
638
+ "outputs": [],
639
+ "source": [
640
+ "vgg16.trainable = True\n",
641
+ "model_vgg16.compile(optimizer=keras.optimizers.Adam(1e-5),\n",
642
+ " loss='categorical_crossentropy', metrics=['accuracy'])"
643
+ ]
644
+ },
645
+ {
646
+ "cell_type": "code",
647
+ "execution_count": null,
648
+ "metadata": {
649
+ "id": "D52hK12o7Giv"
650
+ },
651
+ "outputs": [],
652
+ "source": [
653
+ "history3 = model_vgg16.fit(train_ds, epochs = 40, validation_data = val_ds, callbacks=callbacks)"
654
+ ]
655
+ },
656
+ {
657
+ "cell_type": "code",
658
+ "execution_count": null,
659
+ "metadata": {
660
+ "id": "bpEvB7Ud7Giv"
661
+ },
662
+ "outputs": [],
663
+ "source": [
664
+ "model_vgg16.load_weights('saved_models/model_vgg16.hdf5')\n",
665
+ "model_vgg16.save(\"model3\")"
666
+ ]
667
+ },
668
+ {
669
+ "cell_type": "markdown",
670
+ "metadata": {
671
+ "id": "uYUBzEgZ7Giv"
672
+ },
673
+ "source": [
674
+ "## **Final Evaluation**"
675
+ ]
676
+ },
677
+ {
678
+ "cell_type": "code",
679
+ "execution_count": null,
680
+ "metadata": {
681
+ "id": "8qMXiAQg7Giv"
682
+ },
683
+ "outputs": [],
684
+ "source": [
685
+ "score3 = model_vgg16.evaluate(test_ds, verbose=1)\n",
686
+ "print(f'Model 1 Vanilla Loss: {score1[0]}, Accuracy: {score1[1]*100}')\n",
687
+ "print(f'Model 2 VGG16 Loss: {score2[0]}, Accuracy: {score2[1]*100}')\n",
688
+ "print(f'Model 2 VGG16 Fine-tuned Loss: {score3[0]}, Accuracy: {score3[1]*100}')"
689
+ ]
690
+ },
691
+ {
692
+ "cell_type": "code",
693
+ "execution_count": null,
694
+ "metadata": {
695
+ "id": "c-yIMsfC7Giv"
696
+ },
697
+ "outputs": [],
698
+ "source": [
699
+ "Y_pred = model_vgg16.predict(test_ds)"
700
+ ]
701
+ },
702
+ {
703
+ "cell_type": "code",
704
+ "execution_count": null,
705
+ "metadata": {
706
+ "id": "h75tyWjY7Giv"
707
+ },
708
+ "outputs": [],
709
+ "source": [
710
+ "y_pred = np.argmax(Y_pred, axis=1)\n",
711
+ "confusion_mtx = confusion_matrix(y_pred, test_ds.classes)\n",
712
+ "f,ax = plt.subplots(figsize=(12, 12))\n",
713
+ "sns.heatmap(confusion_mtx, annot=True,\n",
714
+ " linewidths=0.01,\n",
715
+ " linecolor=\"white\",\n",
716
+ " fmt= '.1f',ax=ax,)\n",
717
+ "sns.color_palette(\"rocket\", as_cmap=True)\n",
718
+ "\n",
719
+ "plt.xlabel(\"Predicted Label\")\n",
720
+ "plt.ylabel(\"True Label\")\n",
721
+ "ax.xaxis.set_ticklabels(test_ds.class_indices)\n",
722
+ "ax.yaxis.set_ticklabels(city_classes)\n",
723
+ "plt.title(\"Confusion Matrix\")\n",
724
+ "plt.show()"
725
+ ]
726
+ },
727
+ {
728
+ "cell_type": "code",
729
+ "execution_count": null,
730
+ "metadata": {
731
+ "id": "lT-5Hwdz7Giw"
732
+ },
733
+ "outputs": [],
734
+ "source": [
735
+ "report2 = classification_report(test_ds.classes, y_pred, target_names=city_classes, output_dict=True)\n",
736
+ "df2 = pd.DataFrame(report1).transpose()\n",
737
+ "df2"
738
+ ]
739
+ },
740
+ {
741
+ "cell_type": "code",
742
+ "execution_count": null,
743
+ "metadata": {
744
+ "id": "8TBoFOlT7Giw"
745
+ },
746
+ "outputs": [],
747
+ "source": [
748
+ "plt.figure(figsize=(100, 100))\n",
749
+ "x, label= train_ds.next()\n",
750
+ "for i in range(25):\n",
751
+ " plt.subplot(5, 5, i+1)\n",
752
+ " plt.imshow(x[i])\n",
753
+ " result = np.where(label[i]==1)\n",
754
+ " predict = model_vgg16(tf.expand_dims(x[i], 0))\n",
755
+ " score = tf.nn.softmax(predict[0])\n",
756
+ " score_label = city_classes[np.argmax(score)]\n",
757
+ " plt.title(f'Truth: {city_classes[result[0][0]]}\\nPrediction:{score_label}')\n",
758
+ " plt.axis(False)"
759
+ ]
760
+ },
761
+ {
762
+ "cell_type": "code",
763
+ "execution_count": null,
764
+ "metadata": {
765
+ "id": "UlB0X2Bu-z8T"
766
+ },
767
+ "outputs": [],
768
+ "source": [
769
+ "model_vgg16.save(\"/content/drive/MyDrive/model\")\n",
770
+ "# Assuming your model is named model_vgg16\n",
771
+ "model_vgg16.save(\"/content/drive/MyDrive/tensorflow\", save_format='tf')\n"
772
+ ]
773
+ },
774
+ {
775
+ "cell_type": "code",
776
+ "execution_count": null,
777
+ "metadata": {
778
+ "colab": {
779
+ "base_uri": "https://localhost:8080/"
780
+ },
781
+ "id": "Ui__qIcFG_IA",
782
+ "outputId": "04f422be-7522-45e3-b386-e00bf351a4c8"
783
+ },
784
+ "outputs": [
785
+ {
786
+ "name": "stdout",
787
+ "output_type": "stream",
788
+ "text": [
789
+ "Found 1875 images belonging to 5 classes.\n",
790
+ "59/59 [==============================] - 14s 182ms/step - loss: 1.4103 - accuracy: 0.6363\n",
791
+ "Test loss: 1.4102574586868286\n",
792
+ "Test accuracy: 0.6362666487693787\n"
793
+ ]
794
+ }
795
+ ],
796
+ "source": [
797
+ "test_datagen = ImageDataGenerator(rescale=1./255)\n",
798
+ "test_generator = test_datagen.flow_from_directory(\n",
799
+ " '/content/imgs/test',\n",
800
+ " target_size=(175, 175),\n",
801
+ " batch_size=32,\n",
802
+ " class_mode='categorical'\n",
803
+ ")\n",
804
+ "\n",
805
+ "test_loss, test_accuracy = model_vgg16.evaluate(test_generator, steps=len(test_generator))\n",
806
+ "print('Test loss:', test_loss)\n",
807
+ "print('Test accuracy:', test_accuracy)"
808
+ ]
809
+ },
810
+ {
811
+ "cell_type": "markdown",
812
+ "source": [
813
+ "## **Testing single image**"
814
+ ],
815
+ "metadata": {
816
+ "id": "gM7karcDFUq3"
817
+ }
818
+ },
819
+ {
820
+ "cell_type": "code",
821
+ "execution_count": null,
822
+ "metadata": {
823
+ "colab": {
824
+ "base_uri": "https://localhost:8080/"
825
+ },
826
+ "id": "zqrovHxOHj8o",
827
+ "outputId": "9ddc9a80-c2a4-4526-c856-1146b884ce66"
828
+ },
829
+ "outputs": [
830
+ {
831
+ "name": "stderr",
832
+ "output_type": "stream",
833
+ "text": [
834
+ "WARNING:tensorflow:5 out of the last 19 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7a21dd0d7eb0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n"
835
+ ]
836
+ },
837
+ {
838
+ "name": "stdout",
839
+ "output_type": "stream",
840
+ "text": [
841
+ "1/1 [==============================] - 0s 440ms/step\n",
842
+ "Predicted class: Kolkata\n",
843
+ "Accuracy: 0.8562558\n"
844
+ ]
845
+ }
846
+ ],
847
+ "source": [
848
+ "# import tensorflow.keras as keras\n",
849
+ "# from tensorflow.keras.models import load_model\n",
850
+ "# from tensorflow.keras.preprocessing import image\n",
851
+ "# import numpy as np\n",
852
+ "\n",
853
+ "# model = load_model('/content/model3')\n",
854
+ "\n",
855
+ "\n",
856
+ "\n",
857
+ "# # Load and preprocess the input image\n",
858
+ "# img_path = '/content/Citeisall/Kolkata/Kolkata-Test (10).jpg'\n",
859
+ "# img = image.load_img(img_path, target_size=(175,175))\n",
860
+ "# img = image.img_to_array(img)\n",
861
+ "# img = np.expand_dims(img, axis=0)\n",
862
+ "# img = img / 255.0\n",
863
+ "\n",
864
+ "# # Make predictions on the input image\n",
865
+ "# predictions = model.predict(img)\n",
866
+ "# class_labels = ['Ahmedabad', 'Delhi', 'Kerala', 'Kolkata', 'Mumbai']\n",
867
+ "\n",
868
+ "# # Set the threshold for minimum accuracy\n",
869
+ "# threshold = 0.0\n",
870
+ "\n",
871
+ "# # Get the predicted class label and accuracy\n",
872
+ "# predicted_class_index = np.argmax(predictions)\n",
873
+ "# predicted_class_label = class_labels[predicted_class_index]\n",
874
+ "# accuracy = predictions[0][predicted_class_index]\n",
875
+ "\n",
876
+ "# # Check if accuracy is below the threshold for all classes\n",
877
+ "# if all(accuracy < threshold for accuracy in predictions[0]):\n",
878
+ "# print(\"This location is not in our database.\")\n",
879
+ "# else:\n",
880
+ "# print('Predicted class:', predicted_class_label)\n",
881
+ "# print('Accuracy:', accuracy)\n"
882
+ ]
883
+ },
884
+ {
885
+ "cell_type": "markdown",
886
+ "source": [
887
+ "## **Visulization**\n",
888
+ "\n",
889
+ "---\n",
890
+ "\n"
891
+ ],
892
+ "metadata": {
893
+ "id": "QYj3LdqUFGn-"
894
+ }
895
+ },
896
+ {
897
+ "cell_type": "code",
898
+ "execution_count": null,
899
+ "metadata": {
900
+ "id": "ekP5VUgYISA-"
901
+ },
902
+ "outputs": [],
903
+ "source": [
904
+ "# import numpy as np\n",
905
+ "# import matplotlib.pyplot as plt\n",
906
+ "# import seaborn as sns\n",
907
+ "\n",
908
+ "# # Sample data from the classification report you provided\n",
909
+ "# labels = [\"Ahmedabad\", \"Delhi\", \"Kerala\", \"Kolkata\", \"Mumbai\"]\n",
910
+ "# precision = [0.85, 0.60, 0.64, 0.58, 0.55]\n",
911
+ "# recall = [0.84, 0.65, 0.66, 0.58, 0.49]\n",
912
+ "# f1 = [0.85, 0.62, 0.65, 0.58, 0.52]\n",
913
+ "\n",
914
+ "# # Bar Plot\n",
915
+ "# plt.figure(figsize=(10, 5))\n",
916
+ "# barWidth = 0.25\n",
917
+ "# r1 = np.arange(len(precision))\n",
918
+ "# r2 = [x + barWidth for x in r1]\n",
919
+ "# r3 = [x + barWidth for x in r2]\n",
920
+ "# plt.bar(r1, precision, color='b', width=barWidth, edgecolor='grey', label='precision')\n",
921
+ "# plt.bar(r2, recall, color='r', width=barWidth, edgecolor='grey', label='recall')\n",
922
+ "# plt.bar(r3, f1, color='g', width=barWidth, edgecolor='grey', label='f1-score')\n",
923
+ "# plt.xlabel('Cities', fontweight='bold')\n",
924
+ "# plt.xticks([r + barWidth for r in range(len(precision))], labels)\n",
925
+ "# plt.legend()\n",
926
+ "# plt.show()\n",
927
+ "\n",
928
+ "# # Heatmap\n",
929
+ "# df = {\n",
930
+ "# 'precision': precision,\n",
931
+ "# 'recall': recall,\n",
932
+ "# 'f1-score': f1\n",
933
+ "# }\n",
934
+ "# plt.figure(figsize=(10, 5))\n",
935
+ "# sns.heatmap([precision, recall, f1], yticklabels=[\"precision\", \"recall\", \"f1-score\"], xticklabels=labels, cmap=\"YlGnBu\", annot=True, fmt='.2g')\n",
936
+ "# plt.show()\n",
937
+ "\n",
938
+ "# # Spider (Radar) Plot\n",
939
+ "# angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()\n",
940
+ "# precision += precision[:1]\n",
941
+ "# recall += recall[:1]\n",
942
+ "# f1 += f1[:1]\n",
943
+ "# angles += angles[:1]\n",
944
+ "# plt.figure(figsize=(10, 5))\n",
945
+ "# ax = plt.subplot(111, polar=True)\n",
946
+ "# ax.fill(angles, precision, color='b', alpha=0.25)\n",
947
+ "# ax.fill(angles, recall, color='r', alpha=0.25)\n",
948
+ "# ax.fill(angles, f1, color='g', alpha=0.25)\n",
949
+ "# ax.set_theta_offset(np.pi / 2)\n",
950
+ "# ax.set_theta_direction(-1)\n",
951
+ "# plt.xticks(angles[:-1], labels)\n",
952
+ "# ax.set_rlabel_position(30)\n",
953
+ "# plt.yticks([0.2, 0.4, 0.6, 0.8], [\"0.2\", \"0.4\", \"0.6\", \"0.8\"], color=\"grey\", size=12)\n",
954
+ "# plt.ylim(0, 1)\n",
955
+ "# ax.plot(angles, precision, color='b', linewidth=2, linestyle='solid', label='precision')\n",
956
+ "# ax.plot(angles, recall, color='r', linewidth=2, linestyle='solid', label='recall')\n",
957
+ "# ax.plot(angles, f1, color='g', linewidth=2, linestyle='solid', label='f1-score')\n",
958
+ "# ax.fill(angles, precision, color='b', alpha=0.4)\n",
959
+ "# ax.fill(angles, recall, color='r', alpha=0.4)\n",
960
+ "# ax.fill(angles, f1, color='g', alpha=0.4)\n",
961
+ "# plt.legend(loc=\"upper right\", bbox_to_anchor=(0.1, 0.1))\n",
962
+ "# plt.show()\n"
963
+ ]
964
+ },
965
+ {
966
+ "cell_type": "code",
967
+ "source": [
968
+ "# Import necessary libraries\n",
969
+ "import numpy as np\n",
970
+ "from keras.models import load_model\n",
971
+ "from keras.preprocessing.image import ImageDataGenerator\n",
972
+ "from sklearn.metrics import classification_report\n",
973
+ "import matplotlib.pyplot as plt\n",
974
+ "import pandas as pd\n",
975
+ "\n",
976
+ "# Load the pre-trained model\n",
977
+ "model = load_model('/content/drive/MyDrive/model.h5')\n",
978
+ "\n",
979
+ "# Preprocess the test data\n",
980
+ "test_datagen = ImageDataGenerator(rescale=1./255) # Assuming you rescaled your images during training\n",
981
+ "test_dir = '/content/imgs/test'\n",
982
+ "test_generator = test_datagen.flow_from_directory(\n",
983
+ " test_dir,\n",
984
+ " target_size=(175, 175), # Adjust if you used a different input size during training\n",
985
+ " batch_size=1,\n",
986
+ " class_mode='categorical',\n",
987
+ " shuffle=False\n",
988
+ ")\n",
989
+ "\n",
990
+ "# Predict classes using the model\n",
991
+ "predictions = model.predict(test_generator, steps=test_generator.n, verbose=1)\n",
992
+ "predicted_classes = np.argmax(predictions, axis=1)\n",
993
+ "\n",
994
+ "# Get true labels and class labels\n",
995
+ "true_classes = test_generator.classes\n",
996
+ "class_labels = list(test_generator.class_indices.keys())\n",
997
+ "\n",
998
+ "# Generate the classification report\n",
999
+ "report = classification_report(true_classes, predicted_classes, target_names=class_labels, output_dict=True)\n",
1000
+ "report_df = pd.DataFrame(report).transpose()\n",
1001
+ "\n",
1002
+ "# Plot the metrics in the report\n",
1003
+ "report_df[['precision', 'recall', 'f1-score']].drop(['accuracy', 'macro avg', 'weighted avg']).plot(kind='bar', figsize=(15, 7))\n",
1004
+ "plt.title('Classification Report Metrics')\n",
1005
+ "plt.ylabel('Score')\n",
1006
+ "plt.xticks(rotation=45)\n",
1007
+ "plt.ylim(0, 1)\n",
1008
+ "plt.grid(axis='y')\n",
1009
+ "plt.tight_layout()\n",
1010
+ "plt.show()\n",
1011
+ "\n",
1012
+ "# Import necessary libraries\n",
1013
+ "# import numpy as np\n",
1014
+ "# from keras.models import load_model\n",
1015
+ "# from keras.preprocessing.image import ImageDataGenerator\n",
1016
+ "# from sklearn.metrics import classification_report, confusion_matrix\n",
1017
+ "# import matplotlib.pyplot as plt\n",
1018
+ "# import seaborn as sns\n",
1019
+ "\n",
1020
+ "# # Load the pre-trained model\n",
1021
+ "# model = load_model('/content/drive/MyDrive/model.h5')\n",
1022
+ "\n",
1023
+ "# # Preprocess the test data\n",
1024
+ "# test_datagen = ImageDataGenerator(rescale=1./255) # Assuming you rescaled your images during training\n",
1025
+ "# test_dir = '/content/imgs/test'\n",
1026
+ "# test_generator = test_datagen.flow_from_directory(\n",
1027
+ "# test_dir,\n",
1028
+ "# target_size=(175, 175), # Adjust if you used a different input size during training\n",
1029
+ "# batch_size=1,\n",
1030
+ "# class_mode='categorical',\n",
1031
+ "# shuffle=False\n",
1032
+ "# )\n",
1033
+ "\n",
1034
+ "# # Predict classes using the model\n",
1035
+ "# predictions = model.predict(test_generator, steps=test_generator.n, verbose=1)\n",
1036
+ "# predicted_classes = np.argmax(predictions, axis=1)\n",
1037
+ "\n",
1038
+ "# # Get true labels and class labels\n",
1039
+ "# true_classes = test_generator.classes\n",
1040
+ "# class_labels = list(test_generator.class_indices.keys())\n",
1041
+ "\n",
1042
+ "# # Generate the classification report\n",
1043
+ "# report = classification_report(true_classes, predicted_classes, target_names=class_labels)\n",
1044
+ "# print(report)\n",
1045
+ "\n",
1046
+ "# # Generate the confusion matrix\n",
1047
+ "# confusion_mtx = confusion_matrix(true_classes, predicted_classes)\n",
1048
+ "\n",
1049
+ "# # Plot the heatmap using Seaborn\n",
1050
+ "# plt.figure(figsize=(10, 8))\n",
1051
+ "# sns.heatmap(confusion_mtx, annot=True, fmt='d', cmap='Blues',\n",
1052
+ "# xticklabels=class_labels,\n",
1053
+ "# yticklabels=class_labels)\n",
1054
+ "# plt.xlabel('Predicted Label')\n",
1055
+ "# plt.ylabel('True Label')\n",
1056
+ "# plt.title('Confusion Matrix')\n",
1057
+ "# plt.show()\n"
1058
+ ],
1059
+ "metadata": {
1060
+ "id": "naosYJjSFmYA"
1061
+ },
1062
+ "execution_count": null,
1063
+ "outputs": []
1064
+ },
1065
+ {
1066
+ "cell_type": "code",
1067
+ "source": [],
1068
+ "metadata": {
1069
+ "id": "pM9HS2OxE4Ej"
1070
+ },
1071
+ "execution_count": null,
1072
+ "outputs": []
1073
+ }
1074
+ ],
1075
+ "metadata": {
1076
+ "accelerator": "GPU",
1077
+ "colab": {
1078
+ "provenance": []
1079
+ },
1080
+ "kernelspec": {
1081
+ "display_name": "Python 3",
1082
+ "name": "python3"
1083
+ },
1084
+ "language_info": {
1085
+ "codemirror_mode": {
1086
+ "name": "ipython",
1087
+ "version": 3
1088
+ },
1089
+ "file_extension": ".py",
1090
+ "mimetype": "text/x-python",
1091
+ "name": "python",
1092
+ "nbconvert_exporter": "python",
1093
+ "pygments_lexer": "ipython3",
1094
+ "version": "3.6.4"
1095
+ }
1096
+ },
1097
+ "nbformat": 4,
1098
+ "nbformat_minor": 0
1099
+ }