# Copyright 2017 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Wrapper around gym env. Allows for using batches of possibly identitically seeded environments. """ import gym import numpy as np import random from six.moves import xrange import env_spec def get_env(env_str): return gym.make(env_str) class GymWrapper(object): def __init__(self, env_str, distinct=1, count=1, seeds=None): self.distinct = distinct self.count = count self.total = self.distinct * self.count self.seeds = seeds or [random.randint(0, 1e12) for _ in xrange(self.distinct)] self.envs = [] for seed in self.seeds: for _ in xrange(self.count): env = get_env(env_str) env.seed(seed) if hasattr(env, 'last'): env.last = 100 # for algorithmic envs self.envs.append(env) self.dones = [True] * self.total self.num_episodes_played = 0 one_env = self.get_one() self.use_action_list = hasattr(one_env.action_space, 'spaces') self.env_spec = env_spec.EnvSpec(self.get_one()) def get_seeds(self): return self.seeds def reset(self): self.dones = [False] * self.total self.num_episodes_played += len(self.envs) # reset seeds to be synchronized self.seeds = [random.randint(0, 1e12) for _ in xrange(self.distinct)] counter = 0 for seed in self.seeds: for _ in xrange(self.count): self.envs[counter].seed(seed) counter += 1 return [self.env_spec.convert_obs_to_list(env.reset()) for env in self.envs] def reset_if(self, predicate=None): if predicate is None: predicate = self.dones if self.count != 1: assert np.all(predicate) return self.reset() self.num_episodes_played += sum(predicate) output = [self.env_spec.convert_obs_to_list(env.reset()) if pred else None for env, pred in zip(self.envs, predicate)] for i, pred in enumerate(predicate): if pred: self.dones[i] = False return output def all_done(self): return all(self.dones) def step(self, actions): def env_step(env, action): action = self.env_spec.convert_action_to_gym(action) obs, reward, done, tt = env.step(action) obs = self.env_spec.convert_obs_to_list(obs) return obs, reward, done, tt actions = zip(*actions) outputs = [env_step(env, action) if not done else (self.env_spec.initial_obs(None), 0, True, None) for action, env, done in zip(actions, self.envs, self.dones)] for i, (_, _, done, _) in enumerate(outputs): self.dones[i] = self.dones[i] or done obs, reward, done, tt = zip(*outputs) obs = [list(oo) for oo in zip(*obs)] return [obs, reward, done, tt] def get_one(self): return random.choice(self.envs) def __len__(self): return len(self.envs)