File size: 2,342 Bytes
d30b0db 923ccaf d30b0db 923ccaf d30b0db 923ccaf d30b0db 923ccaf d30b0db 923ccaf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
from dataclasses import asdict
from typing import Any, Dict, Optional
from torch.utils.tensorboard.writer import SummaryWriter
from rl_algo_impls.runner.config import Config, EnvHyperparams
from rl_algo_impls.shared.vec_env.microrts import make_microrts_env
from rl_algo_impls.shared.vec_env.procgen import make_procgen_env
from rl_algo_impls.shared.vec_env.vec_env import make_vec_env
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
def make_env(
config: Config,
hparams: EnvHyperparams,
training: bool = True,
render: bool = False,
normalize_load_path: Optional[str] = None,
tb_writer: Optional[SummaryWriter] = None,
) -> VecEnv:
if hparams.env_type == "procgen":
return make_procgen_env(
config,
hparams,
training=training,
render=render,
normalize_load_path=normalize_load_path,
tb_writer=tb_writer,
)
elif hparams.env_type in {"sb3vec", "gymvec"}:
return make_vec_env(
config,
hparams,
training=training,
render=render,
normalize_load_path=normalize_load_path,
tb_writer=tb_writer,
)
elif hparams.env_type == "microrts":
return make_microrts_env(
config,
hparams,
training=training,
render=render,
normalize_load_path=normalize_load_path,
tb_writer=tb_writer,
)
else:
raise ValueError(f"env_type {hparams.env_type} not supported")
def make_eval_env(
config: Config,
hparams: EnvHyperparams,
override_hparams: Optional[Dict[str, Any]] = None,
**kwargs,
) -> VecEnv:
kwargs = kwargs.copy()
kwargs["training"] = False
env_overrides = config.eval_hyperparams.get("env_overrides")
if env_overrides:
hparams_kwargs = asdict(hparams)
hparams_kwargs.update(env_overrides)
hparams = EnvHyperparams(**hparams_kwargs)
if override_hparams:
hparams_kwargs = asdict(hparams)
for k, v in override_hparams.items():
hparams_kwargs[k] = v
if k == "n_envs" and v == 1:
hparams_kwargs["vec_env_class"] = "sync"
hparams = EnvHyperparams(**hparams_kwargs)
return make_env(config, hparams, **kwargs)
|