austindavis commited on
Commit
3e3ba1d
·
verified ·
1 Parent(s): a0de53d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -74
app.py CHANGED
@@ -2,6 +2,7 @@ import io
2
  import traceback
3
  from typing import List
4
 
 
5
  import chess.pgn
6
  import chess.svg
7
  import gradio as gr
@@ -10,12 +11,12 @@ import tokenizers
10
  import torch
11
  from tokenizers import models, pre_tokenizers, processors
12
  from torch import Tensor as TT
13
- from transformers import AutoModelForCausalLM, GPT2LMHeadModel, PreTrainedTokenizerFast
14
-
15
- import chess
16
 
17
  checkpoint_name = "austindavis/chess-gpt2-uci-8x8x512"
18
 
 
19
  class UciTokenizer(PreTrainedTokenizerFast):
20
  _PAD_TOKEN: str
21
  _UNK_TOKEN: str
@@ -40,14 +41,15 @@ class UciTokenizer(PreTrainedTokenizerFast):
40
  ):
41
  self.stoi = stoi
42
  self.itos = itos
43
-
44
  self._PAD_TOKEN = pad_token
45
  self._UNK_TOKEN = unk_token
46
  self._EOS_TOKEN = eos_token
47
  self._BOS_TOKEN = bos_token
48
 
49
  # Define the model
50
- tok_model = models.WordLevel(vocab=self.stoi, unk_token=self._UNK_TOKEN)
 
51
 
52
  slow_tokenizer = tokenizers.Tokenizer(tok_model)
53
  slow_tokenizer.pre_tokenizer = self._init_pretokenizer()
@@ -58,8 +60,8 @@ class UciTokenizer(PreTrainedTokenizerFast):
58
  pair=None,
59
  special_tokens=[(bos_token, 1)],
60
  )
