from __future__ import absolute_import from __future__ import division from __future__ import print_function """Default configuration for agent and environment.""" from absl import logging from common import config_lib # brain coder def default_config(): return config_lib.Config( agent=config_lib.OneOf( [config_lib.Config( algorithm='pg', policy_lstm_sizes=[35,35], # Set value_lstm_sizes to None to share weights with policy. value_lstm_sizes=[35,35], obs_embedding_size=10, grad_clip_threshold=10.0, param_init_factor=1.0, lr=5e-5, pi_loss_hparam=1.0, vf_loss_hparam=0.5, entropy_beta=1e-2, regularizer=0.0, softmax_tr=1.0, # Reciprocal temperature. optimizer='rmsprop', # 'adam', 'sgd', 'rmsprop' topk=0, # Top-k unique codes will be stored. topk_loss_hparam=0.0, # off policy loss multiplier. # Uniformly sample this many episodes from topk buffer per batch. # If topk is 0, this has no effect. topk_batch_size=1, # Exponential moving average baseline for REINFORCE. # If zero, A2C is used. # If non-zero, should be close to 1, like .99, .999, etc. ema_baseline_decay=0.99, # Whether agent can emit EOS token. If true, agent can emit EOS # token which ends the episode early (ends the sequence). # If false, agent must emit tokens until the timestep limit is # reached. e.g. True means variable length code, False means fixed # length code. # WARNING: Making this false slows things down. eos_token=False, replay_temperature=1.0, # Replay probability. 1 = always replay, 0 = always on policy. alpha=0.0, # Whether to normalize importance weights in each minibatch. iw_normalize=True), config_lib.Config( algorithm='ga', crossover_rate=0.99, mutation_rate=0.086), config_lib.Config( algorithm='rand')], algorithm='pg', ), env=config_lib.Config( # If True, task-specific settings are not needed. task='', # 'print', 'echo', 'reverse', 'remove', ... task_cycle=[], # If non-empty, reptitions will cycle through tasks. task_kwargs='{}', # Python dict literal. task_manager_config=config_lib.Config( # Reward recieved per test case. These bonuses will be scaled # based on how many test cases there are. correct_bonus=2.0, # Bonus for code getting correct answer. code_length_bonus=1.0), # Maximum bonus for short code. correct_syntax=False, ), batch_size=64, timestep_limit=32) def default_config_with_updates(config_string, do_logging=True): if do_logging: logging.info('Config string: "%s"', config_string) config = default_config() config.strict_update(config_lib.Config.parse(config_string)) if do_logging: logging.info('Config:\n%s', config.pretty_str()) return config