NimGPT-3.5 / nim_game_env.py
JLW's picture
Initial commit
1afe246
raw
history blame
3.25 kB
from abc import ABC
import gymnasium as gym
from gymnasium import spaces
import numpy as np
class NimGameEnv(gym.Env, ABC):
"""Custom environment for a simple Nim game.
In this game, there are two players and a number of piles of stones.
Each turn, a player can choose a pile and remove any number of stones from it.
The player who takes the last stone loses.
The observation space is a tuple of integers representing the number of stones in each pile.
The action space is a tuple of two integers, representing the chosen pile and the number of stones to remove.
"""
def __init__(self, starting_stick_piles=[3, 5, 7]):
self.starting_stick_piles = starting_stick_piles
self.num_piles = len(starting_stick_piles)
self.max_stones = max(starting_stick_piles)
self.piles = self._init_piles()
self.current_player = 0
self.action_space = spaces.MultiDiscrete([self.num_piles, self.max_stones + 1])
self.observation_space = spaces.MultiDiscrete([self.max_stones + 1] * self.num_piles)
def step(self, action):
"""Take a step in the environment.
Parameters
----------
action: tuple
The action taken by the player, represented as a tuple of the chosen pile and the number of stones to remove.
Returns
-------
observation: tuple
The current number of stones in each pile.
reward: float
The reward for the current step.
done: bool
Whether the game has ended.
info: dict
Additional information about the step.
"""
# Validate the action
if not self._is_valid_action(action):
raise ValueError("Invalid action")
# Update the piles
pile, num_stones = action
self.piles[pile] -= num_stones
# Determine if the game has ended
done = self._is_game_over()
# Calculate the reward
reward = self._calculate_reward()
# Switch the current player
self.current_player = (self.current_player + 1) % 2
return self.piles, reward, done, {}
def reset(self):
"""Reset the environment to the initial state."""
self.piles = self._init_piles()
self.current_player = 0
text_observation = "The piles contain " + ", ".join(str(x) for x in self.piles) + " sticks."
return text_observation, self.piles
def _init_piles(self):
"""Initialize the stick piles."""
return [3, 5, 7]
def _generate_random_stones(self):
"""Generate a random number of stones (between 1 and max_stones inclusive)."""
return np.random.randint(1, self.max_stones + 1)
def _is_valid_action(self, action):
"""Determine if an action is valid."""
pile, num_stones = action
return 0 <= pile < self.num_piles and 0 < num_stones <= self.max_stones and num_stones <= self.piles[pile]
def _is_game_over(self):
"""Determine if the game has ended."""
return all(pile == 0 for pile in self.piles)
def _calculate_reward(self):
"""Calculate the reward for the current step."""
return 1 if self._is_game_over() else 0