sgoodfriend's picture
DQN playing BreakoutNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
923ccaf
import dataclasses
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union
import numpy as np
from torch.utils.tensorboard.writer import SummaryWriter
@dataclass
class Episode:
score: float = 0
length: int = 0
info: Dict[str, Dict[str, Any]] = dataclasses.field(default_factory=dict)
StatisticSelf = TypeVar("StatisticSelf", bound="Statistic")
@dataclass
class Statistic:
values: np.ndarray
round_digits: int = 2
score_function: str = "mean-std"
@property
def mean(self) -> float:
return np.mean(self.values).item()
@property
def std(self) -> float:
return np.std(self.values).item()
@property
def min(self) -> float:
return np.min(self.values).item()
@property
def max(self) -> float:
return np.max(self.values).item()
def sum(self) -> float:
return np.sum(self.values).item()
def __len__(self) -> int:
return len(self.values)
def score(self) -> float:
if self.score_function == "mean-std":
return self.mean - self.std
elif self.score_function == "mean":
return self.mean
else:
raise NotImplemented(
f"Only mean-std and mean score_functions supported ({self.score_function})"
)
def _diff(self: StatisticSelf, o: StatisticSelf) -> float:
return self.score() - o.score()
def __gt__(self: StatisticSelf, o: StatisticSelf) -> bool:
return self._diff(o) > 0
def __ge__(self: StatisticSelf, o: StatisticSelf) -> bool:
return self._diff(o) >= 0
def __repr__(self) -> str:
mean = round(self.mean, self.round_digits)
if self.round_digits == 0:
mean = int(mean)
if self.score_function == "mean":
return f"{mean}"
std = round(self.std, self.round_digits)
if self.round_digits == 0:
std = int(std)
return f"{mean} +/- {std}"
def to_dict(self) -> Dict[str, float]:
return {
"mean": self.mean,
"std": self.std,
"min": self.min,
"max": self.max,
}
EpisodesStatsSelf = TypeVar("EpisodesStatsSelf", bound="EpisodesStats")
class EpisodesStats:
def __init__(
self,
episodes: Sequence[Episode],
simple: bool = False,
score_function: str = "mean-std",
) -> None:
self.episodes = episodes
self.simple = simple
self.score = Statistic(
np.array([e.score for e in episodes]), score_function=score_function
)
self.length = Statistic(np.array([e.length for e in episodes]), round_digits=0)
additional_values = defaultdict(list)
for e in self.episodes:
if e.info:
for k, v in e.info.items():
if isinstance(v, dict):
for k2, v2 in v.items():
additional_values[f"{k}_{k2}"].append(v2)
else:
additional_values[k].append(v)
self.additional_stats = {
k: Statistic(np.array(values)) for k, values in additional_values.items()
}
self.score_function = score_function
def __gt__(self: EpisodesStatsSelf, o: EpisodesStatsSelf) -> bool:
return self.score > o.score
def __ge__(self: EpisodesStatsSelf, o: EpisodesStatsSelf) -> bool:
return self.score >= o.score
def __repr__(self) -> str:
mean = self.score.mean
score = self.score.score()
if mean != score:
return f"Score: {self.score} ({round(score)}) | Length: {self.length}"
else:
return f"Score: {self.score} | Length: {self.length}"
def __len__(self) -> int:
return len(self.episodes)
def _asdict(self) -> dict:
return {
"n_episodes": len(self.episodes),
"score": self.score.to_dict(),
"length": self.length.to_dict(),
}
def write_to_tensorboard(
self, tb_writer: SummaryWriter, main_tag: str, global_step: Optional[int] = None
) -> None:
stats = {"mean": self.score.mean}
if not self.simple:
stats.update(
{
"min": self.score.min,
"max": self.score.max,
"result": self.score.score(),
"n_episodes": len(self.episodes),
"length": self.length.mean,
}
)
for k, addl_stats in self.additional_stats.items():
stats[k] = addl_stats.mean
for name, value in stats.items():
tb_writer.add_scalar(f"{main_tag}/{name}", value, global_step=global_step)
class EpisodeAccumulator:
def __init__(self, num_envs: int):
self._episodes = []
self.current_episodes = [Episode() for _ in range(num_envs)]
@property
def episodes(self) -> List[Episode]:
return self._episodes
def step(self, reward: np.ndarray, done: np.ndarray, info: List[Dict]) -> None:
for idx, current in enumerate(self.current_episodes):
current.score += reward[idx]
current.length += 1
if done[idx]:
self._episodes.append(current)
self.current_episodes[idx] = Episode()
self.on_done(idx, current, info[idx])
def __len__(self) -> int:
return len(self.episodes)
def on_done(self, ep_idx: int, episode: Episode, info: Dict) -> None:
pass
def stats(self) -> EpisodesStats:
return EpisodesStats(self.episodes)
def log_scalars(
tb_writer: SummaryWriter,
main_tag: str,
tag_scalar_dict: Dict[str, Union[int, float]],
global_step: int,
) -> None:
for tag, value in tag_scalar_dict.items():
tb_writer.add_scalar(f"{main_tag}/{tag}", value, global_step)