RingoDingo commited on
Commit
aa2269b
·
verified ·
1 Parent(s): 1e50c52

Upload 6 files

Browse files
64L1024D_1e-3maxlr_470k_step_1ep_1480ELO.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82fb0554f04255f854344432380ba0719af4e14c631ff8a0c9905a8e99cfbaf2
3
+ size 9746197380
autoplay_muliproc.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import chess
3
+ import chess.engine
4
+ import logging
5
+ import math
6
+ import argparse
7
+ import multiprocessing as mp
8
+ from chesstransformer import ChessTransformer
9
+ import tokenizer as tk
10
+ from tqdm import tqdm
11
+
12
+ # Set up logging
13
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(processName)s - %(levelname)s - %(message)s')
14
+ logger = logging.getLogger(__name__)
15
+
16
+ parser = argparse.ArgumentParser(description='Chess Transformer Testing')
17
+ parser.add_argument('--cores', type=int, default=2, help='Cores to use for CPU chess engine')
18
+ parser.add_argument('--games', type=int, default=10, help='Number of games to play')
19
+ parser.add_argument('--stockfish_elo', type=int, default=1320, help='ELO rating for Stockfish. Min 1320')
20
+ parser.add_argument('--stockfish_path', type=str, default='./stockfish/stockfish-ubuntu-x86-64', help='Path to Stockfish binary')
21
+
22
+ args = parser.parse_args()
23
+
24
+ def setup_model():
25
+ logger.info("Loading ChessTransformer model...")
26
+ model = ChessTransformer()
27
+ model.load_state_dict(torch.load('./64L1024D_1e-3maxlr_470k_step_1ep_1480ELO.pth')["model_state_dict"])
28
+ model.eval().cuda()
29
+ logger.info("Model loaded successfully.")
30
+ return model
31
+
32
+ def predict_top_k_moves(model, tokenizer, game_sequence, k=100, device='cuda'):
33
+ game_sequence = torch.tensor([tokenizer.tokenize_game(game_sequence)], dtype=torch.long).to(device)
34
+
35
+ with torch.no_grad():
36
+ output = model(game_sequence)
37
+ next_move = output[0, -1, :]
38
+ next_softmax = torch.nn.functional.softmax(next_move, dim=-1)
39
+ top_k_probs, top_k_indices = torch.topk(next_softmax, k)
40
+ top_k_moves = [tokenizer.get_move(idx.item()) for idx in top_k_indices]
41
+
42
+ return list(zip(top_k_moves, top_k_probs.tolist()))
43
+
44
+ def get_legal_move(board, moves):
45
+ for move, prob in moves:
46
+ try:
47
+ if chess.Move.from_uci(move) in board.legal_moves:
48
+ return move, prob
49
+ except ValueError:
50
+ continue
51
+ return None, None
52
+
53
+ def play_game(model, tokenizer, stockfish_path, stockfish_elo, model_is_white, game_number):
54
+ #logger.info(f"Game {game_number}: Starting. Model playing as {'white' if model_is_white else 'black'}")
55
+ engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
56
+ engine.configure({"UCI_LimitStrength": True, "UCI_Elo": stockfish_elo})
57
+
58
+ board = chess.Board()
59
+ game_sequence = ['start']
60
+ move_count = 0
61
+
62
+ while not board.is_game_over():
63
+ move_count += 1
64
+ if (board.turn == chess.WHITE) == model_is_white:
65
+ top_k_moves = predict_top_k_moves(model, tokenizer, game_sequence)
66
+ legal_move, prob = get_legal_move(board, top_k_moves)
67
+ if legal_move is None:
68
+ logger.warning(f"Game {game_number}: No legal moves found in top-k on move {move_count}. Game over.")
69
+ return "0-1" if model_is_white else "1-0", move_count
70
+ board.push_uci(legal_move)
71
+ game_sequence.append(legal_move)
72
+ logger.debug(f"Game {game_number}: Model's move: {legal_move} (probability: {prob:.4f})")
73
+ else:
74
+ result = engine.play(board, chess.engine.Limit(time=0.1))
75
+ board.push(result.move)
76
+ game_sequence.append(result.move.uci())
77
+ logger.debug(f"Game {game_number}: Stockfish's move: {result.move.uci()}")
78
+
79
+ engine.quit()
80
+ result = board.result()
81
+ #logger.info(f"Game {game_number}: Finished. Result: {result}. Total moves: {move_count}")
82
+ return result, move_count
83
+
84
+ def worker(args):
85
+ model, tokenizer, stockfish_path, stockfish_elo, game_number = args
86
+ model_is_white = game_number % 2 == 0
87
+ result, move_count = play_game(model, tokenizer, stockfish_path, stockfish_elo, model_is_white, game_number)
88
+ return result, game_number, move_count
89
+
90
+ def calculate_elo_from_win_rate(win_rate, opponent_elo):
91
+ """Calculate ELO based on win rate against an opponent."""
92
+ if win_rate == 0:
93
+ return float('-inf')
94
+ if win_rate == 1:
95
+ return float('inf')
96
+ elo_diff = -400 * math.log10(1 / win_rate - 1)
97
+ return opponent_elo + elo_diff
98
+
99
+ def main():
100
+ mp.set_start_method('spawn') # Set start method to 'spawn' for CUDA support
101
+
102
+ num_games = args.games
103
+ stockfish_elo = args.stockfish_elo
104
+ stockfish_path = args.stockfish_path
105
+
106
+ logger.info(f"Starting tournament: {num_games} games, Stockfish ELO: {stockfish_elo}")
107
+
108
+ model = setup_model()
109
+ tokenizer = tk.Tokenizer()
110
+
111
+ num_processes = args.cores
112
+ logger.info(f"Using {num_processes} CPU cores for parallel processing")
113
+
114
+ tasks = [(model, tokenizer, stockfish_path, stockfish_elo, i) for i in range(num_games)]
115
+
116
+ results = []
117
+ with mp.Pool(processes=num_processes) as pool:
118
+ with tqdm(total=num_games, desc="Games Progress") as pbar:
119
+ for result in pool.imap_unordered(worker, tasks):
120
+ results.append(result)
121
+ pbar.update()
122
+
123
+ # Process results
124
+ wins = draws = losses = 0
125
+ total_moves = 0
126
+ for result, game_number, move_count in results:
127
+ if result == "1-0" and game_number % 2 == 0:
128
+ wins += 1
129
+ elif result == "0-1" and game_number % 2 == 1:
130
+ wins += 1
131
+ elif result == "1/2-1/2":
132
+ draws += 1
133
+ else:
134
+ losses += 1
135
+ total_moves += move_count
136
+
137
+ win_rate = (wins + 0.5 * draws) / num_games
138
+ final_model_elo = calculate_elo_from_win_rate(win_rate, stockfish_elo)
139
+ elo_change = final_model_elo - stockfish_elo
140
+
141
+ logger.info("Tournament completed. Final results:")
142
+ logger.info(f"Total games: {num_games}")
143
+ logger.info(f"Wins: {wins}, Losses: {losses}, Draws: {draws}")
144
+ logger.info(f"Win rate: {win_rate:.2%}")
145
+ logger.info(f"Average moves per game: {total_moves/num_games:.2f}")
146
+ logger.info(f"Stockfish ELO: {stockfish_elo}")
147
+ logger.info(f"Final Model ELO: {final_model_elo:.2f}")
148
+ logger.info(f"ELO Change: {elo_change:+.2f}")
149
+
150
+ if __name__ == "__main__":
151
+ main()
chesstransformer.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ class PositionalEncoding(nn.Module):
7
+ def __init__(self, d_model, max_len=5000):
8
+ super(PositionalEncoding, self).__init__()
9
+ pe = torch.zeros(max_len, d_model)
10
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
11
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
12
+ pe[:, 0::2] = torch.sin(position * div_term)
13
+ pe[:, 1::2] = torch.cos(position * div_term)
14
+ pe = pe.unsqueeze(0).transpose(0, 1)
15
+ self.register_buffer('pe', pe)
16
+
17
+ def forward(self, x):
18
+ x = x + self.pe[:x.size(0), :]
19
+ return x
20
+
21
+ class StochasticDepth(nn.Module):
22
+ def __init__(self, p=0.8):
23
+ super().__init__()
24
+ self.p = p
25
+
26
+ def forward(self, x, residual):
27
+ if self.training:
28
+ if torch.rand(1).item() < self.p:
29
+ return x + residual
30
+ else:
31
+ return x
32
+ else:
33
+ return x + self.p * residual
34
+
35
+ class AdvancedTransformerLayer(nn.Module):
36
+ def __init__(self, d_model, nhead, dropout=0.1, stoch_depth_p=0.8):
37
+ super().__init__()
38
+ dim_feedforward = 4 * d_model
39
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
40
+ self.ff = nn.Sequential(
41
+ nn.Linear(d_model, dim_feedforward),
42
+ nn.ReLU(),
43
+ nn.Linear(dim_feedforward, d_model)
44
+ )
45
+ self.norm1 = nn.LayerNorm(d_model)
46
+ self.norm2 = nn.LayerNorm(d_model)
47
+ self.dropout = nn.Dropout(dropout)
48
+ self.stoch_depth = StochasticDepth(stoch_depth_p)
49
+
50
+ def forward(self, x, src_mask=None, src_key_padding_mask=None):
51
+ # x shape: (seq_len, batch_size, d_model)
52
+ norm_x = self.norm1(x)
53
+
54
+ # Convert boolean mask to float mask
55
+ if src_key_padding_mask is not None:
56
+ src_key_padding_mask = src_key_padding_mask.float().masked_fill(
57
+ src_key_padding_mask, float('-inf')).masked_fill(~src_key_padding_mask, float(0.0))
58
+
59
+ attn_output, _ = self.self_attn(norm_x, norm_x, norm_x,
60
+ attn_mask=src_mask,
61
+ key_padding_mask=src_key_padding_mask)
62
+ x = self.stoch_depth(x, self.dropout(attn_output))
63
+
64
+ norm_x = self.norm2(x)
65
+ ff_output = self.ff(norm_x)
66
+ x = self.stoch_depth(x, self.dropout(ff_output))
67
+ return x
68
+
69
+ class ChessTransformer(nn.Module):
70
+ def __init__(self, num_layers=64, d_model=1024, nhead=8, dropout=0.1, stoch_depth_p=0.9, num_tokens=2066, pad_token_id=2064):
71
+ super().__init__()
72
+ self.embedding = nn.Embedding(num_tokens, d_model)
73
+ self.pos_encoder = PositionalEncoding(d_model)
74
+ self.layers = nn.ModuleList([
75
+ AdvancedTransformerLayer(d_model, nhead, dropout, stoch_depth_p)
76
+ for _ in range(num_layers)
77
+ ])
78
+ self.norm = nn.LayerNorm(d_model)
79
+ self.output = nn.Linear(d_model, num_tokens)
80
+ self.d_model = d_model
81
+ self.padding_idx = pad_token_id
82
+
83
+ def generate_square_subsequent_mask(self, sz):
84
+ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
85
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
86
+ return mask
87
+
88
+ def pad_sequences(self, sequences):
89
+ padding_value = self.padding_idx
90
+ max_len = max(len(seq) for seq in sequences)
91
+ padded_seqs = [seq + [padding_value] * (max_len - len(seq)) for seq in sequences]
92
+ return torch.LongTensor(padded_seqs)
93
+
94
+ def forward(self, x):
95
+ # x shape: (batch_size, seq_len)
96
+ batch_size, seq_len = x.size()
97
+
98
+ # Create padding mask
99
+ padding_mask = (x == self.padding_idx)
100
+
101
+ # Create causal mask
102
+ causal_mask = self.generate_square_subsequent_mask(seq_len).to(x.device)
103
+
104
+ # Embed and add positional encoding
105
+ x = self.embedding(x).transpose(0, 1) * math.sqrt(self.d_model)
106
+ x = self.pos_encoder(x)
107
+
108
+ # Pass through each layer
109
+ for layer in self.layers:
110
+ x = layer(x, src_mask=causal_mask, src_key_padding_mask=padding_mask)
111
+
112
+ x = self.norm(x)
113
+ output = self.output(x.transpose(0, 1))
114
+
115
+ return output
116
+
117
+ def winning_moves_loss(output, ground_truth, win_labels, pad_token_id=2064, start_token_id=2065):
118
+ """
119
+ Compute the loss only for the winning moves of white and black.
120
+ """
121
+ output = output.cuda()
122
+ ground_truth = ground_truth.cuda()
123
+ win_labels = win_labels.cuda()
124
+
125
+ batch_size, seq_len, num_tokens = output.shape
126
+
127
+ # Shift the ground truth to align with the output predictions
128
+ ground_truth_shifted = ground_truth[:, 1:].contiguous()
129
+ output_shifted = output[:, :-1, :].contiguous()
130
+
131
+ # Flatten the output and ground truth for easier masking
132
+ output_flat = output_shifted.view(-1, num_tokens)
133
+ ground_truth_flat = ground_truth_shifted.view(-1)
134
+
135
+ # Apply log softmax to the flattened output
136
+ output_log_softmax = F.log_softmax(output_flat, dim=-1)
137
+
138
+ # Repeat win_labels for each move in the sequence
139
+ win_labels_expanded = win_labels.unsqueeze(1).repeat(1, seq_len - 1).view(-1)
140
+
141
+ # Create a mask for the winning moves
142
+ move_indices = torch.arange(seq_len - 1, device=output.device).unsqueeze(0).repeat(batch_size, 1).view(-1)
143
+ white_win_mask = (win_labels_expanded == 1) & (move_indices % 2 == 0)
144
+ black_win_mask = (win_labels_expanded == 0) & (move_indices % 2 == 1)
145
+
146
+ # Combine the masks
147
+ selected_moves_mask = (white_win_mask | black_win_mask) & (ground_truth_flat != pad_token_id) & (ground_truth_flat != start_token_id)
148
+
149
+ # Calculate the negative log-likelihood loss only for the selected moves
150
+ loss = F.nll_loss(output_log_softmax, ground_truth_flat, reduction='none')
151
+
152
+ loss = loss * selected_moves_mask.float()
153
+
154
+ # Average the loss over the selected moves
155
+ selected_moves_count = selected_moves_mask.float().sum()
156
+ if selected_moves_count > 0:
157
+ loss = loss.sum() / selected_moves_count
158
+ else:
159
+ loss = loss.sum() # If no moves are selected, return 0 loss
160
+
161
+ return loss
162
+
163
+ def all_moves_loss(output, ground_truth, pad_token_id=2064, start_token_id=2065):
164
+ """
165
+ Compute the loss for all valid moves in the sequence, excluding start and padding tokens.
166
+ """
167
+ batch_size, seq_len, num_tokens = output.shape
168
+
169
+ output = output.cuda()
170
+ ground_truth = ground_truth.cuda()
171
+
172
+ # Shift the output and ground truth to align them
173
+ output_shifted = output[:, :-1, :].contiguous()
174
+ ground_truth_shifted = ground_truth[:, 1:].contiguous()
175
+
176
+ # Flatten the shifted output and ground truth
177
+ output_flat = output_shifted.view(-1, num_tokens)
178
+ ground_truth_flat = ground_truth_shifted.view(-1)
179
+
180
+ # Apply log softmax to the flattened output
181
+ output_log_softmax = F.log_softmax(output_flat, dim=-1)
182
+
183
+ # Create a mask for all valid moves (excluding padding and start tokens)
184
+ valid_moves_mask = ((ground_truth_flat != pad_token_id) &
185
+ (ground_truth_flat != start_token_id))
186
+
187
+ # Calculate the negative log-likelihood loss for all moves
188
+ loss = F.nll_loss(output_log_softmax, ground_truth_flat, reduction='none')
189
+
190
+ # Apply the mask to exclude padding and start tokens
191
+ loss = loss * valid_moves_mask.float()
192
+
193
+ # Average the loss over all valid moves
194
+ valid_moves_count = valid_moves_mask.float().sum()
195
+ if valid_moves_count > 0:
196
+ loss = loss.sum() / valid_moves_count
197
+ else:
198
+ loss = loss.sum() # If no valid moves, return 0 loss
199
+
200
+ return loss
201
+
202
+ def weighted_chess_loss(output, ground_truth, win_labels, winning_weight=1.0, losing_weight=0.1, pad_token_id=2064, start_token_id=2065):
203
+ """
204
+ Compute a weighted loss for all moves, with higher weight for winning moves.
205
+ """
206
+ output = output.cuda()
207
+ ground_truth = ground_truth.cuda()
208
+ win_labels = win_labels.cuda()
209
+
210
+ batch_size, seq_len, num_tokens = output.shape
211
+
212
+ # Shift the ground truth to align with the output predictions
213
+ ground_truth_shifted = ground_truth[:, 1:].contiguous()
214
+ output_shifted = output[:, :-1, :].contiguous()
215
+
216
+ # Flatten the output and ground truth for easier masking
217
+ output_flat = output_shifted.view(-1, num_tokens)
218
+ ground_truth_flat = ground_truth_shifted.view(-1)
219
+
220
+ # Apply log softmax to the flattened output
221
+ output_log_softmax = F.log_softmax(output_flat, dim=-1)
222
+
223
+ # Repeat win_labels for each move in the sequence
224
+ win_labels_expanded = win_labels.unsqueeze(1).repeat(1, seq_len - 1).view(-1)
225
+
226
+ # Create masks for winning and losing moves
227
+ move_indices = torch.arange(seq_len - 1, device=output.device).unsqueeze(0).repeat(batch_size, 1).view(-1)
228
+ white_win_mask = (win_labels_expanded == 1) & (move_indices % 2 == 0)
229
+ black_win_mask = (win_labels_expanded == 0) & (move_indices % 2 == 1)
230
+ winning_moves_mask = white_win_mask | black_win_mask
231
+
232
+ # Create a mask for all valid moves (excluding padding and start tokens)
233
+ valid_moves_mask = (ground_truth_flat != pad_token_id) & (ground_truth_flat != start_token_id)
234
+
235
+ # Calculate the negative log-likelihood loss for all valid moves
236
+ loss = F.nll_loss(output_log_softmax, ground_truth_flat, reduction='none')
237
+
238
+ # Apply weights based on whether the move is winning or losing
239
+ weights = torch.where(winning_moves_mask & valid_moves_mask, winning_weight, losing_weight)
240
+
241
+ # Apply the weights and the valid moves mask to the loss
242
+ weighted_loss = loss * weights * valid_moves_mask.float()
243
+
244
+ # Average the loss over all valid moves
245
+ valid_moves_count = valid_moves_mask.float().sum()
246
+ if valid_moves_count > 0:
247
+ avg_loss = weighted_loss.sum() / valid_moves_count
248
+ else:
249
+ avg_loss = weighted_loss.sum() # If no valid moves, return 0 loss
250
+
251
+ return avg_loss
environment.yml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: chessbot
2
+ channels:
3
+ - defaults
4
+ dependencies:
5
+ - _libgcc_mutex=0.1=main
6
+ - _openmp_mutex=5.1=1_gnu
7
+ - bzip2=1.0.8=h5eee18b_6
8
+ - ca-certificates=2024.9.24=h06a4308_0
9
+ - expat=2.6.3=h6a678d5_0
10
+ - ld_impl_linux-64=2.40=h12ee557_0
11
+ - libffi=3.4.4=h6a678d5_1
12
+ - libgcc-ng=11.2.0=h1234567_1
13
+ - libgomp=11.2.0=h1234567_1
14
+ - libstdcxx-ng=11.2.0=h1234567_1
15
+ - libuuid=1.41.5=h5eee18b_0
16
+ - ncurses=6.4=h6a678d5_0
17
+ - openssl=3.0.15=h5eee18b_0
18
+ - pip=24.2=py312h06a4308_0
19
+ - python=3.12.7=h5148396_0
20
+ - readline=8.2=h5eee18b_0
21
+ - setuptools=75.1.0=py312h06a4308_0
22
+ - sqlite=3.45.3=h5eee18b_0
23
+ - tk=8.6.14=h39e8969_0
24
+ - wheel=0.44.0=py312h06a4308_0
25
+ - xz=5.4.6=h5eee18b_1
26
+ - zlib=1.2.13=h5eee18b_1
27
+ - pip:
28
+ - absl-py==2.1.0
29
+ - chess==1.11.0
30
+ - filelock==3.13.1
31
+ - fsspec==2024.2.0
32
+ - grpcio==1.66.2
33
+ - jinja2==3.1.3
34
+ - markdown==3.7
35
+ - markupsafe==2.1.5
36
+ - mpmath==1.3.0
37
+ - networkx==3.2.1
38
+ - numpy==2.1.2
39
+ - nvidia-cublas-cu12==12.4.2.65
40
+ - nvidia-cuda-cupti-cu12==12.4.99
41
+ - nvidia-cuda-nvrtc-cu12==12.4.99
42
+ - nvidia-cuda-runtime-cu12==12.4.99
43
+ - nvidia-cudnn-cu12==9.1.0.70
44
+ - nvidia-cufft-cu12==11.2.0.44
45
+ - nvidia-curand-cu12==10.3.5.119
46
+ - nvidia-cusolver-cu12==11.6.0.99
47
+ - nvidia-cusparse-cu12==12.3.0.142
48
+ - nvidia-nccl-cu12==2.20.5
49
+ - nvidia-nvjitlink-cu12==12.4.99
50
+ - nvidia-nvtx-cu12==12.4.99
51
+ - packaging==24.1
52
+ - pandas==2.2.3
53
+ - protobuf==5.28.2
54
+ - pyarrow==17.0.0
55
+ - python-dateutil==2.9.0.post0
56
+ - pytz==2024.2
57
+ - six==1.16.0
58
+ - sympy==1.12
59
+ - tensorboard==2.18.0
60
+ - tensorboard-data-server==0.7.2
61
+ - torch==2.4.1+cu124
62
+ - tqdm==4.66.5
63
+ - triton==3.0.0
64
+ - typing-extensions==4.9.0
65
+ - tzdata==2024.2
66
+ - werkzeug==3.0.4
play.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from chesstransformer import ChessTransformer
4
+ import tokenizer as tk
5
+
6
+ model = ChessTransformer()
7
+ model.load_state_dict(torch.load('./64L1024D_1e-3maxlr_470k_step_1ep_1480ELO.pth')["model_state_dict"])
8
+ model.eval().cuda()
9
+
10
+ # Initialize tokenizer
11
+ t = tk.Tokenizer()
12
+
13
+ def predict_move(model, game_sequence, tokenizer, device='cuda', top_k=5):
14
+ model.eval()
15
+ game_sequence = torch.tensor([tokenizer.tokenize_game(game_sequence)], dtype=torch.long).to(device)
16
+
17
+ with torch.no_grad():
18
+ output = model(game_sequence)
19
+ logits = output[0, -1, :] # Get logits for the last move
20
+ top_k_logits, top_k_indices = torch.topk(logits, top_k)
21
+
22
+ # Apply softmax to get probabilities
23
+ probs = F.softmax(top_k_logits, dim=-1)
24
+
25
+ # Sample from the probability distribution
26
+ sampled_index = torch.multinomial(probs, 1).item()
27
+ sampled_token = top_k_indices[sampled_index].item()
28
+
29
+ sampled_move = tokenizer.untokenize_game([sampled_token])[0]
30
+
31
+ # Get all top_k moves and their probabilities for display
32
+ top_k_moves = [tokenizer.untokenize_game([idx.item()])[0] for idx in top_k_indices]
33
+ top_k_probs = probs.cpu().numpy()
34
+
35
+ return sampled_move, top_k_moves, top_k_probs
36
+
37
+ def play_game():
38
+ input_game = []
39
+ print("Let's play chess! Enter your moves in UCI format (e.g., 'e2e4'). Type 'exit' to quit or 'undo' to undo the last move.")
40
+
41
+ while True:
42
+ user_move = input("Your move: ").strip()
43
+ if user_move.lower() == 'exit':
44
+ print("Game over. Thanks for playing!")
45
+ break
46
+ elif user_move.lower() == 'undo':
47
+ if len(input_game) >= 2:
48
+ input_game.pop() # Remove bot's move
49
+ input_game.pop() # Remove user's move
50
+ print("Last move undone. Current game sequence:", input_game)
51
+ else:
52
+ print("Cannot undo. No moves to undo.")
53
+ continue
54
+
55
+ input_game.append(user_move)
56
+ print("Current game sequence:", input_game)
57
+
58
+ try:
59
+ bot_move, top_moves, top_probs = predict_move(model, input_game, t)
60
+
61
+ # Display top moves and their probabilities
62
+ moves_probs_str = ', '.join(f"{move} ({prob:.2%})" for move, prob in zip(top_moves, top_probs))
63
+ print(f"Top {len(top_moves)} moves and probabilities: {moves_probs_str}")
64
+
65
+ print(f"Bot's sampled move: {bot_move}")
66
+ input_game.append(bot_move)
67
+ except Exception as e:
68
+ print("An error occurred:", e)
69
+ break
70
+
71
+ if __name__ == "__main__":
72
+ play_game()
tokenizer.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Tokenizer:
2
+ def __init__(self):
3
+ self.move_dict = create_move_dict()
4
+ self.inverse_dict = inverse_move_dict(self.move_dict)
5
+
6
+ def tokenize_game(self, moves_list):
7
+ tokenized_moves = []
8
+ for move in moves_list:
9
+ tokenized_moves.append(self.move_dict[move])
10
+ return tokenized_moves
11
+
12
+ def untokenize_game(self, tokenized_moves):
13
+ inverse_moves = []
14
+ for move in tokenized_moves:
15
+ if move == 2064:
16
+ inverse_moves.append("[pad]")
17
+ continue
18
+ if move == 2065:
19
+ inverse_moves.append("[start]")
20
+ continue
21
+ inverse_moves.append(self.inverse_dict[move])
22
+ return inverse_moves
23
+
24
+ def tokenize_move(self, move):
25
+ return self.move_dict[move]
26
+
27
+ def get_move(self, tokenized_move):
28
+ return self.inverse_dict[tokenized_move]
29
+
30
+
31
+ # Helper function to convert square index to algebraic notation
32
+ def square_to_algebraic(square):
33
+ files = 'abcdefgh'
34
+ ranks = '12345678'
35
+ file = files[square % 8]
36
+ rank = ranks[square // 8]
37
+ return file + rank
38
+
39
+ # Modified chess_moves function to account for all moves
40
+ def chess_moves(starting_square):
41
+ moves = []
42
+ ss = starting_square
43
+
44
+ # Calculate file and rank
45
+ file_start = (ss // 8) * 8
46
+ file_end = file_start + 7
47
+
48
+ # Horizontal moves - to left
49
+ for i in range(ss - 1, file_start - 1, -1):
50
+ moves.append((ss, i))
51
+
52
+ # Horizontal moves - to right
53
+ for i in range(ss + 1, file_end + 1):
54
+ moves.append((ss, i))
55
+
56
+ # Vertical moves - above
57
+ for i in range(ss + 8, 64, 8):
58
+ moves.append((ss, i))
59
+
60
+ # Vertical moves - below
61
+ for i in range(ss - 8, -1, -8):
62
+ moves.append((ss, i))
63
+
64
+ # Diagonal moves
65
+ # Upper left
66
+ i = ss
67
+ while (i := i + 7) < 64 and i % 8 != 7:
68
+ moves.append((ss, i))
69
+
70
+ # Lower left
71
+ i = ss
72
+ while (i := i - 9) >= 0 and i % 8 != 7:
73
+ moves.append((ss, i))
74
+
75
+ # Upper right
76
+ i = ss
77
+ while (i := i + 9) < 64 and i % 8 != 0:
78
+ moves.append((ss, i))
79
+
80
+ # Lower right
81
+ i = ss
82
+ while (i := i - 7) >= 0 and i % 8 != 0:
83
+ moves.append((ss, i))
84
+
85
+ # Inner 5x5 square
86
+ for j in range(-2, 3):
87
+ for i in range(-2, 3):
88
+ target = ss + i + j * 8
89
+ if 0 <= target < 64 and (target // 8 == (ss // 8) + j) and target != ss:
90
+ moves.append((ss, target))
91
+
92
+ # Pawn moves (including promotions)
93
+ if ss // 8 == 1: # White pawn's initial position
94
+ if ss + 8 < 64:
95
+ moves.append((ss, ss + 8))
96
+ if (ss + 16) < 64:
97
+ moves.append((ss, ss + 16))
98
+ if ss + 9 < 64 and (ss + 9) % 8 != 0:
99
+ moves.append((ss, ss + 9))
100
+ if ss + 7 < 64 and (ss + 7) % 8 != 7:
101
+ moves.append((ss, ss + 7))
102
+ elif ss // 8 == 6: # Black pawn's initial position
103
+ if ss - 8 >= 0:
104
+ moves.append((ss, ss - 8))
105
+ if (ss - 16) >= 0:
106
+ moves.append((ss, ss - 16))
107
+ if ss - 9 >= 0 and (ss - 9) % 8 != 7:
108
+ moves.append((ss, ss - 9))
109
+ if ss - 7 >= 0 and (ss - 7) % 8 != 0:
110
+ moves.append((ss, ss - 7))
111
+
112
+ #remove duplicate tuples
113
+ seen = set()
114
+ result = []
115
+ for item in moves:
116
+ if item not in seen:
117
+ seen.add(item)
118
+ result.append(item)
119
+
120
+ return result
121
+
122
+
123
+ # Function to create a dictionary of moves with promotion
124
+ def create_move_dict():
125
+ move_dict = {}
126
+ count = 0
127
+ promotion_pieces = ['q', 'r', 'b', 'n'] # Queen, Rook, Bishop, Knight
128
+
129
+ for i in range(64):
130
+ for move in chess_moves(i):
131
+ start_sq_algebraic = square_to_algebraic(move[0])
132
+ end_sq_algebraic = square_to_algebraic(move[1])
133
+ move_dict[f"{start_sq_algebraic}{end_sq_algebraic}"] = count
134
+ count += 1
135
+ # Add promotions if applicable
136
+ if move[1] // 8 == 7 and i // 8 == 6: # White pawn reaching last rank
137
+ for piece in promotion_pieces:
138
+ move_dict[f"{start_sq_algebraic}{end_sq_algebraic}{piece}"] = count
139
+ count += 1
140
+ elif move[1] // 8 == 0 and i // 8 == 1: # Black pawn reaching last rank
141
+ for piece in promotion_pieces:
142
+ move_dict[f"{start_sq_algebraic}{end_sq_algebraic}{piece}"] = count
143
+ count += 1
144
+
145
+ move_dict["pad"] = 2064
146
+ move_dict["start"] = 2065
147
+ return move_dict
148
+
149
+ def inverse_move_dict(move_dict):
150
+ inverse_dict = {}
151
+ for k, v in move_dict.items():
152
+ inverse_dict[v] = k
153
+ return inverse_dict
154
+
155
+ def tokenize_game(moves_list):
156
+ move_dict = create_move_dict()
157
+ tokenized_moves = []
158
+ for move in moves_list:
159
+ tokenized_moves.append(move_dict[move])
160
+ return tokenized_moves
161
+
162
+ if __name__ == "__main__":
163
+ t = Tokenizer()