|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Utility functions for creating a tf.train.Example proto of image triplets.""" |
|
|
|
import io |
|
import os |
|
from typing import Any, List, Mapping, Optional |
|
|
|
from absl import logging |
|
import apache_beam as beam |
|
import numpy as np |
|
import PIL.Image |
|
import six |
|
from skimage import transform |
|
import tensorflow as tf |
|
|
|
_UINT8_MAX_F = float(np.iinfo(np.uint8).max) |
|
_GAMMA = 2.2 |
|
|
|
|
|
def _resample_image(image: np.ndarray, resample_image_width: int, |
|
resample_image_height: int) -> np.ndarray: |
|
"""Re-samples and returns an `image` to be `resample_image_size`.""" |
|
|
|
image = image.astype(np.float32) / _UINT8_MAX_F |
|
image = np.power(np.clip(image, 0, 1), _GAMMA) |
|
|
|
|
|
resample_image_size = (resample_image_height, resample_image_width) |
|
image = transform.resize_local_mean(image, resample_image_size) |
|
|
|
|
|
image = np.power(np.clip(image, 0, 1), 1.0 / _GAMMA) |
|
image = np.clip(image * _UINT8_MAX_F + 0.5, 0.0, |
|
_UINT8_MAX_F).astype(np.uint8) |
|
return image |
|
|
|
|
|
def generate_image_triplet_example( |
|
triplet_dict: Mapping[str, str], |
|
scale_factor: int = 1, |
|
center_crop_factor: int = 1) -> Optional[tf.train.Example]: |
|
"""Generates and serializes a tf.train.Example proto from an image triplet. |
|
|
|
Default setting creates a triplet Example with the input images unchanged. |
|
Images are processed in the order of center-crop then downscale. |
|
|
|
Args: |
|
triplet_dict: A dict of image key to filepath of the triplet images. |
|
scale_factor: An integer scale factor to isotropically downsample images. |
|
center_crop_factor: An integer cropping factor to center crop images with |
|
the original resolution but isotropically downsized by the factor. |
|
|
|
Returns: |
|
tf.train.Example proto, or None upon error. |
|
|
|
Raises: |
|
ValueError if triplet_dict length is different from three or the scale input |
|
arguments are non-positive. |
|
""" |
|
if len(triplet_dict) != 3: |
|
raise ValueError( |
|
f'Length of triplet_dict must be exactly 3, not {len(triplet_dict)}.') |
|
|
|
if scale_factor <= 0 or center_crop_factor <= 0: |
|
raise ValueError(f'(scale_factor, center_crop_factor) must be positive, ' |
|
f'Not ({scale_factor}, {center_crop_factor}).') |
|
|
|
feature = {} |
|
|
|
|
|
mid_frame_path = os.path.dirname(triplet_dict['frame_1']) |
|
feature['path'] = tf.train.Feature( |
|
bytes_list=tf.train.BytesList(value=[six.ensure_binary(mid_frame_path)])) |
|
|
|
for image_key, image_path in triplet_dict.items(): |
|
if not tf.io.gfile.exists(image_path): |
|
logging.error('File not found: %s', image_path) |
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
byte_array = tf.io.gfile.GFile(image_path, 'rb').read() |
|
except tf.errors.InvalidArgumentError: |
|
logging.exception('Cannot read image file: %s', image_path) |
|
return None |
|
try: |
|
pil_image = PIL.Image.open(io.BytesIO(byte_array)) |
|
except PIL.UnidentifiedImageError: |
|
logging.exception('Cannot decode image file: %s', image_path) |
|
return None |
|
width, height = pil_image.size |
|
pil_image_format = pil_image.format |
|
|
|
|
|
|
|
if center_crop_factor > 1: |
|
image = np.array(pil_image) |
|
quarter_height = image.shape[0] // (2 * center_crop_factor) |
|
quarter_width = image.shape[1] // (2 * center_crop_factor) |
|
image = image[quarter_height:-quarter_height, |
|
quarter_width:-quarter_width, :] |
|
pil_image = PIL.Image.fromarray(image) |
|
|
|
|
|
height, width, _ = image.shape |
|
buffer = io.BytesIO() |
|
try: |
|
pil_image.save(buffer, format='PNG') |
|
except OSError: |
|
logging.exception('Cannot encode image file: %s', image_path) |
|
return None |
|
byte_array = buffer.getvalue() |
|
|
|
|
|
if scale_factor > 1: |
|
image = np.array(pil_image) |
|
image = _resample_image(image, image.shape[1] // scale_factor, |
|
image.shape[0] // scale_factor) |
|
pil_image = PIL.Image.fromarray(image) |
|
|
|
|
|
height, width, _ = image.shape |
|
buffer = io.BytesIO() |
|
try: |
|
pil_image.save(buffer, format='PNG') |
|
except OSError: |
|
logging.exception('Cannot encode image file: %s', image_path) |
|
return None |
|
byte_array = buffer.getvalue() |
|
|
|
|
|
image_feature = tf.train.Feature( |
|
bytes_list=tf.train.BytesList(value=[byte_array])) |
|
height_feature = tf.train.Feature( |
|
int64_list=tf.train.Int64List(value=[height])) |
|
width_feature = tf.train.Feature( |
|
int64_list=tf.train.Int64List(value=[width])) |
|
encoding = tf.train.Feature( |
|
bytes_list=tf.train.BytesList( |
|
value=[six.ensure_binary(pil_image_format.lower())])) |
|
|
|
|
|
feature[f'{image_key}/encoded'] = image_feature |
|
feature[f'{image_key}/format'] = encoding |
|
feature[f'{image_key}/height'] = height_feature |
|
feature[f'{image_key}/width'] = width_feature |
|
|
|
|
|
features = tf.train.Features(feature=feature) |
|
example = tf.train.Example(features=features) |
|
return example |
|
|
|
|
|
class ExampleGenerator(beam.DoFn): |
|
"""Generate a tf.train.Example per input image triplet filepaths.""" |
|
|
|
def __init__(self, |
|
images_map: Mapping[str, Any], |
|
scale_factor: int = 1, |
|
center_crop_factor: int = 1): |
|
"""Initializes the map of 3 images to add to each tf.train.Example. |
|
|
|
Args: |
|
images_map: Map from image key to image filepath. |
|
scale_factor: A scale factor to downsample frames. |
|
center_crop_factor: A factor to centercrop and downsize frames. |
|
""" |
|
super().__init__() |
|
self._images_map = images_map |
|
self._scale_factor = scale_factor |
|
self._center_crop_factor = center_crop_factor |
|
|
|
def process(self, triplet_dict: Mapping[str, str]) -> List[bytes]: |
|
"""Generates a serialized tf.train.Example for a triplet of images. |
|
|
|
Args: |
|
triplet_dict: A dict of image key to filepath of the triplet images. |
|
|
|
Returns: |
|
A serialized tf.train.Example proto. No shuffling is applied. |
|
""" |
|
example = generate_image_triplet_example(triplet_dict, self._scale_factor, |
|
self._center_crop_factor) |
|
if example: |
|
return [example.SerializeToString()] |
|
else: |
|
return [] |
|
|