riiswa commited on
Commit
81461c8
1 Parent(s): d0fc1a7

Fix state handling

Browse files
Files changed (1) hide show
  1. app.py +32 -59
app.py CHANGED
@@ -15,7 +15,6 @@ import matplotlib.pyplot as plt
15
  import torch
16
 
17
  import gradio as gr
18
- import sys
19
 
20
  intro = """
21
  # Making RL Policy Interpretable with Kolmogorov-Arnold Network 🧠 ➙ 🔢
@@ -40,46 +39,18 @@ To follow the progress of KAN in RL you can check the repo [kanrl](https://githu
40
  envs = ["CartPole-v1", "MountainCar-v0", "Acrobot-v1", "Pendulum-v1", "MountainCarContinuous-v0", "LunarLander-v2", "Swimmer-v4", "Hopper-v4"]
41
 
42
 
43
- class Logger:
44
- def __init__(self, filename):
45
- self.terminal = sys.stdout
46
- self.log = open(filename, "w")
47
-
48
- def write(self, message):
49
- self.terminal.write(message)
50
- self.log.write(message)
51
-
52
- def flush(self):
53
- self.terminal.flush()
54
- self.log.flush()
55
-
56
- def isatty(self):
57
- return False
58
-
59
-
60
- sys.stdout = Logger("output.log")
61
- sys.stderr = Logger("output.log")
62
-
63
-
64
- def read_logs():
65
- sys.stdout.flush()
66
- with open("output.log", "r") as f:
67
- return f.read()
68
-
69
-
70
  if __name__ == "__main__":
71
  torch.set_default_dtype(torch.float32)
72
- dataset_path = None
73
- ipe = None
74
- env_name = None
75
 
76
  def load_video_and_dataset(_env_name):
77
- global dataset_path
78
- global env_name
79
  env_name = _env_name
80
 
81
  dataset_path, video_path = generate_dataset_from_expert("ppo", _env_name, 15, 3)
82
- return video_path, gr.Button("Compute the symbolic policy!", interactive=True)
 
 
 
 
83
 
84
 
85
  def parse_integer_list(input_str):
@@ -94,45 +65,44 @@ if __name__ == "__main__":
94
  except ValueError:
95
  return False
96
 
97
- def extract_interpretable_policy(env_name, kan_widths):
98
- global ipe
99
-
100
  widths = parse_integer_list(kan_widths)
101
  if kan_widths is False:
102
  gr.Warning(f"Please enter widths {kan_widths} in the right format... The current run is executed with no hidden layer.")
103
  widths = None
104
 
105
- ipe = InterpretablePolicyExtractor(env_name, widths)
106
- ipe.train_from_dataset(dataset_path, steps=50)
107
 
108
- ipe.policy.prune()
109
- ipe.policy.plot(mask=True, scale=5)
110
 
111
  fig = plt.gcf()
112
  fig.canvas.draw()
113
- return np.array(fig.canvas.renderer.buffer_rgba())
 
 
 
114
 
115
- def symbolic_policy():
116
- global ipe
117
- global env_name
118
  lib = ['x', 'x^2', 'x^3', 'x^4', 'exp', 'log', 'sqrt', 'tanh', 'sin', 'abs']
119
- ipe.policy.auto_symbolic(lib=lib)
120
- env = gym.make(env_name, render_mode="rgb_array")
121
- env = RecordVideo(env, video_folder="videos", episode_trigger=lambda x: True, name_prefix=f"kan-{env_name}")
122
 
123
- rollouts(env, ipe.forward, 2)
124
 
125
- video_path = os.path.join("videos", f"kan-{env_name}.mp4")
126
- video_files = glob.glob(os.path.join("videos", f"kan-{env_name}-episode*.mp4"))
127
  clips = [VideoFileClip(file) for file in video_files]
128
  final_clip = concatenate_videoclips(clips)
129
  final_clip.write_videofile(video_path, codec="libx264", fps=24)
130
 
131
  symbolic_formula = f"### The symbolic formula of the policy is:"
132
- formulas = ipe.policy.symbolic_formula()[0]
133
  for i, formula in enumerate(formulas):
134
  symbolic_formula += "\n$$ a_" + str(i) + "=" + latex(formula) + "$$"
135
- if ipe._action_is_discrete:
136
  symbolic_formula += "\n" + r"$$ a = \underset{i}{\mathrm{argmax}} \ a_i.$$"
137
 
138
  return video_path, symbolic_formula
@@ -143,6 +113,11 @@ if __name__ == "__main__":
143
  """
144
 
145
  with gr.Blocks(theme='gradio/monochrome', css=css) as app:
 
 
 
 
 
146
  gr.Markdown(intro)
147
 
148
  with gr.Row():
@@ -151,18 +126,16 @@ if __name__ == "__main__":
151
  choice = gr.Dropdown(envs, label="Environment name")
152
  expert_video = gr.Video(label="Expert policy video", interactive=False, autoplay=True)
153
  kan_widths = gr.Textbox(value="2", label="Widths of the hidden layers of the KAN, separated by commas (e.g. `3,3`). Leave empty if there are no hidden layers.")
 
154
  button = gr.Button("Compute the symbolic policy!", interactive=False)
155
  with gr.Column():
156
  gr.Markdown("### Symbolic policy extraction")
157
  kan_architecture = gr.Image(interactive=False, label="KAN architecture")
158
  sym_video = gr.Video(label="Symbolic policy video", interactive=False, autoplay=True)
159
  sym_formula = gr.Markdown(elem_id="formula")
