Train a Terrible Tic-Tac-Toe AI

Community Article Published June 18, 2024

This project demonstrates how to build and train a neural network to play Tic-Tac-Toe using PyTorch. The model learns optimal moves from a dataset of all possible game states and their corresponding best moves.

0. Set the device

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Training on {device}.")

1. Board Representation and Conversion

We represent the Tic-Tac-Toe board as a 3x3 list of lists, where each cell can be 'x', 'o', or None. To feed this into our neural network, we convert it into a tensor:

def board_to_tensor(board):
    mapping = {'x': 1, 'o': -1, None: 0}
    return torch.tensor([[mapping[cell] for cell in row] for row in board], dtype=torch.float32).flatten()

This function maps 'x' to 1, 'o' to -1, and empty cells to 0, then flattens the board into a 1D tensor.

2. Dataset Creation

We create a custom PyTorch Dataset to hold our game states and their corresponding best moves:

class TicTacToeDataset(Dataset):
    def __init__(self, boards, moves):
        self.boards = boards
        self.moves = moves

    def __len__(self):
        return len(self.boards)

    def __getitem__(self, idx):
        board = self.boards[idx]
        move = self.moves[idx]
        return board, move

3. Neural Network Architecture

Our Tic-Tac-Toe neural network is a simple feedforward network:

class TicTacToeNN(nn.Module):
    def __init__(self):
        super(TicTacToeNN, self).__init__()
        self.fc1 = nn.Linear(9, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 9)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return self.softmax(x)

It takes a flattened board (9 inputs) and outputs probabilities for each of the 9 possible moves.

4. Data Generation

We generate all possible valid Tic-Tac-Toe boards and their corresponding best moves using a combination of brute-force and heuristic methods. The find_best_move function implements the game logic to determine the optimal move for any given board state.

possible_items = ["x", "o", None]

all_boards = list(list(tup) for tup in itertools.product(possible_items, repeat=9))

valid_boards = [board for board in all_boards if None in board]

boards = []
for flat_board in valid_boards:
    board = [flat_board[i:i+3] for i in range(0, 9, 3)]
    boards.append(board)

boards[:9]
def find_best_move(board):
    def check_win(player):
        for row in board:
            if all(cell == player for cell in row):
                return True

        for col in range(3):
            if all(board[row][col] == player for row in range(3)):
                return True

        if all(board[i][i] == player for i in range(3)) or \
           all(board[i][2-i] == player for i in range(3)):
            return True
        return False
    
    def count_forks(player):
        forks = 0

        for row in board:
            if row.count(player) == 1 and row.count(None) == 2:
                forks += 1

        for col in range(3):
            if [board[row][col] for row in range(3)].count(player) == 1 and \
               [board[row][col] for row in range(3)].count(None) == 2:
                forks += 1
        
        if board[1][1] == player:
            if (board[0][0] == player and board[2][2] == None) or \
               (board[0][0] == None and board[2][2] == player):
                forks += 1
            if (board[0][2] == player and board[2][0] == None) or \
               (board[0][2] == None and board[2][0] == player):
                forks += 1
        return forks
    
    def board_full():
        return all(cell is not None for row in board for cell in row)
    
    def board_empty():
        return all(cell is None for row in board for cell in row)
    
    def is_valid_move(row, col):
        return 0 <= row < 3 and 0 <= col < 3 and board[row][col] is None
    
    def get_best_move():
        for row in range(3):
            for col in range(3):
                if board[row][col] is None:
                    board[row][col] = 'x'
                    if check_win('x'):
                        board[row][col] = None
                        return (row, col)
                    board[row][col] = None
        
        for row in range(3):
            for col in range(3):
                if board[row][col] is None:
                    board[row][col] = 'o'
                    if check_win('o'):
                        board[row][col] = None
                        return (row, col)
                    board[row][col] = None
        
        for row in range(3):
            for col in range(3):
                if board[row][col] is None:
                    board[row][col] = 'x'
                    if count_forks('x') > 1:
                        board[row][col] = None
                        return (row, col)
                    board[row][col] = None
        
        for row in range(3):
            for col in range(3):
                if board[row][col] is None:
                    board[row][col] = 'o'
                    if count_forks('o') > 1:
                        board[row][col] = None
                        return (row, col)
                    board[row][col] = None
        
        if board[1][1] is None:
            return (1, 1)
        
        for (r, c) in [(0, 0), (0, 2), (2, 0), (2, 2)]:
            if board[r][c] is None:
                return (r, c)
        
        for (r, c) in [(0, 1), (1, 0), (1, 2), (2, 1)]:
            if board[r][c] is None:
                return (r, c)
        
        return None
    
    return get_best_move()

moves = []

for board in boards:
    moves.append(list(find_best_move(board)))

moves[:9]

5. Set up the Dataloader

Make sure to send it all to device

tensor_boards = [board_to_tensor(board).to(device) for board in boards]
tensor_moves = torch.tensor([move[0] * 3 + move[1] for move in moves], device=device)

dataset = TicTacToeDataset(tensor_boards, tensor_moves)
g = torch.Generator(device=device)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, generator=g)

6. Training Loop

We use the Adam optimizer and Cross-Entropy Loss to train our model:

model = TicTacToeNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

epochs = 100
for epoch in range(epochs):
    for boards, moves in dataloader:
        boards = boards.to(device)
        moves = moves.to(device)
        optimizer.zero_grad()
        outputs = model(boards)
        loss = criterion(outputs, moves)
        loss.backward()
        optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")

7. Model Evaluation

After training, we can evaluate our model on a test board:

test_board = [[None, "o", "o"],
              [None, "o", None],
              [None, "x", "x"]]
test_tensor = board_to_tensor(test_board).unsqueeze(0).to(device)
model.eval()
with torch.no_grad():
    prediction = model(test_tensor)
    best_move_index = torch.argmax(prediction).item()
    best_move = [best_move_index // 3, best_move_index % 3]
    print(f"Best move for the test board: {best_move}")

This is a dumb project, and it won't work

There are many problems with this setup, especially since it will often generate illegal moves, because it doesn't check for legal moves, just the next most likely move in general.