import copy import numpy as np import random import torch import torch.nn as nn import torch.nn.functional as F from collections import deque from torch.optim import Adam from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs from torch.utils.tensorboard.writer import SummaryWriter from typing import List, NamedTuple, Optional, TypeVar from dqn.policy import DQNPolicy from shared.algorithm import Algorithm from shared.callbacks.callback import Callback from shared.schedule import linear_schedule class Transition(NamedTuple): obs: np.ndarray action: np.ndarray reward: float done: bool next_obs: np.ndarray class Batch(NamedTuple): obs: np.ndarray actions: np.ndarray rewards: np.ndarray dones: np.ndarray next_obs: np.ndarray class ReplayBuffer: def __init__(self, num_envs: int, maxlen: int) -> None: self.num_envs = num_envs self.buffer = deque(maxlen=maxlen) def add( self, obs: VecEnvObs, action: np.ndarray, reward: np.ndarray, done: np.ndarray, next_obs: VecEnvObs, ) -> None: assert isinstance(obs, np.ndarray) assert isinstance(next_obs, np.ndarray) for i in range(self.num_envs): self.buffer.append( Transition(obs[i], action[i], reward[i], done[i], next_obs[i]) ) def sample(self, batch_size: int) -> Batch: ts = random.sample(self.buffer, batch_size) return Batch( obs=np.array([t.obs for t in ts]), actions=np.array([t.action for t in ts]), rewards=np.array([t.reward for t in ts]), dones=np.array([t.done for t in ts]), next_obs=np.array([t.next_obs for t in ts]), ) def __len__(self) -> int: return len(self.buffer) DQNSelf = TypeVar("DQNSelf", bound="DQN") class DQN(Algorithm): def __init__( self, policy: DQNPolicy, env: VecEnv, device: torch.device, tb_writer: SummaryWriter, learning_rate: float = 1e-4, buffer_size: int = 1_000_000, learning_starts: int = 50_000, batch_size: int = 32, tau: float = 1.0, gamma: float = 0.99, train_freq: int = 4, gradient_steps: int = 1, target_update_interval: int = 10_000, exploration_fraction: float = 0.1, exploration_initial_eps: float = 1.0, exploration_final_eps: float = 0.05, max_grad_norm: float = 10.0, ) -> None: super().__init__(policy, env, device, tb_writer) self.policy = policy self.optimizer = Adam(self.policy.q_net.parameters(), lr=learning_rate) self.target_q_net = copy.deepcopy(self.policy.q_net).to(self.device) self.target_q_net.train(False) self.tau = tau self.target_update_interval = target_update_interval self.replay_buffer = ReplayBuffer(self.env.num_envs, buffer_size) self.batch_size = batch_size self.learning_starts = learning_starts self.train_freq = train_freq self.gradient_steps = gradient_steps self.gamma = gamma self.exploration_eps_schedule = linear_schedule( exploration_initial_eps, exploration_final_eps, end_fraction=exploration_fraction, ) self.max_grad_norm = max_grad_norm def learn( self: DQNSelf, total_timesteps: int, callback: Optional[Callback] = None ) -> DQNSelf: self.policy.train(True) obs = self.env.reset() obs = self._collect_rollout(self.learning_starts, obs, 1) learning_steps = total_timesteps - self.learning_starts timesteps_elapsed = 0 steps_since_target_update = 0 while timesteps_elapsed < learning_steps: progress = timesteps_elapsed / learning_steps eps = self.exploration_eps_schedule(progress) obs = self._collect_rollout(self.train_freq, obs, eps) rollout_steps = self.train_freq timesteps_elapsed += rollout_steps for _ in range( self.gradient_steps if self.gradient_steps > 0 else self.train_freq ): self.train() steps_since_target_update += rollout_steps if steps_since_target_update >= self.target_update_interval: self._update_target() steps_since_target_update = 0 if callback: callback.on_step(timesteps_elapsed=rollout_steps) return self def train(self) -> None: if len(self.replay_buffer) < self.batch_size: return o, a, r, d, next_o = self.replay_buffer.sample(self.batch_size) o = torch.as_tensor(o, device=self.device) a = torch.as_tensor(a, device=self.device).unsqueeze(1) r = torch.as_tensor(r, dtype=torch.float32, device=self.device) d = torch.as_tensor(d, dtype=torch.long, device=self.device) next_o = torch.as_tensor(next_o, device=self.device) with torch.no_grad(): target = r + (1 - d) * self.gamma * self.target_q_net(next_o).max(1).values current = self.policy.q_net(o).gather(dim=1, index=a).squeeze(1) loss = F.smooth_l1_loss(current, target) self.optimizer.zero_grad() loss.backward() if self.max_grad_norm: nn.utils.clip_grad_norm_(self.policy.q_net.parameters(), self.max_grad_norm) self.optimizer.step() def _collect_rollout(self, timesteps: int, obs: VecEnvObs, eps: float) -> VecEnvObs: for _ in range(0, timesteps, self.env.num_envs): action = self.policy.act(obs, eps, deterministic=False) next_obs, reward, done, _ = self.env.step(action) self.replay_buffer.add(obs, action, reward, done, next_obs) obs = next_obs return obs def _update_target(self) -> None: for target_param, param in zip( self.target_q_net.parameters(), self.policy.q_net.parameters() ): target_param.data.copy_( self.tau * param.data + (1 - self.tau) * target_param.data )