61
- slow_tokenizer.post_processor=post_proc
62
-
63
  super().__init__(
64
  tokenizer_object=slow_tokenizer,
65
  unk_token=self._UNK_TOKEN,
@@ -84,14 +86,13 @@ class UciTokenizer(PreTrainedTokenizerFast):
84
 
85
  if isinstance(token_ids, TT):
86
  token_ids = token_ids.tolist()
87
-
88
  if isinstance(token_ids, list):
89
- tokens_str = [self.itos.get(xi, self._UNK_TOKEN) for xi in token_ids]
 
90
  moves = self._process_str_tokens(tokens_str)
91
 
92
  return " ".join(moves)
93
-
94
-
95
 
96
  self._decode = _decode
97
 
@@ -100,32 +101,45 @@ class UciTokenizer(PreTrainedTokenizerFast):
100
 
101
  def _process_str_tokens(self, tokens_str: list[str]) -> list[str]:
102
  raise NotImplementedError
103
-
104
  def get_id2square_list() -> list[int]:
105
  raise NotImplementedError
106
 
 
107
  class UciTileTokenizer(UciTokenizer):
108
- """ Uci tokenizer converting start/end tiles and promotion types each into individual tokens"""
 
 
109
  stoi = {
110
  tok: idx
111
  for tok, idx in list(
112
- zip(["<pad>", "<s>", "</s>", "<unk>"] + chess.SQUARE_NAMES + list("qrbn"), range(72))
 
 
 
 
 
113
  )
114
  }
115
-
116
  itos = {
117
  idx: tok
118
  for tok, idx in list(
119
- zip(["<pad>", "<s>", "</s>", "<unk>"] + chess.SQUARE_NAMES + list("qrbn"), range(72))
 
 
 
 
120
  )
121
  }
122
 
123
- id2square:List[int] = [None]*4 + list(range(64))+[None]*4
124
  """
125
- List mapping token IDs to squares on the chess board. Order is file then row, i.e.:
126
- `A1, B1, C1, ..., F8, G8, H8`
 
127
  """
128
-
129
  def get_id2square_list(self) -> List[int]:
130
  return self.id2square
131
 
@@ -147,7 +161,8 @@ class UciTileTokenizer(UciTokenizer):
147
  pre_tokenizer = pre_tokenizers.Sequence(
148
  [
149
  pre_tokenizers.Whitespace(),
150
- pre_tokenizers.Split(pattern=pattern, behavior="merged_with_previous"),
 
151
  ]
152
  )
153
  return pre_tokenizer
@@ -175,40 +190,39 @@ class UciTileTokenizer(UciTokenizer):
175
 
176
  moves.append(next_move)
177
  return moves
178
-
 
179
  def setup_app(model: GPT2LMHeadModel):
180
  """
181
- Configures a Gradio App to use the GPT model for move generation.
182
  The model must be compatible with a UciTileTokenizer.
183
  """
184
  tokenizer = UciTileTokenizer()
185
 
186
  # Initialize the chess board
187
  board = chess.Board()
188
- game:chess.pgn.GameNode = chess.pgn.Game()
189
-
190
-
191
 
192
  game.headers["Event"] = "Example"
193
 
194
  generate_kwargs = {
195
- "max_new_tokens": 3,
196
- "num_return_sequences": 10,
197
- "temperature": 0.5,
198
- "output_scores": True,
199
- "output_logits": True,
200
- "return_dict_in_generate": True
201
- }
202
-
203
- def make_move(input:str, node=game, board = board):
204
  # check for reset
205
- if input.lower() == 'reset':
206
  board.reset()
207
  node.root().variations.clear()
208
  return chess.svg.board(board=board), "New game!"
209
-
210
  # check for pgn
211
- if input[0] == '[' or input[:3] == '1. ':
212
  pgn = io.StringIO(input)
213
  game = chess.pgn.read_game(pgn)
214
  board.reset()
@@ -218,8 +232,10 @@ def setup_app(model: GPT2LMHeadModel):
218
  board.push(move)
219
  node.add_variation(move)
220
 
221
- return chess.svg.board(board=board,lastmove=move), ""#str(node.root()).split(']')[-1].strip()
222
-
 
 
223
 
224
  try:
225
  move = chess.Move.from_uci(input)
@@ -232,22 +248,35 @@ def setup_app(model: GPT2LMHeadModel):
232
 
233
  # get computer's move
234
 
235
- prefix = ' '.join([x.uci() for x in board.move_stack])
236
- encoding = tokenizer(text=prefix,
237
- return_tensors='pt',
238
- )['input_ids']
239
-
240
- output = model.generate(encoding, **generate_kwargs) # [b,p,v]
241
- new_tokens = tokenizer.batch_decode(output.sequences[:,-3:])
242
- unique_moves, unique_indices = np.unique([x[:4] if ' ' in x else x for x in new_tokens], return_index=True)
243
- unique_indices = torch.Tensor(list(unique_indices)).to(dtype=torch.int)
244
- logits = torch.stack(output.logits) # [token, batch, vocab]
245
- logits = logits[:,unique_indices] # [token, batch, vocab]
246
-
 
 
 
 
 
 
 
247
  # select moves based on mean logit value for tokens 1 and 2
248
- logit_priority_order = logits.max(dim=-1).values.T[:,:2].mean(-1).topk(len(unique_indices)).indices
 
 
 
 
 
 
249
  priority_ordered_moves = unique_moves[logit_priority_order]
250
-
251
  # if there's only 1 option, we have to pack it back into a list
252
  if isinstance(priority_ordered_moves, str):
253
  priority_ordered_moves = [priority_ordered_moves]
@@ -260,40 +289,61 @@ def setup_app(model: GPT2LMHeadModel):
260
  while node.next() is not None:
261
  node = node.next()
262
  node = node.add_variation(move)
263
- return chess.svg.board(board=board,lastmove=move), "".join(str(node.root()).split("]")[-1]).strip()
264
-
 
 
 
265
  # no moves are valid
266
- bad_from_tiles = [chess.parse_square(x) for x in [x[:2] for x in unique_moves]]
267
- bad_to_tiles = [chess.parse_square(x) for x in [x[2:] for x in unique_moves]]
268
- arrows = [chess.svg.Arrow(tail, head, color="red") for (tail, head) in zip(bad_from_tiles, bad_to_tiles)]
 
 
 
 
 
 
 
 
 
269
  checks = None
270
  if board.is_check():
271
- checks = board.pieces(chess.PIECE_TYPES[-1],board.turn).pop()
272
-
273
- return chess.svg.board(board=board,arrows=arrows, check=checks), '|'.join(unique_moves)
 
 
 
 
 
274
  else:
275
- return chess.svg.board(board=board,lastmove=move), f"Illegal move: {input}"
276
-
 
 
 
277
  except chess.InvalidMoveError:
278
- return chess.svg.board(board=board), f"Invalid UCI format: {input}"
 
279
  except Exception:
280
  return chess.svg.board(board=board), traceback.format_exc()
281
 
282
- input_box = gr.Textbox(None,placeholder="Enter your move in UCI format")
283
 
284
  # Define the Gradio interface
285
  iface = gr.Interface(
286
  fn=make_move,
287
  inputs=input_box,
288
  outputs=["html", "text"],
289
- examples=[['e2e4'], ['d2d4'], ['Reset']],
290
  title="Play Versus ChessGPT",
291
- description="Enter moves in UCI notation (e.g., e2e4 for pawn from e2 to e4). Enter 'reset' to restart the game.",
292
- allow_flagging='never',
293
- submit_btn = "Move",
294
- stop_btn = "Stop",
295
- clear_btn = "Clear w/o reset",
296
- share=True
297
  )
298
 
299
  iface.output_components[0].label = "Board"
@@ -302,8 +352,9 @@ def setup_app(model: GPT2LMHeadModel):
302
 
303
  return iface
304
 
 
305
  model: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(checkpoint_name)
306
  model.requires_grad_(False)
307
 
308
  iface = setup_app(model)
309
- iface.launch()
 
2
  import traceback
3
  from typing import List
4
 
5
+ import chess
6
  import chess.pgn
7
  import chess.svg
8
  import gradio as gr
 
11
  import torch
12
  from tokenizers import models, pre_tokenizers, processors
13
  from torch import Tensor as TT
14
+ from transformers import (AutoModelForCausalLM, GPT2LMHeadModel,
15
+ PreTrainedTokenizerFast)
 
16
 
17
  checkpoint_name = "austindavis/chess-gpt2-uci-8x8x512"
18
 
19
+
20
  class UciTokenizer(PreTrainedTokenizerFast):
21
  _PAD_TOKEN: str
22
  _UNK_TOKEN: str
 
41
  ):
