Spaces:
Sleeping
Sleeping
# Copyright 2017 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Provides dataset dictionaries as used in our network models.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import os | |
import tensorflow as tf | |
import tensorflow.contrib.slim as slim | |
from tensorflow.contrib.slim.python.slim.data import dataset | |
from tensorflow.contrib.slim.python.slim.data import dataset_data_provider | |
from tensorflow.contrib.slim.python.slim.data import tfexample_decoder | |
_ITEMS_TO_DESCRIPTIONS = { | |
'image': 'Images', | |
'mask': 'Masks', | |
'vox': 'Voxels' | |
} | |
def _get_split(file_pattern, num_samples, num_views, image_size, vox_size): | |
"""Get dataset.Dataset for the given dataset file pattern and properties.""" | |
# A dictionary from TF-Example keys to tf.FixedLenFeature instance. | |
keys_to_features = { | |
'image': tf.FixedLenFeature( | |
shape=[num_views, image_size, image_size, 3], | |
dtype=tf.float32, default_value=None), | |
'mask': tf.FixedLenFeature( | |
shape=[num_views, image_size, image_size, 1], | |
dtype=tf.float32, default_value=None), | |
'vox': tf.FixedLenFeature( | |
shape=[vox_size, vox_size, vox_size, 1], | |
dtype=tf.float32, default_value=None), | |
} | |
items_to_handler = { | |
'image': tfexample_decoder.Tensor( | |
'image', shape=[num_views, image_size, image_size, 3]), | |
'mask': tfexample_decoder.Tensor( | |
'mask', shape=[num_views, image_size, image_size, 1]), | |
'vox': tfexample_decoder.Tensor( | |
'vox', shape=[vox_size, vox_size, vox_size, 1]) | |
} | |
decoder = tfexample_decoder.TFExampleDecoder( | |
keys_to_features, items_to_handler) | |
return dataset.Dataset( | |
data_sources=file_pattern, | |
reader=tf.TFRecordReader, | |
decoder=decoder, | |
num_samples=num_samples, | |
items_to_descriptions=_ITEMS_TO_DESCRIPTIONS) | |
def get(dataset_dir, | |
dataset_name, | |
split_name, | |
shuffle=True, | |
num_readers=1, | |
common_queue_capacity=64, | |
common_queue_min=50): | |
"""Provides input data for a specified dataset and split.""" | |
dataset_to_kwargs = { | |
'shapenet_chair': { | |
'file_pattern': '03001627_%s.tfrecords' % split_name, | |
'num_views': 24, | |
'image_size': 64, | |
'vox_size': 32, | |
}, 'shapenet_all': { | |
'file_pattern': '*_%s.tfrecords' % split_name, | |
'num_views': 24, | |
'image_size': 64, | |
'vox_size': 32, | |
}, | |
} | |
split_sizes = { | |
'shapenet_chair': { | |
'train': 4744, | |
'val': 678, | |
'test': 1356, | |
}, | |
'shapenet_all': { | |
'train': 30643, | |
'val': 4378, | |
'test': 8762, | |
} | |
} | |
kwargs = dataset_to_kwargs[dataset_name] | |
kwargs['file_pattern'] = os.path.join(dataset_dir, kwargs['file_pattern']) | |
kwargs['num_samples'] = split_sizes[dataset_name][split_name] | |
dataset_split = _get_split(**kwargs) | |
data_provider = dataset_data_provider.DatasetDataProvider( | |
dataset_split, | |
num_readers=num_readers, | |
common_queue_capacity=common_queue_capacity, | |
common_queue_min=common_queue_min, | |
shuffle=shuffle) | |
inputs = { | |
'num_samples': dataset_split.num_samples, | |
} | |
[image, mask, vox] = data_provider.get(['image', 'mask', 'vox']) | |
inputs['image'] = image | |
inputs['mask'] = mask | |
inputs['voxel'] = vox | |
return inputs | |