riiswa commited on
Commit
934779e
1 Parent(s): b471ab8

Try to debug

Browse files
Files changed (2) hide show
  1. app.py +2 -6
  2. interpretable.py +1 -3
app.py CHANGED
@@ -37,7 +37,7 @@ To follow the progress of KAN in RL you can check the repo [kanrl](https://githu
37
 
38
  [![riiswa/kanrl - GitHub](https://gh-card.dev/repos/riiswa/kanrl.svg)](https://github.com/riiswa/kanrl)
39
 
40
- *Please be patient, as the process may take a few minutes to run, especially in environments with large state/action spaces or with a complex KAN architecture.*
41
  """
42
 
43
  envs = ["CartPole-v1", "MountainCar-v0", "Acrobot-v1", "Pendulum-v1", "MountainCarContinuous-v0", "LunarLander-v2", "Swimmer-v3", "Hopper-v3", "HalfCheetah-v3", "Walker2d-v3"]
@@ -48,13 +48,9 @@ if __name__ == "__main__":
48
 
49
  def load_video_and_dataset(_env_name):
50
  env_name = _env_name
51
- if env_name in ["Swimmer-v3", "Hopper-v3", "HalfCheetah-v3", "Walker2d-v3"]:
52
- gr.Warning(
53
- "We're currently in the process of adding support for Mujoco environments, so the application may encounter crashes during this phase. We encourage contributors to join us in the repository https://github.com/riiswa/kanrl to assist in the development and support of other environments. Your contributions are invaluable in ensuring a robust and comprehensive framework."
54
- )
55
  agent = "ppo"
56
  if env_name == "Swimmer-v3" or env_name == "Walker2d-v3":
57
- agent ="trpo"
58
 
59
  dataset_path, video_path = generate_dataset_from_expert(agent, _env_name, 15, 3)
60
  return video_path, gr.Button("Compute the symbolic policy!", interactive=True), {
 
37
 
38
  [![riiswa/kanrl - GitHub](https://gh-card.dev/repos/riiswa/kanrl.svg)](https://github.com/riiswa/kanrl)
39
 
40
+ *Please be patient, as the process may take a few minutes to run, especially in environments with large state/action spaces or with a complex KAN architecture. For optimal performance, default parameters may not suffice. Feel free to experiment with different settings to achieve desired results.*
41
  """
42
 
43
  envs = ["CartPole-v1", "MountainCar-v0", "Acrobot-v1", "Pendulum-v1", "MountainCarContinuous-v0", "LunarLander-v2", "Swimmer-v3", "Hopper-v3", "HalfCheetah-v3", "Walker2d-v3"]
 
48
 
49
  def load_video_and_dataset(_env_name):
50
  env_name = _env_name
 
 
 
 
51
  agent = "ppo"
52
  if env_name == "Swimmer-v3" or env_name == "Walker2d-v3":
53
+ agent = "trpo"
54
 
55
  dataset_path, video_path = generate_dataset_from_expert(agent, _env_name, 15, 3)
56
  return video_path, gr.Button("Compute the symbolic policy!", interactive=True), {
interpretable.py CHANGED
@@ -35,12 +35,10 @@ class InterpretablePolicyExtractor:
35
  dataset["test_label"] = dataset["test_label"][:, None]
36
  dataset["train_input"] = dataset["train_input"].float()
37
  dataset["test_input"] = dataset["test_input"].float()
38
- for k,v in dataset.items():
39
- print(k, v.shape, v.dtype)
40
  return self.policy.train(dataset, opt="LBFGS", steps=steps, loss_fn=self.loss_fn)
41
 
42
  def forward(self, observation):
43
- observation = torch.from_numpy(observation)
44
  action = self.policy(observation.unsqueeze(0))
45
  if self._action_is_discrete:
46
  return action.argmax(axis=-1).squeeze().item()
 
35
  dataset["test_label"] = dataset["test_label"][:, None]
36
  dataset["train_input"] = dataset["train_input"].float()
37
  dataset["test_input"] = dataset["test_input"].float()
 
 
38
  return self.policy.train(dataset, opt="LBFGS", steps=steps, loss_fn=self.loss_fn)
39
 
40
  def forward(self, observation):
41
+ observation = torch.from_numpy(observation).float()
42
  action = self.policy(observation.unsqueeze(0))
43
  if self._action_is_discrete:
44
  return action.argmax(axis=-1).squeeze().item()