# Copyright 2018 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Utilities for the instance embedding for segmentation.""" import numpy as np import tensorflow as tf from deeplab import model from deeplab.core import preprocess_utils from feelvos.utils import mask_damaging slim = tf.contrib.slim resolve_shape = preprocess_utils.resolve_shape WRONG_LABEL_PADDING_DISTANCE = 1e20 # With correlation_cost local matching will be much faster. But we provide a # slow fallback for convenience. USE_CORRELATION_COST = False if USE_CORRELATION_COST: # pylint: disable=g-import-not-at-top from correlation_cost.python.ops import correlation_cost_op def pairwise_distances(x, y): """Computes pairwise squared l2 distances between tensors x and y. Args: x: Tensor of shape [n, feature_dim]. y: Tensor of shape [m, feature_dim]. Returns: Float32 distances tensor of shape [n, m]. """ # d[i,j] = (x[i] - y[j]) * (x[i] - y[j])' # = sum(x[i]^2, 1) + sum(y[j]^2, 1) - 2 * x[i] * y[j]' xs = tf.reduce_sum(x * x, axis=1)[:, tf.newaxis] ys = tf.reduce_sum(y * y, axis=1)[tf.newaxis, :] d = xs + ys - 2 * tf.matmul(x, y, transpose_b=True) return d def pairwise_distances2(x, y): """Computes pairwise squared l2 distances between tensors x and y. Naive implementation, high memory use. Could be useful to test the more efficient implementation. Args: x: Tensor of shape [n, feature_dim]. y: Tensor of shape [m, feature_dim]. Returns: distances of shape [n, m]. """ return tf.reduce_sum(tf.squared_difference( x[:, tf.newaxis], y[tf.newaxis, :]), axis=-1) def cross_correlate(x, y, max_distance=9): """Efficiently computes the cross correlation of x and y. Optimized implementation using correlation_cost. Note that we do not normalize by the feature dimension. Args: x: Float32 tensor of shape [height, width, feature_dim]. y: Float32 tensor of shape [height, width, feature_dim]. max_distance: Integer, the maximum distance in pixel coordinates per dimension which is considered to be in the search window. Returns: Float32 tensor of shape [height, width, (2 * max_distance + 1) ** 2]. """ with tf.name_scope('cross_correlation'): corr = correlation_cost_op.correlation_cost( x[tf.newaxis], y[tf.newaxis], kernel_size=1, max_displacement=max_distance, stride_1=1, stride_2=1, pad=max_distance) corr = tf.squeeze(corr, axis=0) # This correlation implementation takes the mean over the feature_dim, # but we want sum here, so multiply by feature_dim. feature_dim = resolve_shape(x)[-1] corr *= feature_dim return corr def local_pairwise_distances(x, y, max_distance=9): """Computes pairwise squared l2 distances using a local search window. Optimized implementation using correlation_cost. Args: x: Float32 tensor of shape [height, width, feature_dim]. y: Float32 tensor of shape [height, width, feature_dim]. max_distance: Integer, the maximum distance in pixel coordinates per dimension which is considered to be in the search window. Returns: Float32 distances tensor of shape [height, width, (2 * max_distance + 1) ** 2]. """ with tf.name_scope('local_pairwise_distances'): # d[i,j] = (x[i] - y[j]) * (x[i] - y[j])' # = sum(x[i]^2, -1) + sum(y[j]^2, -1) - 2 * x[i] * y[j]' corr = cross_correlate(x, y, max_distance=max_distance) xs = tf.reduce_sum(x * x, axis=2)[..., tf.newaxis] ys = tf.reduce_sum(y * y, axis=2)[..., tf.newaxis] ones_ys = tf.ones_like(ys) ys = cross_correlate(ones_ys, ys, max_distance=max_distance) d = xs + ys - 2 * corr # Boundary should be set to Inf. boundary = tf.equal( cross_correlate(ones_ys, ones_ys, max_distance=max_distance), 0) d = tf.where(boundary, tf.fill(tf.shape(d), tf.constant(np.float('inf'))), d) return d def local_pairwise_distances2(x, y, max_distance=9): """Computes pairwise squared l2 distances using a local search window. Naive implementation using map_fn. Used as a slow fallback for when correlation_cost is not available. Args: x: Float32 tensor of shape [height, width, feature_dim]. y: Float32 tensor of shape [height, width, feature_dim]. max_distance: Integer, the maximum distance in pixel coordinates per dimension which is considered to be in the search window. Returns: Float32 distances tensor of shape [height, width, (2 * max_distance + 1) ** 2]. """ with tf.name_scope('local_pairwise_distances2'): padding_val = 1e20 padded_y = tf.pad(y, [[max_distance, max_distance], [max_distance, max_distance], [0, 0]], constant_values=padding_val) height, width, _ = resolve_shape(x) dists = [] for y_start in range(2 * max_distance + 1): y_end = y_start + height y_slice = padded_y[y_start:y_end] for x_start in range(2 * max_distance + 1): x_end = x_start + width offset_y = y_slice[:, x_start:x_end] dist = tf.reduce_sum(tf.squared_difference(x, offset_y), axis=2) dists.append(dist) dists = tf.stack(dists, axis=2) return dists def majority_vote(labels): """Performs a label majority vote along axis 1. Second try, hopefully this time more efficient. We assume that the labels are contiguous starting from 0. It will also work for non-contiguous labels, but be inefficient. Args: labels: Int tensor of shape [n, k] Returns: The majority of labels along axis 1 """ max_label = tf.reduce_max(labels) one_hot = tf.one_hot(labels, depth=max_label + 1) summed = tf.reduce_sum(one_hot, axis=1) majority = tf.argmax(summed, axis=1) return majority def assign_labels_by_nearest_neighbors(reference_embeddings, query_embeddings, reference_labels, k=1): """Segments by nearest neighbor query wrt the reference frame. Args: reference_embeddings: Tensor of shape [height, width, embedding_dim], the embedding vectors for the reference frame query_embeddings: Tensor of shape [n_query_images, height, width, embedding_dim], the embedding vectors for the query frames reference_labels: Tensor of shape [height, width, 1], the class labels of the reference frame k: Integer, the number of nearest neighbors to use Returns: The labels of the nearest neighbors as [n_query_frames, height, width, 1] tensor Raises: ValueError: If k < 1. """ if k < 1: raise ValueError('k must be at least 1') dists = flattened_pairwise_distances(reference_embeddings, query_embeddings) if k == 1: nn_indices = tf.argmin(dists, axis=1)[..., tf.newaxis] else: _, nn_indices = tf.nn.top_k(-dists, k, sorted=False) reference_labels = tf.reshape(reference_labels, [-1]) nn_labels = tf.gather(reference_labels, nn_indices) if k == 1: nn_labels = tf.squeeze(nn_labels, axis=1) else: nn_labels = majority_vote(nn_labels) height = tf.shape(reference_embeddings)[0] width = tf.shape(reference_embeddings)[1] n_query_frames = query_embeddings.shape[0] nn_labels = tf.reshape(nn_labels, [n_query_frames, height, width, 1]) return nn_labels def flattened_pairwise_distances(reference_embeddings, query_embeddings): """Calculates flattened tensor of pairwise distances between ref and query. Args: reference_embeddings: Tensor of shape [..., embedding_dim], the embedding vectors for the reference frame query_embeddings: Tensor of shape [n_query_images, height, width, embedding_dim], the embedding vectors for the query frames. Returns: A distance tensor of shape [reference_embeddings.size / embedding_dim, query_embeddings.size / embedding_dim] """ embedding_dim = resolve_shape(query_embeddings)[-1] reference_embeddings = tf.reshape(reference_embeddings, [-1, embedding_dim]) first_dim = -1 query_embeddings = tf.reshape(query_embeddings, [first_dim, embedding_dim]) dists = pairwise_distances(query_embeddings, reference_embeddings) return dists def nearest_neighbor_features_per_object( reference_embeddings, query_embeddings, reference_labels, max_neighbors_per_object, k_nearest_neighbors, gt_ids=None, n_chunks=100): """Calculates the distance to the nearest neighbor per object. For every pixel of query_embeddings calculate the distance to the nearest neighbor in the (possibly subsampled) reference_embeddings per object. Args: reference_embeddings: Tensor of shape [height, width, embedding_dim], the embedding vectors for the reference frame. query_embeddings: Tensor of shape [n_query_images, height, width, embedding_dim], the embedding vectors for the query frames. reference_labels: Tensor of shape [height, width, 1], the class labels of the reference frame. max_neighbors_per_object: Integer, the maximum number of candidates for the nearest neighbor query per object after subsampling, or 0 for no subsampling. k_nearest_neighbors: Integer, the number of nearest neighbors to use. gt_ids: Int tensor of shape [n_objs] of the sorted unique ground truth ids in the first frame. If None, it will be derived from reference_labels. n_chunks: Integer, the number of chunks to use to save memory (set to 1 for no chunking). Returns: nn_features: A float32 tensor of nearest neighbor features of shape [n_query_images, height, width, n_objects, feature_dim]. gt_ids: An int32 tensor of the unique sorted object ids present in the reference labels. """ with tf.name_scope('nn_features_per_object'): reference_labels_flat = tf.reshape(reference_labels, [-1]) if gt_ids is None: ref_obj_ids, _ = tf.unique(reference_labels_flat) ref_obj_ids = tf.contrib.framework.sort(ref_obj_ids) gt_ids = ref_obj_ids embedding_dim = resolve_shape(reference_embeddings)[-1] reference_embeddings_flat = tf.reshape(reference_embeddings, [-1, embedding_dim]) reference_embeddings_flat, reference_labels_flat = ( subsample_reference_embeddings_and_labels(reference_embeddings_flat, reference_labels_flat, gt_ids, max_neighbors_per_object)) shape = resolve_shape(query_embeddings) query_embeddings_flat = tf.reshape(query_embeddings, [-1, embedding_dim]) nn_features = _nearest_neighbor_features_per_object_in_chunks( reference_embeddings_flat, query_embeddings_flat, reference_labels_flat, gt_ids, k_nearest_neighbors, n_chunks) nn_features_dim = resolve_shape(nn_features)[-1] nn_features_reshaped = tf.reshape(nn_features, tf.stack(shape[:3] + [tf.size(gt_ids), nn_features_dim])) return nn_features_reshaped, gt_ids def _nearest_neighbor_features_per_object_in_chunks( reference_embeddings_flat, query_embeddings_flat, reference_labels_flat, ref_obj_ids, k_nearest_neighbors, n_chunks): """Calculates the nearest neighbor features per object in chunks to save mem. Uses chunking to bound the memory use. Args: reference_embeddings_flat: Tensor of shape [n, embedding_dim], the embedding vectors for the reference frame. query_embeddings_flat: Tensor of shape [m, embedding_dim], the embedding vectors for the query frames. reference_labels_flat: Tensor of shape [n], the class labels of the reference frame. ref_obj_ids: int tensor of unique object ids in the reference labels. k_nearest_neighbors: Integer, the number of nearest neighbors to use. n_chunks: Integer, the number of chunks to use to save memory (set to 1 for no chunking). Returns: nn_features: A float32 tensor of nearest neighbor features of shape [m, n_objects, feature_dim]. """ chunk_size = tf.cast(tf.ceil(tf.cast(tf.shape(query_embeddings_flat)[0], tf.float32) / n_chunks), tf.int32) wrong_label_mask = tf.not_equal(reference_labels_flat, ref_obj_ids[:, tf.newaxis]) all_features = [] for n in range(n_chunks): if n_chunks == 1: query_embeddings_flat_chunk = query_embeddings_flat else: chunk_start = n * chunk_size chunk_end = (n + 1) * chunk_size query_embeddings_flat_chunk = query_embeddings_flat[chunk_start:chunk_end] # Use control dependencies to make sure that the chunks are not processed # in parallel which would prevent any peak memory savings. with tf.control_dependencies(all_features): features = _nn_features_per_object_for_chunk( reference_embeddings_flat, query_embeddings_flat_chunk, wrong_label_mask, k_nearest_neighbors ) all_features.append(features) if n_chunks == 1: nn_features = all_features[0] else: nn_features = tf.concat(all_features, axis=0) return nn_features def _nn_features_per_object_for_chunk( reference_embeddings, query_embeddings, wrong_label_mask, k_nearest_neighbors): """Extracts features for each object using nearest neighbor attention. Args: reference_embeddings: Tensor of shape [n_chunk, embedding_dim], the embedding vectors for the reference frame. query_embeddings: Tensor of shape [m_chunk, embedding_dim], the embedding vectors for the query frames. wrong_label_mask: k_nearest_neighbors: Integer, the number of nearest neighbors to use. Returns: nn_features: A float32 tensor of nearest neighbor features of shape [m_chunk, n_objects, feature_dim]. """ reference_embeddings_key = reference_embeddings query_embeddings_key = query_embeddings dists = flattened_pairwise_distances(reference_embeddings_key, query_embeddings_key) dists = (dists[:, tf.newaxis, :] + tf.cast(wrong_label_mask[tf.newaxis, :, :], tf.float32) * WRONG_LABEL_PADDING_DISTANCE) if k_nearest_neighbors == 1: features = tf.reduce_min(dists, axis=2, keepdims=True) else: # Find the closest k and combine them according to attention_feature_type dists, _ = tf.nn.top_k(-dists, k=k_nearest_neighbors) dists = -dists # If not enough real neighbors were found, pad with the farthest real # neighbor. valid_mask = tf.less(dists, WRONG_LABEL_PADDING_DISTANCE) masked_dists = dists * tf.cast(valid_mask, tf.float32) pad_dist = tf.tile(tf.reduce_max(masked_dists, axis=2)[..., tf.newaxis], multiples=[1, 1, k_nearest_neighbors]) dists = tf.where(valid_mask, dists, pad_dist) # take mean of distances features = tf.reduce_mean(dists, axis=2, keepdims=True) return features def create_embedding_segmentation_features(features, feature_dimension, n_layers, kernel_size, reuse, atrous_rates=None): """Extracts features which can be used to estimate the final segmentation. Args: features: input features of shape [batch, height, width, features] feature_dimension: Integer, the dimensionality used in the segmentation head layers. n_layers: Integer, the number of layers in the segmentation head. kernel_size: Integer, the kernel size used in the segmentation head. reuse: reuse mode for the variable_scope. atrous_rates: List of integers of length n_layers, the atrous rates to use. Returns: Features to be used to estimate the segmentation labels of shape [batch, height, width, embedding_seg_feat_dim]. """ if atrous_rates is None or not atrous_rates: atrous_rates = [1 for _ in range(n_layers)] assert len(atrous_rates) == n_layers with tf.variable_scope('embedding_seg', reuse=reuse): for n in range(n_layers): features = model.split_separable_conv2d( features, feature_dimension, kernel_size=kernel_size, rate=atrous_rates[n], scope='split_separable_conv2d_{}'.format(n)) return features def add_image_summaries(images, nn_features, logits, batch_size, prev_frame_nn_features=None): """Adds image summaries of input images, attention features and logits. Args: images: Image tensor of shape [batch, height, width, channels]. nn_features: Nearest neighbor attention features of shape [batch_size, height, width, n_objects, 1]. logits: Float32 tensor of logits. batch_size: Integer, the number of videos per clone per mini-batch. prev_frame_nn_features: Nearest neighbor attention features wrt. the last frame of shape [batch_size, height, width, n_objects, 1]. Can be None. """ # Separate reference and query images. reshaped_images = tf.reshape(images, tf.stack( [batch_size, -1] + resolve_shape(images)[1:])) reference_images = reshaped_images[:, 0] query_images = reshaped_images[:, 1:] query_images_reshaped = tf.reshape(query_images, tf.stack( [-1] + resolve_shape(images)[1:])) tf.summary.image('ref_images', reference_images, max_outputs=batch_size) tf.summary.image('query_images', query_images_reshaped, max_outputs=10) predictions = tf.cast( tf.argmax(logits, axis=-1), tf.uint8)[..., tf.newaxis] # Scale up so that we can actually see something. tf.summary.image('predictions', predictions * 32, max_outputs=10) # We currently only show the first dimension of the features for background # and the first foreground object. tf.summary.image('nn_fg_features', nn_features[..., 0:1, 0], max_outputs=batch_size) if prev_frame_nn_features is not None: tf.summary.image('nn_fg_features_prev', prev_frame_nn_features[..., 0:1, 0], max_outputs=batch_size) tf.summary.image('nn_bg_features', nn_features[..., 1:2, 0], max_outputs=batch_size) if prev_frame_nn_features is not None: tf.summary.image('nn_bg_features_prev', prev_frame_nn_features[..., 1:2, 0], max_outputs=batch_size) def get_embeddings(images, model_options, embedding_dimension): """Extracts embedding vectors for images. Should only be used for inference. Args: images: A tensor of shape [batch, height, width, channels]. model_options: A ModelOptions instance to configure models. embedding_dimension: Integer, the dimension of the embedding. Returns: embeddings: A tensor of shape [batch, height, width, embedding_dimension]. """ features, end_points = model.extract_features( images, model_options, is_training=False) if model_options.decoder_output_stride is not None: decoder_output_stride = min(model_options.decoder_output_stride) if model_options.crop_size is None: height = tf.shape(images)[1] width = tf.shape(images)[2] else: height, width = model_options.crop_size features = model.refine_by_decoder( features, end_points, crop_size=[height, width], decoder_output_stride=[decoder_output_stride], decoder_use_separable_conv=model_options.decoder_use_separable_conv, model_variant=model_options.model_variant, is_training=False) with tf.variable_scope('embedding'): embeddings = split_separable_conv2d_with_identity_initializer( features, embedding_dimension, scope='split_separable_conv2d') return embeddings def get_logits_with_matching(images, model_options, weight_decay=0.0001, reuse=None, is_training=False, fine_tune_batch_norm=False, reference_labels=None, batch_size=None, num_frames_per_video=None, embedding_dimension=None, max_neighbors_per_object=0, k_nearest_neighbors=1, use_softmax_feedback=True, initial_softmax_feedback=None, embedding_seg_feature_dimension=256, embedding_seg_n_layers=4, embedding_seg_kernel_size=7, embedding_seg_atrous_rates=None, normalize_nearest_neighbor_distances=True, also_attend_to_previous_frame=True, damage_initial_previous_frame_mask=False, use_local_previous_frame_attention=True, previous_frame_attention_window_size=15, use_first_frame_matching=True, also_return_embeddings=False, ref_embeddings=None): """Gets the logits by atrous/image spatial pyramid pooling using attention. Args: images: A tensor of size [batch, height, width, channels]. model_options: A ModelOptions instance to configure models. weight_decay: The weight decay for model variables. reuse: Reuse the model variables or not. is_training: Is training or not. fine_tune_batch_norm: Fine-tune the batch norm parameters or not. reference_labels: The segmentation labels of the reference frame on which attention is applied. batch_size: Integer, the number of videos on a batch num_frames_per_video: Integer, the number of frames per video embedding_dimension: Integer, the dimension of the embedding max_neighbors_per_object: Integer, the maximum number of candidates for the nearest neighbor query per object after subsampling. Can be 0 for no subsampling. k_nearest_neighbors: Integer, the number of nearest neighbors to use. use_softmax_feedback: Boolean, whether to give the softmax predictions of the last frame as additional input to the segmentation head. initial_softmax_feedback: List of Float32 tensors, or None. Can be used to initialize the softmax predictions used for the feedback loop. Only has an effect if use_softmax_feedback is True. embedding_seg_feature_dimension: Integer, the dimensionality used in the segmentation head layers. embedding_seg_n_layers: Integer, the number of layers in the segmentation head. embedding_seg_kernel_size: Integer, the kernel size used in the segmentation head. embedding_seg_atrous_rates: List of integers of length embedding_seg_n_layers, the atrous rates to use for the segmentation head. normalize_nearest_neighbor_distances: Boolean, whether to normalize the nearest neighbor distances to [0,1] using sigmoid, scale and shift. also_attend_to_previous_frame: Boolean, whether to also use nearest neighbor attention with respect to the previous frame. damage_initial_previous_frame_mask: Boolean, whether to artificially damage the initial previous frame mask. Only has an effect if also_attend_to_previous_frame is True. use_local_previous_frame_attention: Boolean, whether to restrict the previous frame attention to a local search window. Only has an effect, if also_attend_to_previous_frame is True. previous_frame_attention_window_size: Integer, the window size used for local previous frame attention, if use_local_previous_frame_attention is True. use_first_frame_matching: Boolean, whether to extract features by matching to the reference frame. This should always be true except for ablation experiments. also_return_embeddings: Boolean, whether to return the embeddings as well. ref_embeddings: Tuple of (first_frame_embeddings, previous_frame_embeddings), each of shape [batch, height, width, embedding_dimension], or None. Returns: outputs_to_logits: A map from output_type to logits. If also_return_embeddings is True, it will also return an embeddings tensor of shape [batch, height, width, embedding_dimension]. """ features, end_points = model.extract_features( images, model_options, weight_decay=weight_decay, reuse=reuse, is_training=is_training, fine_tune_batch_norm=fine_tune_batch_norm) if model_options.decoder_output_stride: decoder_output_stride = min(model_options.decoder_output_stride) if model_options.crop_size is None: height = tf.shape(images)[1] width = tf.shape(images)[2] else: height, width = model_options.crop_size decoder_height = model.scale_dimension(height, 1.0 / decoder_output_stride) decoder_width = model.scale_dimension(width, 1.0 / decoder_output_stride) features = model.refine_by_decoder( features, end_points, crop_size=[height, width], decoder_output_stride=[decoder_output_stride], decoder_use_separable_conv=model_options.decoder_use_separable_conv, model_variant=model_options.model_variant, weight_decay=weight_decay, reuse=reuse, is_training=is_training, fine_tune_batch_norm=fine_tune_batch_norm) with tf.variable_scope('embedding', reuse=reuse): embeddings = split_separable_conv2d_with_identity_initializer( features, embedding_dimension, scope='split_separable_conv2d') embeddings = tf.identity(embeddings, name='embeddings') scaled_reference_labels = tf.image.resize_nearest_neighbor( reference_labels, resolve_shape(embeddings, 4)[1:3], align_corners=True) h, w = decoder_height, decoder_width if num_frames_per_video is None: num_frames_per_video = tf.size(embeddings) // ( batch_size * h * w * embedding_dimension) new_labels_shape = tf.stack([batch_size, -1, h, w, 1]) reshaped_reference_labels = tf.reshape(scaled_reference_labels, new_labels_shape) new_embeddings_shape = tf.stack([batch_size, num_frames_per_video, h, w, embedding_dimension]) reshaped_embeddings = tf.reshape(embeddings, new_embeddings_shape) all_nn_features = [] all_ref_obj_ids = [] # To keep things simple, we do all this separate for each sequence for now. for n in range(batch_size): embedding = reshaped_embeddings[n] if ref_embeddings is None: n_chunks = 100 reference_embedding = embedding[0] if also_attend_to_previous_frame or use_softmax_feedback: queries_embedding = embedding[2:] else: queries_embedding = embedding[1:] else: if USE_CORRELATION_COST: n_chunks = 20 else: n_chunks = 500 reference_embedding = ref_embeddings[0][n] queries_embedding = embedding reference_labels = reshaped_reference_labels[n][0] nn_features_n, ref_obj_ids = nearest_neighbor_features_per_object( reference_embedding, queries_embedding, reference_labels, max_neighbors_per_object, k_nearest_neighbors, n_chunks=n_chunks) if normalize_nearest_neighbor_distances: nn_features_n = (tf.nn.sigmoid(nn_features_n) - 0.5) * 2 all_nn_features.append(nn_features_n) all_ref_obj_ids.append(ref_obj_ids) feat_dim = resolve_shape(features)[-1] features = tf.reshape(features, tf.stack( [batch_size, num_frames_per_video, h, w, feat_dim])) if ref_embeddings is None: # Strip the features for the reference frame. if also_attend_to_previous_frame or use_softmax_feedback: features = features[:, 2:] else: features = features[:, 1:] # To keep things simple, we do all this separate for each sequence for now. outputs_to_logits = {output: [] for output in model_options.outputs_to_num_classes} for n in range(batch_size): features_n = features[n] nn_features_n = all_nn_features[n] nn_features_n_tr = tf.transpose(nn_features_n, [3, 0, 1, 2, 4]) n_objs = tf.shape(nn_features_n_tr)[0] # Repeat features for every object. features_n_tiled = tf.tile(features_n[tf.newaxis], multiples=[n_objs, 1, 1, 1, 1]) prev_frame_labels = None if also_attend_to_previous_frame: prev_frame_labels = reshaped_reference_labels[n, 1] if is_training and damage_initial_previous_frame_mask: # Damage the previous frame masks. prev_frame_labels = mask_damaging.damage_masks(prev_frame_labels, dilate=False) tf.summary.image('prev_frame_labels', tf.cast(prev_frame_labels[tf.newaxis], tf.uint8) * 32) initial_softmax_feedback_n = create_initial_softmax_from_labels( prev_frame_labels, reshaped_reference_labels[n][0], decoder_output_stride=None, reduce_labels=True) elif initial_softmax_feedback is not None: initial_softmax_feedback_n = initial_softmax_feedback[n] else: initial_softmax_feedback_n = None if initial_softmax_feedback_n is None: last_softmax = tf.zeros((n_objs, h, w, 1), dtype=tf.float32) else: last_softmax = tf.transpose(initial_softmax_feedback_n, [2, 0, 1])[ ..., tf.newaxis] assert len(model_options.outputs_to_num_classes) == 1 output = model_options.outputs_to_num_classes.keys()[0] logits = [] n_ref_frames = 1 prev_frame_nn_features_n = None if also_attend_to_previous_frame or use_softmax_feedback: n_ref_frames += 1 if ref_embeddings is not None: n_ref_frames = 0 for t in range(num_frames_per_video - n_ref_frames): to_concat = [features_n_tiled[:, t]] if use_first_frame_matching: to_concat.append(nn_features_n_tr[:, t]) if use_softmax_feedback: to_concat.append(last_softmax) if also_attend_to_previous_frame: assert normalize_nearest_neighbor_distances, ( 'previous frame attention currently only works when normalized ' 'distances are used') embedding = reshaped_embeddings[n] if ref_embeddings is None: last_frame_embedding = embedding[t + 1] query_embeddings = embedding[t + 2, tf.newaxis] else: last_frame_embedding = ref_embeddings[1][0] query_embeddings = embedding if use_local_previous_frame_attention: assert query_embeddings.shape[0] == 1 prev_frame_nn_features_n = ( local_previous_frame_nearest_neighbor_features_per_object( last_frame_embedding, query_embeddings[0], prev_frame_labels, all_ref_obj_ids[n], max_distance=previous_frame_attention_window_size) ) else: prev_frame_nn_features_n, _ = ( nearest_neighbor_features_per_object( last_frame_embedding, query_embeddings, prev_frame_labels, max_neighbors_per_object, k_nearest_neighbors, gt_ids=all_ref_obj_ids[n])) prev_frame_nn_features_n = (tf.nn.sigmoid( prev_frame_nn_features_n) - 0.5) * 2 prev_frame_nn_features_n_sq = tf.squeeze(prev_frame_nn_features_n, axis=0) prev_frame_nn_features_n_tr = tf.transpose( prev_frame_nn_features_n_sq, [2, 0, 1, 3]) to_concat.append(prev_frame_nn_features_n_tr) features_n_concat_t = tf.concat(to_concat, axis=-1) embedding_seg_features_n_t = ( create_embedding_segmentation_features( features_n_concat_t, embedding_seg_feature_dimension, embedding_seg_n_layers, embedding_seg_kernel_size, reuse or n > 0, atrous_rates=embedding_seg_atrous_rates)) logits_t = model.get_branch_logits( embedding_seg_features_n_t, 1, model_options.atrous_rates, aspp_with_batch_norm=model_options.aspp_with_batch_norm, kernel_size=model_options.logits_kernel_size, weight_decay=weight_decay, reuse=reuse or n > 0 or t > 0, scope_suffix=output ) logits.append(logits_t) prev_frame_labels = tf.transpose(tf.argmax(logits_t, axis=0), [2, 0, 1]) last_softmax = tf.nn.softmax(logits_t, axis=0) logits = tf.stack(logits, axis=1) logits_shape = tf.stack( [n_objs, num_frames_per_video - n_ref_frames] + resolve_shape(logits)[2:-1]) logits_reshaped = tf.reshape(logits, logits_shape) logits_transposed = tf.transpose(logits_reshaped, [1, 2, 3, 0]) outputs_to_logits[output].append(logits_transposed) add_image_summaries( images[n * num_frames_per_video: (n+1) * num_frames_per_video], nn_features_n, logits_transposed, batch_size=1, prev_frame_nn_features=prev_frame_nn_features_n) if also_return_embeddings: return outputs_to_logits, embeddings else: return outputs_to_logits def subsample_reference_embeddings_and_labels( reference_embeddings_flat, reference_labels_flat, ref_obj_ids, max_neighbors_per_object): """Subsamples the reference embedding vectors and labels. After subsampling, at most max_neighbors_per_object items will remain per class. Args: reference_embeddings_flat: Tensor of shape [n, embedding_dim], the embedding vectors for the reference frame. reference_labels_flat: Tensor of shape [n, 1], the class labels of the reference frame. ref_obj_ids: An int32 tensor of the unique object ids present in the reference labels. max_neighbors_per_object: Integer, the maximum number of candidates for the nearest neighbor query per object after subsampling, or 0 for no subsampling. Returns: reference_embeddings_flat: Tensor of shape [n_sub, embedding_dim], the subsampled embedding vectors for the reference frame. reference_labels_flat: Tensor of shape [n_sub, 1], the class labels of the reference frame. """ if max_neighbors_per_object == 0: return reference_embeddings_flat, reference_labels_flat same_label_mask = tf.equal(reference_labels_flat[tf.newaxis, :], ref_obj_ids[:, tf.newaxis]) max_neighbors_per_object_repeated = tf.tile( tf.constant(max_neighbors_per_object)[tf.newaxis], multiples=[tf.size(ref_obj_ids)]) # Somehow map_fn on GPU caused trouble sometimes, so let's put it on CPU # for now. with tf.device('cpu:0'): subsampled_indices = tf.map_fn(_create_subsampling_mask, (same_label_mask, max_neighbors_per_object_repeated), dtype=tf.int64, name='subsample_labels_map_fn', parallel_iterations=1) mask = tf.not_equal(subsampled_indices, tf.constant(-1, dtype=tf.int64)) masked_indices = tf.boolean_mask(subsampled_indices, mask) reference_embeddings_flat = tf.gather(reference_embeddings_flat, masked_indices) reference_labels_flat = tf.gather(reference_labels_flat, masked_indices) return reference_embeddings_flat, reference_labels_flat def _create_subsampling_mask(args): """Creates boolean mask which can be used to subsample the labels. Args: args: tuple of (label_mask, max_neighbors_per_object), where label_mask is the mask to be subsampled and max_neighbors_per_object is a int scalar, the maximum number of neighbors to be retained after subsampling. Returns: The boolean mask for subsampling the labels. """ label_mask, max_neighbors_per_object = args indices = tf.squeeze(tf.where(label_mask), axis=1) shuffled_indices = tf.random_shuffle(indices) subsampled_indices = shuffled_indices[:max_neighbors_per_object] n_pad = max_neighbors_per_object - tf.size(subsampled_indices) padded_label = -1 padding = tf.fill((n_pad,), tf.constant(padded_label, dtype=tf.int64)) padded = tf.concat([subsampled_indices, padding], axis=0) return padded def conv2d_identity_initializer(scale=1.0, mean=0, stddev=3e-2): """Creates an identity initializer for TensorFlow conv2d. We add a small amount of normal noise to the initialization matrix. Code copied from lcchen@. Args: scale: The scale coefficient for the identity weight matrix. mean: A 0-D Tensor or Python value of type `dtype`. The mean of the truncated normal distribution. stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation of the truncated normal distribution. Returns: An identity initializer function for TensorFlow conv2d. """ def _initializer(shape, dtype=tf.float32, partition_info=None): """Returns the identity matrix scaled by `scale`. Args: shape: A tuple of int32 numbers indicating the shape of the initializing matrix. dtype: The data type of the initializing matrix. partition_info: (Optional) variable_scope._PartitionInfo object holding additional information about how the variable is partitioned. This input is not used in our case, but is required by TensorFlow. Returns: A identity matrix. Raises: ValueError: If len(shape) != 4, or shape[0] != shape[1], or shape[0] is not odd, or shape[1] is not odd.. """ del partition_info if len(shape) != 4: raise ValueError('Expect shape length to be 4.') if shape[0] != shape[1]: raise ValueError('Expect shape[0] = shape[1].') if shape[0] % 2 != 1: raise ValueError('Expect shape[0] to be odd value.') if shape[1] % 2 != 1: raise ValueError('Expect shape[1] to be odd value.') weights = np.zeros(shape, dtype=np.float32) center_y = shape[0] / 2 center_x = shape[1] / 2 min_channel = min(shape[2], shape[3]) for i in range(min_channel): weights[center_y, center_x, i, i] = scale return tf.constant(weights, dtype=dtype) + tf.truncated_normal( shape, mean=mean, stddev=stddev, dtype=dtype) return _initializer def split_separable_conv2d_with_identity_initializer( inputs, filters, kernel_size=3, rate=1, weight_decay=0.00004, scope=None): """Splits a separable conv2d into depthwise and pointwise conv2d. This operation differs from `tf.layers.separable_conv2d` as this operation applies activation function between depthwise and pointwise conv2d. Args: inputs: Input tensor with shape [batch, height, width, channels]. filters: Number of filters in the 1x1 pointwise convolution. kernel_size: A list of length 2: [kernel_height, kernel_width] of of the filters. Can be an int if both values are the same. rate: Atrous convolution rate for the depthwise convolution. weight_decay: The weight decay to use for regularizing the model. scope: Optional scope for the operation. Returns: Computed features after split separable conv2d. """ initializer = conv2d_identity_initializer() outputs = slim.separable_conv2d( inputs, None, kernel_size=kernel_size, depth_multiplier=1, rate=rate, weights_initializer=initializer, weights_regularizer=None, scope=scope + '_depthwise') return slim.conv2d( outputs, filters, 1, weights_initializer=initializer, weights_regularizer=slim.l2_regularizer(weight_decay), scope=scope + '_pointwise') def create_initial_softmax_from_labels(last_frame_labels, reference_labels, decoder_output_stride, reduce_labels): """Creates initial softmax predictions from last frame labels. Args: last_frame_labels: last frame labels of shape [1, height, width, 1]. reference_labels: reference frame labels of shape [1, height, width, 1]. decoder_output_stride: Integer, the stride of the decoder. Can be None, in this case it's assumed that the last_frame_labels and reference_labels are already scaled to the decoder output resolution. reduce_labels: Boolean, whether to reduce the depth of the softmax one_hot encoding to the actual number of labels present in the reference frame (otherwise the depth will be the highest label index + 1). Returns: init_softmax: the initial softmax predictions. """ if decoder_output_stride is None: labels_output_size = last_frame_labels reference_labels_output_size = reference_labels else: h = tf.shape(last_frame_labels)[1] w = tf.shape(last_frame_labels)[2] h_sub = model.scale_dimension(h, 1.0 / decoder_output_stride) w_sub = model.scale_dimension(w, 1.0 / decoder_output_stride) labels_output_size = tf.image.resize_nearest_neighbor( last_frame_labels, [h_sub, w_sub], align_corners=True) reference_labels_output_size = tf.image.resize_nearest_neighbor( reference_labels, [h_sub, w_sub], align_corners=True) if reduce_labels: unique_labels, _ = tf.unique(tf.reshape(reference_labels_output_size, [-1])) depth = tf.size(unique_labels) else: depth = tf.reduce_max(reference_labels_output_size) + 1 one_hot_assertion = tf.assert_less(tf.reduce_max(labels_output_size), depth) with tf.control_dependencies([one_hot_assertion]): init_softmax = tf.one_hot(tf.squeeze(labels_output_size, axis=-1), depth=depth, dtype=tf.float32) return init_softmax def local_previous_frame_nearest_neighbor_features_per_object( prev_frame_embedding, query_embedding, prev_frame_labels, gt_ids, max_distance=9): """Computes nearest neighbor features while only allowing local matches. Args: prev_frame_embedding: Tensor of shape [height, width, embedding_dim], the embedding vectors for the last frame. query_embedding: Tensor of shape [height, width, embedding_dim], the embedding vectors for the query frames. prev_frame_labels: Tensor of shape [height, width, 1], the class labels of the previous frame. gt_ids: Int Tensor of shape [n_objs] of the sorted unique ground truth ids in the first frame. max_distance: Integer, the maximum distance allowed for local matching. Returns: nn_features: A float32 np.array of nearest neighbor features of shape [1, height, width, n_objects, 1]. """ with tf.name_scope( 'local_previous_frame_nearest_neighbor_features_per_object'): if USE_CORRELATION_COST: tf.logging.info('Using correlation_cost.') d = local_pairwise_distances(query_embedding, prev_frame_embedding, max_distance=max_distance) else: # Slow fallback in case correlation_cost is not available. tf.logging.warn('correlation cost is not available, using slow fallback ' 'implementation.') d = local_pairwise_distances2(query_embedding, prev_frame_embedding, max_distance=max_distance) d = (tf.nn.sigmoid(d) - 0.5) * 2 height = tf.shape(prev_frame_embedding)[0] width = tf.shape(prev_frame_embedding)[1] # Create offset versions of the mask. if USE_CORRELATION_COST: # New, faster code with cross-correlation via correlation_cost. # Due to padding we have to add 1 to the labels. offset_labels = correlation_cost_op.correlation_cost( tf.ones((1, height, width, 1)), tf.cast(prev_frame_labels + 1, tf.float32)[tf.newaxis], kernel_size=1, max_displacement=max_distance, stride_1=1, stride_2=1, pad=max_distance) offset_labels = tf.squeeze(offset_labels, axis=0)[..., tf.newaxis] # Subtract the 1 again and round. offset_labels = tf.round(offset_labels - 1) offset_masks = tf.equal( offset_labels, tf.cast(gt_ids, tf.float32)[tf.newaxis, tf.newaxis, tf.newaxis, :]) else: # Slower code, without dependency to correlation_cost masks = tf.equal(prev_frame_labels, gt_ids[tf.newaxis, tf.newaxis]) padded_masks = tf.pad(masks, [[max_distance, max_distance], [max_distance, max_distance], [0, 0]]) offset_masks = [] for y_start in range(2 * max_distance + 1): y_end = y_start + height masks_slice = padded_masks[y_start:y_end] for x_start in range(2 * max_distance + 1): x_end = x_start + width offset_mask = masks_slice[:, x_start:x_end] offset_masks.append(offset_mask) offset_masks = tf.stack(offset_masks, axis=2) pad = tf.ones((height, width, (2 * max_distance + 1) ** 2, tf.size(gt_ids))) d_tiled = tf.tile(d[..., tf.newaxis], multiples=(1, 1, 1, tf.size(gt_ids))) d_masked = tf.where(offset_masks, d_tiled, pad) dists = tf.reduce_min(d_masked, axis=2) dists = tf.reshape(dists, (1, height, width, tf.size(gt_ids), 1)) return dists