42
  self.stoi = stoi
43
  self.itos = itos
44
+
45
  self._PAD_TOKEN = pad_token
46
  self._UNK_TOKEN = unk_token
47
  self._EOS_TOKEN = eos_token
48
  self._BOS_TOKEN = bos_token
49
 
50
  # Define the model
51
+ tok_model = models.WordLevel(vocab=self.stoi,
52
+ unk_token=self._UNK_TOKEN)
53
 
54
  slow_tokenizer = tokenizers.Tokenizer(tok_model)
55
  slow_tokenizer.pre_tokenizer = self._init_pretokenizer()
 
60
  pair=None,
61
  special_tokens=[(bos_token, 1)],
62
  )
63
+ slow_tokenizer.post_processor = post_proc
64
+
65
  super().__init__(
66
  tokenizer_object=slow_tokenizer,
67
  unk_token=self._UNK_TOKEN,
 
86
 
87
  if isinstance(token_ids, TT):
88
  token_ids = token_ids.tolist()
89
+
90
  if isinstance(token_ids, list):
91
+ tokens_str = [self.itos.get(xi, self._UNK_TOKEN)
92
+ for xi in token_ids]
93
  moves = self._process_str_tokens(tokens_str)
94
 
95
  return " ".join(moves)
 
 
96
 
97
  self._decode = _decode
98
 
 
101
 
102
  def _process_str_tokens(self, tokens_str: list[str]) -> list[str]:
103
  raise NotImplementedError
104
+
105
  def get_id2square_list() -> list[int]:
106
  raise NotImplementedError
107
 
108
+
109
  class UciTileTokenizer(UciTokenizer):
110
+ """Uci tokenizer converting start/end tiles and promotion types each
111
+ into individual tokens"""
112
+
113
  stoi = {
114
  tok: idx
115
  for tok, idx in list(
116
+ zip(
117
+ ["<pad>", "<s>", "</s>", "<unk>"] +
118
+ chess.SQUARE_NAMES +
119
+ list("qrbn"),
120
+ range(72),
121
+ )
122
  )
123
  }
124
+
125
  itos = {
126
  idx: tok
127
  for tok, idx in list(
128
+ zip(
129
+ ["<pad>", "<s>", "</s>", "<unk>"] +
130
+ chess.SQUARE_NAMES + list("qrbn"),
131
+ range(72),
132
+ )
133
  )
134
  }
135
 
136
+ id2square: List[int] = [None] * 4 + list(range(64)) + [None] * 4
137
  """
138
+ List mapping token IDs to squares on the chess board.
139
+ Order is file then row, i.e.:
140
+ `A1, B1, C1, ..., F8, G8, H8`
141
  """
142
+
143
  def get_id2square_list(self) -> List[int]:
144
  return self.id2square
145
 
 
161
  pre_tokenizer = pre_tokenizers.Sequence(
162
  [
163
  pre_tokenizers.Whitespace(),
164
+ pre_tokenizers.Split(pattern=pattern,
165
+ behavior="merged_with_previous"),
166
  ]
167
  )
168
  return pre_tokenizer
 
190
 
191
  moves.append(next_move)
192
  return moves
193
+
194
+
195
  def setup_app(model: GPT2LMHeadModel):
196
  """
197
+ Configures a Gradio App to use the GPT model for move generation.
198
  The model must be compatible with a UciTileTokenizer.
199
  """
200
  tokenizer = UciTileTokenizer()
201
 
202
  # Initialize the chess board
203
  board = chess.Board()
204
+ game: chess.pgn.GameNode = chess.pgn.Game()
 
 
205
 
206
  game.headers["Event"] = "Example"
207
 
208
  generate_kwargs = {
209
+ "max_new_tokens": 3,
210
+ "num_return_sequences": 10,
211
+ "temperature": 0.5,
212
+ "output_scores": True,
213
+ "output_logits": True,
214
+ "return_dict_in_generate": True,
215
+ }
216
+
217
+ def make_move(input: str, node=game, board=board):
218
  # check for reset
219
+ if input.lower() == "reset":
220
  board.reset()
221
  node.root().variations.clear()
222
  return chess.svg.board(board=board), "New game!"
223
+
224
  # check for pgn
225
+ if input[0] == "[" or input[:3] == "1. ":
226
  pgn = io.StringIO(input)
227
  game = chess.pgn.read_game(pgn)
228
  board.reset()
 
232
  board.push(move)
233
  node.add_variation(move)
234
 
235
+ return (
236
+ chess.svg.board(board=board, lastmove=move),
237
+ "",
238
+ ) # str(node.root()).split(']')[-1].strip()
239
 
240
  try:
241
  move = chess.Move.from_uci(input)
 
248
 
249
  # get computer's move
250
 
251
+ prefix = " ".join([x.uci() for x in board.move_stack])
252
+ encoding = tokenizer(
253
+ text=prefix,
254
+ return_tensors="pt",
255
+ )["input_ids"]
256
+
257
+ output = model.generate(encoding, **generate_kwargs) # [b,p,v]
258
+ new_tokens = tokenizer.batch_decode(output.sequences[:, -3:])
259
+ unique_moves, unique_indices = np.unique(
260
+ [x[:4] if " " in x else x for x in new_tokens],
261
+ return_index=True
262
+ )
263
+ unique_indices = (
264
+ torch.Tensor(list(unique_indices))
265
+ .to(dtype=torch.int)
266
+ )
267
+ logits = torch.stack(output.logits) # [token, batch, vocab]
268
+ logits = logits[:, unique_indices] # [token, batch, vocab]
269
+
270
  # select moves based on mean logit value for tokens 1 and 2
271
+ logit_priority_order = (
272
+ logits.max(dim=-1)
273
+ .values.T[:, :2]
274
+ .mean(-1)
275
+ .topk(len(unique_indices))
276
+ .indices
277
+ )
278
  priority_ordered_moves = unique_moves[logit_priority_order]
279
+
280
  # if there's only 1 option, we have to pack it back into a list
281
  if isinstance(priority_ordered_moves, str):
282
  priority_ordered_moves = [priority_ordered_moves]
 
289
  while node.next() is not None:
290
  node = node.next()
291
  node = node.add_variation(move)
292
+ return (
293
+ chess.svg.board(board=board, lastmove=move),
294
+ "".join(str(node.root()).split("]")[-1]).strip(),
295
+ )
296
+
297
  # no moves are valid
298
+ bad_from_tiles = [
299
+ chess.parse_square(x) for x in [x[:2]
300
+ for x in unique_moves]
301
+ ]
302
+ bad_to_tiles = [
303
+ chess.parse_square(x) for x in [x[2:]
304
+ for x in unique_moves]
305
+ ]
306
+ arrows = [
307
+ chess.svg.Arrow(tail, head, color="red")
308
+ for (tail, head) in zip(bad_from_tiles, bad_to_tiles)
309
+ ]
310
  checks = None
311
  if board.is_check():
312
+ checks = (board
313
+ .pieces(chess.PIECE_TYPES[-1], board.turn)
314
+ .pop()
315
+ )
316
+
317
+ return chess.svg.board(
318
+ board=board, arrows=arrows, check=checks
319
+ ), "|".join(unique_moves)
320
  else:
321
+ return (
322
+ chess.svg.board(board=board, lastmove=move),
323
+ f"Illegal move: {input}",
324
+ )
325
+
326
  except chess.InvalidMoveError:
327
+ return (chess.svg.board(board=board),
328
+ f"Invalid UCI format: {input}")
329
  except Exception:
330
  return chess.svg.board(board=board), traceback.format_exc()
331
 
332
+ input_box = gr.Textbox(None, placeholder="Enter your move in UCI format")
333
 
334
  # Define the Gradio interface
335
  iface = gr.Interface(
336
  fn=make_move,
337
  inputs=input_box,
338
  outputs=["html", "text"],
339
+ examples=[["e2e4"], ["d2d4"], ["Reset"]],
340
  title="Play Versus ChessGPT",
341
+ description="Enter moves in UCI notation (e.g., e2e4 for pawn from e2 \
342
+ to e4). Enter 'reset' to restart the game.",
343
+ allow_flagging="never",
344
+ submit_btn="Move",
345
+ stop_btn="Stop",
346
+ clear_btn="Clear w/o reset",
347
  )
348
 
349
  iface.output_components[0].label = "Board"
 
352
 
353
  return iface
354
 
355
+
356
  model: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(checkpoint_name)
357
  model.requires_grad_(False)
358
 
359
  iface = setup_app(model)
360
+ iface.launch(share=True)