Spaces:
Runtime error
Runtime error
from abc import ABC | |
import gymnasium as gym | |
from gymnasium import spaces | |
import numpy as np | |
class NimGameEnv(gym.Env, ABC): | |
"""Custom environment for a simple Nim game. | |
In this game, there are two players and a number of piles of stones. | |
Each turn, a player can choose a pile and remove any number of stones from it. | |
The player who takes the last stone loses. | |
The observation space is a tuple of integers representing the number of stones in each pile. | |
The action space is a tuple of two integers, representing the chosen pile and the number of stones to remove. | |
""" | |
def __init__(self, starting_stick_piles=[3, 5, 7]): | |
self.starting_stick_piles = starting_stick_piles | |
self.num_piles = len(starting_stick_piles) | |
self.max_stones = max(starting_stick_piles) | |
self.piles = self._init_piles() | |
self.current_player = 0 | |
self.action_space = spaces.MultiDiscrete([self.num_piles, self.max_stones + 1]) | |
self.observation_space = spaces.MultiDiscrete([self.max_stones + 1] * self.num_piles) | |
def step(self, action): | |
"""Take a step in the environment. | |
Parameters | |
---------- | |
action: tuple | |
The action taken by the player, represented as a tuple of the chosen pile and the number of stones to remove. | |
Returns | |
------- | |
observation: tuple | |
The current number of stones in each pile. | |
reward: float | |
The reward for the current step. | |
done: bool | |
Whether the game has ended. | |
info: dict | |
Additional information about the step. | |
""" | |
# Validate the action | |
if not self._is_valid_action(action): | |
raise ValueError("Invalid action") | |
# Update the piles | |
pile, num_stones = action | |
self.piles[pile] -= num_stones | |
# Determine if the game has ended | |
done = self._is_game_over() | |
# Calculate the reward | |
reward = self._calculate_reward() | |
# Switch the current player | |
self.current_player = (self.current_player + 1) % 2 | |
return self.piles, reward, done, {} | |
def reset(self): | |
"""Reset the environment to the initial state.""" | |
self.piles = self._init_piles() | |
self.current_player = 0 | |
text_observation = "The piles contain " + ", ".join(str(x) for x in self.piles) + " sticks." | |
return text_observation, self.piles | |
def _init_piles(self): | |
"""Initialize the stick piles.""" | |
return [3, 5, 7] | |
def _generate_random_stones(self): | |
"""Generate a random number of stones (between 1 and max_stones inclusive).""" | |
return np.random.randint(1, self.max_stones + 1) | |
def _is_valid_action(self, action): | |
"""Determine if an action is valid.""" | |
pile, num_stones = action | |
return 0 <= pile < self.num_piles and 0 < num_stones <= self.max_stones and num_stones <= self.piles[pile] | |
def _is_game_over(self): | |
"""Determine if the game has ended.""" | |
return all(pile == 0 for pile in self.piles) | |
def _calculate_reward(self): | |
"""Calculate the reward for the current step.""" | |
return 1 if self._is_game_over() else 0 | |