import numpy as np import tensorflow as tf from PIL import Image def convert_to_tf_tensor(image: Image): np_image = np.array(image) tf_image = tf.convert_to_tensor(np_image) # `expand_dims()` is used to add a batch dimension since # the TF augmentation layers operates on batched inputs. return tf.expand_dims(tf_image, 0) def preprocess_train(example_batch): """Apply train_transforms across a batch.""" images = [ train_data_augmentation(convert_to_tf_tensor(image.convert("RGB"))) for image in example_batch["image"] ] example_batch["pixel_values"] = [tf.transpose(tf.squeeze(image)) for image in images] return example_batch def preprocess_val(example_batch): """Apply val_transforms across a batch.""" images = [ val_data_augmentation(convert_to_tf_tensor(image.convert("RGB"))) for image in example_batch["image"] ] example_batch["pixel_values"] = [tf.transpose(tf.squeeze(image)) for image in images] return example_batch