# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/c51/#c51_jaxpy import argparse import os import random import time from distutils.util import strtobool import flax import flax.linen as nn import gym import jax import jax.numpy as jnp import numpy as np import optax from flax.training.train_state import TrainState from stable_baselines3.common.buffers import ReplayBuffer from torch.utils.tensorboard import SummaryWriter def parse_args(): # fmt: off parser = argparse.ArgumentParser() parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), help="the name of this experiment") parser.add_argument("--seed", type=int, default=1, help="seed of the experiment") parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, help="if toggled, this experiment will be tracked with Weights and Biases") parser.add_argument("--wandb-project-name", type=str, default="cleanRL", help="the wandb's project name") parser.add_argument("--wandb-entity", type=str, default=None, help="the entity (team) of wandb's project") parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, help="whether to capture videos of the agent performances (check out `videos` folder)") parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, help="whether to save model into the `runs/{run_name}` folder") parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, help="whether to upload the saved model to huggingface") parser.add_argument("--hf-entity", type=str, default="", help="the user or org name of the model repository from the Hugging Face Hub") # Algorithm specific arguments parser.add_argument("--env-id", type=str, default="CartPole-v1", help="the id of the environment") parser.add_argument("--total-timesteps", type=int, default=500000, help="total timesteps of the experiments") parser.add_argument("--learning-rate", type=float, default=2.5e-4, help="the learning rate of the optimizer") parser.add_argument("--n-atoms", type=int, default=101, help="the number of atoms") parser.add_argument("--v-min", type=float, default=-100, help="the number of atoms") parser.add_argument("--v-max", type=float, default=100, help="the number of atoms") parser.add_argument("--buffer-size", type=int, default=10000, help="the replay memory buffer size") parser.add_argument("--gamma", type=float, default=0.99, help="the discount factor gamma") parser.add_argument("--target-network-frequency", type=int, default=500, help="the timesteps it takes to update the target network") parser.add_argument("--batch-size", type=int, default=128, help="the batch size of sample from the reply memory") parser.add_argument("--start-e", type=float, default=1, help="the starting epsilon for exploration") parser.add_argument("--end-e", type=float, default=0.05, help="the ending epsilon for exploration") parser.add_argument("--exploration-fraction", type=float, default=0.5, help="the fraction of `total-timesteps` it takes from start-e to go end-e") parser.add_argument("--learning-starts", type=int, default=10000, help="timestep to start learning") parser.add_argument("--train-frequency", type=int, default=10, help="the frequency of training") args = parser.parse_args() # fmt: on return args def make_env(env_id, seed, idx, capture_video, run_name): def thunk(): env = gym.make(env_id) env = gym.wrappers.RecordEpisodeStatistics(env) if capture_video: if idx == 0: env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") env.seed(seed) env.action_space.seed(seed) env.observation_space.seed(seed) return env return thunk # ALGO LOGIC: initialize agent here: class QNetwork(nn.Module): action_dim: int n_atoms: int @nn.compact def __call__(self, x): x = nn.Dense(120)(x) x = nn.relu(x) x = nn.Dense(84)(x) x = nn.relu(x) x = nn.Dense(self.action_dim * self.n_atoms)(x) x = x.reshape((x.shape[0], self.action_dim, self.n_atoms)) x = nn.softmax(x, axis=-1) # pmfs return x class TrainState(TrainState): target_params: flax.core.FrozenDict atoms: jnp.ndarray def linear_schedule(start_e: float, end_e: float, duration: int, t: int): slope = (end_e - start_e) / duration return max(slope * t + start_e, end_e) if __name__ == "__main__": args = parse_args() run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" if args.track: import wandb wandb.init( project=args.wandb_project_name, entity=args.wandb_entity, sync_tensorboard=True, config=vars(args), name=run_name, monitor_gym=True, save_code=True, ) writer = SummaryWriter(f"runs/{run_name}") writer.add_text( "hyperparameters", "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), ) # TRY NOT TO MODIFY: seeding random.seed(args.seed) np.random.seed(args.seed) key = jax.random.PRNGKey(args.seed) key, q_key = jax.random.split(key, 2) # env setup envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" obs = envs.reset() q_network = QNetwork(action_dim=envs.single_action_space.n, n_atoms=args.n_atoms) q_state = TrainState.create( apply_fn=q_network.apply, params=q_network.init(q_key, obs), target_params=q_network.init(q_key, obs), # directly using jnp.linspace leads to numerical errors atoms=jnp.asarray(np.linspace(args.v_min, args.v_max, num=args.n_atoms)), tx=optax.adam(learning_rate=args.learning_rate, eps=0.01 / args.batch_size), ) q_network.apply = jax.jit(q_network.apply) # This step is not necessary as init called on same observation and key will always lead to same initializations q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, 1)) rb = ReplayBuffer( args.buffer_size, envs.single_observation_space, envs.single_action_space, "cpu", handle_timeout_termination=True, ) @jax.jit def update(q_state, observations, actions, next_observations, rewards, dones): next_pmfs = q_network.apply(q_state.target_params, next_observations) # (batch_size, num_actions, num_atoms) next_vals = (next_pmfs * q_state.atoms).sum(axis=-1) # (batch_size, num_actions) next_action = jnp.argmax(next_vals, axis=-1) # (batch_size,) next_pmfs = next_pmfs[np.arange(next_pmfs.shape[0]), next_action] next_atoms = rewards + args.gamma * q_state.atoms * (1 - dones) # projection delta_z = q_state.atoms[1] - q_state.atoms[0] tz = jnp.clip(next_atoms, a_min=(args.v_min), a_max=(args.v_max)) b = (tz - args.v_min) / delta_z l = jnp.clip(jnp.floor(b), a_min=0, a_max=args.n_atoms - 1) u = jnp.clip(jnp.ceil(b), a_min=0, a_max=args.n_atoms - 1) # (l == u).astype(jnp.float) handles the case where bj is exactly an integer # example bj = 1, then the upper ceiling should be uj= 2, and lj= 1 d_m_l = (u + (l == u).astype(jnp.float32) - b) * next_pmfs d_m_u = (b - l) * next_pmfs target_pmfs = jnp.zeros_like(next_pmfs) def project_to_bins(i, val): val = val.at[i, l[i].astype(jnp.int32)].add(d_m_l[i]) val = val.at[i, u[i].astype(jnp.int32)].add(d_m_u[i]) return val target_pmfs = jax.lax.fori_loop(0, target_pmfs.shape[0], project_to_bins, target_pmfs) def loss(q_params, observations, actions, target_pmfs): pmfs = q_network.apply(q_params, observations) old_pmfs = pmfs[np.arange(pmfs.shape[0]), actions.squeeze()] old_pmfs_l = jnp.clip(old_pmfs, a_min=1e-5, a_max=1 - 1e-5) loss = (-(target_pmfs * jnp.log(old_pmfs_l)).sum(-1)).mean() return loss, (old_pmfs * q_state.atoms).sum(-1) (loss_value, old_values), grads = jax.value_and_grad(loss, has_aux=True)( q_state.params, observations, actions, target_pmfs ) q_state = q_state.apply_gradients(grads=grads) return loss_value, old_values, q_state start_time = time.time() # TRY NOT TO MODIFY: start the game obs = envs.reset() for global_step in range(args.total_timesteps): # ALGO LOGIC: put action logic here epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step) if random.random() < epsilon: actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) else: pmfs = q_network.apply(q_state.params, obs) q_vals = (pmfs * q_state.atoms).sum(axis=-1) actions = q_vals.argmax(axis=-1) actions = jax.device_get(actions) # TRY NOT TO MODIFY: execute the game and log data. next_obs, rewards, dones, infos = envs.step(actions) # TRY NOT TO MODIFY: record rewards for plotting purposes for info in infos: if "episode" in info.keys(): print(f"global_step={global_step}, episodic_return={info['episode']['r']}") writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) writer.add_scalar("charts/epsilon", epsilon, global_step) break # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation` real_next_obs = next_obs.copy() for idx, d in enumerate(dones): if d: real_next_obs[idx] = infos[idx]["terminal_observation"] rb.add(obs, real_next_obs, actions, rewards, dones, infos) # TRY NOT TO MODIFY: CRUCIAL step easy to overlook obs = next_obs # ALGO LOGIC: training. if global_step > args.learning_starts and global_step % args.train_frequency == 0: data = rb.sample(args.batch_size) loss, old_val, q_state = update( q_state, data.observations.numpy(), data.actions.numpy(), data.next_observations.numpy(), data.rewards.numpy(), data.dones.numpy(), ) if global_step % 100 == 0: writer.add_scalar("losses/loss", jax.device_get(loss), global_step) writer.add_scalar("losses/q_values", jax.device_get(old_val.mean()), global_step) print("SPS:", int(global_step / (time.time() - start_time))) writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) # update the target network if global_step % args.target_network_frequency == 0: q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, 1)) if args.save_model: model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model" model_data = { "model_weights": q_state.params, "args": vars(args), } with open(model_path, "wb") as f: f.write(flax.serialization.to_bytes(model_data)) print(f"model saved to {model_path}") from cleanrl_utils.evals.c51_jax_eval import evaluate episodic_returns = evaluate( model_path, make_env, args.env_id, eval_episodes=10, run_name=f"{run_name}-eval", Model=QNetwork, epsilon=0.05, ) for idx, episodic_return in enumerate(episodic_returns): writer.add_scalar("eval/episodic_return", episodic_return, idx) if args.upload_model: from cleanrl_utils.huggingface import push_to_hub repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}" repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name push_to_hub(args, episodic_returns, repo_id, "DQN", f"runs/{run_name}", f"videos/{run_name}-eval") envs.close() writer.close()