File size: 18,567 Bytes
0b8359d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Defines convolutional model graph for Seq2Species.

Builds TensorFlow computation graph for predicting the given taxonomic target
labels from short reads of DNA using convolutional filters, followed by
fully-connected layers and a softmax output layer.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import math

import tensorflow as tf

import input as seq2species_input
import seq2label_utils


class ConvolutionalNet(object):
  """Class to build and store the model's computational graph and operations.

  Attributes:
    read_length: int; the length in basepairs of the input reads of DNA.
    placeholders: dict; mapping from name to tf.Placeholder.
    global_step: tf.Variable tracking number of training iterations performed.
    train_op: operation to perform one training step by gradient descent.
    summary_op: operation to log model's performance metrics to TF event files.
    accuracy: tf.Variable giving the model's read-level accuracy for the
      current inputs.
    weighted_accuracy: tf.Variable giving the model's read-level weighted
      accuracy for the current inputs.
    loss: tf.Variable giving the model's current cross entropy loss.
    logits: tf.Variable containing the model's logits for the current inputs.
    predictions: tf.Variable containing the model's current predicted
      probability distributions for the current inputs.
    possible_labels: a dict of possible label values (list of strings), keyed by
      target name.  Labels in the lists are the order used for integer encoding.
    use_tpu: whether model is to be run on TPU.
  """

  def __init__(self, hparams, dataset_info, targets, use_tpu=False):
    """Initializes the ConvolutionalNet according to provided hyperparameters.

    Does not build the graph---this is done by calling `build_graph` on the
    constructed object or using `model_fn`.

    Args:
      hparams: tf.contrib.training.Hparams object containing the model's
        hyperparamters; see configuration.py for hyperparameter definitions.
      dataset_info: a `Seq2LabelDatasetInfo` message reflecting the dataset
        metadata.
      targets: list of strings: the names of the prediction targets.
      use_tpu: whether we are running on TPU; if True, summaries will be
        disabled.
    """
    self._placeholders = {}
    self._targets = targets
    self._dataset_info = dataset_info
    self._hparams = hparams
    all_label_values = seq2label_utils.get_all_label_values(self.dataset_info)
    self._possible_labels = {
        target: all_label_values[target]
        for target in self.targets
    }
    self._use_tpu = use_tpu

  @property
  def hparams(self):
    return self._hparams

  @property
  def dataset_info(self):
    return self._dataset_info

  @property
  def possible_labels(self):
    return self._possible_labels

  @property
  def bases(self):
    return seq2species_input.DNA_BASES

  @property
  def n_bases(self):
    return seq2species_input.NUM_DNA_BASES

  @property
  def targets(self):
    return self._targets

  @property
  def read_length(self):
    return self.dataset_info.read_length

  @property
  def placeholders(self):
    return self._placeholders

  @property
  def global_step(self):
    return self._global_step

  @property
  def train_op(self):
    return self._train_op

  @property
  def summary_op(self):
    return self._summary_op

  @property
  def accuracy(self):
    return self._accuracy

  @property
  def weighted_accuracy(self):
    return self._weighted_accuracy

  @property
  def loss(self):
    return self._loss

  @property
  def total_loss(self):
    return self._total_loss

  @property
  def logits(self):
    return self._logits

  @property
  def predictions(self):
    return self._predictions

  @property
  def use_tpu(self):
    return self._use_tpu

  def _summary_scalar(self, name, scalar):
    """Adds a summary scalar, if the platform supports summaries."""
    if not self.use_tpu:
      return tf.summary.scalar(name, scalar)
    else:
      return None

  def _summary_histogram(self, name, values):
    """Adds a summary histogram, if the platform supports summaries."""
    if not self.use_tpu:
      return tf.summary.histogram(name, values)
    else:
      return None

  def _init_weights(self, shape, scale=1.0, name='weights'):
    """Randomly initializes a weight Tensor of the given shape.

    Args:
      shape: list; desired Tensor dimensions.
      scale: float; standard deviation scale with which to initialize weights.
      name: string name for the variable.

    Returns:
      TF Variable contining truncated random Normal initialized weights.
    """
    num_inputs = shape[0] if len(shape) < 3 else shape[0] * shape[1] * shape[2]
    stddev = scale / math.sqrt(num_inputs)
    return tf.get_variable(
        name,
        shape=shape,
        initializer=tf.truncated_normal_initializer(0., stddev))

  def _init_bias(self, size):
    """Initializes bias vector of given shape as zeros.

    Args:
      size: int; desired size of bias Tensor.

    Returns:
      TF Variable containing the initialized biases.
    """
    return tf.get_variable(
        name='b_{}'.format(size),
        shape=[size],
        initializer=tf.zeros_initializer())

  def _add_summaries(self, mode, gradient_norm, parameter_norm):
    """Defines TensorFlow operation for logging summaries to event files.

    Args:
      mode: the ModeKey string.
      gradient_norm: Tensor; norm of gradients produced during the current
        training operation.
      parameter_norm: Tensor; norm of the model parameters produced during the
        current training operation.
    """
    # Log summaries for TensorBoard.
    if mode == tf.estimator.ModeKeys.TRAIN:
      self._summary_scalar('norm_of_gradients', gradient_norm)
      self._summary_scalar('norm_of_parameters', parameter_norm)
      self._summary_scalar('total_loss', self.total_loss)
      self._summary_scalar('learning_rate', self._learn_rate)
      for target in self.targets:
        self._summary_scalar('per_read_weighted_accuracy/{}'.format(target),
                             self.weighted_accuracy[target])
        self._summary_scalar('per_read_accuracy/{}'.format(target),
                             self.accuracy[target])
        self._summary_histogram('prediction_frequency/{}'.format(target),
                                self._predictions[target])
        self._summary_scalar('cross_entropy_loss/{}'.format(target),
                             self._loss[target])
      self._summary_op = tf.summary.merge_all()
    else:
      # Log average performance metrics over many batches using placeholders.
      summaries = []
      for target in self.targets:
        accuracy_ph = tf.placeholder(tf.float32, shape=())
        weighted_accuracy_ph = tf.placeholder(tf.float32, shape=())
        cross_entropy_ph = tf.placeholder(tf.float32, shape=())
        self._placeholders.update({
            'accuracy/{}'.format(target): accuracy_ph,
            'weighted_accuracy/{}'.format(target): weighted_accuracy_ph,
            'cross_entropy/{}'.format(target): cross_entropy_ph,
        })
        summaries += [
            self._summary_scalar('cross_entropy_loss/{}'.format(target),
                                 cross_entropy_ph),
            self._summary_scalar('per_read_accuracy/{}'.format(target),
                                 accuracy_ph),
            self._summary_scalar('per_read_weighted_accuracy/{}'.format(target),
                                 weighted_accuracy_ph)
        ]

      self._summary_op = tf.summary.merge(summaries)

  def _convolution(self,
                   inputs,
                   filter_dim,
                   pointwise_dim=None,
                   scale=1.0,
                   padding='SAME'):
    """Applies convolutional filter of given dimensions to given input Tensor.

    If a pointwise dimension is specified, a depthwise separable convolution is
    performed.

    Args:
      inputs: 4D Tensor of shape (# reads, 1, # basepairs, # bases).
      filter_dim: integer tuple of the form (width, depth).
      pointwise_dim: int; output dimension for pointwise convolution.
      scale: float; standard deviation scale with which to initialize weights.
      padding: string; type of padding to use. One of "SAME" or "VALID".

    Returns:
      4D Tensor result of applying the convolutional filter to the inputs.
    """
    in_channels = inputs.get_shape()[3].value
    filter_width, filter_depth = filter_dim
    filters = self._init_weights([1, filter_width, in_channels, filter_depth],
                                 scale)
    self._summary_histogram(filters.name.split(':')[0].split('/')[1], filters)
    if pointwise_dim is None:
      return tf.nn.conv2d(
          inputs,
          filters,
          strides=[1, 1, 1, 1],
          padding=padding,
          name='weights')
    pointwise_filters = self._init_weights(
        [1, 1, filter_depth * in_channels, pointwise_dim],
        scale,
        name='pointwise_weights')
    self._summary_histogram(
        pointwise_filters.name.split(':')[0].split('/')[1], pointwise_filters)
    return tf.nn.separable_conv2d(
        inputs,
        filters,
        pointwise_filters,
        strides=[1, 1, 1, 1],
        padding=padding)

  def _pool(self, inputs, pooling_type):
    """Performs pooling across width and height of the given inputs.

    Args:
      inputs: Tensor shaped (batch, height, width, channels) over which to pool.
        In our case, height is a unitary dimension and width can be thought of
        as the read dimension.
      pooling_type: string; one of "avg" or "max".

    Returns:
      Tensor result of performing pooling of the given pooling_type over the
      height and width dimensions of the given inputs.
    """
    if pooling_type == 'max':
      return tf.reduce_max(inputs, axis=[1, 2])
    if pooling_type == 'avg':
      return tf.reduce_sum(
          inputs, axis=[1, 2]) / tf.to_float(tf.shape(inputs)[2])

  def _leaky_relu(self, lrelu_slope, inputs):
    """Applies leaky ReLu activation to the given inputs with the given slope.

    Args:
      lrelu_slope: float; slope value for the activation function.
        A slope of 0.0 defines a standard ReLu activation, while a positive
        slope defines a leaky ReLu.
      inputs: Tensor upon which to apply the activation function.

    Returns:
      Tensor result of applying the activation function to the given inputs.
    """
    with tf.variable_scope('leaky_relu_activation'):
      return tf.maximum(lrelu_slope * inputs, inputs)

  def _dropout(self, inputs, keep_prob):
    """Applies dropout to the given inputs.

    Args:
      inputs: Tensor upon which to apply dropout.
      keep_prob: float; probability with which to randomly retain values in
        the given input.

    Returns:
      Tensor result of applying dropout to the given inputs.
    """
    with tf.variable_scope('dropout'):
      if keep_prob < 1.0:
        return tf.nn.dropout(inputs, keep_prob)
      return inputs

  def build_graph(self, features, labels, mode, batch_size):
    """Creates TensorFlow model graph.

    Args:
      features: a dict of input features Tensors.
      labels: a dict (by target name) of prediction labels.
      mode: the ModeKey string.
      batch_size: the integer batch size.

    Side Effect:
      Adds the following key Tensors and operations as class attributes:
        placeholders, global_step, train_op, summary_op, accuracy,
        weighted_accuracy, loss, logits, and predictions.
    """
    is_train = (mode == tf.estimator.ModeKeys.TRAIN)
    read = features['sequence']

    # Add a unitary dimension, so we can use conv2d.
    read = tf.expand_dims(read, 1)
    prev_out = read

    filters = zip(self.hparams.filter_widths, self.hparams.filter_depths)
    for i, f in enumerate(filters):
      with tf.variable_scope('convolution_' + str(i)):
        if self.hparams.use_depthwise_separable:
          p = self.hparams.pointwise_depths[i]
        else:
          p = None
        conv_out = self._convolution(
            prev_out, f, pointwise_dim=p, scale=self.hparams.weight_scale)
        conv_act_out = self._leaky_relu(self.hparams.lrelu_slope, conv_out)
        prev_out = (
            self._dropout(conv_act_out, self.hparams.keep_prob)
            if is_train else conv_act_out)

    for i in xrange(self.hparams.num_fc_layers):
      with tf.variable_scope('fully_connected_' + str(i)):
        # Create a convolutional layer which is equivalent to a fully-connected
        # layer when reads have length self.hparams.min_read_length.
        # The convolution will tile the layer appropriately for longer reads.
        biases = self._init_bias(self.hparams.num_fc_units)
        if i == 0:
          # Take entire min_read_length segment as input.
          # Output a single value per min_read_length_segment.
          filter_dimensions = (self.hparams.min_read_length,
                               self.hparams.num_fc_units)
        else:
          # Take single output value of previous layer as input.
          filter_dimensions = (1, self.hparams.num_fc_units)
        fc_out = biases + self._convolution(
            prev_out,
            filter_dimensions,
            scale=self.hparams.weight_scale,
            padding='VALID')
        self._summary_histogram(biases.name.split(':')[0].split('/')[1], biases)
        fc_act_out = self._leaky_relu(self.hparams.lrelu_slope, fc_out)
        prev_out = (
            self._dropout(fc_act_out, self.hparams.keep_prob)
            if is_train else fc_act_out)

    # Pool to collapse tiling for reads longer than hparams.min_read_length.
    with tf.variable_scope('pool'):
      pool_out = self._pool(prev_out, self.hparams.pooling_type)

    with tf.variable_scope('output'):
      self._logits = {}
      self._predictions = {}
      self._weighted_accuracy = {}
      self._accuracy = {}
      self._loss = collections.OrderedDict()

      for target in self.targets:
        with tf.variable_scope(target):
          label = labels[target]
          possible_labels = self.possible_labels[target]
          weights = self._init_weights(
              [pool_out.get_shape()[1].value,
               len(possible_labels)],
              self.hparams.weight_scale,
              name='weights')
          biases = self._init_bias(len(possible_labels))
          self._summary_histogram(
              weights.name.split(':')[0].split('/')[1], weights)
          self._summary_histogram(
              biases.name.split(':')[0].split('/')[1], biases)
          logits = tf.matmul(pool_out, weights) + biases
          predictions = tf.nn.softmax(logits)

          gather_inds = tf.stack([tf.range(batch_size), label], axis=1)
          self._weighted_accuracy[target] = tf.reduce_mean(
              tf.gather_nd(predictions, gather_inds))
          argmax_prediction = tf.cast(tf.argmax(predictions, axis=1), tf.int32)
          self._accuracy[target] = tf.reduce_mean(
              tf.to_float(tf.equal(label, argmax_prediction)))

          losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
              labels=label, logits=logits)
          self._loss[target] = tf.reduce_mean(losses)
          self._logits[target] = logits
          self._predictions[target] = predictions

    # Compute total loss
    self._total_loss = tf.add_n(self._loss.values())

    # Define the optimizer.

    # tf.estimator framework builds the global_step for us, but if we aren't
    # using the framework we have to make it ourselves.
    self._global_step = tf.train.get_or_create_global_step()
    if self.hparams.lr_decay < 0:
      self._learn_rate = self.hparams.lr_init
    else:
      self._learn_rate = tf.train.exponential_decay(
          self.hparams.lr_init,
          self._global_step,
          int(self.hparams.train_steps),
          self.hparams.lr_decay,
          staircase=False)
    if self.hparams.optimizer == 'adam':
      opt = tf.train.AdamOptimizer(self._learn_rate, self.hparams.optimizer_hp)
    elif self.hparams.optimizer == 'momentum':
      opt = tf.train.MomentumOptimizer(self._learn_rate,
                                       self.hparams.optimizer_hp)
    if self.use_tpu:
      opt = tf.contrib.tpu.CrossShardOptimizer(opt)

    gradients, variables = zip(*opt.compute_gradients(self._total_loss))
    clipped_gradients, _ = tf.clip_by_global_norm(gradients,
                                                  self.hparams.grad_clip_norm)
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
      self._train_op = opt.apply_gradients(
          zip(clipped_gradients, variables), global_step=self._global_step)

    if not self.use_tpu:
      grad_norm = tf.global_norm(gradients) if is_train else None
      param_norm = tf.global_norm(variables) if is_train else None
      self._add_summaries(mode, grad_norm, param_norm)

  def model_fn(self, features, labels, mode, params):
    """Function fulfilling the tf.estimator model_fn interface.

    Args:
      features: a dict containing the input features for prediction.
      labels: a dict from target name to Tensor-value prediction.
      mode: the ModeKey string.
      params: a dictionary of parameters for building the model; current params
        are params["batch_size"]: the integer batch size.

    Returns:
      A tf.estimator.EstimatorSpec object ready for use in training, inference.
      or evaluation.
    """
    self.build_graph(features, labels, mode, params['batch_size'])

    return tf.estimator.EstimatorSpec(
        mode,
        predictions=self.predictions,
        loss=self.total_loss,
        train_op=self.train_op,
        eval_metric_ops={})