JosefJilek commited on
Commit
40f30f3
·
1 Parent(s): 08c8253
main.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import tensorflow as tf
6
+
7
+ from tensorflow import keras
8
+ from tensorflow.keras import layers
9
+ from tensorflow.keras.models import Sequential
10
+
11
+ import pathlib
12
+
13
+ from tensorflow.python.client import device_lib
14
+ print(device_lib.list_local_devices())
15
+
16
+ data_dir = "C:/Users/jilek/Downloads/AAT+"
17
+ data_dir = pathlib.Path(data_dir).with_suffix('')
18
+
19
+ data_dir_test = "C:/Users/jilek/Downloads/AAT+_TEST"
20
+ data_dir_test = pathlib.Path(data_dir_test).with_suffix('')
21
+
22
+ image_count = len(list(data_dir.glob('*/*.jpg')))
23
+ print(image_count)
24
+
25
+ batch_size = 1
26
+ img_height = 1024
27
+ img_width = 1024
28
+
29
+ train_ds = tf.keras.utils.image_dataset_from_directory(
30
+ data_dir,
31
+ validation_split=0.0,
32
+ #subset="training",
33
+ seed=123,
34
+ labels='inferred',
35
+ label_mode='categorical',
36
+ class_names=["C100", "C095", "C090", "C085", "C080", "C070", "C060", "C040", "C020"],
37
+ color_mode="grayscale", #grayscale
38
+ shuffle=True,
39
+ image_size=(img_height, img_width),
40
+ batch_size=batch_size)
41
+
42
+ val_ds = tf.keras.utils.image_dataset_from_directory(
43
+ data_dir_test,
44
+ validation_split=0.0,
45
+ #subset="validation",
46
+ seed=123,
47
+ labels='inferred',
48
+ label_mode='categorical',
49
+ class_names=["C100", "C095", "C090", "C085", "C080", "C070", "C060", "C040", "C020"],
50
+ color_mode="grayscale",
51
+ image_size=(img_height, img_width),
52
+ batch_size=batch_size)
53
+
54
+ class_names = train_ds.class_names
55
+ print(class_names)
56
+
57
+ for image_batch, labels_batch in train_ds:
58
+ print(image_batch.shape)
59
+ print(labels_batch.shape)
60
+ break
61
+
62
+ AUTOTUNE = tf.data.AUTOTUNE
63
+
64
+ data_augmentation = keras.Sequential(
65
+ [
66
+ layers.RandomFlip("horizontal_and_vertical",
67
+ input_shape=(img_height,
68
+ img_width,
69
+ 1)), #rgb
70
+ #layers.RandomRotation(0.5),
71
+ #layers.RandomZoom(0.5),
72
+ ]
73
+ )
74
+
75
+ train_ds = train_ds.shuffle(buffer_size=900).prefetch(buffer_size=AUTOTUNE) #.cache()
76
+ val_ds = val_ds.prefetch(buffer_size=AUTOTUNE) #.cache()
77
+
78
+
79
+ num_classes = len(class_names)
80
+ print(str(num_classes))
81
+
82
+ model = Sequential([
83
+ layers.Rescaling(1.0/255, input_shape=(img_height, img_width, 1)), #rgb
84
+ #layers.Dropout(0.0),
85
+ #layers.MaxPooling2D(pool_size=(8, 8)),
86
+ layers.Conv2D(4, (4, 4), strides=(2, 2), padding='valid', dilation_rate=(1, 1), groups=1, input_shape=(1024, 1024, 1), activation='relu'),
87
+ layers.Conv2D(8, (4, 4), strides=(2, 2), padding='valid', dilation_rate=(1, 1), groups=1, input_shape=(512, 512, 4), activation='relu'),
88
+ layers.Conv2D(16, (4, 4), strides=(4, 4), padding='valid', dilation_rate=(1, 1), groups=1, input_shape=(256, 256, 8), activation='relu'),
89
+ layers.Conv2D(32, (4, 4), strides=(4, 4), padding='valid', dilation_rate=(1, 1), groups=1, input_shape=(64, 64, 16), activation='relu'),
90
+ layers.Conv2D(64, (4, 4), strides=(4, 4), padding='valid', dilation_rate=(1, 1), groups=1, input_shape=(16, 16, 32), activation='relu'),
91
+ #layers.Conv2D(128, (4, 4), strides=(1, 1), padding='valid', dilation_rate=(1, 1), groups=1, input_shape=(8, 8, 64), activation='relu'),
92
+ #layers.Dropout(0.1),
93
+ layers.Flatten(),
94
+ layers.Dense(32, activation='relu'),
95
+ layers.Dense(num_classes, activation='softmax')
96
+ ])
97
+ model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
98
+ loss=tf.keras.losses.CategoricalCrossentropy(),
99
+ metrics=['accuracy'])
100
+
101
+ model.summary()
102
+ model.save("./model/AAT+")
103
+
104
+ epochs = 130
105
+ history = model.fit(
106
+ train_ds,
107
+ validation_data=val_ds,
108
+ epochs=epochs
109
+ )
110
+
111
+ acc = history.history['accuracy']
112
+ val_acc = history.history['val_accuracy']
113
+
114
+ loss = history.history['loss']
115
+ val_loss = history.history['val_loss']
116
+
117
+ epochs_range = range(epochs)
118
+
119
+ plt.figure(figsize=(8, 8))
120
+ plt.subplot(1, 2, 1)
121
+ plt.plot(epochs_range, acc, label='Training Accuracy')
122
+ plt.plot(epochs_range, val_acc, label='Validation Accuracy')
123
+ plt.legend(loc='lower right')
124
+ plt.title('Training and Validation Accuracy')
125
+
126
+ plt.subplot(1, 2, 2)
127
+ plt.plot(epochs_range, loss, label='Training Loss')
128
+ plt.plot(epochs_range, val_loss, label='Validation Loss')
129
+ plt.legend(loc='upper right')
130
+ plt.title('Training and Validation Loss')
131
+ plt.show()
132
+
133
+ test_dir = "C:/Users/jilek/Downloads/AAT_T/"
134
+ for file_name in os.listdir(test_dir):
135
+ file_path = os.path.join(test_dir, file_name)
136
+ img = tf.keras.utils.load_img(
137
+ file_path, target_size=(img_height, img_width), color_mode="grayscale" #grayscale
138
+ )
139
+ img_array = tf.keras.utils.img_to_array(img)
140
+ img_array = tf.expand_dims(img_array, 0) # Create a batch
141
+
142
+ predictions = model.predict(img_array)
143
+ score = tf.nn.softmax(predictions[0])
144
+
145
+ print(file_name)
146
+ print(
147
+ "This image most likely belongs to {} with a {:.2f} percent confidence."
148
+ .format(class_names[np.argmax(score)], 100 * np.max(score))
149
+ )
150
+
151
+
152
+
153
+
154
+
155
+
156
+
157
+
158
+
model/AAT+/keras_metadata.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:429c89709452f69986c9cc2c9549be3e6e6525d467876355348d373241545d32
3
+ size 23711
model/AAT+/saved_model.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf5d62d94e882bf080ea5aa8f5ba9c4635bddd6a2cfd03d4be9ba875f2b5fa1a
3
+ size 148529
model/AAT+/variables/variables.data-00000-of-00001 ADDED
Binary file (258 kB). View file
 
model/AAT+/variables/variables.index ADDED
Binary file (1.26 kB). View file