NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
12.3 kB
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Input pipe for feeding examples to a Seq2Label model graph."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from google.protobuf import text_format
from protos import seq2label_pb2
import seq2label_utils
DNA_BASES = tuple('ACGT')
NUM_DNA_BASES = len(DNA_BASES)
# Possible FASTA characters/IUPAC ambiguity codes.
# See https://en.wikipedia.org/wiki/Nucleic_acid_notation.
AMBIGUITY_CODES = {
'K': 'GT',
'M': 'AC',
'R': 'AG',
'Y': 'CT',
'S': 'CG',
'W': 'AT',
'B': 'CGT',
'V': 'ACG',
'H': 'ACT',
'D': 'AGT',
'X': 'ACGT',
'N': 'ACGT'
}
def load_dataset_info(dataset_info_path):
"""Load a `Seq2LabelDatasetInfo` from a serialized text proto file."""
dataset_info = seq2label_pb2.Seq2LabelDatasetInfo()
with tf.gfile.Open(dataset_info_path, 'r') as f:
text_format.Parse(f.read(), dataset_info)
return dataset_info
class _InputEncoding(object):
"""A helper class providing the graph operations needed to encode input.
Instantiation of an _InputEncoding will write on the default TF graph, so it
should only be instantiated inside the `input_fn`.
Attributes:
mode: `tf.estimator.ModeKeys`; the execution mode {TRAIN, EVAL, INFER}.
targets: list of strings; the names of the labels of interest (e.g.
"species").
dna_bases: a tuple of the recognized DNA alphabet.
n_bases: the size of the DNA alphabet.
all_characters: list of recognized alphabet, including ambiguity codes.
label_values: a tuple of strings, the possible label values of the
prediction target.
n_labels: the size of label_values
fixed_read_length: an integer value of the statically-known read length, or
None if the read length is to be determined dynamically.
"""
def __init__(self,
dataset_info,
mode,
targets,
noise_rate=0.0,
fixed_read_length=None):
self.mode = mode
self.targets = targets
self.dna_bases = DNA_BASES
self.n_bases = NUM_DNA_BASES
self.all_characters = list(DNA_BASES) + sorted(AMBIGUITY_CODES.keys())
self.character_encodings = np.concatenate(
[[self._character_to_base_distribution(char)]
for char in self.all_characters],
axis=0)
all_legal_label_values = seq2label_utils.get_all_label_values(dataset_info)
# TF lookup tables.
self.characters_table = tf.contrib.lookup.index_table_from_tensor(
mapping=self.all_characters)
self.label_tables = {
target: tf.contrib.lookup.index_table_from_tensor(
all_legal_label_values[target])
for target in targets
}
self.fixed_read_length = fixed_read_length
self.noise_rate = noise_rate
def _character_to_base_distribution(self, char):
"""Maps the given character to a probability distribution over DNA bases.
Args:
char: character to be encoded as a probability distribution over bases.
Returns:
Array of size (self.n_bases,) representing the identity of the given
character as a distribution over the possible DNA bases, self.dna_bases.
Raises:
ValueError: if the given character is not contained in the recognized
alphabet, self.all_characters.
"""
if char not in self.all_characters:
raise ValueError(
'Base distribution requested for unrecognized character %s.' % char)
possible_bases = AMBIGUITY_CODES[char] if char in AMBIGUITY_CODES else char
base_indices = [self.dna_bases.index(base) for base in possible_bases]
probability_weight = 1.0 / len(possible_bases)
distribution = np.zeros((self.n_bases))
distribution[base_indices] = probability_weight
return distribution
def encode_read(self, string_seq):
"""Converts the input read sequence to one-hot encoding.
Args:
string_seq: tf.String; input read sequence.
Returns:
Input read sequence as a one-hot encoded Tensor, with depth and ordering
of one-hot encoding determined by the given bases. Ambiguous characters
such as "N" and "S" are encoded as a probability distribution over the
possible bases they represent.
"""
with tf.variable_scope('encode_read'):
read = tf.string_split([string_seq], delimiter='').values
read = self.characters_table.lookup(read)
read = tf.cast(tf.gather(self.character_encodings, read), tf.float32)
if self.fixed_read_length:
read = tf.reshape(read, (self.fixed_read_length, self.n_bases))
return read
def encode_label(self, target, string_label):
"""Converts the label value to an integer encoding.
Args:
target: str; the target name.
string_label: tf.String; value of the label for the current input read.
Returns:
Given label value as an index into the possible_target_values.
"""
with tf.variable_scope('encode_label/{}'.format(target)):
return tf.cast(self.label_tables[target].lookup(string_label), tf.int32)
def _empty_label(self):
return tf.constant((), dtype=tf.int32, shape=())
def parse_single_tfexample(self, serialized_example):
"""Parses a tf.train.Example proto to a one-hot encoded read, label pair.
Injects noise into the incoming tf.train.Example's read sequence
when noise_rate is non-zero.
Args:
serialized_example: string; the serialized tf.train.Example proto
containing the read sequence and label value of interest as
tf.FixedLenFeatures.
Returns:
Tuple (features, labels) of dicts for the input features and prediction
targets.
"""
with tf.variable_scope('parse_single_tfexample'):
features_spec = {'sequence': tf.FixedLenFeature([], tf.string)}
for target in self.targets:
features_spec[target] = tf.FixedLenFeature([], tf.string)
features = tf.parse_single_example(
serialized_example, features=features_spec)
if self.noise_rate > 0.0:
read_sequence = tf.py_func(seq2label_utils.add_read_noise,
[features['sequence'], self.noise_rate],
(tf.string))
else:
read_sequence = features['sequence']
read_sequence = self.encode_read(read_sequence)
read_features = {'sequence': read_sequence}
if self.mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
label = {
target: self.encode_label(target, features[target])
for target in self.targets
}
else:
label = {target: self._empty_label() for target in self.targets}
return read_features, label
class InputDataset(object):
"""A class providing access to input data for the Seq2Label model.
Attributes:
mode: `tf.estimator.ModeKeys`; the execution mode {TRAIN, EVAL, INFER}.
targets: list of strings; the names of the labels of interest (e.g.
"species").
dataset_info: a `Seq2LabelDatasetInfo` message reflecting the dataset
metadata.
initializer: the TF initializer op for the underlying iterator, which
will rewind the iterator.
is_train: Boolean indicating whether or not the execution mode is TRAIN.
"""
def __init__(self,
mode,
targets,
dataset_info,
train_epochs=None,
noise_rate=0.0,
random_seed=None,
input_tfrecord_files=None,
fixed_read_length=None,
ensure_constant_batch_size=False,
num_parallel_calls=32):
"""Constructor for InputDataset.
Args:
mode: `tf.estimator.ModeKeys`; the execution mode {TRAIN, EVAL, INFER}.
targets: list of strings; the names of the labels of interest (e.g.
"species").
dataset_info: a `Seq2LabelDatasetInfo` message reflecting the dataset
metadata.
train_epochs: the number of training epochs to perform, if mode==TRAIN.
noise_rate: float [0.0, 1.0] specifying rate at which to inject
base-flipping noise into the read sequences.
random_seed: seed to be used for shuffling, if mode==TRAIN.
input_tfrecord_files: a list of filenames for TFRecords of TF examples.
fixed_read_length: an integer value of the statically-known read length,
or None if the read length is to be determined dynamically. The read
length must be known statically for TPU execution.
ensure_constant_batch_size: ensure a constant batch size at the expense of
discarding the last "short" batch. This also gives us a statically
constant batch size, which is essential for e.g. the TPU platform.
num_parallel_calls: the number of dataset elements to process in parallel.
If None, elements will be processed sequentially.
"""
self.input_tfrecord_files = input_tfrecord_files
self.mode = mode
self.targets = targets
self.dataset_info = dataset_info
self._train_epochs = train_epochs
self._noise_rate = noise_rate
self._random_seed = random_seed
if random_seed is not None:
np.random.seed(random_seed)
self._fixed_read_length = fixed_read_length
self._ensure_constant_batch_size = ensure_constant_batch_size
self._num_parallel_calls = num_parallel_calls
@staticmethod
def from_tfrecord_files(input_tfrecord_files, *args, **kwargs):
return InputDataset(
*args, input_tfrecord_files=input_tfrecord_files, **kwargs)
@property
def is_train(self):
return self.mode == tf.estimator.ModeKeys.TRAIN
def input_fn(self, params):
"""Supplies input for the model.
This function supplies input to our model as a function of the mode.
Args:
params: a dictionary, containing:
- params['batch_size']: the integer batch size.
Returns:
A tuple of two values as follows:
1) the *features* dict, containing a tensor value for keys as follows:
- "sequence" - the encoded read input sequence.
2) the *labels* dict. containing a key for `target`, whose value is:
- a string Tensor value (in TRAIN/EVAL mode), or
- a blank Tensor (PREDICT mode).
"""
randomize_input = self.is_train
batch_size = params['batch_size']
encoding = _InputEncoding(
self.dataset_info,
self.mode,
self.targets,
noise_rate=self._noise_rate,
fixed_read_length=self._fixed_read_length)
dataset = tf.data.TFRecordDataset(self.input_tfrecord_files)
dataset = dataset.map(
encoding.parse_single_tfexample,
num_parallel_calls=self._num_parallel_calls)
dataset = dataset.repeat(self._train_epochs if self.is_train else 1)
if randomize_input:
dataset = dataset.shuffle(
buffer_size=max(1000, batch_size), seed=self._random_seed)
if self._ensure_constant_batch_size:
# Only take batches of *exactly* size batch_size; then we get a
# statically knowable batch shape.
dataset = dataset.batch(batch_size, drop_remainder=True)
else:
dataset = dataset.batch(batch_size)
# Prefetch to allow infeed to be in parallel with model computations.
dataset = dataset.prefetch(2)
# Use initializable iterator to support table lookups.
iterator = dataset.make_initializable_iterator()
self.initializer = iterator.initializer
tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
features, labels = iterator.get_next()
return (features, labels)