riiswa's picture
Update
7377a42
raw
history blame
6.81 kB
import glob
import os
import gymnasium as gym
import numpy as np
from gymnasium.wrappers import RecordVideo
from moviepy.video.compositing.concatenate import concatenate_videoclips
from moviepy.video.io.VideoFileClip import VideoFileClip
from sympy import latex
from interpretable import InterpretablePolicyExtractor
from utils import generate_dataset_from_expert, rollouts
import matplotlib.pyplot as plt
import torch
import gradio as gr
intro = """
# Making RL Policy Interpretable with Kolmogorov-Arnold Network 🧠 ➙ 🔢
Waris Radji<sup>1</sup>, Corentin Léger<sup>2</sup>, Hector Kohler<sup>1</sup>
<small><sup>1</sup>[Inria, team Scool](https://team.inria.fr/scool/) <sup>2</sup>[Inria, team Flowers](https://flowers.inria.fr/)</small>
In this demo, we showcase a method to make a trained Reinforcement Learning (RL) policy interpretable using the Kolmogorov-Arnold Network (KAN). The process involves transferring the knowledge from a pre-trained RL policy to a KAN. We achieve this by training the KAN to map actions from observations obtained from trajectories of the pre-trained policy.
## Procedure
- Train the KAN using observations from trajectories generated by a pre-trained RL policy, the KAN learns to map observations to corresponding actions.
- Apply symbolic regression algorithms to the KAN's learned mapping.
- Extract an interpretable policy expressed in symbolic form.
For more information about KAN you can read the [paper](https://arxiv.org/abs/2404.19756), and check the [PyTorch official information](https://github.com/KindXiaoming/pykan).
To follow the progress of KAN in RL you can check the repo [kanrl](https://github.com/riiswa/kanrl).
[![riiswa/kanrl - GitHub](https://gh-card.dev/repos/riiswa/kanrl.svg)](https://github.com/riiswa/kanrl)
"""
envs = ["CartPole-v1", "MountainCar-v0", "Acrobot-v1", "Pendulum-v1", "MountainCarContinuous-v0", "LunarLander-v2", "Swimmer-v3", "Hopper-v3"]
if __name__ == "__main__":
torch.set_default_dtype(torch.float32)
def load_video_and_dataset(_env_name):
env_name = _env_name
if env_name.startswith("Swimmer") or env_name.startswith("Hopper-v3"):
gr.Warning("We're currently in the process of adding support for Mujoco environments, so the application may encounter crashes during this phase. We encourage contributors to join us in the repository https://github.com/riiswa/kanrl to assist in the development and support of other environments. Your contributions are invaluable in ensuring a robust and comprehensive framework.")
dataset_path, video_path = generate_dataset_from_expert("ppo", _env_name, 15, 3)
return video_path, gr.Button("Compute the symbolic policy!", interactive=True), {
"dataset_path": dataset_path,
"ipe": None,
"env_name": env_name
}
def parse_integer_list(input_str):
if not input_str or input_str.isspace():
return None
elements = input_str.split(',')
try:
int_list = tuple([int(elem.strip()) for elem in elements])
return int_list
except ValueError:
return False
def extract_interpretable_policy(kan_widths, epochs, state):
widths = parse_integer_list(kan_widths)
if kan_widths is False:
gr.Warning(f"Please enter widths {kan_widths} in the right format... The current run is executed with no hidden layer.")
widths = None
state["ipe"] = InterpretablePolicyExtractor(state["env_name"], widths)
state["ipe"].train_from_dataset(state["dataset_path"], steps=epochs)
state["ipe"].policy.prune()
state["ipe"].policy.plot(mask=True, scale=5)
fig = plt.gcf()
fig.canvas.draw()
kan_architecture = np.array(fig.canvas.renderer.buffer_rgba())
plt.close()
return kan_architecture, state, fig
def symbolic_policy(state):
lib = ['x', 'x^2', 'x^3', 'x^4', 'exp', 'log', 'sqrt', 'tanh', 'sin', 'abs']
state["ipe"].policy.auto_symbolic(lib=lib)
env = gym.make(state["env_name"], render_mode="rgb_array")
env = RecordVideo(env, video_folder="videos", episode_trigger=lambda x: True, name_prefix=f"""kan-{state["env_name"]}""")
rollouts(env, state["ipe"].forward, 2)
video_path = os.path.join("videos", f"""kan-{state["env_name"]}.mp4""")
video_files = glob.glob(os.path.join("videos", f"""kan-{state["env_name"]}-episode*.mp4"""))
clips = [VideoFileClip(file) for file in video_files]
final_clip = concatenate_videoclips(clips)
final_clip.write_videofile(video_path, codec="libx264", fps=24)
symbolic_formula = f"### The symbolic formula of the policy is:"
formulas = state["ipe"].policy.symbolic_formula()[0]
for i, formula in enumerate(formulas):
symbolic_formula += "\n$$ a_" + str(i) + "=" + latex(formula) + "$$"
if state["ipe"]._action_is_discrete:
symbolic_formula += "\n" + r"$$ a = \underset{i}{\mathrm{argmax}} \ a_i.$$"
return video_path, symbolic_formula
css = """
#formula {overflow-x: auto!important};
"""
with gr.Blocks(theme='gradio/monochrome', css=css) as app:
state = gr.State({
"dataset_path": None,
"ipe": None,
"env_name": None
})
gr.Markdown(intro)
with gr.Row():
with gr.Column():
gr.Markdown("### Pretrained policy loading (PPO from [rl-baselines3-zoo](https://github.com/DLR-RM/rl-baselines3-zoo))")
choice = gr.Dropdown(envs, label="Environment name")
expert_video = gr.Video(label="Expert policy video", interactive=False, autoplay=True)
kan_widths = gr.Textbox(value="2", label="Widths of the hidden layers of the KAN, separated by commas (e.g. `3,3`). Leave empty if there are no hidden layers.")
epochs = gr.Number(value=20, label="KAN training Steps.", minimum=1, maximum=100)
button = gr.Button("Compute the symbolic policy!", interactive=False)
with gr.Column():
gr.Markdown("### Symbolic policy extraction")
kan_architecture = gr.Image(interactive=False, label="KAN architecture")
sym_video = gr.Video(label="Symbolic policy video", interactive=False, autoplay=True)
sym_formula = gr.Markdown(elem_id="formula")
choice.input(load_video_and_dataset, inputs=[choice], outputs=[expert_video, button, state])
button.click(extract_interpretable_policy, inputs=[kan_widths, epochs, state], outputs=[kan_architecture, state]).then(
symbolic_policy, inputs=[state], outputs=[sym_video, sym_formula]
)
app.launch()