sgoodfriend's picture
DQN playing BreakoutNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
923ccaf
from abc import ABC, abstractmethod
from typing import List, Optional, TypeVar
import gym
import torch
from torch.utils.tensorboard.writer import SummaryWriter
from rl_algo_impls.shared.callbacks import Callback
from rl_algo_impls.shared.policy.policy import Policy
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
AlgorithmSelf = TypeVar("AlgorithmSelf", bound="Algorithm")
class Algorithm(ABC):
@abstractmethod
def __init__(
self,
policy: Policy,
env: VecEnv,
device: torch.device,
tb_writer: SummaryWriter,
**kwargs,
) -> None:
super().__init__()
self.policy = policy
self.env = env
self.device = device
self.tb_writer = tb_writer
@abstractmethod
def learn(
self: AlgorithmSelf,
train_timesteps: int,
callbacks: Optional[List[Callback]] = None,
total_timesteps: Optional[int] = None,
start_timesteps: int = 0,
) -> AlgorithmSelf:
...