File size: 1,082 Bytes
923ccaf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from typing import Callable

from rl_algo_impls.shared.callbacks import Callback
from rl_algo_impls.shared.policy.policy import Policy
from rl_algo_impls.wrappers.self_play_wrapper import SelfPlayWrapper


class SelfPlayCallback(Callback):
    def __init__(
        self,
        policy: Policy,
        policy_factory: Callable[[], Policy],
        selfPlayWrapper: SelfPlayWrapper,
    ) -> None:
        super().__init__()
        self.policy = policy
        self.policy_factory = policy_factory
        self.selfPlayWrapper = selfPlayWrapper
        self.checkpoint_policy()

    def on_step(self, timesteps_elapsed: int = 1) -> bool:
        super().on_step(timesteps_elapsed)
        if (
            self.timesteps_elapsed
            >= self.last_checkpoint_step + self.selfPlayWrapper.save_steps
        ):
            self.checkpoint_policy()
        return True

    def checkpoint_policy(self):
        self.selfPlayWrapper.checkpoint_policy(
            self.policy_factory().load_from(self.policy)
        )
        self.last_checkpoint_step = self.timesteps_elapsed