# Copyright 2016 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import os import numpy as np import logging import src.utils as utils import datasets.nav_env_config as nec from datasets import factory def adjust_args_for_mode(args, mode): if mode == 'train': args.control.train = True elif mode == 'val1': # Same settings as for training, to make sure nothing wonky is happening # there. args.control.test = True args.control.test_mode = 'val' args.navtask.task_params.batch_size = 32 elif mode == 'val2': # No data augmentation, not sampling but taking the argmax action, not # sampling from the ground truth at all. args.control.test = True args.arch.action_sample_type = 'argmax' args.arch.sample_gt_prob_type = 'zero' args.navtask.task_params.data_augment = \ utils.Foo(lr_flip=0, delta_angle=0, delta_xy=0, relight=False, relight_fast=False, structured=False) args.control.test_mode = 'val' args.navtask.task_params.batch_size = 32 elif mode == 'bench': # Actually testing the agent in settings that are kept same between # different runs. args.navtask.task_params.batch_size = 16 args.control.test = True args.arch.action_sample_type = 'argmax' args.arch.sample_gt_prob_type = 'zero' args.navtask.task_params.data_augment = \ utils.Foo(lr_flip=0, delta_angle=0, delta_xy=0, relight=False, relight_fast=False, structured=False) args.summary.test_iters = 250 args.control.only_eval_when_done = True args.control.reset_rng_seed = True args.control.test_mode = 'test' else: logging.fatal('Unknown mode: %s.', mode) assert(False) return args def get_solver_vars(solver_str): if solver_str == '': vals = []; else: vals = solver_str.split('_') ks = ['clip', 'dlw', 'long', 'typ', 'isdk', 'adam_eps', 'init_lr']; ks = ks[:len(vals)] # Gradient clipping or not. if len(vals) == 0: ks.append('clip'); vals.append('noclip'); # data loss weight. if len(vals) == 1: ks.append('dlw'); vals.append('dlw20') # how long to train for. if len(vals) == 2: ks.append('long'); vals.append('nolong') # Adam if len(vals) == 3: ks.append('typ'); vals.append('adam2') # reg loss wt if len(vals) == 4: ks.append('rlw'); vals.append('rlw1') # isd_k if len(vals) == 5: ks.append('isdk'); vals.append('isdk415') # 415, inflexion at 2.5k. # adam eps if len(vals) == 6: ks.append('adam_eps'); vals.append('aeps1en8') # init lr if len(vals) == 7: ks.append('init_lr'); vals.append('lr1en3') assert(len(vals) == 8) vars = utils.Foo() for k, v in zip(ks, vals): setattr(vars, k, v) logging.error('solver_vars: %s', vars) return vars def process_solver_str(solver_str): solver = utils.Foo( seed=0, learning_rate_decay=None, clip_gradient_norm=None, max_steps=None, initial_learning_rate=None, momentum=None, steps_per_decay=None, logdir=None, sync=False, adjust_lr_sync=True, wt_decay=0.0001, data_loss_wt=None, reg_loss_wt=None, freeze_conv=True, num_workers=1, task=0, ps_tasks=0, master='local', typ=None, momentum2=None, adam_eps=None) # Clobber with overrides from solver str. solver_vars = get_solver_vars(solver_str) solver.data_loss_wt = float(solver_vars.dlw[3:].replace('x', '.')) solver.adam_eps = float(solver_vars.adam_eps[4:].replace('x', '.').replace('n', '-')) solver.initial_learning_rate = float(solver_vars.init_lr[2:].replace('x', '.').replace('n', '-')) solver.reg_loss_wt = float(solver_vars.rlw[3:].replace('x', '.')) solver.isd_k = float(solver_vars.isdk[4:].replace('x', '.')) long = solver_vars.long if long == 'long': solver.steps_per_decay = 40000 solver.max_steps = 120000 elif long == 'long2': solver.steps_per_decay = 80000 solver.max_steps = 120000 elif long == 'nolong' or long == 'nol': solver.steps_per_decay = 20000 solver.max_steps = 60000 else: logging.fatal('solver_vars.long should be long, long2, nolong or nol.') assert(False) clip = solver_vars.clip if clip == 'noclip' or clip == 'nocl': solver.clip_gradient_norm = 0 elif clip[:4] == 'clip': solver.clip_gradient_norm = float(clip[4:].replace('x', '.')) else: logging.fatal('Unknown solver_vars.clip: %s', clip) assert(False) typ = solver_vars.typ if typ == 'adam': solver.typ = 'adam' solver.momentum = 0.9 solver.momentum2 = 0.999 solver.learning_rate_decay = 1.0 elif typ == 'adam2': solver.typ = 'adam' solver.momentum = 0.9 solver.momentum2 = 0.999 solver.learning_rate_decay = 0.1 elif typ == 'sgd': solver.typ = 'sgd' solver.momentum = 0.99 solver.momentum2 = None solver.learning_rate_decay = 0.1 else: logging.fatal('Unknown solver_vars.typ: %s', typ) assert(False) logging.error('solver: %s', solver) return solver def get_navtask_vars(navtask_str): if navtask_str == '': vals = [] else: vals = navtask_str.split('_') ks_all = ['dataset_name', 'modality', 'task', 'history', 'max_dist', 'num_steps', 'step_size', 'n_ori', 'aux_views', 'data_aug'] ks = ks_all[:len(vals)] # All data or not. if len(vals) == 0: ks.append('dataset_name'); vals.append('sbpd') # modality if len(vals) == 1: ks.append('modality'); vals.append('rgb') # semantic task? if len(vals) == 2: ks.append('task'); vals.append('r2r') # number of history frames. if len(vals) == 3: ks.append('history'); vals.append('h0') # max steps if len(vals) == 4: ks.append('max_dist'); vals.append('32') # num steps if len(vals) == 5: ks.append('num_steps'); vals.append('40') # step size if len(vals) == 6: ks.append('step_size'); vals.append('8') # n_ori if len(vals) == 7: ks.append('n_ori'); vals.append('4') # Auxiliary views. if len(vals) == 8: ks.append('aux_views'); vals.append('nv0') # Normal data augmentation as opposed to structured data augmentation (if set # to straug. if len(vals) == 9: ks.append('data_aug'); vals.append('straug') assert(len(vals) == 10) for i in range(len(ks)): assert(ks[i] == ks_all[i]) vars = utils.Foo() for k, v in zip(ks, vals): setattr(vars, k, v) logging.error('navtask_vars: %s', vals) return vars def process_navtask_str(navtask_str): navtask = nec.nav_env_base_config() # Clobber with overrides from strings. navtask_vars = get_navtask_vars(navtask_str) navtask.task_params.n_ori = int(navtask_vars.n_ori) navtask.task_params.max_dist = int(navtask_vars.max_dist) navtask.task_params.num_steps = int(navtask_vars.num_steps) navtask.task_params.step_size = int(navtask_vars.step_size) navtask.task_params.data_augment.delta_xy = int(navtask_vars.step_size)/2. n_aux_views_each = int(navtask_vars.aux_views[2]) aux_delta_thetas = np.concatenate((np.arange(n_aux_views_each) + 1, -1 -np.arange(n_aux_views_each))) aux_delta_thetas = aux_delta_thetas*np.deg2rad(navtask.camera_param.fov) navtask.task_params.aux_delta_thetas = aux_delta_thetas if navtask_vars.data_aug == 'aug': navtask.task_params.data_augment.structured = False elif navtask_vars.data_aug == 'straug': navtask.task_params.data_augment.structured = True else: logging.fatal('Unknown navtask_vars.data_aug %s.', navtask_vars.data_aug) assert(False) navtask.task_params.num_history_frames = int(navtask_vars.history[1:]) navtask.task_params.n_views = 1+navtask.task_params.num_history_frames navtask.task_params.goal_channels = int(navtask_vars.n_ori) if navtask_vars.task == 'hard': navtask.task_params.type = 'rng_rejection_sampling_many' navtask.task_params.rejection_sampling_M = 2000 navtask.task_params.min_dist = 10 elif navtask_vars.task == 'r2r': navtask.task_params.type = 'room_to_room_many' elif navtask_vars.task == 'ST': # Semantic task at hand. navtask.task_params.goal_channels = \ len(navtask.task_params.semantic_task.class_map_names) navtask.task_params.rel_goal_loc_dim = \ len(navtask.task_params.semantic_task.class_map_names) navtask.task_params.type = 'to_nearest_obj_acc' else: logging.fatal('navtask_vars.task: should be hard or r2r, ST') assert(False) if navtask_vars.modality == 'rgb': navtask.camera_param.modalities = ['rgb'] navtask.camera_param.img_channels = 3 elif navtask_vars.modality == 'd': navtask.camera_param.modalities = ['depth'] navtask.camera_param.img_channels = 2 navtask.task_params.img_height = navtask.camera_param.height navtask.task_params.img_width = navtask.camera_param.width navtask.task_params.modalities = navtask.camera_param.modalities navtask.task_params.img_channels = navtask.camera_param.img_channels navtask.task_params.img_fov = navtask.camera_param.fov navtask.dataset = factory.get_dataset(navtask_vars.dataset_name) return navtask