# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Operations for image patches.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow.compat.v1 as tf def get_patch_mask(y, x, patch_size, image_shape): """Creates a 2D mask array for a square patch of a given size and location. The mask is created with its center at the y and x coordinates, which must be within the image. While the mask center must be within the image, the mask itself can be partially outside of it. If patch_size is an even number, then the mask is created with lower-valued coordinates first (top and left). Args: y: An integer or scalar int32 tensor. The vertical coordinate of the patch mask center. Must be within the range [0, image_height). x: An integer or scalar int32 tensor. The horizontal coordinate of the patch mask center. Must be within the range [0, image_width). patch_size: An integer or scalar int32 tensor. The square size of the patch mask. Must be at least 1. image_shape: A list or 1D int32 tensor representing the shape of the image to which the mask will correspond, with the first two values being image height and width. For example, [image_height, image_width] or [image_height, image_width, image_channels]. Returns: Boolean mask tensor of shape [image_height, image_width] with True values for the patch. Raises: tf.errors.InvalidArgumentError: if x is not in the range [0, image_width), y is not in the range [0, image_height), or patch_size is not at least 1. """ image_hw = image_shape[:2] mask_center_yx = tf.stack([y, x]) with tf.control_dependencies([ tf.debugging.assert_greater_equal( patch_size, 1, message='Patch size must be >= 1'), tf.debugging.assert_greater_equal( mask_center_yx, 0, message='Patch center (y, x) must be >= (0, 0)'), tf.debugging.assert_less( mask_center_yx, image_hw, message='Patch center (y, x) must be < image (h, w)') ]): mask_center_yx = tf.identity(mask_center_yx) half_patch_size = tf.cast(patch_size, dtype=tf.float32) / 2 start_yx = mask_center_yx - tf.cast(tf.floor(half_patch_size), dtype=tf.int32) end_yx = mask_center_yx + tf.cast(tf.ceil(half_patch_size), dtype=tf.int32) start_yx = tf.maximum(start_yx, 0) end_yx = tf.minimum(end_yx, image_hw) start_y = start_yx[0] start_x = start_yx[1] end_y = end_yx[0] end_x = end_yx[1] lower_pad = image_hw[0] - end_y upper_pad = start_y left_pad = start_x right_pad = image_hw[1] - end_x mask = tf.ones([end_y - start_y, end_x - start_x], dtype=tf.bool) return tf.pad(mask, [[upper_pad, lower_pad], [left_pad, right_pad]])