NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
9.38 kB
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Cell structure used by NAS."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from six.moves import range
from six.moves import zip
import tensorflow as tf
from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib import slim as contrib_slim
from deeplab.core import xception as xception_utils
from deeplab.core.utils import resize_bilinear
from deeplab.core.utils import scale_dimension
from tensorflow.contrib.slim.nets import resnet_utils
arg_scope = contrib_framework.arg_scope
slim = contrib_slim
separable_conv2d_same = functools.partial(xception_utils.separable_conv2d_same,
regularize_depthwise=True)
class NASBaseCell(object):
"""NASNet Cell class that is used as a 'layer' in image architectures."""
def __init__(self, num_conv_filters, operations, used_hiddenstates,
hiddenstate_indices, drop_path_keep_prob, total_num_cells,
total_training_steps, batch_norm_fn=slim.batch_norm):
"""Init function.
For more details about NAS cell, see
https://arxiv.org/abs/1707.07012 and https://arxiv.org/abs/1712.00559.
Args:
num_conv_filters: The number of filters for each convolution operation.
operations: List of operations that are performed in the NASNet Cell in
order.
used_hiddenstates: Binary array that signals if the hiddenstate was used
within the cell. This is used to determine what outputs of the cell
should be concatenated together.
hiddenstate_indices: Determines what hiddenstates should be combined
together with the specified operations to create the NASNet cell.
drop_path_keep_prob: Float, drop path keep probability.
total_num_cells: Integer, total number of cells.
total_training_steps: Integer, total training steps.
batch_norm_fn: Function, batch norm function. Defaults to
slim.batch_norm.
"""
if len(hiddenstate_indices) != len(operations):
raise ValueError(
'Number of hiddenstate_indices and operations should be the same.')
if len(operations) % 2:
raise ValueError('Number of operations should be even.')
self._num_conv_filters = num_conv_filters
self._operations = operations
self._used_hiddenstates = used_hiddenstates
self._hiddenstate_indices = hiddenstate_indices
self._drop_path_keep_prob = drop_path_keep_prob
self._total_num_cells = total_num_cells
self._total_training_steps = total_training_steps
self._batch_norm_fn = batch_norm_fn
def __call__(self, net, scope, filter_scaling, stride, prev_layer, cell_num):
"""Runs the conv cell."""
self._cell_num = cell_num
self._filter_scaling = filter_scaling
self._filter_size = int(self._num_conv_filters * filter_scaling)
with tf.variable_scope(scope):
net = self._cell_base(net, prev_layer)
for i in range(len(self._operations) // 2):
with tf.variable_scope('comb_iter_{}'.format(i)):
h1 = net[self._hiddenstate_indices[i * 2]]
h2 = net[self._hiddenstate_indices[i * 2 + 1]]
with tf.variable_scope('left'):
h1 = self._apply_conv_operation(
h1, self._operations[i * 2], stride,
self._hiddenstate_indices[i * 2] < 2)
with tf.variable_scope('right'):
h2 = self._apply_conv_operation(
h2, self._operations[i * 2 + 1], stride,
self._hiddenstate_indices[i * 2 + 1] < 2)
with tf.variable_scope('combine'):
h = h1 + h2
net.append(h)
with tf.variable_scope('cell_output'):
net = self._combine_unused_states(net)
return net
def _cell_base(self, net, prev_layer):
"""Runs the beginning of the conv cell before the chosen ops are run."""
filter_size = self._filter_size
if prev_layer is None:
prev_layer = net
else:
if net.shape[2] != prev_layer.shape[2]:
prev_layer = resize_bilinear(
prev_layer, tf.shape(net)[1:3], prev_layer.dtype)
if filter_size != prev_layer.shape[3]:
prev_layer = tf.nn.relu(prev_layer)
prev_layer = slim.conv2d(prev_layer, filter_size, 1, scope='prev_1x1')
prev_layer = self._batch_norm_fn(prev_layer, scope='prev_bn')
net = tf.nn.relu(net)
net = slim.conv2d(net, filter_size, 1, scope='1x1')
net = self._batch_norm_fn(net, scope='beginning_bn')
net = tf.split(axis=3, num_or_size_splits=1, value=net)
net.append(prev_layer)
return net
def _apply_conv_operation(self, net, operation, stride,
is_from_original_input):
"""Applies the predicted conv operation to net."""
if stride > 1 and not is_from_original_input:
stride = 1
input_filters = net.shape[3]
filter_size = self._filter_size
if 'separable' in operation:
num_layers = int(operation.split('_')[-1])
kernel_size = int(operation.split('x')[0][-1])
for layer_num in range(num_layers):
net = tf.nn.relu(net)
net = separable_conv2d_same(
net,
filter_size,
kernel_size,
depth_multiplier=1,
scope='separable_{0}x{0}_{1}'.format(kernel_size, layer_num + 1),
stride=stride)
net = self._batch_norm_fn(
net, scope='bn_sep_{0}x{0}_{1}'.format(kernel_size, layer_num + 1))
stride = 1
elif 'atrous' in operation:
kernel_size = int(operation.split('x')[0][-1])
net = tf.nn.relu(net)
if stride == 2:
scaled_height = scale_dimension(tf.shape(net)[1], 0.5)
scaled_width = scale_dimension(tf.shape(net)[2], 0.5)
net = resize_bilinear(net, [scaled_height, scaled_width], net.dtype)
net = resnet_utils.conv2d_same(
net, filter_size, kernel_size, rate=1, stride=1,
scope='atrous_{0}x{0}'.format(kernel_size))
else:
net = resnet_utils.conv2d_same(
net, filter_size, kernel_size, rate=2, stride=1,
scope='atrous_{0}x{0}'.format(kernel_size))
net = self._batch_norm_fn(net, scope='bn_atr_{0}x{0}'.format(kernel_size))
elif operation in ['none']:
if stride > 1 or (input_filters != filter_size):
net = tf.nn.relu(net)
net = slim.conv2d(net, filter_size, 1, stride=stride, scope='1x1')
net = self._batch_norm_fn(net, scope='bn_1')
elif 'pool' in operation:
pooling_type = operation.split('_')[0]
pooling_shape = int(operation.split('_')[-1].split('x')[0])
if pooling_type == 'avg':
net = slim.avg_pool2d(net, pooling_shape, stride=stride, padding='SAME')
elif pooling_type == 'max':
net = slim.max_pool2d(net, pooling_shape, stride=stride, padding='SAME')
else:
raise ValueError('Unimplemented pooling type: ', pooling_type)
if input_filters != filter_size:
net = slim.conv2d(net, filter_size, 1, stride=1, scope='1x1')
net = self._batch_norm_fn(net, scope='bn_1')
else:
raise ValueError('Unimplemented operation', operation)
if operation != 'none':
net = self._apply_drop_path(net)
return net
def _combine_unused_states(self, net):
"""Concatenates the unused hidden states of the cell."""
used_hiddenstates = self._used_hiddenstates
states_to_combine = ([
h for h, is_used in zip(net, used_hiddenstates) if not is_used])
net = tf.concat(values=states_to_combine, axis=3)
return net
@contrib_framework.add_arg_scope
def _apply_drop_path(self, net):
"""Apply drop_path regularization."""
drop_path_keep_prob = self._drop_path_keep_prob
if drop_path_keep_prob < 1.0:
# Scale keep prob by layer number.
assert self._cell_num != -1
layer_ratio = (self._cell_num + 1) / float(self._total_num_cells)
drop_path_keep_prob = 1 - layer_ratio * (1 - drop_path_keep_prob)
# Decrease keep prob over time.
current_step = tf.cast(tf.train.get_or_create_global_step(), tf.float32)
current_ratio = tf.minimum(1.0, current_step / self._total_training_steps)
drop_path_keep_prob = (1 - current_ratio * (1 - drop_path_keep_prob))
# Drop path.
noise_shape = [tf.shape(net)[0], 1, 1, 1]
random_tensor = drop_path_keep_prob
random_tensor += tf.random_uniform(noise_shape, dtype=tf.float32)
binary_tensor = tf.cast(tf.floor(random_tensor), net.dtype)
keep_prob_inv = tf.cast(1.0 / drop_path_keep_prob, net.dtype)
net = net * keep_prob_inv * binary_tensor
return net