kinet-test / Kinetix /examples /example_premade_level_replay.py
tree3po's picture
Upload 190 files
e0f25ed verified
raw
history blame
1.57 kB
import jax
import jax.numpy as jnp
import jax.random
from jax2d.engine import PhysicsEngine
from matplotlib import pyplot as plt
from kinetix.environment.env import make_kinetix_env_from_args
from kinetix.environment.env_state import StaticEnvParams, EnvParams
from kinetix.environment.ued.distributions import sample_kinetix_level
from kinetix.environment.ued.ued_state import UEDParams
from kinetix.render.renderer_pixels import make_render_pixels
from kinetix.util.saving import load_from_json_file
def main():
# Load a premade level
level, static_env_params, env_params = load_from_json_file("worlds/l/grasp_easy.json")
# Create the environment
env = make_kinetix_env_from_args(
obs_type="pixels", action_type="continuous", reset_type="replay", static_env_params=static_env_params
)
# Reset the environment state to this level
rng = jax.random.PRNGKey(0)
rng, _rng = jax.random.split(rng)
obs, env_state = env.reset_to_level(_rng, level, env_params)
# Take a step in the environment
rng, _rng = jax.random.split(rng)
action = env.action_space(env_params).sample(_rng)
rng, _rng = jax.random.split(rng)
obs, env_state, reward, done, info = env.step(_rng, env_state, action, env_params)
# Render environment
renderer = make_render_pixels(env_params, static_env_params)
# There are a lot of wrappers
pixels = renderer(env_state.env_state.env_state.env_state)
plt.imshow(pixels.astype(jnp.uint8).transpose(1, 0, 2)[::-1])
plt.show()
if __name__ == "__main__":
main()