fahmizainal17's picture
Upload folder using huggingface_hub
2c8df23 verified
import numpy as np
import matplotlib.pyplot as plt
import gradio as gr
import time
from matplotlib.colors import ListedColormap
import matplotlib.patches as patches
class GridWorld:
"""A simple grid world environment with obstacles."""
def __init__(self, height=4, width=4):
# Grid dimensions
self.height = height
self.width = width
# Define states
self.n_states = self.height * self.width
# Actions: 0: up, 1: right, 2: down, 3: left
self.n_actions = 4
self.action_names = ['Up', 'Right', 'Down', 'Left']
# Define rewards
self.rewards = np.zeros((self.height, self.width))
# Goal state
self.rewards[self.height-1, self.width-1] = 1.0
# Obstacles (negative reward)
self.obstacles = []
if height >= 4 and width >= 4:
self.rewards[1, 1] = -1.0
self.rewards[1, 2] = -1.0
self.rewards[2, 1] = -1.0
self.obstacles = [(1, 1), (1, 2), (2, 1)]
# Start state
self.start_state = (0, 0)
# Goal state
self.goal_state = (self.height-1, self.width-1)
# Reset the environment
self.reset()
def reset(self):
"""Reset the agent to the start state."""
self.agent_position = self.start_state
return self._get_state()
def _get_state(self):
"""Convert the agent's (row, col) position to a state number."""
row, col = self.agent_position
return row * self.width + col
def _get_pos_from_state(self, state):
"""Convert a state number to (row, col) position."""
row = state // self.width
col = state % self.width
return (row, col)
def step(self, action):
"""Take an action and return next_state, reward, done."""
row, col = self.agent_position
# Apply the action
if action == 0: # up
row = max(0, row - 1)
elif action == 1: # right
col = min(self.width - 1, col + 1)
elif action == 2: # down
row = min(self.height - 1, row + 1)
elif action == 3: # left
col = max(0, col - 1)
# Update agent position
self.agent_position = (row, col)
# Get reward
reward = self.rewards[row, col]
# Check if episode is done
done = (row, col) == self.goal_state
return self._get_state(), reward, done
class QLearningAgent:
"""A simple Q-learning agent."""
def __init__(self, n_states, n_actions, learning_rate=0.1, discount_factor=0.9, exploration_rate=1.0, exploration_decay=0.995):
"""Initialize the Q-learning agent."""
self.n_states = n_states
self.n_actions = n_actions
self.learning_rate = learning_rate
self.discount_factor = discount_factor
self.exploration_rate = exploration_rate
self.exploration_decay = exploration_decay
# Initialize Q-table
self.q_table = np.zeros((n_states, n_actions))
# Track visited states for visualization
self.visit_counts = np.zeros(n_states)
# Training metrics
self.rewards_history = []
self.exploration_rates = []
def select_action(self, state):
"""Select an action using epsilon-greedy policy."""
if np.random.random() < self.exploration_rate:
# Explore: select a random action
return np.random.randint(self.n_actions)
else:
# Exploit: select the action with the highest Q-value
return np.argmax(self.q_table[state])
def update(self, state, action, reward, next_state, done):
"""Update the Q-table using the Q-learning update rule."""
# Calculate the Q-target
if done:
q_target = reward
else:
q_target = reward + self.discount_factor * np.max(self.q_table[next_state])
# Update the Q-value
self.q_table[state, action] += self.learning_rate * (q_target - self.q_table[state, action])
# Update visit count for visualization
self.visit_counts[state] += 1
def decay_exploration(self):
"""Decay the exploration rate."""
self.exploration_rate *= self.exploration_decay
self.exploration_rates.append(self.exploration_rate)
def get_policy(self):
"""Return the current greedy policy."""
return np.argmax(self.q_table, axis=1)
def reset(self):
"""Reset the agent for a new training session."""
self.q_table = np.zeros((self.n_states, self.n_actions))
self.visit_counts = np.zeros(self.n_states)
self.rewards_history = []
self.exploration_rates = []
def create_gridworld_figure(env, agent, episode_count=0, total_reward=0):
"""Create a figure with environment, visit heatmap, and Q-values."""
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
fig.suptitle(f"Episode: {episode_count}, Total Reward: {total_reward:.2f}, Exploration Rate: {agent.exploration_rate:.2f}")
# Define colors for different cell types
colors = {
'empty': 'white',
'obstacle': 'black',
'goal': 'green',
'start': 'blue',
'agent': 'red'
}
# Helper function to draw grid
def draw_grid(ax):
# Create a grid
for i in range(env.height + 1):
ax.axhline(i, color='black', lw=1)
for j in range(env.width + 1):
ax.axvline(j, color='black', lw=1)
# Set limits and remove ticks
ax.set_xlim(0, env.width)
ax.set_ylim(0, env.height)
ax.invert_yaxis() # Invert y-axis to match grid coordinates
ax.set_xticks(np.arange(0.5, env.width, 1))
ax.set_yticks(np.arange(0.5, env.height, 1))
ax.set_xticklabels(range(env.width))
ax.set_yticklabels(range(env.height))
# Helper function to draw a cell
def draw_cell(ax, row, col, cell_type):
color = colors.get(cell_type, 'white')
rect = patches.Rectangle((col, row), 1, 1, linewidth=1, edgecolor='black', facecolor=color, alpha=0.7)
ax.add_patch(rect)
# Helper function to draw an arrow
def draw_arrow(ax, row, col, action):
# Coordinates for arrows
arrow_starts = {
0: (col + 0.5, row + 0.7), # up
1: (col + 0.3, row + 0.5), # right
2: (col + 0.5, row + 0.3), # down
3: (col + 0.7, row + 0.5) # left
}
arrow_ends = {
0: (col + 0.5, row + 0.3), # up
1: (col + 0.7, row + 0.5), # right
2: (col + 0.5, row + 0.7), # down
3: (col + 0.3, row + 0.5) # left
}
ax.annotate('', xy=arrow_ends[action], xytext=arrow_starts[action],
arrowprops=dict(arrowstyle='->', lw=2, color='blue'))
# Draw Environment
ax = axes[0]
ax.set_title('GridWorld Environment')
draw_grid(ax)
# Draw cells
for i in range(env.height):
for j in range(env.width):
if (i, j) in env.obstacles:
draw_cell(ax, i, j, 'obstacle')
elif (i, j) == env.goal_state:
draw_cell(ax, i, j, 'goal')
elif (i, j) == env.start_state:
draw_cell(ax, i, j, 'start')
# Draw agent
row, col = env.agent_position
draw_cell(ax, row, col, 'agent')
# Draw policy arrows
policy = agent.get_policy()
for state in range(env.n_states):
row, col = env._get_pos_from_state(state)
if (row, col) not in env.obstacles and (row, col) != env.goal_state:
draw_arrow(ax, row, col, policy[state])
# Ensure proper aspect ratio
ax.set_aspect('equal')
# Draw Visit Heatmap
ax = axes[1]
ax.set_title('State Visitation Heatmap')
draw_grid(ax)
# Create heatmap data
heatmap_data = np.zeros((env.height, env.width))
for state in range(env.n_states):
row, col = env._get_pos_from_state(state)
heatmap_data[row, col] = agent.visit_counts[state]
# Normalize values for coloring
max_visits = max(1, np.max(heatmap_data))
# Draw heatmap
for i in range(env.height):
for j in range(env.width):
if (i, j) in env.obstacles:
draw_cell(ax, i, j, 'obstacle')
elif (i, j) == env.goal_state:
draw_cell(ax, i, j, 'goal')
else:
intensity = heatmap_data[i, j] / max_visits
color = plt.cm.viridis(intensity)
rect = patches.Rectangle((j, i), 1, 1, linewidth=1, edgecolor='black', facecolor=color, alpha=0.7)
ax.add_patch(rect)
# Add visit count text
if heatmap_data[i, j] > 0:
ax.text(j + 0.5, i + 0.5, int(heatmap_data[i, j]), ha='center', va='center', color='white' if intensity > 0.5 else 'black')
# Ensure proper aspect ratio
ax.set_aspect('equal')
# Draw Q-values
ax = axes[2]
ax.set_title('Q-Values')
draw_grid(ax)
# Draw Q-values for each cell
for state in range(env.n_states):
row, col = env._get_pos_from_state(state)
if (row, col) in env.obstacles:
draw_cell(ax, row, col, 'obstacle')
continue
if (row, col) == env.goal_state:
draw_cell(ax, row, col, 'goal')
continue
# Calculate q-values for each action
q_values = agent.q_table[state]
# Draw arrows proportional to Q-values
for action in range(env.n_actions):
q_value = q_values[action]
# Only draw arrows for positive Q-values
if q_value > 0:
# Normalize arrow size
max_q = max(0.1, np.max(q_values))
arrow_size = 0.3 * (q_value / max_q)
# Position calculations
center_x = col + 0.5
center_y = row + 0.5
# Direction vectors
directions = [
(0, -arrow_size), # up
(arrow_size, 0), # right
(0, arrow_size), # down
(-arrow_size, 0) # left
]
dx, dy = directions[action]
# Draw arrow
ax.arrow(center_x, center_y, dx, dy, head_width=0.1, head_length=0.1,
fc='blue', ec='blue', alpha=0.7)
# Add Q-value text
text_positions = [
(center_x, center_y - 0.25), # up
(center_x + 0.25, center_y), # right
(center_x, center_y + 0.25), # down
(center_x - 0.25, center_y) # left
]
tx, ty = text_positions[action]
ax.text(tx, ty, f"{q_value:.2f}", ha='center', va='center', fontsize=8,
bbox=dict(facecolor='white', alpha=0.7, boxstyle='round,pad=0.1'))
# Ensure proper aspect ratio
ax.set_aspect('equal')
plt.tight_layout()
return fig
def create_metrics_figure(agent):
"""Create a figure with training metrics."""
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
# Plot rewards
if agent.rewards_history:
axes[0].plot(agent.rewards_history)
axes[0].set_title('Rewards per Episode')
axes[0].set_xlabel('Episode')
axes[0].set_ylabel('Total Reward')
axes[0].grid(True)
else:
axes[0].set_title('No reward data yet')
# Plot exploration rate
if agent.exploration_rates:
axes[1].plot(agent.exploration_rates)
axes[1].set_title('Exploration Rate Decay')
axes[1].set_xlabel('Episode')
axes[1].set_ylabel('Exploration Rate (ε)')
axes[1].grid(True)
else:
axes[1].set_title('No exploration rate data yet')
plt.tight_layout()
return fig
def train_single_episode(env, agent):
"""Train for a single episode and return the total reward."""
state = env.reset()
total_reward = 0
done = False
steps = 0
max_steps = env.width * env.height * 3 # Prevent infinite loops
while not done and steps < max_steps:
# Select action
action = agent.select_action(state)
# Take the action
next_state, reward, done = env.step(action)
# Update the Q-table
agent.update(state, action, reward, next_state, done)
# Update state and total reward
state = next_state
total_reward += reward
steps += 1
# Decay exploration rate
agent.decay_exploration()
# Store the total reward
agent.rewards_history.append(total_reward)
return total_reward
def train_agent(env, agent, episodes, progress=gr.Progress()):
"""Train the agent for a specified number of episodes."""
progress_text = ""
progress(0, desc="Starting training...")
for episode in progress.tqdm(range(episodes)):
total_reward = train_single_episode(env, agent)
if (episode + 1) % 10 == 0 or episode == episodes - 1:
progress_text += f"Episode {episode + 1}/{episodes}, Reward: {total_reward}, Exploration: {agent.exploration_rate:.3f}\n"
# Create final visualization
env_fig = create_gridworld_figure(env, agent, episode_count=episodes, total_reward=total_reward)
metrics_fig = create_metrics_figure(agent)
return env_fig, metrics_fig, progress_text
def run_test_episode(env, agent):
"""Run a test episode using the learned policy."""
state = env.reset()
total_reward = 0
done = False
path = [env._get_pos_from_state(state)]
steps = 0
max_steps = env.width * env.height * 3 # Prevent infinite loops
while not done and steps < max_steps:
# Select the best action from the learned policy
action = np.argmax(agent.q_table[state])
# Take the action
next_state, reward, done = env.step(action)
# Update state and total reward
state = next_state
total_reward += reward
path.append(env._get_pos_from_state(state))
steps += 1
# Create visualization
env_fig = create_gridworld_figure(env, agent, episode_count="Test", total_reward=total_reward)
# Format path for display
path_text = "Path taken:\n"
for i, pos in enumerate(path):
path_text += f"Step {i}: {pos}\n"
return env_fig, path_text, f"Test completed with total reward: {total_reward}"
def create_ui():
"""Create the Gradio interface."""
# Create environment and agent
env = GridWorld(height=4, width=4)
agent = QLearningAgent(
n_states=env.n_states,
n_actions=env.n_actions,
learning_rate=0.1,
discount_factor=0.9,
exploration_rate=1.0,
exploration_decay=0.995
)
# Create initial visualizations
init_env_fig = create_gridworld_figure(env, agent)
init_metrics_fig = create_metrics_figure(agent)
with gr.Blocks(title="Q-Learning GridWorld Simulator") as demo:
gr.Markdown("# Q-Learning GridWorld Simulator")
with gr.Tab("Environment Setup"):
with gr.Row():
with gr.Column():
grid_height = gr.Slider(minimum=3, maximum=8, value=4, step=1, label="Grid Height")
grid_width = gr.Slider(minimum=3, maximum=8, value=4, step=1, label="Grid Width")
setup_btn = gr.Button("Setup Environment")
env_display = gr.Plot(value=init_env_fig, label="Environment")
with gr.Row():
setup_info = gr.Textbox(label="Environment Info", value="4x4 GridWorld with start at (0,0) and goal at (3,3)")
with gr.Tab("Train Agent"):
with gr.Row():
with gr.Column():
learning_rate = gr.Slider(minimum=0.01, maximum=1.0, value=0.1, step=0.01, label="Learning Rate (α)")
discount_factor = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.01, label="Discount Factor (γ)")
exploration_rate = gr.Slider(minimum=0.1, maximum=1.0, value=1.0, step=0.01, label="Initial Exploration Rate (ε)")
exploration_decay = gr.Slider(minimum=0.9, maximum=0.999, value=0.995, step=0.001, label="Exploration Decay Rate")
episodes = gr.Slider(minimum=10, maximum=500, value=100, step=10, label="Number of Episodes")
train_btn = gr.Button("Train Agent")
with gr.Row():
train_env_display = gr.Plot(label="Training Environment")
train_metrics_display = gr.Plot(label="Training Metrics")
train_log = gr.Textbox(label="Training Log", lines=10)
with gr.Tab("Test Agent"):
with gr.Row():
test_btn = gr.Button("Test Trained Agent")
with gr.Row():
test_env_display = gr.Plot(label="Test Environment")
with gr.Row():
with gr.Column():
path_display = gr.Textbox(label="Path Taken", lines=10)
test_result = gr.Textbox(label="Test Result")
# Setup environment callback
def setup_environment(height, width):
nonlocal env, agent
env = GridWorld(height=int(height), width=int(width))
agent = QLearningAgent(
n_states=env.n_states,
n_actions=env.n_actions,
learning_rate=0.1,
discount_factor=0.9,
exploration_rate=1.0,
exploration_decay=0.995
)
env_fig = create_gridworld_figure(env, agent)
info_text = f"{height}x{width} GridWorld with start at (0,0) and goal at ({height-1},{width-1})"
if env.obstacles:
info_text += f"\nObstacles at: {env.obstacles}"
return env_fig, info_text
setup_btn.click(
setup_environment,
inputs=[grid_height, grid_width],
outputs=[env_display, setup_info]
)
# Train agent callback
def start_training(lr, df, er, ed, eps):
nonlocal env, agent
agent = QLearningAgent(
n_states=env.n_states,
n_actions=env.n_actions,
learning_rate=float(lr),
discount_factor=float(df),
exploration_rate=float(er),
exploration_decay=float(ed)
)
env_fig, metrics_fig, log = train_agent(env, agent, int(eps))
return env_fig, metrics_fig, log
train_btn.click(
start_training,
inputs=[learning_rate, discount_factor, exploration_rate, exploration_decay, episodes],
outputs=[train_env_display, train_metrics_display, train_log]
)
# Test agent callback
def test_trained_agent():
nonlocal env, agent
env_fig, path_text, result = run_test_episode(env, agent)
return env_fig, path_text, result
test_btn.click(
test_trained_agent,
inputs=[],
outputs=[test_env_display, path_display, test_result]
)
return demo
if __name__ == "__main__":
# Install required packages
# !pip install gradio matplotlib numpy
# Create and launch the UI
demo = create_ui()
demo.launch(share=True)