NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
3.35 kB
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Default configuration for agent and environment."""
from absl import logging
from common import config_lib # brain coder
def default_config():
return config_lib.Config(
agent=config_lib.OneOf(
[config_lib.Config(
algorithm='pg',
policy_lstm_sizes=[35,35],
# Set value_lstm_sizes to None to share weights with policy.
value_lstm_sizes=[35,35],
obs_embedding_size=10,
grad_clip_threshold=10.0,
param_init_factor=1.0,
lr=5e-5,
pi_loss_hparam=1.0,
vf_loss_hparam=0.5,
entropy_beta=1e-2,
regularizer=0.0,
softmax_tr=1.0, # Reciprocal temperature.
optimizer='rmsprop', # 'adam', 'sgd', 'rmsprop'
topk=0, # Top-k unique codes will be stored.
topk_loss_hparam=0.0, # off policy loss multiplier.
# Uniformly sample this many episodes from topk buffer per batch.
# If topk is 0, this has no effect.
topk_batch_size=1,
# Exponential moving average baseline for REINFORCE.
# If zero, A2C is used.
# If non-zero, should be close to 1, like .99, .999, etc.
ema_baseline_decay=0.99,
# Whether agent can emit EOS token. If true, agent can emit EOS
# token which ends the episode early (ends the sequence).
# If false, agent must emit tokens until the timestep limit is
# reached. e.g. True means variable length code, False means fixed
# length code.
# WARNING: Making this false slows things down.
eos_token=False,
replay_temperature=1.0,
# Replay probability. 1 = always replay, 0 = always on policy.
alpha=0.0,
# Whether to normalize importance weights in each minibatch.
iw_normalize=True),
config_lib.Config(
algorithm='ga',
crossover_rate=0.99,
mutation_rate=0.086),
config_lib.Config(
algorithm='rand')],
algorithm='pg',
),
env=config_lib.Config(
# If True, task-specific settings are not needed.
task='', # 'print', 'echo', 'reverse', 'remove', ...
task_cycle=[], # If non-empty, reptitions will cycle through tasks.
task_kwargs='{}', # Python dict literal.
task_manager_config=config_lib.Config(
# Reward recieved per test case. These bonuses will be scaled
# based on how many test cases there are.
correct_bonus=2.0, # Bonus for code getting correct answer.
code_length_bonus=1.0), # Maximum bonus for short code.
correct_syntax=False,
),
batch_size=64,
timestep_limit=32)
def default_config_with_updates(config_string, do_logging=True):
if do_logging:
logging.info('Config string: "%s"', config_string)
config = default_config()
config.strict_update(config_lib.Config.parse(config_string))
if do_logging:
logging.info('Config:\n%s', config.pretty_str())
return config