kinet-test / Kinetix /examples /example_random_level_replay.py
tree3po's picture
Upload 190 files
e0f25ed verified
raw
history blame
1.7 kB
import jax
import jax.numpy as jnp
import jax.random
from jax2d.engine import PhysicsEngine
from matplotlib import pyplot as plt
from kinetix.environment.env import make_kinetix_env_from_args
from kinetix.environment.env_state import StaticEnvParams, EnvParams
from kinetix.environment.ued.distributions import sample_kinetix_level
from kinetix.environment.ued.ued_state import UEDParams
from kinetix.render.renderer_pixels import make_render_pixels
def main():
# Use default parameters
env_params = EnvParams()
static_env_params = StaticEnvParams()
ued_params = UEDParams()
# Create the environment
env = make_kinetix_env_from_args(
obs_type="pixels", action_type="continuous", reset_type="replay", static_env_params=static_env_params
)
# Sample a random level
rng = jax.random.PRNGKey(0)
rng, _rng = jax.random.split(rng)
level = sample_kinetix_level(_rng, env.physics_engine, env_params, static_env_params, ued_params)
# Reset the environment state to this level
rng, _rng = jax.random.split(rng)
obs, env_state = env.reset_to_level(_rng, level, env_params)
# Take a step in the environment
rng, _rng = jax.random.split(rng)
action = env.action_space(env_params).sample(_rng)
rng, _rng = jax.random.split(rng)
obs, env_state, reward, done, info = env.step(_rng, env_state, action, env_params)
# Render environment
renderer = make_render_pixels(env_params, static_env_params)
# There are a lot of wrappers
pixels = renderer(env_state.env_state.env_state.env_state)
plt.imshow(pixels.astype(jnp.uint8).transpose(1, 0, 2)[::-1])
plt.show()
if __name__ == "__main__":
main()