Spaces:
Runtime error
Runtime error
import numpy as np | |
import tensorflow as tf | |
from PIL import Image | |
def convert_to_tf_tensor(image: Image): | |
np_image = np.array(image) | |
tf_image = tf.convert_to_tensor(np_image) | |
# `expand_dims()` is used to add a batch dimension since | |
# the TF augmentation layers operates on batched inputs. | |
return tf.expand_dims(tf_image, 0) | |
def preprocess_train(example_batch): | |
"""Apply train_transforms across a batch.""" | |
images = [ | |
train_data_augmentation(convert_to_tf_tensor(image.convert("RGB"))) for image in example_batch["image"] | |
] | |
example_batch["pixel_values"] = [tf.transpose(tf.squeeze(image)) for image in images] | |
return example_batch | |
def preprocess_val(example_batch): | |
"""Apply val_transforms across a batch.""" | |
images = [ | |
val_data_augmentation(convert_to_tf_tensor(image.convert("RGB"))) for image in example_batch["image"] | |
] | |
example_batch["pixel_values"] = [tf.transpose(tf.squeeze(image)) for image in images] | |
return example_batch |