160
- with gr.Accordion("See logs"):
161
- logs = gr.Textbox(label="Logs", interactive=False)
162
- choice.input(load_video_and_dataset, inputs=[choice], outputs=[expert_video, button])
163
- button.click(extract_interpretable_policy, inputs=[choice, kan_widths], outputs=[kan_architecture]).then(
164
- symbolic_policy, inputs=[], outputs=[sym_video, sym_formula]
165
  )
166
- app.load(read_logs, None, logs, every=1)
167
 
168
  app.launch()
 
15
  import torch
16
 
17
  import gradio as gr
 
18
 
19
  intro = """
20
  # Making RL Policy Interpretable with Kolmogorov-Arnold Network 🧠 ➙ 🔢
 
39
  envs = ["CartPole-v1", "MountainCar-v0", "Acrobot-v1", "Pendulum-v1", "MountainCarContinuous-v0", "LunarLander-v2", "Swimmer-v4", "Hopper-v4"]
40
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  if __name__ == "__main__":
43
  torch.set_default_dtype(torch.float32)
 
 
 
44
 
45
  def load_video_and_dataset(_env_name):
 
 
46
  env_name = _env_name
47
 
48
  dataset_path, video_path = generate_dataset_from_expert("ppo", _env_name, 15, 3)
49
+ return video_path, gr.Button("Compute the symbolic policy!", interactive=True), {
50
+ "dataset_path": dataset_path,
51
+ "ipe": None,
52
+ "env_name": env_name
53
+ }
54
 
55
 
56
  def parse_integer_list(input_str):
 
65
  except ValueError:
66
  return False
67
 
68
+ def extract_interpretable_policy(kan_widths, epochs, state):
 
 
69
  widths = parse_integer_list(kan_widths)
70
  if kan_widths is False:
71
  gr.Warning(f"Please enter widths {kan_widths} in the right format... The current run is executed with no hidden layer.")
72
  widths = None
73
 
74
+ state["ipe"] = InterpretablePolicyExtractor(state["env_name"], widths)
75
+ state["ipe"].train_from_dataset(state["dataset_path"], steps=epochs)
76
 
77
+ state["ipe"].policy.prune()
78
+ state["ipe"].policy.plot(mask=True, scale=5)
79
 
80
  fig = plt.gcf()
81
  fig.canvas.draw()
82
+ kan_architecture = np.array(fig.canvas.renderer.buffer_rgba())
83
+ plt.close()
84
+
85
+ return kan_architecture, state, fig
86
 
87
+ def symbolic_policy(state):
 
 
88
  lib = ['x', 'x^2', 'x^3', 'x^4', 'exp', 'log', 'sqrt', 'tanh', 'sin', 'abs']
89
+ state["ipe"].policy.auto_symbolic(lib=lib)
90
+ env = gym.make(state["env_name"], render_mode="rgb_array")
91
+ env = RecordVideo(env, video_folder="videos", episode_trigger=lambda x: True, name_prefix=f"""kan-{state["env_name"]}""")
92
 
93
+ rollouts(env, state["ipe"].forward, 2)
94
 
95
+ video_path = os.path.join("videos", f"""kan-{state["env_name"]}.mp4""")
96
+ video_files = glob.glob(os.path.join("videos", f"""kan-{state["env_name"]}-episode*.mp4"""))
97
  clips = [VideoFileClip(file) for file in video_files]
98
  final_clip = concatenate_videoclips(clips)
99
  final_clip.write_videofile(video_path, codec="libx264", fps=24)
100
 
101
  symbolic_formula = f"### The symbolic formula of the policy is:"
102
+ formulas = state["ipe"].policy.symbolic_formula()[0]
103
  for i, formula in enumerate(formulas):
104
  symbolic_formula += "\n$$ a_" + str(i) + "=" + latex(formula) + "$$"
105
+ if state["ipe"]._action_is_discrete:
106
  symbolic_formula += "\n" + r"$$ a = \underset{i}{\mathrm{argmax}} \ a_i.$$"
107
 
108
  return video_path, symbolic_formula
 
113
  """
114
 
115
  with gr.Blocks(theme='gradio/monochrome', css=css) as app:
116
+ state = gr.State({
117
+ "dataset_path": None,
118
+ "ipe": None,
119
+ "env_name": None
120
+ })
121
  gr.Markdown(intro)
122
 
123
  with gr.Row():
 
126
  choice = gr.Dropdown(envs, label="Environment name")
127
  expert_video = gr.Video(label="Expert policy video", interactive=False, autoplay=True)
128
  kan_widths = gr.Textbox(value="2", label="Widths of the hidden layers of the KAN, separated by commas (e.g. `3,3`). Leave empty if there are no hidden layers.")
129
+ epochs = gr.Number(value=20, label="KAN training Steps.", minimum=1, maximum=100)
130
  button = gr.Button("Compute the symbolic policy!", interactive=False)
131
  with gr.Column():
132
  gr.Markdown("### Symbolic policy extraction")
133
  kan_architecture = gr.Image(interactive=False, label="KAN architecture")
134
  sym_video = gr.Video(label="Symbolic policy video", interactive=False, autoplay=True)
135
  sym_formula = gr.Markdown(elem_id="formula")
136
+ choice.input(load_video_and_dataset, inputs=[choice], outputs=[expert_video, button, state])
137
+ button.click(extract_interpretable_policy, inputs=[kan_widths, epochs, state], outputs=[kan_architecture, state]).then(
138
+ symbolic_policy, inputs=[state], outputs=[sym_video, sym_formula]
 
 
139
  )
 
140
 
141
  app.launch()