sgoodfriend's picture
A2C playing PongNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
05b94c0
from abc import abstractmethod
from typing import NamedTuple, Optional, Sequence, Tuple, TypeVar
import gym
import numpy as np
import torch
from gym.spaces import Box, Space
from rl_algo_impls.shared.policy.actor_critic_network import (
ConnectedTrioActorCriticNetwork,
SeparateActorCriticNetwork,
UNetActorCriticNetwork,
)
from rl_algo_impls.shared.policy.policy import Policy
from rl_algo_impls.wrappers.vectorable_wrapper import (
VecEnv,
VecEnvObs,
single_action_space,
single_observation_space,
)
class Step(NamedTuple):
a: np.ndarray
v: np.ndarray
logp_a: np.ndarray
clamped_a: np.ndarray
class ACForward(NamedTuple):
logp_a: torch.Tensor
entropy: torch.Tensor
v: torch.Tensor
FEAT_EXT_FILE_NAME = "feat_ext.pt"
V_FEAT_EXT_FILE_NAME = "v_feat_ext.pt"
PI_FILE_NAME = "pi.pt"
V_FILE_NAME = "v.pt"
ActorCriticSelf = TypeVar("ActorCriticSelf", bound="ActorCritic")
def clamp_actions(
actions: np.ndarray, action_space: gym.Space, squash_output: bool
) -> np.ndarray:
if isinstance(action_space, Box):
low, high = action_space.low, action_space.high # type: ignore
if squash_output:
# Squashed output is already between -1 and 1. Rescale if the actual
# output needs to something other than -1 and 1
return low + 0.5 * (actions + 1) * (high - low)
else:
return np.clip(actions, low, high)
return actions
class OnPolicy(Policy):
@abstractmethod
def value(self, obs: VecEnvObs) -> np.ndarray:
...
@abstractmethod
def step(self, obs: VecEnvObs, action_masks: Optional[np.ndarray] = None) -> Step:
...
@property
@abstractmethod
def action_shape(self) -> Tuple[int, ...]:
...
class ActorCritic(OnPolicy):
def __init__(
self,
env: VecEnv,
pi_hidden_sizes: Optional[Sequence[int]] = None,
v_hidden_sizes: Optional[Sequence[int]] = None,
init_layers_orthogonal: bool = True,
activation_fn: str = "tanh",
log_std_init: float = -0.5,
use_sde: bool = False,
full_std: bool = True,
squash_output: bool = False,
share_features_extractor: bool = True,
cnn_flatten_dim: int = 512,
cnn_style: str = "nature",
cnn_layers_init_orthogonal: Optional[bool] = None,
impala_channels: Sequence[int] = (16, 32, 32),
actor_head_style: str = "single",
**kwargs,
) -> None:
super().__init__(env, **kwargs)
observation_space = single_observation_space(env)
action_space = single_action_space(env)
action_plane_space = getattr(env, "action_plane_space", None)
self.action_space = action_space
self.squash_output = squash_output
if actor_head_style == "unet":
self.network = UNetActorCriticNetwork(
observation_space,
action_space,
action_plane_space,
v_hidden_sizes=v_hidden_sizes,
init_layers_orthogonal=init_layers_orthogonal,
activation_fn=activation_fn,
cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
)
elif share_features_extractor:
self.network = ConnectedTrioActorCriticNetwork(
observation_space,
action_space,
pi_hidden_sizes=pi_hidden_sizes,
v_hidden_sizes=v_hidden_sizes,
init_layers_orthogonal=init_layers_orthogonal,
activation_fn=activation_fn,
log_std_init=log_std_init,
use_sde=use_sde,
full_std=full_std,
squash_output=squash_output,
cnn_flatten_dim=cnn_flatten_dim,
cnn_style=cnn_style,
cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
impala_channels=impala_channels,
actor_head_style=actor_head_style,
action_plane_space=action_plane_space,
)
else:
self.network = SeparateActorCriticNetwork(
observation_space,
action_space,
pi_hidden_sizes=pi_hidden_sizes,
v_hidden_sizes=v_hidden_sizes,
init_layers_orthogonal=init_layers_orthogonal,
activation_fn=activation_fn,
log_std_init=log_std_init,
use_sde=use_sde,
full_std=full_std,
squash_output=squash_output,
cnn_flatten_dim=cnn_flatten_dim,
cnn_style=cnn_style,
cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
impala_channels=impala_channels,
actor_head_style=actor_head_style,
action_plane_space=action_plane_space,
)
def forward(
self,
obs: torch.Tensor,
action: torch.Tensor,
action_masks: Optional[torch.Tensor] = None,
) -> ACForward:
(_, logp_a, entropy), v = self.network(obs, action, action_masks=action_masks)
assert logp_a is not None
assert entropy is not None
return ACForward(logp_a, entropy, v)
def value(self, obs: VecEnvObs) -> np.ndarray:
o = self._as_tensor(obs)
with torch.no_grad():
v = self.network.value(o)
return v.cpu().numpy()
def step(self, obs: VecEnvObs, action_masks: Optional[np.ndarray] = None) -> Step:
o = self._as_tensor(obs)
a_masks = self._as_tensor(action_masks) if action_masks is not None else None
with torch.no_grad():
(pi, _, _), v = self.network.distribution_and_value(o, action_masks=a_masks)
a = pi.sample()
logp_a = pi.log_prob(a)
a_np = a.cpu().numpy()
clamped_a_np = clamp_actions(a_np, self.action_space, self.squash_output)
return Step(a_np, v.cpu().numpy(), logp_a.cpu().numpy(), clamped_a_np)
def act(
self,
obs: np.ndarray,
deterministic: bool = True,
action_masks: Optional[np.ndarray] = None,
) -> np.ndarray:
if not deterministic:
return self.step(obs, action_masks=action_masks).clamped_a
else:
o = self._as_tensor(obs)
a_masks = (
self._as_tensor(action_masks) if action_masks is not None else None
)
with torch.no_grad():
(pi, _, _), _ = self.network.distribution_and_value(
o, action_masks=a_masks
)
a = pi.mode
return clamp_actions(a.cpu().numpy(), self.action_space, self.squash_output)
def load(self, path: str) -> None:
super().load(path)
self.reset_noise()
def load_from(self: ActorCriticSelf, policy: ActorCriticSelf) -> ActorCriticSelf:
super().load_from(policy)
self.reset_noise()
return self
def reset_noise(self, batch_size: Optional[int] = None) -> None:
self.network.reset_noise(
batch_size=batch_size if batch_size else self.env.num_envs
)
@property
def action_shape(self) -> Tuple[int, ...]:
return self.network.action_shape