austindavis commited on
Commit
44c8b7b
·
verified ·
1 Parent(s): 6d52781

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +292 -113
app.py CHANGED
@@ -1,128 +1,307 @@
1
- !pip install transformers chess torch
 
 
2
 
3
- # %%
4
- from transformers import AutoModelForCausalLM, GPT2LMHeadModel
5
- import torch
6
- import chess
7
  import chess.pgn
8
  import chess.svg
9
  import gradio as gr
10
- from modeling.utils import uci_to_board
11
- from modeling.uci_tokenizers import UciTileTokenizer
12
  import numpy as np
13
- import io
14
- import traceback
15
- # %%
16
- checkpoint_name = "austindavis/chess-gpt2-uci-8x8x512"
17
- tokenizer = UciTileTokenizer()
18
- model: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(checkpoint_name)
19
- model.requires_grad_(False)
20
 
21
- # %%
22
- # Initialize the chess board
23
- board = chess.Board()
24
- game:chess.pgn.GameNode = chess.pgn.Game()
 
25
 
 
 
26
 
 
 
27
 
28
- game.headers["Event"] = "Example"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- generate_kwargs = {
31
- "max_new_tokens": 3,
32
- "num_return_sequences": 10,
33
- "temperature": 0.5,
34
- "output_scores": True,
35
- "output_logits": True,
36
- "return_dict_in_generate": True
37
- }
38
 
39
- def make_move(input:str, node=game, board = board):
40
- # check for reset
41
- if input.lower() == 'reset':
42
- board.reset()
43
- node.root().variations.clear()
44
- return chess.svg.board(board=board), "New game!"
45
-
46
- # check for pgn
47
- if input[0] == '[' or input[:3] == '1. ':
48
- pgn = io.StringIO(input)
49
- game = chess.pgn.read_game(pgn)
50
- board.reset()
51
- node.root().variations.clear()
52
-
53
- for move in game.mainline_moves():
54
- board.push(move)
55
- node.add_variation(move)
56
-
57
- return chess.svg.board(board=board,lastmove=move), ""#str(node.root()).split(']')[-1].strip()
58
-
59
-
60
- try:
61
- move = chess.Move.from_uci(input)
62
- if move in board.legal_moves:
63
- board.push(move)
64
-
65
- while node.next() is not None:
66
- node = node.next()
67
- node = node.add_variation(move)
68
-
69
- # get computer's move
70
-
71
- prefix = ' '.join([x.uci() for x in board.move_stack])
72
- encoding = tokenizer(text=prefix,
73
- return_tensors='pt',
74
- )['input_ids']
75
-
76
- output = model.generate(encoding, **generate_kwargs) # [b,p,v]
77
- new_tokens = tokenizer.batch_decode(output.sequences[:,-3:])
78
- unique_moves, unique_indices = np.unique([x[:4] if ' ' in x else x for x in new_tokens], return_index=True)
79
- unique_indices = torch.Tensor(list(unique_indices)).to(dtype=torch.int)
80
- logits = torch.stack(output.logits) # [token, batch, vocab]
81
- logits = logits[:,unique_indices] # [token, batch, vocab]
82
 
83
- # select moves based on mean logit value for tokens 1 and 2
84
- logit_priority_order = logits.max(dim=-1).values.T[:,:2].mean(-1).topk(len(unique_indices)).indices
85
- priority_ordered_moves = unique_moves[logit_priority_order]
 
 
86
 
87
- # if there's only 1 option, we have to pack it back into a list
88
- if isinstance(priority_ordered_moves, str):
89
- priority_ordered_moves = [priority_ordered_moves]
90
-
91
- # test if any moves are valid
92
- for uci in priority_ordered_moves:
93
- move = chess.Move.from_uci(uci)
94
- if move in board.legal_moves:
95
- board.push(move)
96
- while node.next() is not None:
97
- node = node.next()
98
- node = node.add_variation(move)
99
- return chess.svg.board(board=board,lastmove=move), ""#str(node.root()).split(']')[-1].strip()
100
 
