from functools import partial import jax import jax.numpy as jnp import numpy as np from jax2d import joint from jax2d.engine import select_shape from jax2d.maths import rmat from jax2d.sim_state import RigidBody from jaxgl.maths import dist_from_line from jaxgl.renderer import clear_screen, make_renderer from jaxgl.shaders import ( fragment_shader_quad, fragment_shader_edged_quad, make_fragment_shader_texture, nearest_neighbour, make_fragment_shader_quad_textured, ) from kinetix.render.textures import ( THRUSTER_TEXTURE_16_RGBA, RJOINT_TEXTURE_6_RGBA, FJOINT_TEXTURE_6_RGBA, ) from kinetix.environment.env_state import StaticEnvParams, EnvParams, EnvState from flax import struct def make_render_pixels( params, static_params: StaticEnvParams, ): screen_dim = static_params.screen_dim downscale = static_params.downscale joint_tex_size = 6 thruster_tex_size = 16 FIXATED_COLOUR = jnp.array([80, 80, 80]) JOINT_COLOURS = jnp.array( [ # [0, 0, 255], [255, 255, 255], # yellow [255, 255, 0], # yellow [255, 0, 255], # purple/magenta [0, 255, 255], # cyan [255, 153, 51], # white ] ) def colour_thruster_texture(colour): return THRUSTER_TEXTURE_16_RGBA.at[:9, :, :3].mul(colour[None, None, :] / 255.0) coloured_thruster_textures = jax.vmap(colour_thruster_texture)(JOINT_COLOURS) ROLE_COLOURS = jnp.array( [ [160.0, 160.0, 160.0], # None [0.0, 204.0, 0.0], # Green: The ball [0.0, 102.0, 204.0], # Blue: The goal [255.0, 102.0, 102.0], # Red: Death Objects ] ) BACKGROUND_COLOUR = jnp.array([255.0, 255.0, 255.0]) def _get_colour(shape_role, inverse_inertia): base_colour = ROLE_COLOURS[shape_role] f = (inverse_inertia == 0) * 1 is_not_normal = (shape_role != 0) * 1 return jnp.array( [ base_colour, base_colour, FIXATED_COLOUR, base_colour * 0.5, ] )[2 * f + is_not_normal] # Pixels per unit distance ppud = params.pixels_per_unit // downscale downscaled_screen_dim = (screen_dim[0] // downscale, screen_dim[1] // downscale) full_screen_size = ( downscaled_screen_dim[0] + (static_params.max_shape_size * 2 * ppud), downscaled_screen_dim[1] + (static_params.max_shape_size * 2 * ppud), ) cleared_screen = clear_screen(full_screen_size, BACKGROUND_COLOUR) def _world_space_to_pixel_space(x): return (x + static_params.max_shape_size) * ppud def fragment_shader_kinetix_circle(position, current_frag, unit_position, uniform): centre, radius, rotation, colour, mask = uniform dist = jnp.sqrt(jnp.square(position - centre).sum()) inside = dist <= radius on_edge = dist > radius - 2 # TODO - precompute? normal = jnp.array([jnp.sin(rotation), -jnp.cos(rotation)]) dist = dist_from_line(position, centre, centre + normal) on_edge |= (dist < 1) & (jnp.dot(normal, position - centre) <= 0) fragment = jax.lax.select(on_edge, jnp.zeros(3), colour) return jax.lax.select(inside & mask, fragment, current_frag) def fragment_shader_kinetix_joint(position, current_frag, unit_position, uniform): texture, colour, mask = uniform tex_coord = ( jnp.array( [ joint_tex_size * unit_position[0], joint_tex_size * unit_position[1], ] ) - 0.5 ) tex_frag = nearest_neighbour(texture, tex_coord) tex_frag = tex_frag.at[3].mul(mask) tex_frag = tex_frag.at[:3].mul(colour / 255.0) tex_frag = (tex_frag[3] * tex_frag[:3]) + ((1.0 - tex_frag[3]) * current_frag) return tex_frag thruster_pixel_size = thruster_tex_size // downscale thruster_pixel_size_diagonal = (thruster_pixel_size * np.sqrt(2)).astype(jnp.int32) + 1 def fragment_shader_kinetix_thruster(fragment_position, current_frag, unit_position, uniform): thruster_position, rotation, texture, mask = uniform tex_position = jnp.matmul(rmat(-rotation), (fragment_position - thruster_position)) / thruster_pixel_size + 0.5 mask &= (tex_position[0] >= 0) & (tex_position[0] <= 1) & (tex_position[1] >= 0) & (tex_position[1] <= 1) eps = 0.001 tex_coord = ( jnp.array( [ thruster_tex_size * tex_position[0], thruster_tex_size * tex_position[1], ] ) - 0.5 + eps ) tex_frag = nearest_neighbour(texture, tex_coord) tex_frag = tex_frag.at[3].mul(mask) tex_frag = (tex_frag[3] * tex_frag[:3]) + ((1.0 - tex_frag[3]) * current_frag) return tex_frag patch_size_1d = static_params.max_shape_size * ppud patch_size = (patch_size_1d, patch_size_1d) circle_renderer = make_renderer(full_screen_size, fragment_shader_kinetix_circle, patch_size, batched=True) quad_renderer = make_renderer(full_screen_size, fragment_shader_edged_quad, patch_size, batched=True) big_quad_renderer = make_renderer(full_screen_size, fragment_shader_edged_quad, downscaled_screen_dim) joint_pixel_size = joint_tex_size // downscale joint_renderer = make_renderer( full_screen_size, fragment_shader_kinetix_joint, (joint_pixel_size, joint_pixel_size), batched=True ) thruster_renderer = make_renderer( full_screen_size, fragment_shader_kinetix_thruster, (thruster_pixel_size_diagonal, thruster_pixel_size_diagonal), batched=True, ) @jax.jit def render_pixels(state: EnvState): pixels = cleared_screen # Floor floor_uniform = ( _world_space_to_pixel_space(state.polygon.position[0, None, :] + state.polygon.vertices[0]), _get_colour(state.polygon_shape_roles[0], 0), jnp.zeros(3), True, ) pixels = big_quad_renderer(pixels, _world_space_to_pixel_space(jnp.zeros(2, dtype=jnp.int32)), floor_uniform) # Rectangles rectangle_patch_positions = _world_space_to_pixel_space( state.polygon.position - (static_params.max_shape_size / 2.0) ).astype(jnp.int32) rectangle_rmats = jax.vmap(rmat)(state.polygon.rotation) rectangle_rmats = jnp.repeat(rectangle_rmats[:, None, :, :], repeats=static_params.max_polygon_vertices, axis=1) rectangle_vertices_pixel_space = _world_space_to_pixel_space( state.polygon.position[:, None, :] + jax.vmap(jax.vmap(jnp.matmul))(rectangle_rmats, state.polygon.vertices) ) rectangle_colours = jax.vmap(_get_colour)(state.polygon_shape_roles, state.polygon.inverse_mass) rectangle_edge_colours = jnp.zeros((static_params.num_polygons, 3)) rectangle_uniforms = ( rectangle_vertices_pixel_space, rectangle_colours, rectangle_edge_colours, state.polygon.active, ) pixels = quad_renderer(pixels, rectangle_patch_positions, rectangle_uniforms) # Circles circle_positions_pixel_space = _world_space_to_pixel_space(state.circle.position) circle_radii_pixel_space = state.circle.radius * ppud circle_patch_positions = _world_space_to_pixel_space( state.circle.position - (static_params.max_shape_size / 2.0) ).astype(jnp.int32) circle_colours = jax.vmap(_get_colour)(state.circle_shape_roles, state.circle.inverse_mass) circle_uniforms = ( circle_positions_pixel_space, circle_radii_pixel_space, state.circle.rotation, circle_colours, state.circle.active, ) pixels = circle_renderer(pixels, circle_patch_positions, circle_uniforms) # Joints joint_patch_positions = jnp.round( _world_space_to_pixel_space(state.joint.global_position) - (joint_pixel_size // 2) ).astype(jnp.int32) joint_textures = jax.vmap(jax.lax.select, in_axes=(0, None, None))( state.joint.is_fixed_joint, FJOINT_TEXTURE_6_RGBA, RJOINT_TEXTURE_6_RGBA ) joint_colours = JOINT_COLOURS[ (state.motor_bindings + 1) * (state.joint.motor_on & (~state.joint.is_fixed_joint)) ] joint_uniforms = (joint_textures, joint_colours, state.joint.active) pixels = joint_renderer(pixels, joint_patch_positions, joint_uniforms) # Thrusters thruster_positions = jnp.round(_world_space_to_pixel_space(state.thruster.global_position)).astype(jnp.int32) thruster_patch_positions = thruster_positions - (thruster_pixel_size_diagonal // 2) thruster_textures = coloured_thruster_textures[state.thruster_bindings + 1] thruster_rotations = ( state.thruster.rotation + jax.vmap(select_shape, in_axes=(None, 0, None))( state, state.thruster.object_index, static_params ).rotation ) thruster_uniforms = (thruster_positions, thruster_rotations, thruster_textures, state.thruster.active) pixels = thruster_renderer(pixels, thruster_patch_positions, thruster_uniforms) # Crop out the sides crop_amount = static_params.max_shape_size * ppud return pixels[crop_amount:-crop_amount, crop_amount:-crop_amount] return render_pixels @struct.dataclass class PixelsObservation: image: jnp.ndarray global_info: jnp.ndarray def make_render_pixels_rl(params, static_params: StaticEnvParams): render_fn = make_render_pixels(params, static_params) def inner(state): pixels = render_fn(state) / 255.0 return PixelsObservation( image=pixels, global_info=jnp.array([state.gravity[1] / 10.0]), ) return inner