austindavis commited on
Commit
8e4af20
1 Parent(s): 0a4d537

Create agents/uci_tokenizers.py

Browse files
Files changed (1) hide show
  1. agents/uci_tokenizers.py +314 -0
agents/uci_tokenizers.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import chess
4
+ import tiktoken
5
+ import tokenizers
6
+ from tokenizers import models, pre_tokenizers, processors
7
+ from torch import Tensor as TT
8
+ from transformers import PreTrainedTokenizerFast
9
+ from transformers.tokenization_utils_fast import BatchEncoding
10
+
11
+
12
+ def getTiktokenizer() -> tiktoken.Encoding:
13
+ """
14
+ Defines a tiktoken-based BPE encoder for UCI chess moves. This
15
+ tokenizer effectively tokenizes UCI moves by the square names.
16
+ One notable variation is that promotions must be in upper-case.
17
+
18
+ Vocabulary:
19
+ Special Tokens (4): "\<|pad|\>", "\<|startoftext|\>", "\<|endoftext|\>", "\<|unknown|\>"
20
+ Square Tokens (64): a1 through h8
21
+ Promote Tokens (4): Q, B, R, N
22
+ UNUSED (8120): Need 8192-4-64-4=8120 unused tokens of the form <|unused####|>
23
+ """
24
+ special_tokens = ["<|pad|>", "<|startoftext|>", "<|endoftext|>", "<|unknown|>"]
25
+ unused_tokens = [f"<|unused{i:04d}" for i in range(8120)]
26
+ chess_vocab = special_tokens + chess.SQUARE_NAMES + list("QBRN") + unused_tokens
27
+ mergeable_ranks = {k.encode():v for (v,k) in enumerate(chess_vocab)}
28
+ chess_pat_str = r'[a-h][1-8]|[QBRN]'
29
+
30
+ enc = tiktoken.Encoding(
31
+ name="chess_enc",
32
+ pat_str=chess_pat_str, # or \d|\s
33
+ mergeable_ranks=mergeable_ranks,
34
+ special_tokens={k:v for (v,k) in enumerate(special_tokens)},
35
+ )
36
+
37
+ return enc
38
+
39
+
40
+ class UciTokenizer(PreTrainedTokenizerFast):
41
+ _PAD_TOKEN: str
42
+ _UNK_TOKEN: str
43
+ _EOS_TOKEN: str
44
+ _BOS_TOKEN: str
45
+
46
+
47
+ stoi: dict[str, int]
48
+ """Integer to String mapping"""
49
+
50
+ itos: dict[int, str]
51
+ """String to Integer Mapping. This is the vocab"""
52
+
53
+ def __init__(
54
+ self,
55
+ stoi,
56
+ itos,
57
+ pad_token,
58
+ unk_token,
59
+ bos_token,
60
+ eos_token,
61
+ name_or_path,
62
+ **kwargs
63
+ ):
64
+ self.stoi = stoi
65
+ self.itos = itos
66
+
67
+ self._PAD_TOKEN = pad_token
68
+ self._UNK_TOKEN = unk_token
69
+ self._EOS_TOKEN = eos_token
70
+ self._BOS_TOKEN = bos_token
71
+
72
+ # Define the model
73
+ tok_model = models.WordLevel(vocab=self.stoi, unk_token=self._UNK_TOKEN)
74
+
75
+ slow_tokenizer = tokenizers.Tokenizer(tok_model)
76
+ slow_tokenizer.pre_tokenizer = self._init_pretokenizer()
77
+
78
+ # post processing adds special tokens unless explicitly ignored
79
+ post_proc = processors.TemplateProcessing(
80
+ single=f"{bos_token} $0",
81
+ pair=None,
82
+ special_tokens=[(bos_token, 1)],
83
+ )
84
+ slow_tokenizer.post_processor=post_proc
85
+
86
+ super().__init__(
87
+ tokenizer_object=slow_tokenizer,
88
+ unk_token=self._UNK_TOKEN,
89
+ bos_token=self._BOS_TOKEN,
90
+ eos_token=self._EOS_TOKEN,
91
+ pad_token=self._PAD_TOKEN,
92
+ name_or_path=name_or_path,
93
+ **kwargs
94
+ )
95
+
96
+ # Override the decode behavior to ensure spaces are correctly handled
97
+ def _decode(
98
+ token_ids: int | List[int] | dict | TT,
99
+ skip_special_tokens=False,
100
+ clean_up_tokenization_spaces=False,
101
+ ) -> int | List[int]:
102
+
103
+ if isinstance(token_ids, int):
104
+ return self.itos.get(token_ids, self._UNK_TOKEN)
105
+
106
+ if isinstance(token_ids, dict):
107
+ token_ids = token_ids["input_ids"]
108
+
109
+ if isinstance(token_ids, TT):
110
+ token_ids = token_ids.tolist()
111
+
112
+ if isinstance(token_ids, list):
113
+ tokens_str = [self.itos.get(xi, self._UNK_TOKEN) for xi in token_ids]
114
+ processed_tokens = self._process_str_tokens(tokens_str)
115
+
116
+ return " ".join(processed_tokens)
117
+
118
+ raise ValueError(f"Unknown input type to decode() for argument 'token_ids'. Received: {type(token_ids)} ")
119
+
120
+
121
+ self._decode = _decode
122
+
123
+ def _init_pretokenizer(self) -> pre_tokenizers.PreTokenizer:
124
+ raise NotImplementedError
125
+
126
+ def _process_str_tokens(self, tokens_str: list[str], return_player_ids: bool) -> list[str]:
127
+ raise NotImplementedError
128
+
129
+ def get_id2square_list() -> list[int]:
130
+ raise NotImplementedError
131
+
132
+
133
+ class UciTileTokenizer(UciTokenizer):
134
+ """ Uci tokenizer converting start/end tiles and promotion types each into individual tokens"""
135
+
136
+ SPECIAL_TOKENS = ["<|pad|>", "<|startoftext|>", "<|endoftext|>", "<|unknown|>"]
137
+
138
+ stoi = {
139
+ tok: idx
140
+ for tok, idx in list(
141
+ zip(SPECIAL_TOKENS + chess.SQUARE_NAMES + list("QRBN"), range(72))
142
+ )
143
+ }
144
+
145
+ itos = {
146
+ idx: tok
147
+ for tok, idx in list(
148
+ zip(SPECIAL_TOKENS + chess.SQUARE_NAMES + list("QRBN"), range(72))
149
+ )
150
+ }
151
+
152
+ id2square:List[int] = list(range(4,68))
153
+ """
154
+ List mapping token IDs to squares on the chess board. Order is file then rank, i.e.:
155
+ `A1, B1, C1, ..., F8, G8, H8`
156
+ """
157
+
158
+ def get_id2square_list(self) -> List[int]:
159
+ return self.id2square
160
+
161
+ def __init__(self, **kwargs):
162
+ super().__init__(
163
+ self.stoi,
164
+ self.itos,
165
+ pad_token="<|pad|>",
166
+ unk_token="<|unknown|>",
167
+ bos_token="<|startoftext|>",
168
+ eos_token="<|endoftext|>",
169
+ name_or_path="austindavis/uci_tile_tokenizer",
170
+ clean_up_tokenization_spaces=False,
171
+ **kwargs
172
+ )
173
+
174
+ def _init_pretokenizer(self):
175
+ # Pre-tokenizer to split input into UCI moves
176
+ pattern = tokenizers.Regex(r"\d|[QBRN]")
177
+ pre_tokenizer = pre_tokenizers.Sequence(
178
+ [
179
+ pre_tokenizers.Whitespace(),
180
+ pre_tokenizers.Split(pattern=pattern, behavior="merged_with_previous"),
181
+ ]
182
+ )
183
+ return pre_tokenizer
184
+
185
+ def _process_str_tokens(self, token_str: list[str]):
186
+ moves = []
187
+ next_move = ""
188
+ for token in token_str:
189
+
190
+ # skip special tokens
191
+ if token in self.all_special_tokens:
192
+ continue
193
+
194
+ # handle promotions
195
+ if len(token) == 1:
196
+ next_move += token
197
+ continue
198
+
199
+ # handle regular tokens if there's room
200
+ if len(next_move) < 4:
201
+ next_move += token
202
+ continue
203
+
204
+ moves.append(next_move)
205
+ next_move = token
206
+
207
+ moves.append(next_move)
208
+ return moves
209
+
210
+ @staticmethod
211
+ def compute_players(encoding: BatchEncoding, according_to='output'):
212
+ """
213
+ Determines which player (white=True, black=False) is associated with each token in the sequence.
214
+ This method works based on chess move sequences tokenized using the UciTileTokenizer.
215
+
216
+ # Parameters:
217
+ ----------
218
+ **`encoding`** : BatchEncoding
219
+ Tokenized input of a chess game, where each token represents a move or special token.
220
+
221
+ **`according_to`** : str (optional, default='output')
222
+ Specifies the perspective for associating players:
223
+ - 'output': Returns the player whose next move is predicted by the sequence (the output move).
224
+ - Otherwise: Returns the player associated with the input tokens (i.e., which player made each move).
225
+
226
+ # Returns:
227
+ -------
228
+ List[bool]
229
+ A list of boolean values indicating the player for each token:
230
+ - True for white (player 1),
231
+ - False for black (player 2).
232
+
233
+ The list length corresponds to the number of tokens in the sequence, including special tokens if any.
234
+
235
+ # Example Usage:
236
+ ```
237
+ >>> tok = UciTileTokenizer()
238
+ >>> encoding = tok('e2e4 d7d5 e4d5 e7e6 d5e6 d8g5 e6e7 g5f6 e7f8Q')
239
+ >>> print(encoding['input_ids'])
240
+ [1, 16, 32, 55, 39, 32, 39, 56, 48, 39, 48, 63, 42, 48, 56, 42, 49, 56, 65, 68]
241
+ >>> tok.compute_players(encoding)
242
+ [True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, True, False]
243
+ >>> tok.compute_players(encoding, according_to='input')
244
+ [True, True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, True]
245
+ ```
246
+
247
+ # Notes:
248
+ -------
249
+ This method does not rely on board position calculations. Therefore, when
250
+ using `according_to='output'`, it cannot reliably predict which player is
251
+ responsible for selecting the final token of the sequence. For instance,
252
+ if a pawn is moved to the back rank (e.g., 'e7e8'), then white must select
253
+ the promotion class on the next token; however, this algorithm will predict
254
+ that black is responsible for selecting the next token instead of white.
255
+ """
256
+
257
+ return [UciTileTokenizer._compute_players_single(encoding[i].ids) for i in range(len(encoding['input_ids']))]
258
+
259
+
260
+
261
+ @staticmethod
262
+ def _compute_players_single(input_ids: list[int], according_to: str='output'):
263
+ players = [] if according_to == "output" else [True]
264
+ current_player = False
265
+ num_tokens_in_ply = 0
266
+ has_specials = False
267
+
268
+ for i, token_id in enumerate(input_ids):
269
+ if token_id == 1:
270
+ has_specials = True
271
+ continue
272
+
273
+ if num_tokens_in_ply == 0:
274
+ # check if promotion OR unknown token ID
275
+ if token_id > 67 or token_id == 3:
276
+ players.append(current_player)
277
+ num_tokens_in_ply = 0
278
+ else:
279
+ num_tokens_in_ply += 1
280
+ current_player = not current_player
281
+ players.append(current_player)
282
+ elif num_tokens_in_ply == 1:
283
+ num_tokens_in_ply = 0
284
+ players.append(current_player)
285
+ else:
286
+ raise ValueError("Illegal move sequence")
287
+
288
+ if according_to == "output":
289
+ # anticipate what output should be based on the final input token
290
+ # see notes for more detail
291
+ if num_tokens_in_ply == 0:
292
+ if token_id > 67:
293
+ players.append(not current_player)
294
+ else:
295
+ players.append(current_player)
296
+ else:
297
+ players.append(current_player)
298
+
299
+ return players if has_specials else players[1:]
300
+
301
+ if __name__ == "__main__":
302
+ tok = UciTileTokenizer()
303
+ encoding = tok('e2e4Q b7b8N e2e7 a1',add_special_tokens=True)
304
+ print(f"{encoding['input_ids']=}\n{tok.compute_players(encoding, according_to='output')=}")
305
+ print(f"{encoding['input_ids']=}\n{tok.compute_players(encoding, according_to='input')=}")
306
+
307
+ encoding = tok('e2e4Q b7b8N e2e7 a1',add_special_tokens=False)
308
+ print(f"{encoding['input_ids']=}\n{tok.compute_players(encoding, according_to='output')=}")
309
+ print(f"{encoding['input_ids']=}\n{tok.compute_players(encoding, according_to='input')=}")
310
+
311
+ encoding = tok('e2e4 d7d5 e4d5 e7e6 d5e6 d8g5 e6e7 g5f6 e7f8Q')
312
+ print(encoding['input_ids'])
313
+ print(tok.compute_players(encoding))
314
+ print(tok.compute_players(encoding, according_to='input'))