101
- # no moves are valid
102
- bad_from_tiles = [chess.parse_square(x) for x in [x[:2] for x in unique_moves]]
103
- bad_to_tiles = [chess.parse_square(x) for x in [x[2:] for x in unique_moves]]
104
- arrows = [chess.svg.Arrow(tail, head, color="red") for (tail, head) in zip(bad_from_tiles, bad_to_tiles)]
105
- return chess.svg.board(board=board,arrows=arrows), '|'.join(unique_moves)
106
- else:
107
- return chess.svg.board(board=board,lastmove=move), f"Illegal move: {input}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
- except ValueError:
110
- return chess.svg.board(board=board), f"Invalid UCI format: {uci} {list(unique_moves)}"
111
- except Exception:
112
- return chess.svg.board(board=board), traceback.format_exc()
113
-
114
- # Define the Gradio interface
115
- iface = gr.Interface(
116
- fn=make_move,
117
- inputs="text",
118
- outputs=["html", "text"],
119
- examples=[['e2e4'], ['d2d4'], ['Reset']],
120
- title="Play Versus ChessGPT",
121
- description="Enter moves in UCI notation (e.g., e2e4 for pawn from e2 to e4). Enter 'reset' to restart the game.",
122
- allow_flagging='never',
123
- )
124
-
125
- # Launch the Gradio app
126
- iface.launch()
127
-
128
- # %%
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import traceback
3
+ from typing import List
4
 
 
 
 
 
5
  import chess.pgn
6
  import chess.svg
7
  import gradio as gr
 
 
8
  import numpy as np
9
+ import tokenizers
10
+ import torch
11
+ from tokenizers import models, pre_tokenizers, processors
12
+ from torch import Tensor as TT
13
+ from transformers import AutoModelForCausalLM, GPT2LMHeadModel, PreTrainedTokenizerFast
14
+
15
+ import chess
16
 
17
+ class UciTokenizer(PreTrainedTokenizerFast):
18
+ _PAD_TOKEN: str
19
+ _UNK_TOKEN: str
20
+ _EOS_TOKEN: str
21
+ _BOS_TOKEN: str
22
 
23
+ stoi: dict[str, int]
24
+ """Integer to String mapping"""
25
 
26
+ itos: dict[int, str]
27
+ """String to Integer Mapping. This is the vocab"""
28
 
29
+ def __init__(
30
+ self,
31
+ stoi,
32
+ itos,
33
+ pad_token,
34
+ unk_token,
35
+ bos_token,
36
+ eos_token,
37
+ name_or_path,
38
+ ):
39
+ self.stoi = stoi
40
+ self.itos = itos
41
+
42
+ self._PAD_TOKEN = pad_token
43
+ self._UNK_TOKEN = unk_token
44
+ self._EOS_TOKEN = eos_token
45
+ self._BOS_TOKEN = bos_token
46
 
47
+ # Define the model
48
+ tok_model = models.WordLevel(vocab=self.stoi, unk_token=self._UNK_TOKEN)
 
 
 
 
 
 
49
 
50
+ slow_tokenizer = tokenizers.Tokenizer(tok_model)
51
+ slow_tokenizer.pre_tokenizer = self._init_pretokenizer()
52
+
53
+ # post processing adds special tokens unless explicitly ignored
54
+ post_proc = processors.TemplateProcessing(
55
+ single=f"{bos_token} $0",
56
+ pair=None,
57
+ special_tokens=[(bos_token, 1)],
58
+ )
59
+ slow_tokenizer.post_processor=post_proc
60
+
61
+ super().__init__(
62
+ tokenizer_object=slow_tokenizer,
63
+ unk_token=self._UNK_TOKEN,
64
+ bos_token=self._BOS_TOKEN,
65
+ eos_token=self._EOS_TOKEN,
66
+ pad_token=self._PAD_TOKEN,
67
+ name_or_path=name_or_path,
68
+ )
69
+
70
+ # Override the decode behavior to ensure spaces are correctly handled
71
+ def _decode(
72
+ token_ids: int | List[int],
73
+ skip_special_tokens=False,
74
+ clean_up_tokenization_spaces=False,
75
+ ) -> int | List[int]:
76
+
77
+ if isinstance(token_ids, int):
78
+ return self.itos.get(token_ids, self._UNK_TOKEN)
79
+
80
+ if isinstance(token_ids, dict):
81
+ token_ids = token_ids["input_ids"]
82
+
83
+ if isinstance(token_ids, TT):
84
+ token_ids = token_ids.tolist()
 
 
 
 
 
 
 
 
85
 
