fahmizainal17 commited on
Commit
2c8df23
·
verified ·
1 Parent(s): cd24bf1

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. .github/workflows/update_space.yml +28 -0
  2. .gitignore +174 -0
  3. README.md +4 -8
  4. requirements.txt +3 -0
  5. rl_gradio.py +564 -0
.github/workflows/update_space.yml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Run Python script
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+
8
+ jobs:
9
+ build:
10
+ runs-on: ubuntu-latest
11
+
12
+ steps:
13
+ - name: Checkout
14
+ uses: actions/checkout@v2
15
+
16
+ - name: Set up Python
17
+ uses: actions/setup-python@v2
18
+ with:
19
+ python-version: '3.12.8'
20
+
21
+ - name: Install Gradio
22
+ run: python -m pip install gradio
23
+
24
+ - name: Log in to Hugging Face
25
+ run: python -c 'import huggingface_hub; huggingface_hub.login(token="${{ secrets.hf_token }}")'
26
+
27
+ - name: Deploy to Spaces
28
+ run: gradio deploy
.gitignore ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # PyPI configuration file
171
+ .pypirc
172
+
173
+ .gradio
174
+ .hugging_face
README.md CHANGED
@@ -1,12 +1,8 @@
1
  ---
2
- title: Q-Learning GridWorld Simulator
3
- emoji: 🐠
4
- colorFrom: purple
5
- colorTo: blue
6
  sdk: gradio
7
  sdk_version: 5.19.0
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Q-Learning_GridWorld_Simulator
3
+ app_file: rl_gradio.py
 
 
4
  sdk: gradio
5
  sdk_version: 5.19.0
 
 
6
  ---
