austindavis
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,128 +1,307 @@
|
|
1 |
-
|
|
|
|
|
2 |
|
3 |
-
# %%
|
4 |
-
from transformers import AutoModelForCausalLM, GPT2LMHeadModel
|
5 |
-
import torch
|
6 |
-
import chess
|
7 |
import chess.pgn
|
8 |
import chess.svg
|
9 |
import gradio as gr
|
10 |
-
from modeling.utils import uci_to_board
|
11 |
-
from modeling.uci_tokenizers import UciTileTokenizer
|
12 |
import numpy as np
|
13 |
-
import
|
14 |
-
import
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
25 |
|
|
|
|
|
26 |
|
|
|
|
|
27 |
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
"num_return_sequences": 10,
|
33 |
-
"temperature": 0.5,
|
34 |
-
"output_scores": True,
|
35 |
-
"output_logits": True,
|
36 |
-
"return_dict_in_generate": True
|
37 |
-
}
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
)['input_ids']
|
75 |
-
|
76 |
-
output = model.generate(encoding, **generate_kwargs) # [b,p,v]
|
77 |
-
new_tokens = tokenizer.batch_decode(output.sequences[:,-3:])
|
78 |
-
unique_moves, unique_indices = np.unique([x[:4] if ' ' in x else x for x in new_tokens], return_index=True)
|
79 |
-
unique_indices = torch.Tensor(list(unique_indices)).to(dtype=torch.int)
|
80 |
-
logits = torch.stack(output.logits) # [token, batch, vocab]
|
81 |
-
logits = logits[:,unique_indices] # [token, batch, vocab]
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
86 |
|
87 |
-
# if there's only 1 option, we have to pack it back into a list
|
88 |
-
if isinstance(priority_ordered_moves, str):
|
89 |
-
priority_ordered_moves = [priority_ordered_moves]
|
90 |
-
|
91 |
-
# test if any moves are valid
|
92 |
-
for uci in priority_ordered_moves:
|
93 |
-
move = chess.Move.from_uci(uci)
|
94 |
-
if move in board.legal_moves:
|
95 |
-
board.push(move)
|
96 |
-
while node.next() is not None:
|
97 |
-
node = node.next()
|
98 |
-
node = node.add_variation(move)
|
99 |
-
return chess.svg.board(board=board,lastmove=move), ""#str(node.root()).split(']')[-1].strip()
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import traceback
|
3 |
+
from typing import List
|
4 |
|
|
|
|
|
|
|
|
|
5 |
import chess.pgn
|
6 |
import chess.svg
|
7 |
import gradio as gr
|
|
|
|
|
8 |
import numpy as np
|
9 |
+
import tokenizers
|
10 |
+
import torch
|
11 |
+
from tokenizers import models, pre_tokenizers, processors
|
12 |
+
from torch import Tensor as TT
|
13 |
+
from transformers import AutoModelForCausalLM, GPT2LMHeadModel, PreTrainedTokenizerFast
|
14 |
+
|
15 |
+
import chess
|
16 |
|
17 |
+
class UciTokenizer(PreTrainedTokenizerFast):
|
18 |
+
_PAD_TOKEN: str
|
19 |
+
_UNK_TOKEN: str
|
20 |
+
_EOS_TOKEN: str
|
21 |
+
_BOS_TOKEN: str
|
22 |
|
23 |
+
stoi: dict[str, int]
|
24 |
+
"""Integer to String mapping"""
|
25 |
|
26 |
+
itos: dict[int, str]
|
27 |
+
"""String to Integer Mapping. This is the vocab"""
|
28 |
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
stoi,
|
32 |
+
itos,
|
33 |
+
pad_token,
|
34 |
+
unk_token,
|
35 |
+
bos_token,
|
36 |
+
eos_token,
|
37 |
+
name_or_path,
|
38 |
+
):
|
39 |
+
self.stoi = stoi
|
40 |
+
self.itos = itos
|
41 |
+
|
42 |
+
self._PAD_TOKEN = pad_token
|
43 |
+
self._UNK_TOKEN = unk_token
|
44 |
+
self._EOS_TOKEN = eos_token
|
45 |
+
self._BOS_TOKEN = bos_token
|
46 |
|
47 |
+
# Define the model
|
48 |
+
tok_model = models.WordLevel(vocab=self.stoi, unk_token=self._UNK_TOKEN)
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
+
slow_tokenizer = tokenizers.Tokenizer(tok_model)
|
51 |
+
slow_tokenizer.pre_tokenizer = self._init_pretokenizer()
|
52 |
+
|
53 |
+
# post processing adds special tokens unless explicitly ignored
|
54 |
+
post_proc = processors.TemplateProcessing(
|
55 |
+
single=f"{bos_token} $0",
|
56 |
+
pair=None,
|
57 |
+
special_tokens=[(bos_token, 1)],
|
58 |
+
)
|
59 |
+
slow_tokenizer.post_processor=post_proc
|
60 |
+
|
61 |
+
super().__init__(
|
62 |
+
tokenizer_object=slow_tokenizer,
|
63 |
+
unk_token=self._UNK_TOKEN,
|
64 |
+
bos_token=self._BOS_TOKEN,
|
65 |
+
eos_token=self._EOS_TOKEN,
|
66 |
+
pad_token=self._PAD_TOKEN,
|
67 |
+
name_or_path=name_or_path,
|
68 |
+
)
|
69 |
+
|
70 |
+
# Override the decode behavior to ensure spaces are correctly handled
|
71 |
+
def _decode(
|
72 |
+
token_ids: int | List[int],
|
73 |
+
skip_special_tokens=False,
|
74 |
+
clean_up_tokenization_spaces=False,
|
75 |
+
) -> int | List[int]:
|
76 |
+
|
77 |
+
if isinstance(token_ids, int):
|
78 |
+
return self.itos.get(token_ids, self._UNK_TOKEN)
|
79 |
+
|
80 |
+
if isinstance(token_ids, dict):
|
81 |
+
token_ids = token_ids["input_ids"]
|
82 |
+
|
83 |
+
if isinstance(token_ids, TT):
|
84 |
+
token_ids = token_ids.tolist()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
+
if isinstance(token_ids, list):
|
87 |
+
tokens_str = [self.itos.get(xi, self._UNK_TOKEN) for xi in token_ids]
|
88 |
+
moves = self._process_str_tokens(tokens_str)
|
89 |
+
|
90 |
+
return " ".join(moves)
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
+
|
94 |
+
self._decode = _decode
|
95 |
+
|
96 |
+
def _init_pretokenizer(self) -> pre_tokenizers.PreTokenizer:
|
97 |
+
raise NotImplementedError
|
98 |
+
|
99 |
+
def _process_str_tokens(self, tokens_str: list[str]) -> list[str]:
|
100 |
+
raise NotImplementedError
|
101 |
+
|
102 |
+
def get_id2square_list() -> list[int]:
|
103 |
+
raise NotImplementedError
|
104 |
+
|
105 |
+
class UciTileTokenizer(UciTokenizer):
|
106 |
+
""" Uci tokenizer converting start/end tiles and promotion types each into individual tokens"""
|
107 |
+
stoi = {
|
108 |
+
tok: idx
|
109 |
+
for tok, idx in list(
|
110 |
+
zip(["<pad>", "<s>", "</s>", "<unk>"] + chess.SQUARE_NAMES + list("qrbn"), range(72))
|
111 |
+
)
|
112 |
+
}
|
113 |
+
|
114 |
+
itos = {
|
115 |
+
idx: tok
|
116 |
+
for tok, idx in list(
|
117 |
+
zip(["<pad>", "<s>", "</s>", "<unk>"] + chess.SQUARE_NAMES + list("qrbn"), range(72))
|
118 |
+
)
|
119 |
+
}
|
120 |
+
|
121 |
+
id2square:List[int] = [None]*4 + list(range(64))+[None]*4
|
122 |
+
"""
|
123 |
+
List mapping token IDs to squares on the chess board. Order is file then row, i.e.:
|
124 |
+
`A1, B1, C1, ..., F8, G8, H8`
|
125 |
+
"""
|
126 |
+
|
127 |
+
def get_id2square_list(self) -> List[int]:
|
128 |
+
return self.id2square
|
129 |
+
|
130 |
+
def __init__(self):
|
131 |
+
|
132 |
+
super().__init__(
|
133 |
+
self.stoi,
|
134 |
+
self.itos,
|
135 |
+
pad_token="<pad>",
|
136 |
+
unk_token="<unk>",
|
137 |
+
bos_token="<s>",
|
138 |
+
eos_token="</s>",
|
139 |
+
name_or_path="austindavis/uci_tile_tokenizer",
|
140 |
+
)
|
141 |
+
|
142 |
+
def _init_pretokenizer(self):
|
143 |
+
# Pre-tokenizer to split input into UCI moves
|
144 |
+
pattern = tokenizers.Regex(r"\d")
|
145 |
+
pre_tokenizer = pre_tokenizers.Sequence(
|
146 |
+
[
|
147 |
+
pre_tokenizers.Whitespace(),
|
148 |
+
pre_tokenizers.Split(pattern=pattern, behavior="merged_with_previous"),
|
149 |
+
]
|
150 |
+
)
|
151 |
+
return pre_tokenizer
|
152 |
+
|
153 |
+
def _process_str_tokens(self, token_str):
|
154 |
+
moves = []
|
155 |
+
next_move = ""
|
156 |
+
for token in token_str:
|
157 |
+
|
158 |
+
# skip special tokens
|
159 |
+
if token in self.all_special_tokens:
|
160 |
+
continue
|
161 |
+
|
162 |
+
# handle promotions
|
163 |
+
if len(token) == 1:
|
164 |
+
moves.append(next_move + token)
|
165 |
+
continue
|
166 |
+
|
167 |
+
# handle regular tokens
|
168 |
+
if len(next_move) == 4:
|
169 |
+
moves.append(next_move)
|
170 |
+
next_move = token
|
171 |
+
else:
|
172 |
+
next_move += token
|
173 |
+
|
174 |
+
moves.append(next_move)
|
175 |
+
return moves
|
176 |
|
177 |
+
def setup_app(model: GPT2LMHeadModel):
|
178 |
+
"""
|
179 |
+
Configures a Gradio App to use the GPT model for move generation.
|
180 |
+
The model must be compatible with a UciTileTokenizer.
|
181 |
+
"""
|
182 |
+
tokenizer = UciTileTokenizer()
|
183 |
+
|
184 |
+
# Initialize the chess board
|
185 |
+
board = chess.Board()
|
186 |
+
game:chess.pgn.GameNode = chess.pgn.Game()
|
187 |
+
|
188 |
+
|
189 |
+
|
190 |
+
game.headers["Event"] = "Example"
|
191 |
+
|
192 |
+
generate_kwargs = {
|
193 |
+
"max_new_tokens": 3,
|
194 |
+
"num_return_sequences": 10,
|
195 |
+
"temperature": 0.5,
|
196 |
+
"output_scores": True,
|
197 |
+
"output_logits": True,
|
198 |
+
"return_dict_in_generate": True
|
199 |
+
}
|
200 |
+
|
201 |
+
def make_move(input:str, node=game, board = board):
|
202 |
+
# check for reset
|
203 |
+
if input.lower() == 'reset':
|
204 |
+
board.reset()
|
205 |
+
node.root().variations.clear()
|
206 |
+
return chess.svg.board(board=board), "New game!"
|
207 |
+
|
208 |
+
# check for pgn
|
209 |
+
if input[0] == '[' or input[:3] == '1. ':
|
210 |
+
pgn = io.StringIO(input)
|
211 |
+
game = chess.pgn.read_game(pgn)
|
212 |
+
board.reset()
|
213 |
+
node.root().variations.clear()
|
214 |
+
|
215 |
+
for move in game.mainline_moves():
|
216 |
+
board.push(move)
|
217 |
+
node.add_variation(move)
|
218 |
+
|
219 |
+
return chess.svg.board(board=board,lastmove=move), ""#str(node.root()).split(']')[-1].strip()
|
220 |
+
|
221 |
+
|
222 |
+
try:
|
223 |
+
move = chess.Move.from_uci(input)
|
224 |
+
if move in board.legal_moves:
|
225 |
+
board.push(move)
|
226 |
+
|
227 |
+
while node.next() is not None:
|
228 |
+
node = node.next()
|
229 |
+
node = node.add_variation(move)
|
230 |
+
|
231 |
+
# get computer's move
|
232 |
+
|
233 |
+
prefix = ' '.join([x.uci() for x in board.move_stack])
|
234 |
+
encoding = tokenizer(text=prefix,
|
235 |
+
return_tensors='pt',
|
236 |
+
)['input_ids']
|
237 |
+
|
238 |
+
output = model.generate(encoding, **generate_kwargs) # [b,p,v]
|
239 |
+
new_tokens = tokenizer.batch_decode(output.sequences[:,-3:])
|
240 |
+
unique_moves, unique_indices = np.unique([x[:4] if ' ' in x else x for x in new_tokens], return_index=True)
|
241 |
+
unique_indices = torch.Tensor(list(unique_indices)).to(dtype=torch.int)
|
242 |
+
logits = torch.stack(output.logits) # [token, batch, vocab]
|
243 |
+
logits = logits[:,unique_indices] # [token, batch, vocab]
|
244 |
+
|
245 |
+
# select moves based on mean logit value for tokens 1 and 2
|
246 |
+
logit_priority_order = logits.max(dim=-1).values.T[:,:2].mean(-1).topk(len(unique_indices)).indices
|
247 |
+
priority_ordered_moves = unique_moves[logit_priority_order]
|
248 |
+
|
249 |
+
# if there's only 1 option, we have to pack it back into a list
|
250 |
+
if isinstance(priority_ordered_moves, str):
|
251 |
+
priority_ordered_moves = [priority_ordered_moves]
|
252 |
+
|
253 |
+
# test if any moves are valid
|
254 |
+
for uci in priority_ordered_moves:
|
255 |
+
move = chess.Move.from_uci(uci)
|
256 |
+
if move in board.legal_moves:
|
257 |
+
board.push(move)
|
258 |
+
while node.next() is not None:
|
259 |
+
node = node.next()
|
260 |
+
node = node.add_variation(move)
|
261 |
+
return chess.svg.board(board=board,lastmove=move), "".join(str(node.root()).split("]")[-1]).strip()
|
262 |
+
|
263 |
+
# no moves are valid
|
264 |
+
bad_from_tiles = [chess.parse_square(x) for x in [x[:2] for x in unique_moves]]
|
265 |
+
bad_to_tiles = [chess.parse_square(x) for x in [x[2:] for x in unique_moves]]
|
266 |
+
arrows = [chess.svg.Arrow(tail, head, color="red") for (tail, head) in zip(bad_from_tiles, bad_to_tiles)]
|
267 |
+
checks = None
|
268 |
+
if board.is_check():
|
269 |
+
checks = board.pieces(chess.PIECE_TYPES[-1],board.turn).pop()
|
270 |
+
|
271 |
+
return chess.svg.board(board=board,arrows=arrows, check=checks), '|'.join(unique_moves)
|
272 |
+
else:
|
273 |
+
return chess.svg.board(board=board,lastmove=move), f"Illegal move: {input}"
|
274 |
+
|
275 |
+
except chess.InvalidMoveError:
|
276 |
+
return chess.svg.board(board=board), f"Invalid UCI format: {input}"
|
277 |
+
except Exception:
|
278 |
+
return chess.svg.board(board=board), traceback.format_exc()
|
279 |
+
|
280 |
+
input_box = gr.Textbox(None,placeholder="Enter your move in UCI format")
|
281 |
+
|
282 |
+
# Define the Gradio interface
|
283 |
+
iface = gr.Interface(
|
284 |
+
fn=make_move,
|
285 |
+
inputs=input_box,
|
286 |
+
outputs=["html", "text"],
|
287 |
+
examples=[['e2e4'], ['d2d4'], ['Reset']],
|
288 |
+
title="Play Versus ChessGPT",
|
289 |
+
description="Enter moves in UCI notation (e.g., e2e4 for pawn from e2 to e4). Enter 'reset' to restart the game.",
|
290 |
+
allow_flagging='never',
|
291 |
+
submit_btn = "Move",
|
292 |
+
stop_btn = "Stop",
|
293 |
+
clear_btn = "Clear w/o reset",
|
294 |
+
)
|
295 |
+
|
296 |
+
iface.output_components[0].label = "Board"
|
297 |
+
iface.output_components[0].show_label = True
|
298 |
+
iface.output_components[1].label = "Move Sequence"
|
299 |
+
|
300 |
+
return iface
|
301 |
+
|
302 |
+
checkpoint_name = "austindavis/gpt2-lichess-uci-202306"
|
303 |
+
model: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(checkpoint_name)
|
304 |
+
model.requires_grad_(False)
|
305 |
+
|
306 |
+
iface = setup_app(model)
|
307 |
+
iface.launch()
|