File size: 4,896 Bytes
aa2269b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
class Tokenizer:
def __init__(self):
self.move_dict = create_move_dict()
self.inverse_dict = inverse_move_dict(self.move_dict)
def tokenize_game(self, moves_list):
tokenized_moves = []
for move in moves_list:
tokenized_moves.append(self.move_dict[move])
return tokenized_moves
def untokenize_game(self, tokenized_moves):
inverse_moves = []
for move in tokenized_moves:
if move == 2064:
inverse_moves.append("[pad]")
continue
if move == 2065:
inverse_moves.append("[start]")
continue
inverse_moves.append(self.inverse_dict[move])
return inverse_moves
def tokenize_move(self, move):
return self.move_dict[move]
def get_move(self, tokenized_move):
return self.inverse_dict[tokenized_move]
# Helper function to convert square index to algebraic notation
def square_to_algebraic(square):
files = 'abcdefgh'
ranks = '12345678'
file = files[square % 8]
rank = ranks[square // 8]
return file + rank
# Modified chess_moves function to account for all moves
def chess_moves(starting_square):
moves = []
ss = starting_square
# Calculate file and rank
file_start = (ss // 8) * 8
file_end = file_start + 7
# Horizontal moves - to left
for i in range(ss - 1, file_start - 1, -1):
moves.append((ss, i))
# Horizontal moves - to right
for i in range(ss + 1, file_end + 1):
moves.append((ss, i))
# Vertical moves - above
for i in range(ss + 8, 64, 8):
moves.append((ss, i))
# Vertical moves - below
for i in range(ss - 8, -1, -8):
moves.append((ss, i))
# Diagonal moves
# Upper left
i = ss
while (i := i + 7) < 64 and i % 8 != 7:
moves.append((ss, i))
# Lower left
i = ss
while (i := i - 9) >= 0 and i % 8 != 7:
moves.append((ss, i))
# Upper right
i = ss
while (i := i + 9) < 64 and i % 8 != 0:
moves.append((ss, i))
# Lower right
i = ss
while (i := i - 7) >= 0 and i % 8 != 0:
moves.append((ss, i))
# Inner 5x5 square
for j in range(-2, 3):
for i in range(-2, 3):
target = ss + i + j * 8
if 0 <= target < 64 and (target // 8 == (ss // 8) + j) and target != ss:
moves.append((ss, target))
# Pawn moves (including promotions)
if ss // 8 == 1: # White pawn's initial position
if ss + 8 < 64:
moves.append((ss, ss + 8))
if (ss + 16) < 64:
moves.append((ss, ss + 16))
if ss + 9 < 64 and (ss + 9) % 8 != 0:
moves.append((ss, ss + 9))
if ss + 7 < 64 and (ss + 7) % 8 != 7:
moves.append((ss, ss + 7))
elif ss // 8 == 6: # Black pawn's initial position
if ss - 8 >= 0:
moves.append((ss, ss - 8))
if (ss - 16) >= 0:
moves.append((ss, ss - 16))
if ss - 9 >= 0 and (ss - 9) % 8 != 7:
moves.append((ss, ss - 9))
if ss - 7 >= 0 and (ss - 7) % 8 != 0:
moves.append((ss, ss - 7))
#remove duplicate tuples
seen = set()
result = []
for item in moves:
if item not in seen:
seen.add(item)
result.append(item)
return result
# Function to create a dictionary of moves with promotion
def create_move_dict():
move_dict = {}
count = 0
promotion_pieces = ['q', 'r', 'b', 'n'] # Queen, Rook, Bishop, Knight
for i in range(64):
for move in chess_moves(i):
start_sq_algebraic = square_to_algebraic(move[0])
end_sq_algebraic = square_to_algebraic(move[1])
move_dict[f"{start_sq_algebraic}{end_sq_algebraic}"] = count
count += 1
# Add promotions if applicable
if move[1] // 8 == 7 and i // 8 == 6: # White pawn reaching last rank
for piece in promotion_pieces:
move_dict[f"{start_sq_algebraic}{end_sq_algebraic}{piece}"] = count
count += 1
elif move[1] // 8 == 0 and i // 8 == 1: # Black pawn reaching last rank
for piece in promotion_pieces:
move_dict[f"{start_sq_algebraic}{end_sq_algebraic}{piece}"] = count
count += 1
move_dict["pad"] = 2064
move_dict["start"] = 2065
return move_dict
def inverse_move_dict(move_dict):
inverse_dict = {}
for k, v in move_dict.items():
inverse_dict[v] = k
return inverse_dict
def tokenize_game(moves_list):
move_dict = create_move_dict()
tokenized_moves = []
for move in moves_list:
tokenized_moves.append(move_dict[move])
return tokenized_moves
if __name__ == "__main__":
t = Tokenizer() |