chansung's picture
hello
efb03ac
import datetime
import os
from typing import List
import absl
import keras_tuner
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
import tensorflow_transform as tft
from tensorflow_cloud import CloudTuner
from tfx.v1.components import TunerFnResult
from tfx.components.trainer.fn_args_utils import DataAccessor
from tfx.components.trainer.fn_args_utils import FnArgs
from tfx.dsl.io import fileio
from tfx_bsl.tfxio import dataset_options
import tfx.extensions.google_cloud_ai_platform.constants as vertex_const
import tfx.extensions.google_cloud_ai_platform.trainer.executor as vertex_training_const
import tfx.extensions.google_cloud_ai_platform.tuner.executor as vertex_tuner_const
_TRAIN_DATA_SIZE = 128
_EVAL_DATA_SIZE = 128
_TRAIN_BATCH_SIZE = 32
_EVAL_BATCH_SIZE = 32
_CLASSIFIER_LEARNING_RATE = 1e-3
_FINETUNE_LEARNING_RATE = 7e-6
_CLASSIFIER_EPOCHS = 30
_IMAGE_KEY = "image"
_LABEL_KEY = "label"
def INFO(text: str):
absl.logging.info(text)
def _transformed_name(key: str) -> str:
return key + "_xf"
def _get_signature(model):
signatures = {
"serving_default": _get_serve_image_fn(model).get_concrete_function(
tf.TensorSpec(
shape=[None, 224, 224, 3],
dtype=tf.float32,
name=_transformed_name(_IMAGE_KEY),
)
)
}
return signatures
def _get_serve_image_fn(model):
@tf.function
def serve_image_fn(image_tensor):
return model(image_tensor)
return serve_image_fn
def _image_augmentation(image_features):
batch_size = tf.shape(image_features)[0]
image_features = tf.image.random_flip_left_right(image_features)
image_features = tf.image.resize_with_crop_or_pad(image_features, 250, 250)
image_features = tf.image.random_crop(image_features, (batch_size, 224, 224, 3))
return image_features
def _data_augmentation(feature_dict):
image_features = feature_dict[_transformed_name(_IMAGE_KEY)]
image_features = _image_augmentation(image_features)
feature_dict[_transformed_name(_IMAGE_KEY)] = image_features
return feature_dict
def _input_fn(
file_pattern: List[str],
data_accessor: DataAccessor,
tf_transform_output: tft.TFTransformOutput,
is_train: bool = False,
batch_size: int = 200,
) -> tf.data.Dataset:
dataset = data_accessor.tf_dataset_factory(
file_pattern,
dataset_options.TensorFlowDatasetOptions(
batch_size=batch_size, label_key=_transformed_name(_LABEL_KEY)
),
tf_transform_output.transformed_metadata.schema,
)
if is_train:
dataset = dataset.map(lambda x, y: (_data_augmentation(x), y))
return dataset
def _get_hyperparameters() -> keras_tuner.HyperParameters:
hp = keras_tuner.HyperParameters()
hp.Choice("learning_rate", [1e-3, 1e-2], default=1e-3)
return hp
def _build_keras_model(hparams: keras_tuner.HyperParameters) -> tf.keras.Model:
base_model = tf.keras.applications.ResNet50(
input_shape=(224, 224, 3), include_top=False, weights="imagenet", pooling="max"
)
base_model.input_spec = None
base_model.trainable = False
model = tf.keras.Sequential(
[
tf.keras.layers.InputLayer(
input_shape=(224, 224, 3), name=_transformed_name(_IMAGE_KEY)
),
base_model,
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(10, activation="softmax"),
]
)
model.compile(
loss="sparse_categorical_crossentropy",
optimizer=Adam(learning_rate=hparams.get("learning_rate")),
metrics=["sparse_categorical_accuracy"],
)
model.summary(print_fn=INFO)
return model
def cloud_tuner_fn(fn_args: FnArgs) -> TunerFnResult:
TUNING_ARGS_KEY = vertex_tuner_const.TUNING_ARGS_KEY
TRAINING_ARGS_KEY = vertex_training_const.TRAINING_ARGS_KEY
VERTEX_PROJECT_KEY = "project"
VERTEX_REGION_KEY = "region"
tuner = CloudTuner(
_build_keras_model,
max_trials=6,
hyperparameters=_get_hyperparameters(),
project_id=fn_args.custom_config[TUNING_ARGS_KEY][VERTEX_PROJECT_KEY],
region=fn_args.custom_config[TUNING_ARGS_KEY][VERTEX_REGION_KEY],
objective="val_sparse_categorical_accuracy",
directory=fn_args.working_dir,
)
tf_transform_output = tft.TFTransformOutput(fn_args.transform_graph_path)
train_dataset = _input_fn(
fn_args.train_files,
fn_args.data_accessor,
tf_transform_output,
is_train=True,
batch_size=_TRAIN_BATCH_SIZE,
)
eval_dataset = _input_fn(
fn_args.eval_files,
fn_args.data_accessor,
tf_transform_output,
is_train=False,
batch_size=_EVAL_BATCH_SIZE,
)
return TunerFnResult(
tuner=tuner,
fit_kwargs={
"x": train_dataset,
"validation_data": eval_dataset,
"steps_per_epoch": steps_per_epoch,
"validation_steps": fn_args.eval_steps,
},
)
def tuner_fn(fn_args: FnArgs) -> TunerFnResult:
steps_per_epoch = int(_TRAIN_DATA_SIZE / _TRAIN_BATCH_SIZE)
tuner = keras_tuner.RandomSearch(
_build_keras_model,
max_trials=6,
hyperparameters=_get_hyperparameters(),
allow_new_entries=False,
objective=keras_tuner.Objective("val_sparse_categorical_accuracy", "max"),
directory=fn_args.working_dir,
project_name="img_classification_tuning",
)
tf_transform_output = tft.TFTransformOutput(fn_args.transform_graph_path)
train_dataset = _input_fn(
fn_args.train_files,
fn_args.data_accessor,
tf_transform_output,
is_train=True,
batch_size=_TRAIN_BATCH_SIZE,
)
eval_dataset = _input_fn(
fn_args.eval_files,
fn_args.data_accessor,
tf_transform_output,
is_train=False,
batch_size=_EVAL_BATCH_SIZE,
)
return TunerFnResult(
tuner=tuner,
fit_kwargs={
"x": train_dataset,
"validation_data": eval_dataset,
"steps_per_epoch": steps_per_epoch,
"validation_steps": fn_args.eval_steps,
},
)
def run_fn(fn_args: FnArgs):
steps_per_epoch = int(_TRAIN_DATA_SIZE / _TRAIN_BATCH_SIZE)
total_epochs = int(fn_args.train_steps / steps_per_epoch)
if _CLASSIFIER_EPOCHS > total_epochs:
raise ValueError("Classifier epochs is greater than the total epochs")
tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)
train_dataset = _input_fn(
fn_args.train_files,
fn_args.data_accessor,
tf_transform_output,
is_train=True,
batch_size=_TRAIN_BATCH_SIZE,
)
eval_dataset = _input_fn(
fn_args.eval_files,
fn_args.data_accessor,
tf_transform_output,
is_train=False,
batch_size=_EVAL_BATCH_SIZE,
)
INFO("Tensorboard logging to {}".format(fn_args.model_run_dir))
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=fn_args.model_run_dir, update_freq="batch"
)
if fn_args.hyperparameters:
hparams = keras_tuner.HyperParameters.from_config(fn_args.hyperparameters)
else:
hparams = _get_hyperparameters()
INFO(f"HyperParameters for training: ${hparams.get_config()}")
model = _build_keras_model(hparams)
model.fit(
train_dataset,
epochs=_CLASSIFIER_EPOCHS,
steps_per_epoch=steps_per_epoch,
validation_data=eval_dataset,
validation_steps=fn_args.eval_steps,
callbacks=[tensorboard_callback],
)
model.save(
fn_args.serving_model_dir, save_format="tf", signatures=_get_signature(model)
)