File size: 4,896 Bytes
aa2269b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
class Tokenizer:
    def __init__(self):
        self.move_dict = create_move_dict()
        self.inverse_dict = inverse_move_dict(self.move_dict)

    def tokenize_game(self, moves_list):
        tokenized_moves = []
        for move in moves_list:
            tokenized_moves.append(self.move_dict[move])
        return tokenized_moves
    
    def untokenize_game(self, tokenized_moves):
        inverse_moves = []
        for move in tokenized_moves:
            if move == 2064:
                inverse_moves.append("[pad]")
                continue
            if move == 2065:
                inverse_moves.append("[start]")
                continue
            inverse_moves.append(self.inverse_dict[move])
        return inverse_moves
    
    def tokenize_move(self, move):
        return self.move_dict[move]
    
    def get_move(self, tokenized_move):
        return self.inverse_dict[tokenized_move]


# Helper function to convert square index to algebraic notation
def square_to_algebraic(square):
    files = 'abcdefgh'
    ranks = '12345678'
    file = files[square % 8]
    rank = ranks[square // 8]
    return file + rank

# Modified chess_moves function to account for all moves
def chess_moves(starting_square):
    moves = []
    ss = starting_square

    # Calculate file and rank
    file_start = (ss // 8) * 8
    file_end = file_start + 7

    # Horizontal moves - to left
    for i in range(ss - 1, file_start - 1, -1):
        moves.append((ss, i))

    # Horizontal moves - to right
    for i in range(ss + 1, file_end + 1):
        moves.append((ss, i))

    # Vertical moves - above
    for i in range(ss + 8, 64, 8):
        moves.append((ss, i))

    # Vertical moves - below
    for i in range(ss - 8, -1, -8):
        moves.append((ss, i))

    # Diagonal moves
    # Upper left
    i = ss
    while (i := i + 7) < 64 and i % 8 != 7:
        moves.append((ss, i))

    # Lower left
    i = ss
    while (i := i - 9) >= 0 and i % 8 != 7:
        moves.append((ss, i))

    # Upper right
    i = ss
    while (i := i + 9) < 64 and i % 8 != 0:
        moves.append((ss, i))

    # Lower right
    i = ss
    while (i := i - 7) >= 0 and i % 8 != 0:
        moves.append((ss, i))

    # Inner 5x5 square
    for j in range(-2, 3):
        for i in range(-2, 3):
            target = ss + i + j * 8
            if 0 <= target < 64 and (target // 8 == (ss // 8) + j) and target != ss:
                moves.append((ss, target))

    # Pawn moves (including promotions)
    if ss // 8 == 1:  # White pawn's initial position
        if ss + 8 < 64:
            moves.append((ss, ss + 8))
            if (ss + 16) < 64:
                moves.append((ss, ss + 16))
        if ss + 9 < 64 and (ss + 9) % 8 != 0:
            moves.append((ss, ss + 9))
        if ss + 7 < 64 and (ss + 7) % 8 != 7:
            moves.append((ss, ss + 7))
    elif ss // 8 == 6:  # Black pawn's initial position
        if ss - 8 >= 0:
            moves.append((ss, ss - 8))
            if (ss - 16) >= 0:
                moves.append((ss, ss - 16))
        if ss - 9 >= 0 and (ss - 9) % 8 != 7:
            moves.append((ss, ss - 9))
        if ss - 7 >= 0 and (ss - 7) % 8 != 0:
            moves.append((ss, ss - 7))

    #remove duplicate tuples
    seen = set()
    result = []
    for item in moves:
        if item not in seen:
            seen.add(item)
            result.append(item)

    return result


# Function to create a dictionary of moves with promotion
def create_move_dict():
    move_dict = {}
    count = 0
    promotion_pieces = ['q', 'r', 'b', 'n']  # Queen, Rook, Bishop, Knight

    for i in range(64):
        for move in chess_moves(i):
            start_sq_algebraic = square_to_algebraic(move[0])
            end_sq_algebraic = square_to_algebraic(move[1])
            move_dict[f"{start_sq_algebraic}{end_sq_algebraic}"] = count
            count += 1
            # Add promotions if applicable
            if move[1] // 8 == 7 and i // 8 == 6:  # White pawn reaching last rank
                for piece in promotion_pieces:
                    move_dict[f"{start_sq_algebraic}{end_sq_algebraic}{piece}"] = count
                    count += 1
            elif move[1] // 8 == 0 and i // 8 == 1:  # Black pawn reaching last rank
                for piece in promotion_pieces:
                    move_dict[f"{start_sq_algebraic}{end_sq_algebraic}{piece}"] = count
                    count += 1

    move_dict["pad"] = 2064
    move_dict["start"] = 2065
    return move_dict

def inverse_move_dict(move_dict):
    inverse_dict = {}
    for k, v in move_dict.items():
        inverse_dict[v] = k
    return inverse_dict

def tokenize_game(moves_list):
    move_dict = create_move_dict()
    tokenized_moves = []
    for move in moves_list:
        tokenized_moves.append(move_dict[move])
    return tokenized_moves

if __name__ == "__main__":
    t = Tokenizer()