RingoDingo
commited on
Upload 6 files
Browse files- 64L1024D_1e-3maxlr_470k_step_1ep_1480ELO.pth +3 -0
- autoplay_muliproc.py +151 -0
- chesstransformer.py +251 -0
- environment.yml +66 -0
- play.py +72 -0
- tokenizer.py +163 -0
64L1024D_1e-3maxlr_470k_step_1ep_1480ELO.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:82fb0554f04255f854344432380ba0719af4e14c631ff8a0c9905a8e99cfbaf2
|
3 |
+
size 9746197380
|
autoplay_muliproc.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import chess
|
3 |
+
import chess.engine
|
4 |
+
import logging
|
5 |
+
import math
|
6 |
+
import argparse
|
7 |
+
import multiprocessing as mp
|
8 |
+
from chesstransformer import ChessTransformer
|
9 |
+
import tokenizer as tk
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
# Set up logging
|
13 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(processName)s - %(levelname)s - %(message)s')
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
parser = argparse.ArgumentParser(description='Chess Transformer Testing')
|
17 |
+
parser.add_argument('--cores', type=int, default=2, help='Cores to use for CPU chess engine')
|
18 |
+
parser.add_argument('--games', type=int, default=10, help='Number of games to play')
|
19 |
+
parser.add_argument('--stockfish_elo', type=int, default=1320, help='ELO rating for Stockfish. Min 1320')
|
20 |
+
parser.add_argument('--stockfish_path', type=str, default='./stockfish/stockfish-ubuntu-x86-64', help='Path to Stockfish binary')
|
21 |
+
|
22 |
+
args = parser.parse_args()
|
23 |
+
|
24 |
+
def setup_model():
|
25 |
+
logger.info("Loading ChessTransformer model...")
|
26 |
+
model = ChessTransformer()
|
27 |
+
model.load_state_dict(torch.load('./64L1024D_1e-3maxlr_470k_step_1ep_1480ELO.pth')["model_state_dict"])
|
28 |
+
model.eval().cuda()
|
29 |
+
logger.info("Model loaded successfully.")
|
30 |
+
return model
|
31 |
+
|
32 |
+
def predict_top_k_moves(model, tokenizer, game_sequence, k=100, device='cuda'):
|
33 |
+
game_sequence = torch.tensor([tokenizer.tokenize_game(game_sequence)], dtype=torch.long).to(device)
|
34 |
+
|
35 |
+
with torch.no_grad():
|
36 |
+
output = model(game_sequence)
|
37 |
+
next_move = output[0, -1, :]
|
38 |
+
next_softmax = torch.nn.functional.softmax(next_move, dim=-1)
|
39 |
+
top_k_probs, top_k_indices = torch.topk(next_softmax, k)
|
40 |
+
top_k_moves = [tokenizer.get_move(idx.item()) for idx in top_k_indices]
|
41 |
+
|
42 |
+
return list(zip(top_k_moves, top_k_probs.tolist()))
|
43 |
+
|
44 |
+
def get_legal_move(board, moves):
|
45 |
+
for move, prob in moves:
|
46 |
+
try:
|
47 |
+
if chess.Move.from_uci(move) in board.legal_moves:
|
48 |
+
return move, prob
|
49 |
+
except ValueError:
|
50 |
+
continue
|
51 |
+
return None, None
|
52 |
+
|
53 |
+
def play_game(model, tokenizer, stockfish_path, stockfish_elo, model_is_white, game_number):
|
54 |
+
#logger.info(f"Game {game_number}: Starting. Model playing as {'white' if model_is_white else 'black'}")
|
55 |
+
engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
|
56 |
+
engine.configure({"UCI_LimitStrength": True, "UCI_Elo": stockfish_elo})
|
57 |
+
|
58 |
+
board = chess.Board()
|
59 |
+
game_sequence = ['start']
|
60 |
+
move_count = 0
|
61 |
+
|
62 |
+
while not board.is_game_over():
|
63 |
+
move_count += 1
|
64 |
+
if (board.turn == chess.WHITE) == model_is_white:
|
65 |
+
top_k_moves = predict_top_k_moves(model, tokenizer, game_sequence)
|
66 |
+
legal_move, prob = get_legal_move(board, top_k_moves)
|
67 |
+
if legal_move is None:
|
68 |
+
logger.warning(f"Game {game_number}: No legal moves found in top-k on move {move_count}. Game over.")
|
69 |
+
return "0-1" if model_is_white else "1-0", move_count
|
70 |
+
board.push_uci(legal_move)
|
71 |
+
game_sequence.append(legal_move)
|
72 |
+
logger.debug(f"Game {game_number}: Model's move: {legal_move} (probability: {prob:.4f})")
|
73 |
+
else:
|
74 |
+
result = engine.play(board, chess.engine.Limit(time=0.1))
|
75 |
+
board.push(result.move)
|
76 |
+
game_sequence.append(result.move.uci())
|
77 |
+
logger.debug(f"Game {game_number}: Stockfish's move: {result.move.uci()}")
|
78 |
+
|
79 |
+
engine.quit()
|
80 |
+
result = board.result()
|
81 |
+
#logger.info(f"Game {game_number}: Finished. Result: {result}. Total moves: {move_count}")
|
82 |
+
return result, move_count
|
83 |
+
|
84 |
+
def worker(args):
|
85 |
+
model, tokenizer, stockfish_path, stockfish_elo, game_number = args
|
86 |
+
model_is_white = game_number % 2 == 0
|
87 |
+
result, move_count = play_game(model, tokenizer, stockfish_path, stockfish_elo, model_is_white, game_number)
|
88 |
+
return result, game_number, move_count
|
89 |
+
|
90 |
+
def calculate_elo_from_win_rate(win_rate, opponent_elo):
|
91 |
+
"""Calculate ELO based on win rate against an opponent."""
|
92 |
+
if win_rate == 0:
|
93 |
+
return float('-inf')
|
94 |
+
if win_rate == 1:
|
95 |
+
return float('inf')
|
96 |
+
elo_diff = -400 * math.log10(1 / win_rate - 1)
|
97 |
+
return opponent_elo + elo_diff
|
98 |
+
|
99 |
+
def main():
|
100 |
+
mp.set_start_method('spawn') # Set start method to 'spawn' for CUDA support
|
101 |
+
|
102 |
+
num_games = args.games
|
103 |
+
stockfish_elo = args.stockfish_elo
|
104 |
+
stockfish_path = args.stockfish_path
|
105 |
+
|
106 |
+
logger.info(f"Starting tournament: {num_games} games, Stockfish ELO: {stockfish_elo}")
|
107 |
+
|
108 |
+
model = setup_model()
|
109 |
+
tokenizer = tk.Tokenizer()
|
110 |
+
|
111 |
+
num_processes = args.cores
|
112 |
+
logger.info(f"Using {num_processes} CPU cores for parallel processing")
|
113 |
+
|
114 |
+
tasks = [(model, tokenizer, stockfish_path, stockfish_elo, i) for i in range(num_games)]
|
115 |
+
|
116 |
+
results = []
|
117 |
+
with mp.Pool(processes=num_processes) as pool:
|
118 |
+
with tqdm(total=num_games, desc="Games Progress") as pbar:
|
119 |
+
for result in pool.imap_unordered(worker, tasks):
|
120 |
+
results.append(result)
|
121 |
+
pbar.update()
|
122 |
+
|
123 |
+
# Process results
|
124 |
+
wins = draws = losses = 0
|
125 |
+
total_moves = 0
|
126 |
+
for result, game_number, move_count in results:
|
127 |
+
if result == "1-0" and game_number % 2 == 0:
|
128 |
+
wins += 1
|
129 |
+
elif result == "0-1" and game_number % 2 == 1:
|
130 |
+
wins += 1
|
131 |
+
elif result == "1/2-1/2":
|
132 |
+
draws += 1
|
133 |
+
else:
|
134 |
+
losses += 1
|
135 |
+
total_moves += move_count
|
136 |
+
|
137 |
+
win_rate = (wins + 0.5 * draws) / num_games
|
138 |
+
final_model_elo = calculate_elo_from_win_rate(win_rate, stockfish_elo)
|
139 |
+
elo_change = final_model_elo - stockfish_elo
|
140 |
+
|
141 |
+
logger.info("Tournament completed. Final results:")
|
142 |
+
logger.info(f"Total games: {num_games}")
|
143 |
+
logger.info(f"Wins: {wins}, Losses: {losses}, Draws: {draws}")
|
144 |
+
logger.info(f"Win rate: {win_rate:.2%}")
|
145 |
+
logger.info(f"Average moves per game: {total_moves/num_games:.2f}")
|
146 |
+
logger.info(f"Stockfish ELO: {stockfish_elo}")
|
147 |
+
logger.info(f"Final Model ELO: {final_model_elo:.2f}")
|
148 |
+
logger.info(f"ELO Change: {elo_change:+.2f}")
|
149 |
+
|
150 |
+
if __name__ == "__main__":
|
151 |
+
main()
|
chesstransformer.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
|
6 |
+
class PositionalEncoding(nn.Module):
|
7 |
+
def __init__(self, d_model, max_len=5000):
|
8 |
+
super(PositionalEncoding, self).__init__()
|
9 |
+
pe = torch.zeros(max_len, d_model)
|
10 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
11 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
12 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
13 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
14 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
15 |
+
self.register_buffer('pe', pe)
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
x = x + self.pe[:x.size(0), :]
|
19 |
+
return x
|
20 |
+
|
21 |
+
class StochasticDepth(nn.Module):
|
22 |
+
def __init__(self, p=0.8):
|
23 |
+
super().__init__()
|
24 |
+
self.p = p
|
25 |
+
|
26 |
+
def forward(self, x, residual):
|
27 |
+
if self.training:
|
28 |
+
if torch.rand(1).item() < self.p:
|
29 |
+
return x + residual
|
30 |
+
else:
|
31 |
+
return x
|
32 |
+
else:
|
33 |
+
return x + self.p * residual
|
34 |
+
|
35 |
+
class AdvancedTransformerLayer(nn.Module):
|
36 |
+
def __init__(self, d_model, nhead, dropout=0.1, stoch_depth_p=0.8):
|
37 |
+
super().__init__()
|
38 |
+
dim_feedforward = 4 * d_model
|
39 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
40 |
+
self.ff = nn.Sequential(
|
41 |
+
nn.Linear(d_model, dim_feedforward),
|
42 |
+
nn.ReLU(),
|
43 |
+
nn.Linear(dim_feedforward, d_model)
|
44 |
+
)
|
45 |
+
self.norm1 = nn.LayerNorm(d_model)
|
46 |
+
self.norm2 = nn.LayerNorm(d_model)
|
47 |
+
self.dropout = nn.Dropout(dropout)
|
48 |
+
self.stoch_depth = StochasticDepth(stoch_depth_p)
|
49 |
+
|
50 |
+
def forward(self, x, src_mask=None, src_key_padding_mask=None):
|
51 |
+
# x shape: (seq_len, batch_size, d_model)
|
52 |
+
norm_x = self.norm1(x)
|
53 |
+
|
54 |
+
# Convert boolean mask to float mask
|
55 |
+
if src_key_padding_mask is not None:
|
56 |
+
src_key_padding_mask = src_key_padding_mask.float().masked_fill(
|
57 |
+
src_key_padding_mask, float('-inf')).masked_fill(~src_key_padding_mask, float(0.0))
|
58 |
+
|
59 |
+
attn_output, _ = self.self_attn(norm_x, norm_x, norm_x,
|
60 |
+
attn_mask=src_mask,
|
61 |
+
key_padding_mask=src_key_padding_mask)
|
62 |
+
x = self.stoch_depth(x, self.dropout(attn_output))
|
63 |
+
|
64 |
+
norm_x = self.norm2(x)
|
65 |
+
ff_output = self.ff(norm_x)
|
66 |
+
x = self.stoch_depth(x, self.dropout(ff_output))
|
67 |
+
return x
|
68 |
+
|
69 |
+
class ChessTransformer(nn.Module):
|
70 |
+
def __init__(self, num_layers=64, d_model=1024, nhead=8, dropout=0.1, stoch_depth_p=0.9, num_tokens=2066, pad_token_id=2064):
|
71 |
+
super().__init__()
|
72 |
+
self.embedding = nn.Embedding(num_tokens, d_model)
|
73 |
+
self.pos_encoder = PositionalEncoding(d_model)
|
74 |
+
self.layers = nn.ModuleList([
|
75 |
+
AdvancedTransformerLayer(d_model, nhead, dropout, stoch_depth_p)
|
76 |
+
for _ in range(num_layers)
|
77 |
+
])
|
78 |
+
self.norm = nn.LayerNorm(d_model)
|
79 |
+
self.output = nn.Linear(d_model, num_tokens)
|
80 |
+
self.d_model = d_model
|
81 |
+
self.padding_idx = pad_token_id
|
82 |
+
|
83 |
+
def generate_square_subsequent_mask(self, sz):
|
84 |
+
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
85 |
+
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
86 |
+
return mask
|
87 |
+
|
88 |
+
def pad_sequences(self, sequences):
|
89 |
+
padding_value = self.padding_idx
|
90 |
+
max_len = max(len(seq) for seq in sequences)
|
91 |
+
padded_seqs = [seq + [padding_value] * (max_len - len(seq)) for seq in sequences]
|
92 |
+
return torch.LongTensor(padded_seqs)
|
93 |
+
|
94 |
+
def forward(self, x):
|
95 |
+
# x shape: (batch_size, seq_len)
|
96 |
+
batch_size, seq_len = x.size()
|
97 |
+
|
98 |
+
# Create padding mask
|
99 |
+
padding_mask = (x == self.padding_idx)
|
100 |
+
|
101 |
+
# Create causal mask
|
102 |
+
causal_mask = self.generate_square_subsequent_mask(seq_len).to(x.device)
|
103 |
+
|
104 |
+
# Embed and add positional encoding
|
105 |
+
x = self.embedding(x).transpose(0, 1) * math.sqrt(self.d_model)
|
106 |
+
x = self.pos_encoder(x)
|
107 |
+
|
108 |
+
# Pass through each layer
|
109 |
+
for layer in self.layers:
|
110 |
+
x = layer(x, src_mask=causal_mask, src_key_padding_mask=padding_mask)
|
111 |
+
|
112 |
+
x = self.norm(x)
|
113 |
+
output = self.output(x.transpose(0, 1))
|
114 |
+
|
115 |
+
return output
|
116 |
+
|
117 |
+
def winning_moves_loss(output, ground_truth, win_labels, pad_token_id=2064, start_token_id=2065):
|
118 |
+
"""
|
119 |
+
Compute the loss only for the winning moves of white and black.
|
120 |
+
"""
|
121 |
+
output = output.cuda()
|
122 |
+
ground_truth = ground_truth.cuda()
|
123 |
+
win_labels = win_labels.cuda()
|
124 |
+
|
125 |
+
batch_size, seq_len, num_tokens = output.shape
|
126 |
+
|
127 |
+
# Shift the ground truth to align with the output predictions
|
128 |
+
ground_truth_shifted = ground_truth[:, 1:].contiguous()
|
129 |
+
output_shifted = output[:, :-1, :].contiguous()
|
130 |
+
|
131 |
+
# Flatten the output and ground truth for easier masking
|
132 |
+
output_flat = output_shifted.view(-1, num_tokens)
|
133 |
+
ground_truth_flat = ground_truth_shifted.view(-1)
|
134 |
+
|
135 |
+
# Apply log softmax to the flattened output
|
136 |
+
output_log_softmax = F.log_softmax(output_flat, dim=-1)
|
137 |
+
|
138 |
+
# Repeat win_labels for each move in the sequence
|
139 |
+
win_labels_expanded = win_labels.unsqueeze(1).repeat(1, seq_len - 1).view(-1)
|
140 |
+
|
141 |
+
# Create a mask for the winning moves
|
142 |
+
move_indices = torch.arange(seq_len - 1, device=output.device).unsqueeze(0).repeat(batch_size, 1).view(-1)
|
143 |
+
white_win_mask = (win_labels_expanded == 1) & (move_indices % 2 == 0)
|
144 |
+
black_win_mask = (win_labels_expanded == 0) & (move_indices % 2 == 1)
|
145 |
+
|
146 |
+
# Combine the masks
|
147 |
+
selected_moves_mask = (white_win_mask | black_win_mask) & (ground_truth_flat != pad_token_id) & (ground_truth_flat != start_token_id)
|
148 |
+
|
149 |
+
# Calculate the negative log-likelihood loss only for the selected moves
|
150 |
+
loss = F.nll_loss(output_log_softmax, ground_truth_flat, reduction='none')
|
151 |
+
|
152 |
+
loss = loss * selected_moves_mask.float()
|
153 |
+
|
154 |
+
# Average the loss over the selected moves
|
155 |
+
selected_moves_count = selected_moves_mask.float().sum()
|
156 |
+
if selected_moves_count > 0:
|
157 |
+
loss = loss.sum() / selected_moves_count
|
158 |
+
else:
|
159 |
+
loss = loss.sum() # If no moves are selected, return 0 loss
|
160 |
+
|
161 |
+
return loss
|
162 |
+
|
163 |
+
def all_moves_loss(output, ground_truth, pad_token_id=2064, start_token_id=2065):
|
164 |
+
"""
|
165 |
+
Compute the loss for all valid moves in the sequence, excluding start and padding tokens.
|
166 |
+
"""
|
167 |
+
batch_size, seq_len, num_tokens = output.shape
|
168 |
+
|
169 |
+
output = output.cuda()
|
170 |
+
ground_truth = ground_truth.cuda()
|
171 |
+
|
172 |
+
# Shift the output and ground truth to align them
|
173 |
+
output_shifted = output[:, :-1, :].contiguous()
|
174 |
+
ground_truth_shifted = ground_truth[:, 1:].contiguous()
|
175 |
+
|
176 |
+
# Flatten the shifted output and ground truth
|
177 |
+
output_flat = output_shifted.view(-1, num_tokens)
|
178 |
+
ground_truth_flat = ground_truth_shifted.view(-1)
|
179 |
+
|
180 |
+
# Apply log softmax to the flattened output
|
181 |
+
output_log_softmax = F.log_softmax(output_flat, dim=-1)
|
182 |
+
|
183 |
+
# Create a mask for all valid moves (excluding padding and start tokens)
|
184 |
+
valid_moves_mask = ((ground_truth_flat != pad_token_id) &
|
185 |
+
(ground_truth_flat != start_token_id))
|
186 |
+
|
187 |
+
# Calculate the negative log-likelihood loss for all moves
|
188 |
+
loss = F.nll_loss(output_log_softmax, ground_truth_flat, reduction='none')
|
189 |
+
|
190 |
+
# Apply the mask to exclude padding and start tokens
|
191 |
+
loss = loss * valid_moves_mask.float()
|
192 |
+
|
193 |
+
# Average the loss over all valid moves
|
194 |
+
valid_moves_count = valid_moves_mask.float().sum()
|
195 |
+
if valid_moves_count > 0:
|
196 |
+
loss = loss.sum() / valid_moves_count
|
197 |
+
else:
|
198 |
+
loss = loss.sum() # If no valid moves, return 0 loss
|
199 |
+
|
200 |
+
return loss
|
201 |
+
|
202 |
+
def weighted_chess_loss(output, ground_truth, win_labels, winning_weight=1.0, losing_weight=0.1, pad_token_id=2064, start_token_id=2065):
|
203 |
+
"""
|
204 |
+
Compute a weighted loss for all moves, with higher weight for winning moves.
|
205 |
+
"""
|
206 |
+
output = output.cuda()
|
207 |
+
ground_truth = ground_truth.cuda()
|
208 |
+
win_labels = win_labels.cuda()
|
209 |
+
|
210 |
+
batch_size, seq_len, num_tokens = output.shape
|
211 |
+
|
212 |
+
# Shift the ground truth to align with the output predictions
|
213 |
+
ground_truth_shifted = ground_truth[:, 1:].contiguous()
|
214 |
+
output_shifted = output[:, :-1, :].contiguous()
|
215 |
+
|
216 |
+
# Flatten the output and ground truth for easier masking
|
217 |
+
output_flat = output_shifted.view(-1, num_tokens)
|
218 |
+
ground_truth_flat = ground_truth_shifted.view(-1)
|
219 |
+
|
220 |
+
# Apply log softmax to the flattened output
|
221 |
+
output_log_softmax = F.log_softmax(output_flat, dim=-1)
|
222 |
+
|
223 |
+
# Repeat win_labels for each move in the sequence
|
224 |
+
win_labels_expanded = win_labels.unsqueeze(1).repeat(1, seq_len - 1).view(-1)
|
225 |
+
|
226 |
+
# Create masks for winning and losing moves
|
227 |
+
move_indices = torch.arange(seq_len - 1, device=output.device).unsqueeze(0).repeat(batch_size, 1).view(-1)
|
228 |
+
white_win_mask = (win_labels_expanded == 1) & (move_indices % 2 == 0)
|
229 |
+
black_win_mask = (win_labels_expanded == 0) & (move_indices % 2 == 1)
|
230 |
+
winning_moves_mask = white_win_mask | black_win_mask
|
231 |
+
|
232 |
+
# Create a mask for all valid moves (excluding padding and start tokens)
|
233 |
+
valid_moves_mask = (ground_truth_flat != pad_token_id) & (ground_truth_flat != start_token_id)
|
234 |
+
|
235 |
+
# Calculate the negative log-likelihood loss for all valid moves
|
236 |
+
loss = F.nll_loss(output_log_softmax, ground_truth_flat, reduction='none')
|
237 |
+
|
238 |
+
# Apply weights based on whether the move is winning or losing
|
239 |
+
weights = torch.where(winning_moves_mask & valid_moves_mask, winning_weight, losing_weight)
|
240 |
+
|
241 |
+
# Apply the weights and the valid moves mask to the loss
|
242 |
+
weighted_loss = loss * weights * valid_moves_mask.float()
|
243 |
+
|
244 |
+
# Average the loss over all valid moves
|
245 |
+
valid_moves_count = valid_moves_mask.float().sum()
|
246 |
+
if valid_moves_count > 0:
|
247 |
+
avg_loss = weighted_loss.sum() / valid_moves_count
|
248 |
+
else:
|
249 |
+
avg_loss = weighted_loss.sum() # If no valid moves, return 0 loss
|
250 |
+
|
251 |
+
return avg_loss
|
environment.yml
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: chessbot
|
2 |
+
channels:
|
3 |
+
- defaults
|
4 |
+
dependencies:
|
5 |
+
- _libgcc_mutex=0.1=main
|
6 |
+
- _openmp_mutex=5.1=1_gnu
|
7 |
+
- bzip2=1.0.8=h5eee18b_6
|
8 |
+
- ca-certificates=2024.9.24=h06a4308_0
|
9 |
+
- expat=2.6.3=h6a678d5_0
|
10 |
+
- ld_impl_linux-64=2.40=h12ee557_0
|
11 |
+
- libffi=3.4.4=h6a678d5_1
|
12 |
+
- libgcc-ng=11.2.0=h1234567_1
|
13 |
+
- libgomp=11.2.0=h1234567_1
|
14 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
15 |
+
- libuuid=1.41.5=h5eee18b_0
|
16 |
+
- ncurses=6.4=h6a678d5_0
|
17 |
+
- openssl=3.0.15=h5eee18b_0
|
18 |
+
- pip=24.2=py312h06a4308_0
|
19 |
+
- python=3.12.7=h5148396_0
|
20 |
+
- readline=8.2=h5eee18b_0
|
21 |
+
- setuptools=75.1.0=py312h06a4308_0
|
22 |
+
- sqlite=3.45.3=h5eee18b_0
|
23 |
+
- tk=8.6.14=h39e8969_0
|
24 |
+
- wheel=0.44.0=py312h06a4308_0
|
25 |
+
- xz=5.4.6=h5eee18b_1
|
26 |
+
- zlib=1.2.13=h5eee18b_1
|
27 |
+
- pip:
|
28 |
+
- absl-py==2.1.0
|
29 |
+
- chess==1.11.0
|
30 |
+
- filelock==3.13.1
|
31 |
+
- fsspec==2024.2.0
|
32 |
+
- grpcio==1.66.2
|
33 |
+
- jinja2==3.1.3
|
34 |
+
- markdown==3.7
|
35 |
+
- markupsafe==2.1.5
|
36 |
+
- mpmath==1.3.0
|
37 |
+
- networkx==3.2.1
|
38 |
+
- numpy==2.1.2
|
39 |
+
- nvidia-cublas-cu12==12.4.2.65
|
40 |
+
- nvidia-cuda-cupti-cu12==12.4.99
|
41 |
+
- nvidia-cuda-nvrtc-cu12==12.4.99
|
42 |
+
- nvidia-cuda-runtime-cu12==12.4.99
|
43 |
+
- nvidia-cudnn-cu12==9.1.0.70
|
44 |
+
- nvidia-cufft-cu12==11.2.0.44
|
45 |
+
- nvidia-curand-cu12==10.3.5.119
|
46 |
+
- nvidia-cusolver-cu12==11.6.0.99
|
47 |
+
- nvidia-cusparse-cu12==12.3.0.142
|
48 |
+
- nvidia-nccl-cu12==2.20.5
|
49 |
+
- nvidia-nvjitlink-cu12==12.4.99
|
50 |
+
- nvidia-nvtx-cu12==12.4.99
|
51 |
+
- packaging==24.1
|
52 |
+
- pandas==2.2.3
|
53 |
+
- protobuf==5.28.2
|
54 |
+
- pyarrow==17.0.0
|
55 |
+
- python-dateutil==2.9.0.post0
|
56 |
+
- pytz==2024.2
|
57 |
+
- six==1.16.0
|
58 |
+
- sympy==1.12
|
59 |
+
- tensorboard==2.18.0
|
60 |
+
- tensorboard-data-server==0.7.2
|
61 |
+
- torch==2.4.1+cu124
|
62 |
+
- tqdm==4.66.5
|
63 |
+
- triton==3.0.0
|
64 |
+
- typing-extensions==4.9.0
|
65 |
+
- tzdata==2024.2
|
66 |
+
- werkzeug==3.0.4
|
play.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from chesstransformer import ChessTransformer
|
4 |
+
import tokenizer as tk
|
5 |
+
|
6 |
+
model = ChessTransformer()
|
7 |
+
model.load_state_dict(torch.load('./64L1024D_1e-3maxlr_470k_step_1ep_1480ELO.pth')["model_state_dict"])
|
8 |
+
model.eval().cuda()
|
9 |
+
|
10 |
+
# Initialize tokenizer
|
11 |
+
t = tk.Tokenizer()
|
12 |
+
|
13 |
+
def predict_move(model, game_sequence, tokenizer, device='cuda', top_k=5):
|
14 |
+
model.eval()
|
15 |
+
game_sequence = torch.tensor([tokenizer.tokenize_game(game_sequence)], dtype=torch.long).to(device)
|
16 |
+
|
17 |
+
with torch.no_grad():
|
18 |
+
output = model(game_sequence)
|
19 |
+
logits = output[0, -1, :] # Get logits for the last move
|
20 |
+
top_k_logits, top_k_indices = torch.topk(logits, top_k)
|
21 |
+
|
22 |
+
# Apply softmax to get probabilities
|
23 |
+
probs = F.softmax(top_k_logits, dim=-1)
|
24 |
+
|
25 |
+
# Sample from the probability distribution
|
26 |
+
sampled_index = torch.multinomial(probs, 1).item()
|
27 |
+
sampled_token = top_k_indices[sampled_index].item()
|
28 |
+
|
29 |
+
sampled_move = tokenizer.untokenize_game([sampled_token])[0]
|
30 |
+
|
31 |
+
# Get all top_k moves and their probabilities for display
|
32 |
+
top_k_moves = [tokenizer.untokenize_game([idx.item()])[0] for idx in top_k_indices]
|
33 |
+
top_k_probs = probs.cpu().numpy()
|
34 |
+
|
35 |
+
return sampled_move, top_k_moves, top_k_probs
|
36 |
+
|
37 |
+
def play_game():
|
38 |
+
input_game = []
|
39 |
+
print("Let's play chess! Enter your moves in UCI format (e.g., 'e2e4'). Type 'exit' to quit or 'undo' to undo the last move.")
|
40 |
+
|
41 |
+
while True:
|
42 |
+
user_move = input("Your move: ").strip()
|
43 |
+
if user_move.lower() == 'exit':
|
44 |
+
print("Game over. Thanks for playing!")
|
45 |
+
break
|
46 |
+
elif user_move.lower() == 'undo':
|
47 |
+
if len(input_game) >= 2:
|
48 |
+
input_game.pop() # Remove bot's move
|
49 |
+
input_game.pop() # Remove user's move
|
50 |
+
print("Last move undone. Current game sequence:", input_game)
|
51 |
+
else:
|
52 |
+
print("Cannot undo. No moves to undo.")
|
53 |
+
continue
|
54 |
+
|
55 |
+
input_game.append(user_move)
|
56 |
+
print("Current game sequence:", input_game)
|
57 |
+
|
58 |
+
try:
|
59 |
+
bot_move, top_moves, top_probs = predict_move(model, input_game, t)
|
60 |
+
|
61 |
+
# Display top moves and their probabilities
|
62 |
+
moves_probs_str = ', '.join(f"{move} ({prob:.2%})" for move, prob in zip(top_moves, top_probs))
|
63 |
+
print(f"Top {len(top_moves)} moves and probabilities: {moves_probs_str}")
|
64 |
+
|
65 |
+
print(f"Bot's sampled move: {bot_move}")
|
66 |
+
input_game.append(bot_move)
|
67 |
+
except Exception as e:
|
68 |
+
print("An error occurred:", e)
|
69 |
+
break
|
70 |
+
|
71 |
+
if __name__ == "__main__":
|
72 |
+
play_game()
|
tokenizer.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class Tokenizer:
|
2 |
+
def __init__(self):
|
3 |
+
self.move_dict = create_move_dict()
|
4 |
+
self.inverse_dict = inverse_move_dict(self.move_dict)
|
5 |
+
|
6 |
+
def tokenize_game(self, moves_list):
|
7 |
+
tokenized_moves = []
|
8 |
+
for move in moves_list:
|
9 |
+
tokenized_moves.append(self.move_dict[move])
|
10 |
+
return tokenized_moves
|
11 |
+
|
12 |
+
def untokenize_game(self, tokenized_moves):
|
13 |
+
inverse_moves = []
|
14 |
+
for move in tokenized_moves:
|
15 |
+
if move == 2064:
|
16 |
+
inverse_moves.append("[pad]")
|
17 |
+
continue
|
18 |
+
if move == 2065:
|
19 |
+
inverse_moves.append("[start]")
|
20 |
+
continue
|
21 |
+
inverse_moves.append(self.inverse_dict[move])
|
22 |
+
return inverse_moves
|
23 |
+
|
24 |
+
def tokenize_move(self, move):
|
25 |
+
return self.move_dict[move]
|
26 |
+
|
27 |
+
def get_move(self, tokenized_move):
|
28 |
+
return self.inverse_dict[tokenized_move]
|
29 |
+
|
30 |
+
|
31 |
+
# Helper function to convert square index to algebraic notation
|
32 |
+
def square_to_algebraic(square):
|
33 |
+
files = 'abcdefgh'
|
34 |
+
ranks = '12345678'
|
35 |
+
file = files[square % 8]
|
36 |
+
rank = ranks[square // 8]
|
37 |
+
return file + rank
|
38 |
+
|
39 |
+
# Modified chess_moves function to account for all moves
|
40 |
+
def chess_moves(starting_square):
|
41 |
+
moves = []
|
42 |
+
ss = starting_square
|
43 |
+
|
44 |
+
# Calculate file and rank
|
45 |
+
file_start = (ss // 8) * 8
|
46 |
+
file_end = file_start + 7
|
47 |
+
|
48 |
+
# Horizontal moves - to left
|
49 |
+
for i in range(ss - 1, file_start - 1, -1):
|
50 |
+
moves.append((ss, i))
|
51 |
+
|
52 |
+
# Horizontal moves - to right
|
53 |
+
for i in range(ss + 1, file_end + 1):
|
54 |
+
moves.append((ss, i))
|
55 |
+
|
56 |
+
# Vertical moves - above
|
57 |
+
for i in range(ss + 8, 64, 8):
|
58 |
+
moves.append((ss, i))
|
59 |
+
|
60 |
+
# Vertical moves - below
|
61 |
+
for i in range(ss - 8, -1, -8):
|
62 |
+
moves.append((ss, i))
|
63 |
+
|
64 |
+
# Diagonal moves
|
65 |
+
# Upper left
|
66 |
+
i = ss
|
67 |
+
while (i := i + 7) < 64 and i % 8 != 7:
|
68 |
+
moves.append((ss, i))
|
69 |
+
|
70 |
+
# Lower left
|
71 |
+
i = ss
|
72 |
+
while (i := i - 9) >= 0 and i % 8 != 7:
|
73 |
+
moves.append((ss, i))
|
74 |
+
|
75 |
+
# Upper right
|
76 |
+
i = ss
|
77 |
+
while (i := i + 9) < 64 and i % 8 != 0:
|
78 |
+
moves.append((ss, i))
|
79 |
+
|
80 |
+
# Lower right
|
81 |
+
i = ss
|
82 |
+
while (i := i - 7) >= 0 and i % 8 != 0:
|
83 |
+
moves.append((ss, i))
|
84 |
+
|
85 |
+
# Inner 5x5 square
|
86 |
+
for j in range(-2, 3):
|
87 |
+
for i in range(-2, 3):
|
88 |
+
target = ss + i + j * 8
|
89 |
+
if 0 <= target < 64 and (target // 8 == (ss // 8) + j) and target != ss:
|
90 |
+
moves.append((ss, target))
|
91 |
+
|
92 |
+
# Pawn moves (including promotions)
|
93 |
+
if ss // 8 == 1: # White pawn's initial position
|
94 |
+
if ss + 8 < 64:
|
95 |
+
moves.append((ss, ss + 8))
|
96 |
+
if (ss + 16) < 64:
|
97 |
+
moves.append((ss, ss + 16))
|
98 |
+
if ss + 9 < 64 and (ss + 9) % 8 != 0:
|
99 |
+
moves.append((ss, ss + 9))
|
100 |
+
if ss + 7 < 64 and (ss + 7) % 8 != 7:
|
101 |
+
moves.append((ss, ss + 7))
|
102 |
+
elif ss // 8 == 6: # Black pawn's initial position
|
103 |
+
if ss - 8 >= 0:
|
104 |
+
moves.append((ss, ss - 8))
|
105 |
+
if (ss - 16) >= 0:
|
106 |
+
moves.append((ss, ss - 16))
|
107 |
+
if ss - 9 >= 0 and (ss - 9) % 8 != 7:
|
108 |
+
moves.append((ss, ss - 9))
|
109 |
+
if ss - 7 >= 0 and (ss - 7) % 8 != 0:
|
110 |
+
moves.append((ss, ss - 7))
|
111 |
+
|
112 |
+
#remove duplicate tuples
|
113 |
+
seen = set()
|
114 |
+
result = []
|
115 |
+
for item in moves:
|
116 |
+
if item not in seen:
|
117 |
+
seen.add(item)
|
118 |
+
result.append(item)
|
119 |
+
|
120 |
+
return result
|
121 |
+
|
122 |
+
|
123 |
+
# Function to create a dictionary of moves with promotion
|
124 |
+
def create_move_dict():
|
125 |
+
move_dict = {}
|
126 |
+
count = 0
|
127 |
+
promotion_pieces = ['q', 'r', 'b', 'n'] # Queen, Rook, Bishop, Knight
|
128 |
+
|
129 |
+
for i in range(64):
|
130 |
+
for move in chess_moves(i):
|
131 |
+
start_sq_algebraic = square_to_algebraic(move[0])
|
132 |
+
end_sq_algebraic = square_to_algebraic(move[1])
|
133 |
+
move_dict[f"{start_sq_algebraic}{end_sq_algebraic}"] = count
|
134 |
+
count += 1
|
135 |
+
# Add promotions if applicable
|
136 |
+
if move[1] // 8 == 7 and i // 8 == 6: # White pawn reaching last rank
|
137 |
+
for piece in promotion_pieces:
|
138 |
+
move_dict[f"{start_sq_algebraic}{end_sq_algebraic}{piece}"] = count
|
139 |
+
count += 1
|
140 |
+
elif move[1] // 8 == 0 and i // 8 == 1: # Black pawn reaching last rank
|
141 |
+
for piece in promotion_pieces:
|
142 |
+
move_dict[f"{start_sq_algebraic}{end_sq_algebraic}{piece}"] = count
|
143 |
+
count += 1
|
144 |
+
|
145 |
+
move_dict["pad"] = 2064
|
146 |
+
move_dict["start"] = 2065
|
147 |
+
return move_dict
|
148 |
+
|
149 |
+
def inverse_move_dict(move_dict):
|
150 |
+
inverse_dict = {}
|
151 |
+
for k, v in move_dict.items():
|
152 |
+
inverse_dict[v] = k
|
153 |
+
return inverse_dict
|
154 |
+
|
155 |
+
def tokenize_game(moves_list):
|
156 |
+
move_dict = create_move_dict()
|
157 |
+
tokenized_moves = []
|
158 |
+
for move in moves_list:
|
159 |
+
tokenized_moves.append(move_dict[move])
|
160 |
+
return tokenized_moves
|
161 |
+
|
162 |
+
if __name__ == "__main__":
|
163 |
+
t = Tokenizer()
|