86
+ if isinstance(token_ids, list):
87
+ tokens_str = [self.itos.get(xi, self._UNK_TOKEN) for xi in token_ids]
88
+ moves = self._process_str_tokens(tokens_str)
89
+
90
+ return " ".join(moves)
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+
94
+ self._decode = _decode
95
+
96
+ def _init_pretokenizer(self) -> pre_tokenizers.PreTokenizer:
97
+ raise NotImplementedError
98
+
99
+ def _process_str_tokens(self, tokens_str: list[str]) -> list[str]:
100
+ raise NotImplementedError
101
+
102
+ def get_id2square_list() -> list[int]:
103
+ raise NotImplementedError
104
+
105
+ class UciTileTokenizer(UciTokenizer):
106
+ """ Uci tokenizer converting start/end tiles and promotion types each into individual tokens"""
107
+ stoi = {
108
+ tok: idx
109
+ for tok, idx in list(
110
+ zip(["<pad>", "<s>", "</s>", "<unk>"] + chess.SQUARE_NAMES + list("qrbn"), range(72))
111
+ )
112
+ }
113
+
114
+ itos = {
115
+ idx: tok
116
+ for tok, idx in list(
117
+ zip(["<pad>", "<s>", "</s>", "<unk>"] + chess.SQUARE_NAMES + list("qrbn"), range(72))
118
+ )
119
+ }
120
+
121
+ id2square:List[int] = [None]*4 + list(range(64))+[None]*4
122
+ """
123
+ List mapping token IDs to squares on the chess board. Order is file then row, i.e.:
124
+ `A1, B1, C1, ..., F8, G8, H8`
125
+ """
126
+
127
+ def get_id2square_list(self) -> List[int]:
128
+ return self.id2square
129
+
130
+ def __init__(self):
131
+
132
+ super().__init__(
133
+ self.stoi,
134
+ self.itos,
135
+ pad_token="<pad>",
136
+ unk_token="<unk>",
137
+ bos_token="<s>",
138
+ eos_token="</s>",
139
+ name_or_path="austindavis/uci_tile_tokenizer",
140
+ )
141
+
142
+ def _init_pretokenizer(self):
143
+ # Pre-tokenizer to split input into UCI moves
144
+ pattern = tokenizers.Regex(r"\d")
145
+ pre_tokenizer = pre_tokenizers.Sequence(
146
+ [
147
+ pre_tokenizers.Whitespace(),
148
+ pre_tokenizers.Split(pattern=pattern, behavior="merged_with_previous"),
149
+ ]
150
+ )
151
+ return pre_tokenizer
152
+
153
+ def _process_str_tokens(self, token_str):
154
+ moves = []
155
+ next_move = ""
156
+ for token in token_str:
157
+
158
+ # skip special tokens
159
+ if token in self.all_special_tokens:
160
+ continue
161
+
162
+ # handle promotions
163
+ if len(token) == 1:
164
+ moves.append(next_move + token)
165
+ continue
166
+
167
+ # handle regular tokens
168
+ if len(next_move) == 4:
169
+ moves.append(next_move)
170
+ next_move = token
171
+ else:
172
+ next_move += token
173
+
174
+ moves.append(next_move)
175
+ return moves
176
 
