import abc from rlkit.policies.base import ExplorationPolicy class ExplorationStrategy(object, metaclass=abc.ABCMeta): @abc.abstractmethod def get_action(self, t, observation, policy, **kwargs): pass def reset(self): pass class RawExplorationStrategy(ExplorationStrategy, metaclass=abc.ABCMeta): @abc.abstractmethod def get_action_from_raw_action(self, action, **kwargs): pass def get_action(self, t, policy, *args, **kwargs): action, agent_info = policy.get_action(*args, **kwargs) return self.get_action_from_raw_action(action, t=t), agent_info def reset(self): pass class PolicyWrappedWithExplorationStrategy(ExplorationPolicy): def __init__( self, exploration_strategy: ExplorationStrategy, policy, ): self.es = exploration_strategy self.policy = policy self.t = 0 def set_num_steps_total(self, t): self.t = t def get_action(self, *args, **kwargs): return self.es.get_action(self.t, self.policy, *args, **kwargs) def reset(self): self.es.reset() self.policy.reset()