bobfromjapan's picture
Update train.py
9fd189a verified
import tensorflow as tf
from tensorflow.keras import models, layers
from tensorflow.keras.applications import VGG16
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
from tensorflow.keras.optimizers.experimental import SGD
from tensorflow.keras import mixed_precision
import os
import random
mixed_precision.set_global_policy('mixed_float16')
#print(tf.config.list_physical_devices('GPU'))
epochs = 90
crop_size=224
main_directory = "FractalDB-1k/"
def lr_scheduler(epoch):
if epoch < 30:
return 0.01
elif epoch < 60:
return 0.01 * 0.1
else:
return 0.01 * 0.01
batch_size = 128
directory = 'FractalDB-1k'
print('Loading data...')
def get_image_paths_and_labels(directory):
image_paths = []
labels = []
class_names = os.listdir(directory)
for class_name in class_names:
class_dir = os.path.join(directory, class_name)
class_image_paths = [os.path.join(class_dir, fname) for fname in os.listdir(class_dir)]
image_paths += class_image_paths
labels += [class_name] * len(class_image_paths)
return image_paths, labels
image_paths, labels = get_image_paths_and_labels(main_directory)
print(len(image_paths))
label_to_index = dict((name, index) for index, name in enumerate(set(labels)))
numeric_labels = [label_to_index[name] for name in labels]
print(numeric_labels[:10])
print(len(label_to_index))
paired_list = list(zip(image_paths, numeric_labels))
random.shuffle(paired_list)
image_paths, numeric_labels = zip(*paired_list)
image_paths = list(image_paths)
numeric_labels = list(numeric_labels)
print(image_paths[:10])
print(numeric_labels[:10])
path_ds = tf.data.Dataset.from_tensor_slices(image_paths)
label_ds = tf.data.Dataset.from_tensor_slices(numeric_labels)
image_label_ds = tf.data.Dataset.zip((path_ds, label_ds))
def load_and_preprocess_image(path, label):
image = tf.io.read_file(path)
image = tf.image.decode_png(image, channels=3)
image = tf.image.random_crop(image, size=[crop_size, crop_size, 3])
image = tf.cast(image, tf.float32)
mean = tf.constant([0.2, 0.2, 0.2])
std = tf.constant([0.5, 0.5, 0.5])
image /= 255.0
image = (image - mean) / std
return image, tf.cast(label, tf.int32)
image_label_ds = image_label_ds.map(load_and_preprocess_image)
image_label_ds = image_label_ds.shuffle(10000).batch(batch_size).prefetch(buffer_size=tf.data.AUTOTUNE)
lr_scheduler_cb = LearningRateScheduler(lr_scheduler)
print('Building model...')
base_model = VGG16(weights=None, include_top=False,input_shape=(crop_size, crop_size, 3))
model = models.Sequential([
base_model,
layers.Flatten(),
layers.Dense(4096, input_shape=(512*7*7,), activation='relu'),
layers.Dropout(0.2),
layers.Dense(4096, activation='relu'),
layers.Dropout(0.2),
layers.Dense(1000),
layers.Softmax()
])
model.summary()
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
'drop02/vgg-fractaldb-{epoch:02d}.h5', save_weights_only=True, period=5)
model.compile(optimizer=SGD(learning_rate=0.01, weight_decay=1e-4, momentum=0.9),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy'])
print('Training...')
model.fit(image_label_ds, epochs=epochs, callbacks=[checkpoint_cb,lr_scheduler_cb])