Spaces:
Running
Running
File size: 3,593 Bytes
0b8359d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Wrapper for selecting the navigation environment that we want to train and
test on.
"""
import numpy as np
import os, glob
import platform
import logging
from tensorflow.python.platform import app
from tensorflow.python.platform import flags
import render.swiftshader_renderer as renderer
import src.file_utils as fu
import src.utils as utils
def get_dataset(dataset_name):
if dataset_name == 'sbpd':
dataset = StanfordBuildingParserDataset(dataset_name)
else:
logging.fatal('Not one of sbpd')
return dataset
class Loader():
def get_data_dir():
pass
def get_meta_data(self, file_name, data_dir=None):
if data_dir is None:
data_dir = self.get_data_dir()
full_file_name = os.path.join(data_dir, 'meta', file_name)
assert(fu.exists(full_file_name)), \
'{:s} does not exist'.format(full_file_name)
ext = os.path.splitext(full_file_name)[1]
if ext == '.txt':
ls = []
with fu.fopen(full_file_name, 'r') as f:
for l in f:
ls.append(l.rstrip())
elif ext == '.pkl':
ls = utils.load_variables(full_file_name)
return ls
def load_building(self, name, data_dir=None):
if data_dir is None:
data_dir = self.get_data_dir()
out = {}
out['name'] = name
out['data_dir'] = data_dir
out['room_dimension_file'] = os.path.join(data_dir, 'room-dimension',
name+'.pkl')
out['class_map_folder'] = os.path.join(data_dir, 'class-maps')
return out
def load_building_meshes(self, building):
dir_name = os.path.join(building['data_dir'], 'mesh', building['name'])
mesh_file_name = glob.glob1(dir_name, '*.obj')[0]
mesh_file_name_full = os.path.join(dir_name, mesh_file_name)
logging.error('Loading building from obj file: %s', mesh_file_name_full)
shape = renderer.Shape(mesh_file_name_full, load_materials=True,
name_prefix=building['name']+'_')
return [shape]
class StanfordBuildingParserDataset(Loader):
def __init__(self, ver):
self.ver = ver
self.data_dir = None
def get_data_dir(self):
if self.data_dir is None:
self.data_dir = 'data/stanford_building_parser_dataset/'
return self.data_dir
def get_benchmark_sets(self):
return self._get_benchmark_sets()
def get_split(self, split_name):
if self.ver == 'sbpd':
return self._get_split(split_name)
else:
logging.fatal('Unknown version.')
def _get_benchmark_sets(self):
sets = ['train1', 'val', 'test']
return sets
def _get_split(self, split_name):
train = ['area1', 'area5a', 'area5b', 'area6']
train1 = ['area1']
val = ['area3']
test = ['area4']
sets = {}
sets['train'] = train
sets['train1'] = train1
sets['val'] = val
sets['test'] = test
sets['all'] = sorted(list(set(train + val + test)))
return sets[split_name]
|