Spaces:
Running
Running
# Copyright 2018 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Input pipe for feeding examples to a Seq2Label model graph.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import numpy as np | |
import tensorflow as tf | |
from google.protobuf import text_format | |
from protos import seq2label_pb2 | |
import seq2label_utils | |
DNA_BASES = tuple('ACGT') | |
NUM_DNA_BASES = len(DNA_BASES) | |
# Possible FASTA characters/IUPAC ambiguity codes. | |
# See https://en.wikipedia.org/wiki/Nucleic_acid_notation. | |
AMBIGUITY_CODES = { | |
'K': 'GT', | |
'M': 'AC', | |
'R': 'AG', | |
'Y': 'CT', | |
'S': 'CG', | |
'W': 'AT', | |
'B': 'CGT', | |
'V': 'ACG', | |
'H': 'ACT', | |
'D': 'AGT', | |
'X': 'ACGT', | |
'N': 'ACGT' | |
} | |
def load_dataset_info(dataset_info_path): | |
"""Load a `Seq2LabelDatasetInfo` from a serialized text proto file.""" | |
dataset_info = seq2label_pb2.Seq2LabelDatasetInfo() | |
with tf.gfile.Open(dataset_info_path, 'r') as f: | |
text_format.Parse(f.read(), dataset_info) | |
return dataset_info | |
class _InputEncoding(object): | |
"""A helper class providing the graph operations needed to encode input. | |
Instantiation of an _InputEncoding will write on the default TF graph, so it | |
should only be instantiated inside the `input_fn`. | |
Attributes: | |
mode: `tf.estimator.ModeKeys`; the execution mode {TRAIN, EVAL, INFER}. | |
targets: list of strings; the names of the labels of interest (e.g. | |
"species"). | |
dna_bases: a tuple of the recognized DNA alphabet. | |
n_bases: the size of the DNA alphabet. | |
all_characters: list of recognized alphabet, including ambiguity codes. | |
label_values: a tuple of strings, the possible label values of the | |
prediction target. | |
n_labels: the size of label_values | |
fixed_read_length: an integer value of the statically-known read length, or | |
None if the read length is to be determined dynamically. | |
""" | |
def __init__(self, | |
dataset_info, | |
mode, | |
targets, | |
noise_rate=0.0, | |
fixed_read_length=None): | |
self.mode = mode | |
self.targets = targets | |
self.dna_bases = DNA_BASES | |
self.n_bases = NUM_DNA_BASES | |
self.all_characters = list(DNA_BASES) + sorted(AMBIGUITY_CODES.keys()) | |
self.character_encodings = np.concatenate( | |
[[self._character_to_base_distribution(char)] | |
for char in self.all_characters], | |
axis=0) | |
all_legal_label_values = seq2label_utils.get_all_label_values(dataset_info) | |
# TF lookup tables. | |
self.characters_table = tf.contrib.lookup.index_table_from_tensor( | |
mapping=self.all_characters) | |
self.label_tables = { | |
target: tf.contrib.lookup.index_table_from_tensor( | |
all_legal_label_values[target]) | |
for target in targets | |
} | |
self.fixed_read_length = fixed_read_length | |
self.noise_rate = noise_rate | |
def _character_to_base_distribution(self, char): | |
"""Maps the given character to a probability distribution over DNA bases. | |
Args: | |
char: character to be encoded as a probability distribution over bases. | |
Returns: | |
Array of size (self.n_bases,) representing the identity of the given | |
character as a distribution over the possible DNA bases, self.dna_bases. | |
Raises: | |
ValueError: if the given character is not contained in the recognized | |
alphabet, self.all_characters. | |
""" | |
if char not in self.all_characters: | |
raise ValueError( | |
'Base distribution requested for unrecognized character %s.' % char) | |
possible_bases = AMBIGUITY_CODES[char] if char in AMBIGUITY_CODES else char | |
base_indices = [self.dna_bases.index(base) for base in possible_bases] | |
probability_weight = 1.0 / len(possible_bases) | |
distribution = np.zeros((self.n_bases)) | |
distribution[base_indices] = probability_weight | |
return distribution | |
def encode_read(self, string_seq): | |
"""Converts the input read sequence to one-hot encoding. | |
Args: | |
string_seq: tf.String; input read sequence. | |
Returns: | |
Input read sequence as a one-hot encoded Tensor, with depth and ordering | |
of one-hot encoding determined by the given bases. Ambiguous characters | |
such as "N" and "S" are encoded as a probability distribution over the | |
possible bases they represent. | |
""" | |
with tf.variable_scope('encode_read'): | |
read = tf.string_split([string_seq], delimiter='').values | |
read = self.characters_table.lookup(read) | |
read = tf.cast(tf.gather(self.character_encodings, read), tf.float32) | |
if self.fixed_read_length: | |
read = tf.reshape(read, (self.fixed_read_length, self.n_bases)) | |
return read | |
def encode_label(self, target, string_label): | |
"""Converts the label value to an integer encoding. | |
Args: | |
target: str; the target name. | |
string_label: tf.String; value of the label for the current input read. | |
Returns: | |
Given label value as an index into the possible_target_values. | |
""" | |
with tf.variable_scope('encode_label/{}'.format(target)): | |
return tf.cast(self.label_tables[target].lookup(string_label), tf.int32) | |
def _empty_label(self): | |
return tf.constant((), dtype=tf.int32, shape=()) | |
def parse_single_tfexample(self, serialized_example): | |
"""Parses a tf.train.Example proto to a one-hot encoded read, label pair. | |
Injects noise into the incoming tf.train.Example's read sequence | |
when noise_rate is non-zero. | |
Args: | |
serialized_example: string; the serialized tf.train.Example proto | |
containing the read sequence and label value of interest as | |
tf.FixedLenFeatures. | |
Returns: | |
Tuple (features, labels) of dicts for the input features and prediction | |
targets. | |
""" | |
with tf.variable_scope('parse_single_tfexample'): | |
features_spec = {'sequence': tf.FixedLenFeature([], tf.string)} | |
for target in self.targets: | |
features_spec[target] = tf.FixedLenFeature([], tf.string) | |
features = tf.parse_single_example( | |
serialized_example, features=features_spec) | |
if self.noise_rate > 0.0: | |
read_sequence = tf.py_func(seq2label_utils.add_read_noise, | |
[features['sequence'], self.noise_rate], | |
(tf.string)) | |
else: | |
read_sequence = features['sequence'] | |
read_sequence = self.encode_read(read_sequence) | |
read_features = {'sequence': read_sequence} | |
if self.mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL): | |
label = { | |
target: self.encode_label(target, features[target]) | |
for target in self.targets | |
} | |
else: | |
label = {target: self._empty_label() for target in self.targets} | |
return read_features, label | |
class InputDataset(object): | |
"""A class providing access to input data for the Seq2Label model. | |
Attributes: | |
mode: `tf.estimator.ModeKeys`; the execution mode {TRAIN, EVAL, INFER}. | |
targets: list of strings; the names of the labels of interest (e.g. | |
"species"). | |
dataset_info: a `Seq2LabelDatasetInfo` message reflecting the dataset | |
metadata. | |
initializer: the TF initializer op for the underlying iterator, which | |
will rewind the iterator. | |
is_train: Boolean indicating whether or not the execution mode is TRAIN. | |
""" | |
def __init__(self, | |
mode, | |
targets, | |
dataset_info, | |
train_epochs=None, | |
noise_rate=0.0, | |
random_seed=None, | |
input_tfrecord_files=None, | |
fixed_read_length=None, | |
ensure_constant_batch_size=False, | |
num_parallel_calls=32): | |
"""Constructor for InputDataset. | |
Args: | |
mode: `tf.estimator.ModeKeys`; the execution mode {TRAIN, EVAL, INFER}. | |
targets: list of strings; the names of the labels of interest (e.g. | |
"species"). | |
dataset_info: a `Seq2LabelDatasetInfo` message reflecting the dataset | |
metadata. | |
train_epochs: the number of training epochs to perform, if mode==TRAIN. | |
noise_rate: float [0.0, 1.0] specifying rate at which to inject | |
base-flipping noise into the read sequences. | |
random_seed: seed to be used for shuffling, if mode==TRAIN. | |
input_tfrecord_files: a list of filenames for TFRecords of TF examples. | |
fixed_read_length: an integer value of the statically-known read length, | |
or None if the read length is to be determined dynamically. The read | |
length must be known statically for TPU execution. | |
ensure_constant_batch_size: ensure a constant batch size at the expense of | |
discarding the last "short" batch. This also gives us a statically | |
constant batch size, which is essential for e.g. the TPU platform. | |
num_parallel_calls: the number of dataset elements to process in parallel. | |
If None, elements will be processed sequentially. | |
""" | |
self.input_tfrecord_files = input_tfrecord_files | |
self.mode = mode | |
self.targets = targets | |
self.dataset_info = dataset_info | |
self._train_epochs = train_epochs | |
self._noise_rate = noise_rate | |
self._random_seed = random_seed | |
if random_seed is not None: | |
np.random.seed(random_seed) | |
self._fixed_read_length = fixed_read_length | |
self._ensure_constant_batch_size = ensure_constant_batch_size | |
self._num_parallel_calls = num_parallel_calls | |
def from_tfrecord_files(input_tfrecord_files, *args, **kwargs): | |
return InputDataset( | |
*args, input_tfrecord_files=input_tfrecord_files, **kwargs) | |
def is_train(self): | |
return self.mode == tf.estimator.ModeKeys.TRAIN | |
def input_fn(self, params): | |
"""Supplies input for the model. | |
This function supplies input to our model as a function of the mode. | |
Args: | |
params: a dictionary, containing: | |
- params['batch_size']: the integer batch size. | |
Returns: | |
A tuple of two values as follows: | |
1) the *features* dict, containing a tensor value for keys as follows: | |
- "sequence" - the encoded read input sequence. | |
2) the *labels* dict. containing a key for `target`, whose value is: | |
- a string Tensor value (in TRAIN/EVAL mode), or | |
- a blank Tensor (PREDICT mode). | |
""" | |
randomize_input = self.is_train | |
batch_size = params['batch_size'] | |
encoding = _InputEncoding( | |
self.dataset_info, | |
self.mode, | |
self.targets, | |
noise_rate=self._noise_rate, | |
fixed_read_length=self._fixed_read_length) | |
dataset = tf.data.TFRecordDataset(self.input_tfrecord_files) | |
dataset = dataset.map( | |
encoding.parse_single_tfexample, | |
num_parallel_calls=self._num_parallel_calls) | |
dataset = dataset.repeat(self._train_epochs if self.is_train else 1) | |
if randomize_input: | |
dataset = dataset.shuffle( | |
buffer_size=max(1000, batch_size), seed=self._random_seed) | |
if self._ensure_constant_batch_size: | |
# Only take batches of *exactly* size batch_size; then we get a | |
# statically knowable batch shape. | |
dataset = dataset.batch(batch_size, drop_remainder=True) | |
else: | |
dataset = dataset.batch(batch_size) | |
# Prefetch to allow infeed to be in parallel with model computations. | |
dataset = dataset.prefetch(2) | |
# Use initializable iterator to support table lookups. | |
iterator = dataset.make_initializable_iterator() | |
self.initializer = iterator.initializer | |
tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer) | |
features, labels = iterator.get_next() | |
return (features, labels) | |