File size: 1,082 Bytes
923ccaf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
from typing import Callable
from rl_algo_impls.shared.callbacks import Callback
from rl_algo_impls.shared.policy.policy import Policy
from rl_algo_impls.wrappers.self_play_wrapper import SelfPlayWrapper
class SelfPlayCallback(Callback):
def __init__(
self,
policy: Policy,
policy_factory: Callable[[], Policy],
selfPlayWrapper: SelfPlayWrapper,
) -> None:
super().__init__()
self.policy = policy
self.policy_factory = policy_factory
self.selfPlayWrapper = selfPlayWrapper
self.checkpoint_policy()
def on_step(self, timesteps_elapsed: int = 1) -> bool:
super().on_step(timesteps_elapsed)
if (
self.timesteps_elapsed
>= self.last_checkpoint_step + self.selfPlayWrapper.save_steps
):
self.checkpoint_policy()
return True
def checkpoint_policy(self):
self.selfPlayWrapper.checkpoint_policy(
self.policy_factory().load_from(self.policy)
)
self.last_checkpoint_step = self.timesteps_elapsed
|