NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
18.6 kB
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines convolutional model graph for Seq2Species.
Builds TensorFlow computation graph for predicting the given taxonomic target
labels from short reads of DNA using convolutional filters, followed by
fully-connected layers and a softmax output layer.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import math
import tensorflow as tf
import input as seq2species_input
import seq2label_utils
class ConvolutionalNet(object):
"""Class to build and store the model's computational graph and operations.
Attributes:
read_length: int; the length in basepairs of the input reads of DNA.
placeholders: dict; mapping from name to tf.Placeholder.
global_step: tf.Variable tracking number of training iterations performed.
train_op: operation to perform one training step by gradient descent.
summary_op: operation to log model's performance metrics to TF event files.
accuracy: tf.Variable giving the model's read-level accuracy for the
current inputs.
weighted_accuracy: tf.Variable giving the model's read-level weighted
accuracy for the current inputs.
loss: tf.Variable giving the model's current cross entropy loss.
logits: tf.Variable containing the model's logits for the current inputs.
predictions: tf.Variable containing the model's current predicted
probability distributions for the current inputs.
possible_labels: a dict of possible label values (list of strings), keyed by
target name. Labels in the lists are the order used for integer encoding.
use_tpu: whether model is to be run on TPU.
"""
def __init__(self, hparams, dataset_info, targets, use_tpu=False):
"""Initializes the ConvolutionalNet according to provided hyperparameters.
Does not build the graph---this is done by calling `build_graph` on the
constructed object or using `model_fn`.
Args:
hparams: tf.contrib.training.Hparams object containing the model's
hyperparamters; see configuration.py for hyperparameter definitions.
dataset_info: a `Seq2LabelDatasetInfo` message reflecting the dataset
metadata.
targets: list of strings: the names of the prediction targets.
use_tpu: whether we are running on TPU; if True, summaries will be
disabled.
"""
self._placeholders = {}
self._targets = targets
self._dataset_info = dataset_info
self._hparams = hparams
all_label_values = seq2label_utils.get_all_label_values(self.dataset_info)
self._possible_labels = {
target: all_label_values[target]
for target in self.targets
}
self._use_tpu = use_tpu
@property
def hparams(self):
return self._hparams
@property
def dataset_info(self):
return self._dataset_info
@property
def possible_labels(self):
return self._possible_labels
@property
def bases(self):
return seq2species_input.DNA_BASES
@property
def n_bases(self):
return seq2species_input.NUM_DNA_BASES
@property
def targets(self):
return self._targets
@property
def read_length(self):
return self.dataset_info.read_length
@property
def placeholders(self):
return self._placeholders
@property
def global_step(self):
return self._global_step
@property
def train_op(self):
return self._train_op
@property
def summary_op(self):
return self._summary_op
@property
def accuracy(self):
return self._accuracy
@property
def weighted_accuracy(self):
return self._weighted_accuracy
@property
def loss(self):
return self._loss
@property
def total_loss(self):
return self._total_loss
@property
def logits(self):
return self._logits
@property
def predictions(self):
return self._predictions
@property
def use_tpu(self):
return self._use_tpu
def _summary_scalar(self, name, scalar):
"""Adds a summary scalar, if the platform supports summaries."""
if not self.use_tpu:
return tf.summary.scalar(name, scalar)
else:
return None
def _summary_histogram(self, name, values):
"""Adds a summary histogram, if the platform supports summaries."""
if not self.use_tpu:
return tf.summary.histogram(name, values)
else:
return None
def _init_weights(self, shape, scale=1.0, name='weights'):
"""Randomly initializes a weight Tensor of the given shape.
Args:
shape: list; desired Tensor dimensions.
scale: float; standard deviation scale with which to initialize weights.
name: string name for the variable.
Returns:
TF Variable contining truncated random Normal initialized weights.
"""
num_inputs = shape[0] if len(shape) < 3 else shape[0] * shape[1] * shape[2]
stddev = scale / math.sqrt(num_inputs)
return tf.get_variable(
name,
shape=shape,
initializer=tf.truncated_normal_initializer(0., stddev))
def _init_bias(self, size):
"""Initializes bias vector of given shape as zeros.
Args:
size: int; desired size of bias Tensor.
Returns:
TF Variable containing the initialized biases.
"""
return tf.get_variable(
name='b_{}'.format(size),
shape=[size],
initializer=tf.zeros_initializer())
def _add_summaries(self, mode, gradient_norm, parameter_norm):
"""Defines TensorFlow operation for logging summaries to event files.
Args:
mode: the ModeKey string.
gradient_norm: Tensor; norm of gradients produced during the current
training operation.
parameter_norm: Tensor; norm of the model parameters produced during the
current training operation.
"""
# Log summaries for TensorBoard.
if mode == tf.estimator.ModeKeys.TRAIN:
self._summary_scalar('norm_of_gradients', gradient_norm)
self._summary_scalar('norm_of_parameters', parameter_norm)
self._summary_scalar('total_loss', self.total_loss)
self._summary_scalar('learning_rate', self._learn_rate)
for target in self.targets:
self._summary_scalar('per_read_weighted_accuracy/{}'.format(target),
self.weighted_accuracy[target])
self._summary_scalar('per_read_accuracy/{}'.format(target),
self.accuracy[target])
self._summary_histogram('prediction_frequency/{}'.format(target),
self._predictions[target])
self._summary_scalar('cross_entropy_loss/{}'.format(target),
self._loss[target])
self._summary_op = tf.summary.merge_all()
else:
# Log average performance metrics over many batches using placeholders.
summaries = []
for target in self.targets:
accuracy_ph = tf.placeholder(tf.float32, shape=())
weighted_accuracy_ph = tf.placeholder(tf.float32, shape=())
cross_entropy_ph = tf.placeholder(tf.float32, shape=())
self._placeholders.update({
'accuracy/{}'.format(target): accuracy_ph,
'weighted_accuracy/{}'.format(target): weighted_accuracy_ph,
'cross_entropy/{}'.format(target): cross_entropy_ph,
})
summaries += [
self._summary_scalar('cross_entropy_loss/{}'.format(target),
cross_entropy_ph),
self._summary_scalar('per_read_accuracy/{}'.format(target),
accuracy_ph),
self._summary_scalar('per_read_weighted_accuracy/{}'.format(target),
weighted_accuracy_ph)
]
self._summary_op = tf.summary.merge(summaries)
def _convolution(self,
inputs,
filter_dim,
pointwise_dim=None,
scale=1.0,
padding='SAME'):
"""Applies convolutional filter of given dimensions to given input Tensor.
If a pointwise dimension is specified, a depthwise separable convolution is
performed.
Args:
inputs: 4D Tensor of shape (# reads, 1, # basepairs, # bases).
filter_dim: integer tuple of the form (width, depth).
pointwise_dim: int; output dimension for pointwise convolution.
scale: float; standard deviation scale with which to initialize weights.
padding: string; type of padding to use. One of "SAME" or "VALID".
Returns:
4D Tensor result of applying the convolutional filter to the inputs.
"""
in_channels = inputs.get_shape()[3].value
filter_width, filter_depth = filter_dim
filters = self._init_weights([1, filter_width, in_channels, filter_depth],
scale)
self._summary_histogram(filters.name.split(':')[0].split('/')[1], filters)
if pointwise_dim is None:
return tf.nn.conv2d(
inputs,
filters,
strides=[1, 1, 1, 1],
padding=padding,
name='weights')
pointwise_filters = self._init_weights(
[1, 1, filter_depth * in_channels, pointwise_dim],
scale,
name='pointwise_weights')
self._summary_histogram(
pointwise_filters.name.split(':')[0].split('/')[1], pointwise_filters)
return tf.nn.separable_conv2d(
inputs,
filters,
pointwise_filters,
strides=[1, 1, 1, 1],
padding=padding)
def _pool(self, inputs, pooling_type):
"""Performs pooling across width and height of the given inputs.
Args:
inputs: Tensor shaped (batch, height, width, channels) over which to pool.
In our case, height is a unitary dimension and width can be thought of
as the read dimension.
pooling_type: string; one of "avg" or "max".
Returns:
Tensor result of performing pooling of the given pooling_type over the
height and width dimensions of the given inputs.
"""
if pooling_type == 'max':
return tf.reduce_max(inputs, axis=[1, 2])
if pooling_type == 'avg':
return tf.reduce_sum(
inputs, axis=[1, 2]) / tf.to_float(tf.shape(inputs)[2])
def _leaky_relu(self, lrelu_slope, inputs):
"""Applies leaky ReLu activation to the given inputs with the given slope.
Args:
lrelu_slope: float; slope value for the activation function.
A slope of 0.0 defines a standard ReLu activation, while a positive
slope defines a leaky ReLu.
inputs: Tensor upon which to apply the activation function.
Returns:
Tensor result of applying the activation function to the given inputs.
"""
with tf.variable_scope('leaky_relu_activation'):
return tf.maximum(lrelu_slope * inputs, inputs)
def _dropout(self, inputs, keep_prob):
"""Applies dropout to the given inputs.
Args:
inputs: Tensor upon which to apply dropout.
keep_prob: float; probability with which to randomly retain values in
the given input.
Returns:
Tensor result of applying dropout to the given inputs.
"""
with tf.variable_scope('dropout'):
if keep_prob < 1.0:
return tf.nn.dropout(inputs, keep_prob)
return inputs
def build_graph(self, features, labels, mode, batch_size):
"""Creates TensorFlow model graph.
Args:
features: a dict of input features Tensors.
labels: a dict (by target name) of prediction labels.
mode: the ModeKey string.
batch_size: the integer batch size.
Side Effect:
Adds the following key Tensors and operations as class attributes:
placeholders, global_step, train_op, summary_op, accuracy,
weighted_accuracy, loss, logits, and predictions.
"""
is_train = (mode == tf.estimator.ModeKeys.TRAIN)
read = features['sequence']
# Add a unitary dimension, so we can use conv2d.
read = tf.expand_dims(read, 1)
prev_out = read
filters = zip(self.hparams.filter_widths, self.hparams.filter_depths)
for i, f in enumerate(filters):
with tf.variable_scope('convolution_' + str(i)):
if self.hparams.use_depthwise_separable:
p = self.hparams.pointwise_depths[i]
else:
p = None
conv_out = self._convolution(
prev_out, f, pointwise_dim=p, scale=self.hparams.weight_scale)
conv_act_out = self._leaky_relu(self.hparams.lrelu_slope, conv_out)
prev_out = (
self._dropout(conv_act_out, self.hparams.keep_prob)
if is_train else conv_act_out)
for i in xrange(self.hparams.num_fc_layers):
with tf.variable_scope('fully_connected_' + str(i)):
# Create a convolutional layer which is equivalent to a fully-connected
# layer when reads have length self.hparams.min_read_length.
# The convolution will tile the layer appropriately for longer reads.
biases = self._init_bias(self.hparams.num_fc_units)
if i == 0:
# Take entire min_read_length segment as input.
# Output a single value per min_read_length_segment.
filter_dimensions = (self.hparams.min_read_length,
self.hparams.num_fc_units)
else:
# Take single output value of previous layer as input.
filter_dimensions = (1, self.hparams.num_fc_units)
fc_out = biases + self._convolution(
prev_out,
filter_dimensions,
scale=self.hparams.weight_scale,
padding='VALID')
self._summary_histogram(biases.name.split(':')[0].split('/')[1], biases)
fc_act_out = self._leaky_relu(self.hparams.lrelu_slope, fc_out)
prev_out = (
self._dropout(fc_act_out, self.hparams.keep_prob)
if is_train else fc_act_out)
# Pool to collapse tiling for reads longer than hparams.min_read_length.
with tf.variable_scope('pool'):
pool_out = self._pool(prev_out, self.hparams.pooling_type)
with tf.variable_scope('output'):
self._logits = {}
self._predictions = {}
self._weighted_accuracy = {}
self._accuracy = {}
self._loss = collections.OrderedDict()
for target in self.targets:
with tf.variable_scope(target):
label = labels[target]
possible_labels = self.possible_labels[target]
weights = self._init_weights(
[pool_out.get_shape()[1].value,
len(possible_labels)],
self.hparams.weight_scale,
name='weights')
biases = self._init_bias(len(possible_labels))
self._summary_histogram(
weights.name.split(':')[0].split('/')[1], weights)
self._summary_histogram(
biases.name.split(':')[0].split('/')[1], biases)
logits = tf.matmul(pool_out, weights) + biases
predictions = tf.nn.softmax(logits)
gather_inds = tf.stack([tf.range(batch_size), label], axis=1)
self._weighted_accuracy[target] = tf.reduce_mean(
tf.gather_nd(predictions, gather_inds))
argmax_prediction = tf.cast(tf.argmax(predictions, axis=1), tf.int32)
self._accuracy[target] = tf.reduce_mean(
tf.to_float(tf.equal(label, argmax_prediction)))
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=label, logits=logits)
self._loss[target] = tf.reduce_mean(losses)
self._logits[target] = logits
self._predictions[target] = predictions
# Compute total loss
self._total_loss = tf.add_n(self._loss.values())
# Define the optimizer.
# tf.estimator framework builds the global_step for us, but if we aren't
# using the framework we have to make it ourselves.
self._global_step = tf.train.get_or_create_global_step()
if self.hparams.lr_decay < 0:
self._learn_rate = self.hparams.lr_init
else:
self._learn_rate = tf.train.exponential_decay(
self.hparams.lr_init,
self._global_step,
int(self.hparams.train_steps),
self.hparams.lr_decay,
staircase=False)
if self.hparams.optimizer == 'adam':
opt = tf.train.AdamOptimizer(self._learn_rate, self.hparams.optimizer_hp)
elif self.hparams.optimizer == 'momentum':
opt = tf.train.MomentumOptimizer(self._learn_rate,
self.hparams.optimizer_hp)
if self.use_tpu:
opt = tf.contrib.tpu.CrossShardOptimizer(opt)
gradients, variables = zip(*opt.compute_gradients(self._total_loss))
clipped_gradients, _ = tf.clip_by_global_norm(gradients,
self.hparams.grad_clip_norm)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
self._train_op = opt.apply_gradients(
zip(clipped_gradients, variables), global_step=self._global_step)
if not self.use_tpu:
grad_norm = tf.global_norm(gradients) if is_train else None
param_norm = tf.global_norm(variables) if is_train else None
self._add_summaries(mode, grad_norm, param_norm)
def model_fn(self, features, labels, mode, params):
"""Function fulfilling the tf.estimator model_fn interface.
Args:
features: a dict containing the input features for prediction.
labels: a dict from target name to Tensor-value prediction.
mode: the ModeKey string.
params: a dictionary of parameters for building the model; current params
are params["batch_size"]: the integer batch size.
Returns:
A tf.estimator.EstimatorSpec object ready for use in training, inference.
or evaluation.
"""
self.build_graph(features, labels, mode, params['batch_size'])
return tf.estimator.EstimatorSpec(
mode,
predictions=self.predictions,
loss=self.total_loss,
train_op=self.train_op,
eval_metric_ops={})