NOABOL35631y's picture
Create Y
309bbd2
import numpy as np
import tensorflow as tf
from PIL import Image
def convert_to_tf_tensor(image: Image):
np_image = np.array(image)
tf_image = tf.convert_to_tensor(np_image)
# `expand_dims()` is used to add a batch dimension since
# the TF augmentation layers operates on batched inputs.
return tf.expand_dims(tf_image, 0)
def preprocess_train(example_batch):
"""Apply train_transforms across a batch."""
images = [
train_data_augmentation(convert_to_tf_tensor(image.convert("RGB"))) for image in example_batch["image"]
]
example_batch["pixel_values"] = [tf.transpose(tf.squeeze(image)) for image in images]
return example_batch
def preprocess_val(example_batch):
"""Apply val_transforms across a batch."""
images = [
val_data_augmentation(convert_to_tf_tensor(image.convert("RGB"))) for image in example_batch["image"]
]
example_batch["pixel_values"] = [tf.transpose(tf.squeeze(image)) for image in images]
return example_batch