7
+ # Reinforcement_Learning_Project
8
+ Simple Project to enforce learning by Q-Learning
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio==5.19.0
2
+ matplotlib==3.10.1
3
+ numpy==2.2.3
rl_gradio.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import gradio as gr
4
+ import time
5
+ from matplotlib.colors import ListedColormap
6
+ import matplotlib.patches as patches
7
+
8
+ class GridWorld:
9
+ """A simple grid world environment with obstacles."""
10
+
11
+ def __init__(self, height=4, width=4):
12
+ # Grid dimensions
13
+ self.height = height
14
+ self.width = width
15
+
16
+ # Define states
17
+ self.n_states = self.height * self.width
18
+
19
+ # Actions: 0: up, 1: right, 2: down, 3: left
20
+ self.n_actions = 4
21
+ self.action_names = ['Up', 'Right', 'Down', 'Left']
22
+
23
+ # Define rewards
24
+ self.rewards = np.zeros((self.height, self.width))
25
+ # Goal state
26
+ self.rewards[self.height-1, self.width-1] = 1.0
27
+ # Obstacles (negative reward)
28
+ self.obstacles = []
29
+ if height >= 4 and width >= 4:
30
+ self.rewards[1, 1] = -1.0
31
+ self.rewards[1, 2] = -1.0
32
+ self.rewards[2, 1] = -1.0
33
+ self.obstacles = [(1, 1), (1, 2), (2, 1)]
34
+
35
+ # Start state
36
+ self.start_state = (0, 0)
37
+
38
+ # Goal state
39
+ self.goal_state = (self.height-1, self.width-1)
40
+
41
+ # Reset the environment
42
+ self.reset()
43
+
44
+ def reset(self):
45
+ """Reset the agent to the start state."""
46
+ self.agent_position = self.start_state
47
+ return self._get_state()
48
+
49
+ def _get_state(self):
50
+ """Convert the agent's (row, col) position to a state number."""
51
+ row, col = self.agent_position
52
+ return row * self.width + col
53
+
54
+ def _get_pos_from_state(self, state):
55
+ """Convert a state number to (row, col) position."""
56
+ row = state // self.width
57
+ col = state % self.width
58
+ return (row, col)
59
+
60
+ def step(self, action):
61
+ """Take an action and return next_state, reward, done."""
62
+ row, col = self.agent_position
63
+
64
+ # Apply the action
65
+ if action == 0: # up
66
+ row = max(0, row - 1)
67
+ elif action == 1: # right
68
+ col = min(self.width - 1, col + 1)
69
+ elif action == 2: # down
70
+ row = min(self.height - 1, row + 1)
71
+ elif action == 3: # left
72
+ col = max(0, col - 1)
73
+
74
+ # Update agent position
75
+ self.agent_position = (row, col)
76
+
77
+ # Get reward
78
+ reward = self.rewards[row, col]
79
+
80
+ # Check if episode is done
81
+ done = (row, col) == self.goal_state
82
+
83
+ return self._get_state(), reward, done
84
+
85
+ class QLearningAgent:
86
+ """A simple Q-learning agent."""
87
+
88
+ def __init__(self, n_states, n_actions, learning_rate=0.1, discount_factor=0.9, exploration_rate=1.0, exploration_decay=0.995):
89
+ """Initialize the Q-learning agent."""
90
+ self.n_states = n_states
91
+ self.n_actions = n_actions
92
+ self.learning_rate = learning_rate
93
+ self.discount_factor = discount_factor
94
+ self.exploration_rate = exploration_rate
95
+ self.exploration_decay = exploration_decay
96
+
97
+ # Initialize Q-table
98
+ self.q_table = np.zeros((n_states, n_actions))
99
+
100
+ # Track visited states for visualization
101
+ self.visit_counts = np.zeros(n_states)
102
+
103
+ # Training metrics
104
+ self.rewards_history = []
105
+ self.exploration_rates = []
106
+
107
+ def select_action(self, state):
108
+ """Select an action using epsilon-greedy policy."""
109
+ if np.random.random() < self.exploration_rate:
110
+ # Explore: select a random action
111
+ return np.random.randint(self.n_actions)
112
+ else:
113
+ # Exploit: select the action with the highest Q-value
114
+ return np.argmax(self.q_table[state])
115
+
116
+ def update(self, state, action, reward, next_state, done):
117
+ """Update the Q-table using the Q-learning update rule."""
118
+ # Calculate the Q-target
119
+ if done:
120
+ q_target = reward
121
+ else:
122
+ q_target = reward + self.discount_factor * np.max(self.q_table[next_state])
123
+
124
+ # Update the Q-value
125
+ self.q_table[state, action] += self.learning_rate * (q_target - self.q_table[state, action])
126
+
127
+ # Update visit count for visualization
128
+ self.visit_counts[state] += 1
129
+
130
+ def decay_exploration(self):
131
+ """Decay the exploration rate."""
132
+ self.exploration_rate *= self.exploration_decay
133
+ self.exploration_rates.append(self.exploration_rate)
134
+
135
+ def get_policy(self):
136
+ """Return the current greedy policy."""
137
+ return np.argmax(self.q_table, axis=1)
138
+
139
+ def reset(self):
140
+ """Reset the agent for a new training session."""
141
+ self.q_table = np.zeros((self.n_states, self.n_actions))
142
+ self.visit_counts = np.zeros(self.n_states)
143
+ self.rewards_history = []
144
+ self.exploration_rates = []
145
+
146
+
147
+ def create_gridworld_figure(env, agent, episode_count=0, total_reward=0):
148
+ """Create a figure with environment, visit heatmap, and Q-values."""
149
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
150
+ fig.suptitle(f"Episode: {episode_count}, Total Reward: {total_reward:.2f}, Exploration Rate: {agent.exploration_rate:.2f}")
151
+
152
+ # Define colors for different cell types
153
+ colors = {
154
+ 'empty': 'white',
155
+ 'obstacle': 'black',
156
+ 'goal': 'green',
157
+ 'start': 'blue',
158
+ 'agent': 'red'
159
+ }
160
+
161
+ # Helper function to draw grid
162
+ def draw_grid(ax):
163
+ # Create a grid
164
+ for i in range(env.height + 1):
165
+ ax.axhline(i, color='black', lw=1)
166
+ for j in range(env.width + 1):
167
+ ax.axvline(j, color='black', lw=1)
168
+
169
+ # Set limits and remove ticks
170
+ ax.set_xlim(0, env.width)
171
+ ax.set_ylim(0, env.height)
172
+ ax.invert_yaxis() # Invert y-axis to match grid coordinates
173
+ ax.set_xticks(np.arange(0.5, env.width, 1))
174
+ ax.set_yticks(np.arange(0.5, env.height, 1))
175
+ ax.set_xticklabels(range(env.width))
176
+ ax.set_yticklabels(range(env.height))
177
+
178
+ # Helper function to draw a cell
179
+ def draw_cell(ax, row, col, cell_type):
180
+ color = colors.get(cell_type, 'white')
181
+ rect = patches.Rectangle((col, row), 1, 1, linewidth=1, edgecolor='black', facecolor=color, alpha=0.7)
182
+ ax.add_patch(rect)
183
+
184
+ # Helper function to draw an arrow
185
+ def draw_arrow(ax, row, col, action):
186
+ # Coordinates for arrows
187
+ arrow_starts = {
188
+ 0: (col + 0.5, row + 0.7), # up
189
+ 1: (col + 0.3, row + 0.5), # right
190
+ 2: (col + 0.5, row + 0.3), # down
191
+ 3: (col + 0.7, row + 0.5) # left
192
+ }
193
+
194
+ arrow_ends = {
195
+ 0: (col + 0.5, row + 0.3), # up
196
+ 1: (col + 0.7, row + 0.5), # right
197
+ 2: (col + 0.5, row + 0.7), # down
198
+ 3: (col + 0.3, row + 0.5) # left
199
+ }
200
+
201
+ ax.annotate('', xy=arrow_ends[action], xytext=arrow_starts[action],
202
+ arrowprops=dict(arrowstyle='->', lw=2, color='blue'))
203
+
204
+ # Draw Environment
205
+ ax = axes[0]
206
+ ax.set_title('GridWorld Environment')
207
+ draw_grid(ax)
208
+
209
+ # Draw cells
210
+ for i in range(env.height):
211
+ for j in range(env.width):
212
+ if (i, j) in env.obstacles:
213
+ draw_cell(ax, i, j, 'obstacle')
214
+ elif (i, j) == env.goal_state:
215
+ draw_cell(ax, i, j, 'goal')
216
+ elif (i, j) == env.start_state:
217
+ draw_cell(ax, i, j, 'start')
218
+
219
+ # Draw agent
220
+ row, col = env.agent_position
221
+ draw_cell(ax, row, col, 'agent')
222
+
223
+ # Draw policy arrows
224
+ policy = agent.get_policy()
225
+ for state in range(env.n_states):
226
+ row, col = env._get_pos_from_state(state)
227
+ if (row, col) not in env.obstacles and (row, col) != env.goal_state:
228
+ draw_arrow(ax, row, col, policy[state])
229
+
230
+ # Ensure proper aspect ratio
231
+ ax.set_aspect('equal')
232
+
233
+ # Draw Visit Heatmap
234
+ ax = axes[1]
235
+ ax.set_title('State Visitation Heatmap')
236
+ draw_grid(ax)
237
+
238
+ # Create heatmap data
239
+ heatmap_data = np.zeros((env.height, env.width))
240
+ for state in range(env.n_states):
241
+ row, col = env._get_pos_from_state(state)
242
+ heatmap_data[row, col] = agent.visit_counts[state]
243
+
244
+ # Normalize values for coloring
245
+ max_visits = max(1, np.max(heatmap_data))
246
+
247
+ # Draw heatmap
248
+ for i in range(env.height):
249
+ for j in range(env.width):
250
+ if (i, j) in env.obstacles:
251
+ draw_cell(ax, i, j, 'obstacle')
252
+ elif (i, j) == env.goal_state:
253
+ draw_cell(ax, i, j, 'goal')
254
+ else:
255
+ intensity = heatmap_data[i, j] / max_visits
256
+ color = plt.cm.viridis(intensity)
257
+ rect = patches.Rectangle((j, i), 1, 1, linewidth=1, edgecolor='black', facecolor=color, alpha=0.7)
258
+ ax.add_patch(rect)
259
+ # Add visit count text
260
+ if heatmap_data[i, j] > 0:
261
+ ax.text(j + 0.5, i + 0.5, int(heatmap_data[i, j]), ha='center', va='center', color='white' if intensity > 0.5 else 'black')
262
+
263
+ # Ensure proper aspect ratio
264
+ ax.set_aspect('equal')
265
+
266
+ # Draw Q-values
267
+ ax = axes[2]
268
+ ax.set_title('Q-Values')
269
+ draw_grid(ax)
270
+
271
+ # Draw Q-values for each cell
272
+ for state in range(env.n_states):
273
+ row, col = env._get_pos_from_state(state)
274
+
275
+ if (row, col) in env.obstacles:
276
+ draw_cell(ax, row, col, 'obstacle')
277
+ continue
278
+
279
+ if (row, col) == env.goal_state:
280
+ draw_cell(ax, row, col, 'goal')
281
+ continue
282
+
283
+ # Calculate q-values for each action
284
+ q_values = agent.q_table[state]
285
+
286
+ # Draw arrows proportional to Q-values
287
+ for action in range(env.n_actions):
288
+ q_value = q_values[action]
289
+
290
+ # Only draw arrows for positive Q-values
291
+ if q_value > 0:
292
+ # Normalize arrow size
293
+ max_q = max(0.1, np.max(q_values))
294
+ arrow_size = 0.3 * (q_value / max_q)
295
+
296
+ # Position calculations
297
+ center_x = col + 0.5
298
+ center_y = row + 0.5
299
+
300
+ # Direction vectors
301
+ directions = [
302
+ (0, -arrow_size), # up
303
+ (arrow_size, 0), # right
304
+ (0, arrow_size), # down
305
+ (-arrow_size, 0) # left
306
+ ]
307
+
308
+ dx, dy = directions[action]
309
+
310
+ # Draw arrow
311
+ ax.arrow(center_x, center_y, dx, dy, head_width=0.1, head_length=0.1,
312
+ fc='blue', ec='blue', alpha=0.7)
313
+
314
+ # Add Q-value text
315
+ text_positions = [
316
+ (center_x, center_y - 0.25), # up
317
+ (center_x + 0.25, center_y), # right
318
+ (center_x, center_y + 0.25), # down
319
+ (center_x - 0.25, center_y) # left
320
+ ]
321
+
322
+ tx, ty = text_positions[action]
323
+ ax.text(tx, ty, f"{q_value:.2f}", ha='center', va='center', fontsize=8,
324
+ bbox=dict(facecolor='white', alpha=0.7, boxstyle='round,pad=0.1'))
325
+
326
+ # Ensure proper aspect ratio
327
+ ax.set_aspect('equal')
328
+
329
+ plt.tight_layout()
330
+ return fig
331
+
332
+ def create_metrics_figure(agent):
333
+ """Create a figure with training metrics."""
334
+ fig, axes = plt.subplots(1, 2, figsize=(12, 4))
335
+
336
+ # Plot rewards
337
+ if agent.rewards_history:
338
+ axes[0].plot(agent.rewards_history)
339
+ axes[0].set_title('Rewards per Episode')
340
+ axes[0].set_xlabel('Episode')
341
+ axes[0].set_ylabel('Total Reward')
342
+ axes[0].grid(True)
343
+ else:
344
+ axes[0].set_title('No reward data yet')
345
+
346
+ # Plot exploration rate
347
+ if agent.exploration_rates:
348
+ axes[1].plot(agent.exploration_rates)
349
+ axes[1].set_title('Exploration Rate Decay')
350
+ axes[1].set_xlabel('Episode')
351
+ axes[1].set_ylabel('Exploration Rate (ε)')
352
+ axes[1].grid(True)
353
+ else:
354
+ axes[1].set_title('No exploration rate data yet')
355
+
356
+ plt.tight_layout()
357
+ return fig
358
+
359
+ def train_single_episode(env, agent):
360
+ """Train for a single episode and return the total reward."""
361
+ state = env.reset()
362
+ total_reward = 0
363
+ done = False
364
+ steps = 0
365
+ max_steps = env.width * env.height * 3 # Prevent infinite loops
366
+
367
+ while not done and steps < max_steps:
368
+ # Select action
369
+ action = agent.select_action(state)
370
+
371
+ # Take the action
372
+ next_state, reward, done = env.step(action)
373
+
374
+ # Update the Q-table
375
+ agent.update(state, action, reward, next_state, done)
376
+
377
+ # Update state and total reward
378
+ state = next_state
379
+ total_reward += reward
380
+ steps += 1
381
+
382
+ # Decay exploration rate
383
+ agent.decay_exploration()
384
+
385
+ # Store the total reward
386
+ agent.rewards_history.append(total_reward)
387
+
388
+ return total_reward
389
+
390
+ def train_agent(env, agent, episodes, progress=gr.Progress()):
391
+ """Train the agent for a specified number of episodes."""
392
+ progress_text = ""
393
+ progress(0, desc="Starting training...")
394
+
395
+ for episode in progress.tqdm(range(episodes)):
396
+ total_reward = train_single_episode(env, agent)
397
+
398
+ if (episode + 1) % 10 == 0 or episode == episodes - 1:
399
+ progress_text += f"Episode {episode + 1}/{episodes}, Reward: {total_reward}, Exploration: {agent.exploration_rate:.3f}\n"
400
+
401
+ # Create final visualization
402
+ env_fig = create_gridworld_figure(env, agent, episode_count=episodes, total_reward=total_reward)
403
+ metrics_fig = create_metrics_figure(agent)
404
+
405
+ return env_fig, metrics_fig, progress_text
406
+
407
+ def run_test_episode(env, agent):
408
+ """Run a test episode using the learned policy."""
409
+ state = env.reset()
410
+ total_reward = 0
411
+ done = False
412
+ path = [env._get_pos_from_state(state)]
413
+ steps = 0
414
+ max_steps = env.width * env.height * 3 # Prevent infinite loops
415
+
416
+ while not done and steps < max_steps:
417
+ # Select the best action from the learned policy
418
+ action = np.argmax(agent.q_table[state])
419
+
420
+ # Take the action
421
+ next_state, reward, done = env.step(action)
422
+
423
+ # Update state and total reward
424
+ state = next_state
425
+ total_reward += reward
426
+ path.append(env._get_pos_from_state(state))
427
+ steps += 1
428
+
429
+ # Create visualization
430
+ env_fig = create_gridworld_figure(env, agent, episode_count="Test", total_reward=total_reward)
431
+
432
+ # Format path for display
433
+ path_text = "Path taken:\n"
434
+ for i, pos in enumerate(path):
435
+ path_text += f"Step {i}: {pos}\n"
436
+
437
+ return env_fig, path_text, f"Test completed with total reward: {total_reward}"
438
+
439
+ def create_ui():
440
+ """Create the Gradio interface."""
441
+ # Create environment and agent
442
+ env = GridWorld(height=4, width=4)
443
+ agent = QLearningAgent(
444
+ n_states=env.n_states,
445
+ n_actions=env.n_actions,
446
+ learning_rate=0.1,
447
+ discount_factor=0.9,
448
+ exploration_rate=1.0,
449
+ exploration_decay=0.995
450
+ )
451
+
452
+ # Create initial visualizations
453
+ init_env_fig = create_gridworld_figure(env, agent)
454
+ init_metrics_fig = create_metrics_figure(agent)
455
+
456
+ with gr.Blocks(title="Q-Learning GridWorld Simulator") as demo:
457
+ gr.Markdown("# Q-Learning GridWorld Simulator")
458
+
459
+ with gr.Tab("Environment Setup"):
460
+ with gr.Row():
461
+ with gr.Column():
462
+ grid_height = gr.Slider(minimum=3, maximum=8, value=4, step=1, label="Grid Height")
463
+ grid_width = gr.Slider(minimum=3, maximum=8, value=4, step=1, label="Grid Width")
464
+ setup_btn = gr.Button("Setup Environment")
465
+
466
+ env_display = gr.Plot(value=init_env_fig, label="Environment")
467
+
468
+ with gr.Row():
469
+ setup_info = gr.Textbox(label="Environment Info", value="4x4 GridWorld with start at (0,0) and goal at (3,3)")
470
+
471
+ with gr.Tab("Train Agent"):
472
+ with gr.Row():
473
+ with gr.Column():
474
+ learning_rate = gr.Slider(minimum=0.01, maximum=1.0, value=0.1, step=0.01, label="Learning Rate (α)")
475
+ discount_factor = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.01, label="Discount Factor (γ)")
476
+ exploration_rate = gr.Slider(minimum=0.1, maximum=1.0, value=1.0, step=0.01, label="Initial Exploration Rate (ε)")
477
+ exploration_decay = gr.Slider(minimum=0.9, maximum=0.999, value=0.995, step=0.001, label="Exploration Decay Rate")
478
+ episodes = gr.Slider(minimum=10, maximum=500, value=100, step=10, label="Number of Episodes")
479
+ train_btn = gr.Button("Train Agent")
480
+
481
+ with gr.Row():
482
+ train_env_display = gr.Plot(label="Training Environment")
483
+ train_metrics_display = gr.Plot(label="Training Metrics")
484
+
485
+ train_log = gr.Textbox(label="Training Log", lines=10)
486
+
487
+ with gr.Tab("Test Agent"):
488
+ with gr.Row():
489
+ test_btn = gr.Button("Test Trained Agent")
490
+
491
+ with gr.Row():
492
+ test_env_display = gr.Plot(label="Test Environment")
493
+
494
+ with gr.Row():
495
+ with gr.Column():
496
+ path_display = gr.Textbox(label="Path Taken", lines=10)
497
+ test_result = gr.Textbox(label="Test Result")
498
+
499
+ # Setup environment callback
500
+ def setup_environment(height, width):
501
+ nonlocal env, agent
502
+ env = GridWorld(height=int(height), width=int(width))
503
+ agent = QLearningAgent(
504
+ n_states=env.n_states,
505
+ n_actions=env.n_actions,
506
+ learning_rate=0.1,
507
+ discount_factor=0.9,
508
+ exploration_rate=1.0,
509
+ exploration_decay=0.995
510
+ )
511
+ env_fig = create_gridworld_figure(env, agent)
512
+ info_text = f"{height}x{width} GridWorld with start at (0,0) and goal at ({height-1},{width-1})"
513
+ if env.obstacles:
514
+ info_text += f"\nObstacles at: {env.obstacles}"
515
+ return env_fig, info_text
516
+
517
+ setup_btn.click(
518
+ setup_environment,
519
+ inputs=[grid_height, grid_width],
520
+ outputs=[env_display, setup_info]
521
+ )
522
+
523
+ # Train agent callback
524
+ def start_training(lr, df, er, ed, eps):
525
+ nonlocal env, agent
526
+ agent = QLearningAgent(
527
+ n_states=env.n_states,
528
+ n_actions=env.n_actions,
529
+ learning_rate=float(lr),
530
+ discount_factor=float(df),
531
+ exploration_rate=float(er),
532
+ exploration_decay=float(ed)
533
+ )
534
+ env_fig, metrics_fig, log = train_agent(env, agent, int(eps))
535
+ return env_fig, metrics_fig, log
536
+
537
+ train_btn.click(
538
+ start_training,
539
+ inputs=[learning_rate, discount_factor, exploration_rate, exploration_decay, episodes],
540
+ outputs=[train_env_display, train_metrics_display, train_log]
541
+ )
542
+
543
+ # Test agent callback
544
+ def test_trained_agent():
545
+ nonlocal env, agent
546
+ env_fig, path_text, result = run_test_episode(env, agent)
547
+ return env_fig, path_text, result
548
+
549
+ test_btn.click(
550
+ test_trained_agent,
551
+ inputs=[],
552
+ outputs=[test_env_display, path_display, test_result]
553
+ )
554
+
555
+ return demo
556
+
557
+ if __name__ == "__main__":
558
+ # Install required packages
559
+ # !pip install gradio matplotlib numpy
560
+
561
+ # Create and launch the UI
562
+ demo = create_ui()
563
+ demo.launch(share=True)
564
+