|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import gradio as gr |
|
import time |
|
from matplotlib.colors import ListedColormap |
|
import matplotlib.patches as patches |
|
|
|
class GridWorld: |
|
"""A simple grid world environment with obstacles.""" |
|
|
|
def __init__(self, height=4, width=4): |
|
|
|
self.height = height |
|
self.width = width |
|
|
|
|
|
self.n_states = self.height * self.width |
|
|
|
|
|
self.n_actions = 4 |
|
self.action_names = ['Up', 'Right', 'Down', 'Left'] |
|
|
|
|
|
self.rewards = np.zeros((self.height, self.width)) |
|
|
|
self.rewards[self.height-1, self.width-1] = 1.0 |
|
|
|
self.obstacles = [] |
|
if height >= 4 and width >= 4: |
|
self.rewards[1, 1] = -1.0 |
|
self.rewards[1, 2] = -1.0 |
|
self.rewards[2, 1] = -1.0 |
|
self.obstacles = [(1, 1), (1, 2), (2, 1)] |
|
|
|
|
|
self.start_state = (0, 0) |
|
|
|
|
|
self.goal_state = (self.height-1, self.width-1) |
|
|
|
|
|
self.reset() |
|
|
|
def reset(self): |
|
"""Reset the agent to the start state.""" |
|
self.agent_position = self.start_state |
|
return self._get_state() |
|
|
|
def _get_state(self): |
|
"""Convert the agent's (row, col) position to a state number.""" |
|
row, col = self.agent_position |
|
return row * self.width + col |
|
|
|
def _get_pos_from_state(self, state): |
|
"""Convert a state number to (row, col) position.""" |
|
row = state // self.width |
|
col = state % self.width |
|
return (row, col) |
|
|
|
def step(self, action): |
|
"""Take an action and return next_state, reward, done.""" |
|
row, col = self.agent_position |
|
|
|
|
|
if action == 0: |
|
row = max(0, row - 1) |
|
elif action == 1: |
|
col = min(self.width - 1, col + 1) |
|
elif action == 2: |
|
row = min(self.height - 1, row + 1) |
|
elif action == 3: |
|
col = max(0, col - 1) |
|
|
|
|
|
self.agent_position = (row, col) |
|
|
|
|
|
reward = self.rewards[row, col] |
|
|
|
|
|
done = (row, col) == self.goal_state |
|
|
|
return self._get_state(), reward, done |
|
|
|
class QLearningAgent: |
|
"""A simple Q-learning agent.""" |
|
|
|
def __init__(self, n_states, n_actions, learning_rate=0.1, discount_factor=0.9, exploration_rate=1.0, exploration_decay=0.995): |
|
"""Initialize the Q-learning agent.""" |
|
self.n_states = n_states |
|
self.n_actions = n_actions |
|
self.learning_rate = learning_rate |
|
self.discount_factor = discount_factor |
|
self.exploration_rate = exploration_rate |
|
self.exploration_decay = exploration_decay |
|
|
|
|
|
self.q_table = np.zeros((n_states, n_actions)) |
|
|
|
|
|
self.visit_counts = np.zeros(n_states) |
|
|
|
|
|
self.rewards_history = [] |
|
self.exploration_rates = [] |
|
|
|
def select_action(self, state): |
|
"""Select an action using epsilon-greedy policy.""" |
|
if np.random.random() < self.exploration_rate: |
|
|
|
return np.random.randint(self.n_actions) |
|
else: |
|
|
|
return np.argmax(self.q_table[state]) |
|
|
|
def update(self, state, action, reward, next_state, done): |
|
"""Update the Q-table using the Q-learning update rule.""" |
|
|
|
if done: |
|
q_target = reward |
|
else: |
|
q_target = reward + self.discount_factor * np.max(self.q_table[next_state]) |
|
|
|
|
|
self.q_table[state, action] += self.learning_rate * (q_target - self.q_table[state, action]) |
|
|
|
|
|
self.visit_counts[state] += 1 |
|
|
|
def decay_exploration(self): |
|
"""Decay the exploration rate.""" |
|
self.exploration_rate *= self.exploration_decay |
|
self.exploration_rates.append(self.exploration_rate) |
|
|
|
def get_policy(self): |
|
"""Return the current greedy policy.""" |
|
return np.argmax(self.q_table, axis=1) |
|
|
|
def reset(self): |
|
"""Reset the agent for a new training session.""" |
|
self.q_table = np.zeros((self.n_states, self.n_actions)) |
|
self.visit_counts = np.zeros(self.n_states) |
|
self.rewards_history = [] |
|
self.exploration_rates = [] |
|
|
|
|
|
def create_gridworld_figure(env, agent, episode_count=0, total_reward=0): |
|
"""Create a figure with environment, visit heatmap, and Q-values.""" |
|
fig, axes = plt.subplots(1, 3, figsize=(15, 5)) |
|
fig.suptitle(f"Episode: {episode_count}, Total Reward: {total_reward:.2f}, Exploration Rate: {agent.exploration_rate:.2f}") |
|
|
|
|
|
colors = { |
|
'empty': 'white', |
|
'obstacle': 'black', |
|
'goal': 'green', |
|
'start': 'blue', |
|
'agent': 'red' |
|
} |
|
|
|
|
|
def draw_grid(ax): |
|
|
|
for i in range(env.height + 1): |
|
ax.axhline(i, color='black', lw=1) |
|
for j in range(env.width + 1): |
|
ax.axvline(j, color='black', lw=1) |
|
|
|
|
|
ax.set_xlim(0, env.width) |
|
ax.set_ylim(0, env.height) |
|
ax.invert_yaxis() |
|
ax.set_xticks(np.arange(0.5, env.width, 1)) |
|
ax.set_yticks(np.arange(0.5, env.height, 1)) |
|
ax.set_xticklabels(range(env.width)) |
|
ax.set_yticklabels(range(env.height)) |
|
|
|
|
|
def draw_cell(ax, row, col, cell_type): |
|
color = colors.get(cell_type, 'white') |
|
rect = patches.Rectangle((col, row), 1, 1, linewidth=1, edgecolor='black', facecolor=color, alpha=0.7) |
|
ax.add_patch(rect) |
|
|
|
|
|
def draw_arrow(ax, row, col, action): |
|
|
|
arrow_starts = { |
|
0: (col + 0.5, row + 0.7), |
|
1: (col + 0.3, row + 0.5), |
|
2: (col + 0.5, row + 0.3), |
|
3: (col + 0.7, row + 0.5) |
|
} |
|
|
|
arrow_ends = { |
|
0: (col + 0.5, row + 0.3), |
|
1: (col + 0.7, row + 0.5), |
|
2: (col + 0.5, row + 0.7), |
|
3: (col + 0.3, row + 0.5) |
|
} |
|
|
|
ax.annotate('', xy=arrow_ends[action], xytext=arrow_starts[action], |
|
arrowprops=dict(arrowstyle='->', lw=2, color='blue')) |
|
|
|
|
|
ax = axes[0] |
|
ax.set_title('GridWorld Environment') |
|
draw_grid(ax) |
|
|
|
|
|
for i in range(env.height): |
|
for j in range(env.width): |
|
if (i, j) in env.obstacles: |
|
draw_cell(ax, i, j, 'obstacle') |
|
elif (i, j) == env.goal_state: |
|
draw_cell(ax, i, j, 'goal') |
|
elif (i, j) == env.start_state: |
|
draw_cell(ax, i, j, 'start') |
|
|
|
|
|
row, col = env.agent_position |
|
draw_cell(ax, row, col, 'agent') |
|
|
|
|
|
policy = agent.get_policy() |
|
for state in range(env.n_states): |
|
row, col = env._get_pos_from_state(state) |
|
if (row, col) not in env.obstacles and (row, col) != env.goal_state: |
|
draw_arrow(ax, row, col, policy[state]) |
|
|
|
|
|
ax.set_aspect('equal') |
|
|
|
|
|
ax = axes[1] |
|
ax.set_title('State Visitation Heatmap') |
|
draw_grid(ax) |
|
|
|
|
|
heatmap_data = np.zeros((env.height, env.width)) |
|
for state in range(env.n_states): |
|
row, col = env._get_pos_from_state(state) |
|
heatmap_data[row, col] = agent.visit_counts[state] |
|
|
|
|
|
max_visits = max(1, np.max(heatmap_data)) |
|
|
|
|
|
for i in range(env.height): |
|
for j in range(env.width): |
|
if (i, j) in env.obstacles: |
|
draw_cell(ax, i, j, 'obstacle') |
|
elif (i, j) == env.goal_state: |
|
draw_cell(ax, i, j, 'goal') |
|
else: |
|
intensity = heatmap_data[i, j] / max_visits |
|
color = plt.cm.viridis(intensity) |
|
rect = patches.Rectangle((j, i), 1, 1, linewidth=1, edgecolor='black', facecolor=color, alpha=0.7) |
|
ax.add_patch(rect) |
|
|
|
if heatmap_data[i, j] > 0: |
|
ax.text(j + 0.5, i + 0.5, int(heatmap_data[i, j]), ha='center', va='center', color='white' if intensity > 0.5 else 'black') |
|
|
|
|
|
ax.set_aspect('equal') |
|
|
|
|
|
ax = axes[2] |
|
ax.set_title('Q-Values') |
|
draw_grid(ax) |
|
|
|
|
|
for state in range(env.n_states): |
|
row, col = env._get_pos_from_state(state) |
|
|
|
if (row, col) in env.obstacles: |
|
draw_cell(ax, row, col, 'obstacle') |
|
continue |
|
|
|
if (row, col) == env.goal_state: |
|
draw_cell(ax, row, col, 'goal') |
|
continue |
|
|
|
|
|
q_values = agent.q_table[state] |
|
|
|
|
|
for action in range(env.n_actions): |
|
q_value = q_values[action] |
|
|
|
|
|
if q_value > 0: |
|
|
|
max_q = max(0.1, np.max(q_values)) |
|
arrow_size = 0.3 * (q_value / max_q) |
|
|
|
|
|
center_x = col + 0.5 |
|
center_y = row + 0.5 |
|
|
|
|
|
directions = [ |
|
(0, -arrow_size), |
|
(arrow_size, 0), |
|
(0, arrow_size), |
|
(-arrow_size, 0) |
|
] |
|
|
|
dx, dy = directions[action] |
|
|
|
|
|
ax.arrow(center_x, center_y, dx, dy, head_width=0.1, head_length=0.1, |
|
fc='blue', ec='blue', alpha=0.7) |
|
|
|
|
|
text_positions = [ |
|
(center_x, center_y - 0.25), |
|
(center_x + 0.25, center_y), |
|
(center_x, center_y + 0.25), |
|
(center_x - 0.25, center_y) |
|
] |
|
|
|
tx, ty = text_positions[action] |
|
ax.text(tx, ty, f"{q_value:.2f}", ha='center', va='center', fontsize=8, |
|
bbox=dict(facecolor='white', alpha=0.7, boxstyle='round,pad=0.1')) |
|
|
|
|
|
ax.set_aspect('equal') |
|
|
|
plt.tight_layout() |
|
return fig |
|
|
|
def create_metrics_figure(agent): |
|
"""Create a figure with training metrics.""" |
|
fig, axes = plt.subplots(1, 2, figsize=(12, 4)) |
|
|
|
|
|
if agent.rewards_history: |
|
axes[0].plot(agent.rewards_history) |
|
axes[0].set_title('Rewards per Episode') |
|
axes[0].set_xlabel('Episode') |
|
axes[0].set_ylabel('Total Reward') |
|
axes[0].grid(True) |
|
else: |
|
axes[0].set_title('No reward data yet') |
|
|
|
|
|
if agent.exploration_rates: |
|
axes[1].plot(agent.exploration_rates) |
|
axes[1].set_title('Exploration Rate Decay') |
|
axes[1].set_xlabel('Episode') |
|
axes[1].set_ylabel('Exploration Rate (ε)') |
|
axes[1].grid(True) |
|
else: |
|
axes[1].set_title('No exploration rate data yet') |
|
|
|
plt.tight_layout() |
|
return fig |
|
|
|
def train_single_episode(env, agent): |
|
"""Train for a single episode and return the total reward.""" |
|
state = env.reset() |
|
total_reward = 0 |
|
done = False |
|
steps = 0 |
|
max_steps = env.width * env.height * 3 |
|
|
|
while not done and steps < max_steps: |
|
|
|
action = agent.select_action(state) |
|
|
|
|
|
next_state, reward, done = env.step(action) |
|
|
|
|
|
agent.update(state, action, reward, next_state, done) |
|
|
|
|
|
state = next_state |
|
total_reward += reward |
|
steps += 1 |
|
|
|
|
|
agent.decay_exploration() |
|
|
|
|
|
agent.rewards_history.append(total_reward) |
|
|
|
return total_reward |
|
|
|
def train_agent(env, agent, episodes, progress=gr.Progress()): |
|
"""Train the agent for a specified number of episodes.""" |
|
progress_text = "" |
|
progress(0, desc="Starting training...") |
|
|
|
for episode in progress.tqdm(range(episodes)): |
|
total_reward = train_single_episode(env, agent) |
|
|
|
if (episode + 1) % 10 == 0 or episode == episodes - 1: |
|
progress_text += f"Episode {episode + 1}/{episodes}, Reward: {total_reward}, Exploration: {agent.exploration_rate:.3f}\n" |
|
|
|
|
|
env_fig = create_gridworld_figure(env, agent, episode_count=episodes, total_reward=total_reward) |
|
metrics_fig = create_metrics_figure(agent) |
|
|
|
return env_fig, metrics_fig, progress_text |
|
|
|
def run_test_episode(env, agent): |
|
"""Run a test episode using the learned policy.""" |
|
state = env.reset() |
|
total_reward = 0 |
|
done = False |
|
path = [env._get_pos_from_state(state)] |
|
steps = 0 |
|
max_steps = env.width * env.height * 3 |
|
|
|
while not done and steps < max_steps: |
|
|
|
action = np.argmax(agent.q_table[state]) |
|
|
|
|
|
next_state, reward, done = env.step(action) |
|
|
|
|
|
state = next_state |
|
total_reward += reward |
|
path.append(env._get_pos_from_state(state)) |
|
steps += 1 |
|
|
|
|
|
env_fig = create_gridworld_figure(env, agent, episode_count="Test", total_reward=total_reward) |
|
|
|
|
|
path_text = "Path taken:\n" |
|
for i, pos in enumerate(path): |
|
path_text += f"Step {i}: {pos}\n" |
|
|
|
return env_fig, path_text, f"Test completed with total reward: {total_reward}" |
|
|
|
def create_ui(): |
|
"""Create the Gradio interface.""" |
|
|
|
env = GridWorld(height=4, width=4) |
|
agent = QLearningAgent( |
|
n_states=env.n_states, |
|
n_actions=env.n_actions, |
|
learning_rate=0.1, |
|
discount_factor=0.9, |
|
exploration_rate=1.0, |
|
exploration_decay=0.995 |
|
) |
|
|
|
|
|
init_env_fig = create_gridworld_figure(env, agent) |
|
init_metrics_fig = create_metrics_figure(agent) |
|
|
|
with gr.Blocks(title="Q-Learning GridWorld Simulator") as demo: |
|
gr.Markdown("# Q-Learning GridWorld Simulator") |
|
|
|
with gr.Tab("Environment Setup"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
grid_height = gr.Slider(minimum=3, maximum=8, value=4, step=1, label="Grid Height") |
|
grid_width = gr.Slider(minimum=3, maximum=8, value=4, step=1, label="Grid Width") |
|
setup_btn = gr.Button("Setup Environment") |
|
|
|
env_display = gr.Plot(value=init_env_fig, label="Environment") |
|
|
|
with gr.Row(): |
|
setup_info = gr.Textbox(label="Environment Info", value="4x4 GridWorld with start at (0,0) and goal at (3,3)") |
|
|
|
with gr.Tab("Train Agent"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
learning_rate = gr.Slider(minimum=0.01, maximum=1.0, value=0.1, step=0.01, label="Learning Rate (α)") |
|
discount_factor = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.01, label="Discount Factor (γ)") |
|
exploration_rate = gr.Slider(minimum=0.1, maximum=1.0, value=1.0, step=0.01, label="Initial Exploration Rate (ε)") |
|
exploration_decay = gr.Slider(minimum=0.9, maximum=0.999, value=0.995, step=0.001, label="Exploration Decay Rate") |
|
episodes = gr.Slider(minimum=10, maximum=500, value=100, step=10, label="Number of Episodes") |
|
train_btn = gr.Button("Train Agent") |
|
|
|
with gr.Row(): |
|
train_env_display = gr.Plot(label="Training Environment") |
|
train_metrics_display = gr.Plot(label="Training Metrics") |
|
|
|
train_log = gr.Textbox(label="Training Log", lines=10) |
|
|
|
with gr.Tab("Test Agent"): |
|
with gr.Row(): |
|
test_btn = gr.Button("Test Trained Agent") |
|
|
|
with gr.Row(): |
|
test_env_display = gr.Plot(label="Test Environment") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
path_display = gr.Textbox(label="Path Taken", lines=10) |
|
test_result = gr.Textbox(label="Test Result") |
|
|
|
|
|
def setup_environment(height, width): |
|
nonlocal env, agent |
|
env = GridWorld(height=int(height), width=int(width)) |
|
agent = QLearningAgent( |
|
n_states=env.n_states, |
|
n_actions=env.n_actions, |
|
learning_rate=0.1, |
|
discount_factor=0.9, |
|
exploration_rate=1.0, |
|
exploration_decay=0.995 |
|
) |
|
env_fig = create_gridworld_figure(env, agent) |
|
info_text = f"{height}x{width} GridWorld with start at (0,0) and goal at ({height-1},{width-1})" |
|
if env.obstacles: |
|
info_text += f"\nObstacles at: {env.obstacles}" |
|
return env_fig, info_text |
|
|
|
setup_btn.click( |
|
setup_environment, |
|
inputs=[grid_height, grid_width], |
|
outputs=[env_display, setup_info] |
|
) |
|
|
|
|
|
def start_training(lr, df, er, ed, eps): |
|
nonlocal env, agent |
|
agent = QLearningAgent( |
|
n_states=env.n_states, |
|
n_actions=env.n_actions, |
|
learning_rate=float(lr), |
|
discount_factor=float(df), |
|
exploration_rate=float(er), |
|
exploration_decay=float(ed) |
|
) |
|
env_fig, metrics_fig, log = train_agent(env, agent, int(eps)) |
|
return env_fig, metrics_fig, log |
|
|
|
train_btn.click( |
|
start_training, |
|
inputs=[learning_rate, discount_factor, exploration_rate, exploration_decay, episodes], |
|
outputs=[train_env_display, train_metrics_display, train_log] |
|
) |
|
|
|
|
|
def test_trained_agent(): |
|
nonlocal env, agent |
|
env_fig, path_text, result = run_test_episode(env, agent) |
|
return env_fig, path_text, result |
|
|
|
test_btn.click( |
|
test_trained_agent, |
|
inputs=[], |
|
outputs=[test_env_display, path_display, test_result] |
|
) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
demo = create_ui() |
|
demo.launch(share=True) |
|
|