177
+ def setup_app(model: GPT2LMHeadModel):
178
+ """
179
+ Configures a Gradio App to use the GPT model for move generation.
180
+ The model must be compatible with a UciTileTokenizer.
181
+ """
182
+ tokenizer = UciTileTokenizer()
183
+
184
+ # Initialize the chess board
185
+ board = chess.Board()
186
+ game:chess.pgn.GameNode = chess.pgn.Game()
187
+
188
+
189
+
190
+ game.headers["Event"] = "Example"
191
+
192
+ generate_kwargs = {
193
+ "max_new_tokens": 3,
194
+ "num_return_sequences": 10,
195
+ "temperature": 0.5,
196
+ "output_scores": True,
197
+ "output_logits": True,
198
+ "return_dict_in_generate": True
199
+ }
200
+
201
+ def make_move(input:str, node=game, board = board):
202
+ # check for reset
203
+ if input.lower() == 'reset':
204
+ board.reset()
205
+ node.root().variations.clear()
206
+ return chess.svg.board(board=board), "New game!"
207
+
208
+ # check for pgn
209
+ if input[0] == '[' or input[:3] == '1. ':
210
+ pgn = io.StringIO(input)
211
+ game = chess.pgn.read_game(pgn)
212
+ board.reset()
213
+ node.root().variations.clear()
214
+
215
+ for move in game.mainline_moves():
216
+ board.push(move)
217
+ node.add_variation(move)
218
+
219
+ return chess.svg.board(board=board,lastmove=move), ""#str(node.root()).split(']')[-1].strip()
220
+
221
+
222
+ try:
223
+ move = chess.Move.from_uci(input)
224
+ if move in board.legal_moves:
225
+ board.push(move)
226
+
227
+ while node.next() is not None:
228
+ node = node.next()
229
+ node = node.add_variation(move)
230
+
231
+ # get computer's move
232
+
233
+ prefix = ' '.join([x.uci() for x in board.move_stack])
234
+ encoding = tokenizer(text=prefix,
235
+ return_tensors='pt',
236
+ )['input_ids']
237
+
238
+ output = model.generate(encoding, **generate_kwargs) # [b,p,v]
239
+ new_tokens = tokenizer.batch_decode(output.sequences[:,-3:])
240
+ unique_moves, unique_indices = np.unique([x[:4] if ' ' in x else x for x in new_tokens], return_index=True)
241
+ unique_indices = torch.Tensor(list(unique_indices)).to(dtype=torch.int)
242
+ logits = torch.stack(output.logits) # [token, batch, vocab]
243
+ logits = logits[:,unique_indices] # [token, batch, vocab]
244
+
245
+ # select moves based on mean logit value for tokens 1 and 2
246
+ logit_priority_order = logits.max(dim=-1).values.T[:,:2].mean(-1).topk(len(unique_indices)).indices
247
+ priority_ordered_moves = unique_moves[logit_priority_order]
248
+
249
+ # if there's only 1 option, we have to pack it back into a list
250
+ if isinstance(priority_ordered_moves, str):
251
+ priority_ordered_moves = [priority_ordered_moves]
252
+
253
+ # test if any moves are valid
254
+ for uci in priority_ordered_moves:
255
+ move = chess.Move.from_uci(uci)
256
+ if move in board.legal_moves:
257
+ board.push(move)
258
+ while node.next() is not None:
259
+ node = node.next()
260
+ node = node.add_variation(move)
261
+ return chess.svg.board(board=board,lastmove=move), "".join(str(node.root()).split("]")[-1]).strip()
262
+
263
+ # no moves are valid
264
+ bad_from_tiles = [chess.parse_square(x) for x in [x[:2] for x in unique_moves]]
265
+ bad_to_tiles = [chess.parse_square(x) for x in [x[2:] for x in unique_moves]]
266
+ arrows = [chess.svg.Arrow(tail, head, color="red") for (tail, head) in zip(bad_from_tiles, bad_to_tiles)]
267
+ checks = None
268
+ if board.is_check():
269
+ checks = board.pieces(chess.PIECE_TYPES[-1],board.turn).pop()
270
+
271
+ return chess.svg.board(board=board,arrows=arrows, check=checks), '|'.join(unique_moves)
272
+ else:
273
+ return chess.svg.board(board=board,lastmove=move), f"Illegal move: {input}"
274
+
275
+ except chess.InvalidMoveError:
276
+ return chess.svg.board(board=board), f"Invalid UCI format: {input}"
277
+ except Exception:
278
+ return chess.svg.board(board=board), traceback.format_exc()
279
+
280
+ input_box = gr.Textbox(None,placeholder="Enter your move in UCI format")
281
+
282
+ # Define the Gradio interface
283
+ iface = gr.Interface(
284
+ fn=make_move,
285
+ inputs=input_box,
286
+ outputs=["html", "text"],
287
+ examples=[['e2e4'], ['d2d4'], ['Reset']],
288
+ title="Play Versus ChessGPT",
289
+ description="Enter moves in UCI notation (e.g., e2e4 for pawn from e2 to e4). Enter 'reset' to restart the game.",
290
+ allow_flagging='never',
291
+ submit_btn = "Move",
292
+ stop_btn = "Stop",
293
+ clear_btn = "Clear w/o reset",
294
+ )
295
+
296
+ iface.output_components[0].label = "Board"
297
+ iface.output_components[0].show_label = True
298
+ iface.output_components[1].label = "Move Sequence"
299
+
300
+ return iface
301
+
302
+ checkpoint_name = "austindavis/gpt2-lichess-uci-202306"
303
+ model: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(checkpoint_name)
304
+ model.requires_grad_(False)
305
+
306
+ iface = setup_app(model)
307
+ iface.launch()