Add warning on mujoco using
Browse files- README.md +7 -0
- app.py +3 -1
- packages.txt +2 -1
- requirements.txt +1 -1
- utils.py +0 -4
README.md
CHANGED
@@ -11,3 +11,10 @@ license: mit
|
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
14 |
+
|
15 |
+
### Application demo :
|
16 |
+
|
17 |
+
- Choose a RL environment from the gymnasium library. A policy from a pre-trained Proximal Policy Optimization (PPO) agent will automatically be loaded, which generates an expert dataset and videos of the agent's performance in the selected environment.
|
18 |
+
- Click the "Compute Symbolic Policy" button to train a KAN policy on the expert dataset. Once it is done, you can visualize the KAN network and watch videos of the KAN agent's performance in the selected environment !
|
19 |
+
|
20 |
+
<img alt="Interpretability app demo" src="demo/app_demo.gif">
|
app.py
CHANGED
@@ -36,7 +36,7 @@ For more information about KAN you can read the [paper](https://arxiv.org/abs/24
|
|
36 |
To follow the progress of KAN in RL you can check the repo [kanrl](https://github.com/riiswa/kanrl).
|
37 |
"""
|
38 |
|
39 |
-
envs = ["CartPole-v1", "MountainCar-v0", "Acrobot-v1", "Pendulum-v1", "MountainCarContinuous-v0", "LunarLander-v2", "Swimmer-
|
40 |
|
41 |
|
42 |
if __name__ == "__main__":
|
@@ -45,6 +45,8 @@ if __name__ == "__main__":
|
|
45 |
def load_video_and_dataset(_env_name):
|
46 |
env_name = _env_name
|
47 |
|
|
|
|
|
48 |
dataset_path, video_path = generate_dataset_from_expert("ppo", _env_name, 15, 3)
|
49 |
return video_path, gr.Button("Compute the symbolic policy!", interactive=True), {
|
50 |
"dataset_path": dataset_path,
|
|
|
36 |
To follow the progress of KAN in RL you can check the repo [kanrl](https://github.com/riiswa/kanrl).
|
37 |
"""
|
38 |
|
39 |
+
envs = ["CartPole-v1", "MountainCar-v0", "Acrobot-v1", "Pendulum-v1", "MountainCarContinuous-v0", "LunarLander-v2", "Swimmer-v3", "Hopper-v3"]
|
40 |
|
41 |
|
42 |
if __name__ == "__main__":
|
|
|
45 |
def load_video_and_dataset(_env_name):
|
46 |
env_name = _env_name
|
47 |
|
48 |
+
if env_name.startswith("Swimmer") or env_name.startswith("Hopper-v3"):
|
49 |
+
gr.Warning("We're currently in the process of adding support for Mujoco environments, so the application may encounter crashes during this phase. We encourage contributors to join us in the repository https://github.com/riiswa/kanrl to assist in the development and support of other environments. Your contributions are invaluable in ensuring a robust and comprehensive framework.")
|
50 |
dataset_path, video_path = generate_dataset_from_expert("ppo", _env_name, 15, 3)
|
51 |
return video_path, gr.Button("Compute the symbolic policy!", interactive=True), {
|
52 |
"dataset_path": dataset_path,
|
packages.txt
CHANGED
@@ -3,4 +3,5 @@ libgl1-mesa-glx
|
|
3 |
libglew-dev
|
4 |
libosmesa6-dev
|
5 |
software-properties-common
|
6 |
-
patchelf
|
|
|
|
3 |
libglew-dev
|
4 |
libosmesa6-dev
|
5 |
software-properties-common
|
6 |
+
patchelf
|
7 |
+
swig
|
requirements.txt
CHANGED
@@ -13,4 +13,4 @@ stable_baselines3
|
|
13 |
rl_zoo3
|
14 |
gym
|
15 |
shimmy>=0.2.1
|
16 |
-
mujoco-py
|
|
|
13 |
rl_zoo3
|
14 |
gym
|
15 |
shimmy>=0.2.1
|
16 |
+
free-mujoco-py
|
utils.py
CHANGED
@@ -112,10 +112,6 @@ def rollouts(env, policy, num_episodes=1):
|
|
112 |
def generate_dataset_from_expert(algo, env_name, num_train_episodes=5, num_test_episodes=2, force=False):
|
113 |
if env_name.startswith("Swimmer") or env_name.startswith("Hopper"):
|
114 |
install_mujoco()
|
115 |
-
if env_name == "Swimmer-v4":
|
116 |
-
env_name = "Swimmer-v3"
|
117 |
-
elif env_name == "Hopper-v4":
|
118 |
-
env_name = "Hopper-v3"
|
119 |
dataset_path = os.path.join("datasets", f"{algo}-{env_name}.pt")
|
120 |
video_path = os.path.join("videos", f"{algo}-{env_name}.mp4")
|
121 |
if os.path.exists(dataset_path) and os.path.exists(video_path) and not force:
|
|
|
112 |
def generate_dataset_from_expert(algo, env_name, num_train_episodes=5, num_test_episodes=2, force=False):
|
113 |
if env_name.startswith("Swimmer") or env_name.startswith("Hopper"):
|
114 |
install_mujoco()
|
|
|
|
|
|
|
|
|
115 |
dataset_path = os.path.join("datasets", f"{algo}-{env_name}.pt")
|
116 |
video_path = os.path.join("videos", f"{algo}-{env_name}.mp4")
|
117 |
if os.path.exists(dataset_path) and os.path.exists(video_path) and not force:
|