|
import tensorflow as tf |
|
from tensorflow.keras import models, layers |
|
from tensorflow.keras.applications import VGG16 |
|
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler |
|
from tensorflow.keras.optimizers.experimental import SGD |
|
from tensorflow.keras import mixed_precision |
|
import os |
|
import random |
|
|
|
mixed_precision.set_global_policy('mixed_float16') |
|
|
|
|
|
|
|
epochs = 90 |
|
crop_size=224 |
|
main_directory = "FractalDB-1k/" |
|
|
|
def lr_scheduler(epoch): |
|
if epoch < 30: |
|
return 0.01 |
|
elif epoch < 60: |
|
return 0.01 * 0.1 |
|
else: |
|
return 0.01 * 0.01 |
|
|
|
batch_size = 128 |
|
directory = 'FractalDB-1k' |
|
|
|
print('Loading data...') |
|
|
|
def get_image_paths_and_labels(directory): |
|
image_paths = [] |
|
labels = [] |
|
class_names = os.listdir(directory) |
|
for class_name in class_names: |
|
class_dir = os.path.join(directory, class_name) |
|
class_image_paths = [os.path.join(class_dir, fname) for fname in os.listdir(class_dir)] |
|
image_paths += class_image_paths |
|
labels += [class_name] * len(class_image_paths) |
|
return image_paths, labels |
|
|
|
|
|
image_paths, labels = get_image_paths_and_labels(main_directory) |
|
print(len(image_paths)) |
|
|
|
label_to_index = dict((name, index) for index, name in enumerate(set(labels))) |
|
numeric_labels = [label_to_index[name] for name in labels] |
|
|
|
print(numeric_labels[:10]) |
|
print(len(label_to_index)) |
|
|
|
paired_list = list(zip(image_paths, numeric_labels)) |
|
|
|
random.shuffle(paired_list) |
|
|
|
image_paths, numeric_labels = zip(*paired_list) |
|
|
|
image_paths = list(image_paths) |
|
numeric_labels = list(numeric_labels) |
|
|
|
print(image_paths[:10]) |
|
print(numeric_labels[:10]) |
|
path_ds = tf.data.Dataset.from_tensor_slices(image_paths) |
|
label_ds = tf.data.Dataset.from_tensor_slices(numeric_labels) |
|
image_label_ds = tf.data.Dataset.zip((path_ds, label_ds)) |
|
|
|
|
|
def load_and_preprocess_image(path, label): |
|
image = tf.io.read_file(path) |
|
image = tf.image.decode_png(image, channels=3) |
|
image = tf.image.random_crop(image, size=[crop_size, crop_size, 3]) |
|
image = tf.cast(image, tf.float32) |
|
|
|
mean = tf.constant([0.2, 0.2, 0.2]) |
|
std = tf.constant([0.5, 0.5, 0.5]) |
|
image /= 255.0 |
|
image = (image - mean) / std |
|
return image, tf.cast(label, tf.int32) |
|
|
|
image_label_ds = image_label_ds.map(load_and_preprocess_image) |
|
image_label_ds = image_label_ds.shuffle(10000).batch(batch_size).prefetch(buffer_size=tf.data.AUTOTUNE) |
|
|
|
|
|
lr_scheduler_cb = LearningRateScheduler(lr_scheduler) |
|
|
|
print('Building model...') |
|
base_model = VGG16(weights=None, include_top=False,input_shape=(crop_size, crop_size, 3)) |
|
model = models.Sequential([ |
|
base_model, |
|
layers.Flatten(), |
|
layers.Dense(4096, input_shape=(512*7*7,), activation='relu'), |
|
layers.Dropout(0.2), |
|
layers.Dense(4096, activation='relu'), |
|
layers.Dropout(0.2), |
|
layers.Dense(1000), |
|
layers.Softmax() |
|
]) |
|
|
|
model.summary() |
|
|
|
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint( |
|
'drop02/vgg-fractaldb-{epoch:02d}.h5', save_weights_only=True, period=5) |
|
|
|
model.compile(optimizer=SGD(learning_rate=0.01, weight_decay=1e-4, momentum=0.9), |
|
loss=tf.keras.losses.SparseCategoricalCrossentropy(), |
|
metrics=['accuracy']) |
|
|
|
print('Training...') |
|
model.fit(image_label_ds, epochs=epochs, callbacks=[checkpoint_cb,lr_scheduler_cb]) |