austindavis
commited on
Commit
•
1c8df7b
1
Parent(s):
224a696
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %%
|
2 |
+
from transformers import AutoModelForCausalLM, GPT2LMHeadModel
|
3 |
+
import torch
|
4 |
+
import chess
|
5 |
+
import chess.pgn
|
6 |
+
import chess.svg
|
7 |
+
import gradio as gr
|
8 |
+
from modeling.utils import uci_to_board
|
9 |
+
from modeling.uci_tokenizers import UciTileTokenizer
|
10 |
+
import numpy as np
|
11 |
+
import io
|
12 |
+
import traceback
|
13 |
+
# %%
|
14 |
+
checkpoint_name = "austindavis/gpt2-lichess-uci-201601"
|
15 |
+
tokenizer = UciTileTokenizer()
|
16 |
+
model: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(checkpoint_name)
|
17 |
+
model.requires_grad_(False)
|
18 |
+
|
19 |
+
# %%
|
20 |
+
# Initialize the chess board
|
21 |
+
board = chess.Board()
|
22 |
+
game:chess.pgn.GameNode = chess.pgn.Game()
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
game.headers["Event"] = "Example"
|
27 |
+
|
28 |
+
generate_kwargs = {
|
29 |
+
"max_new_tokens": 3,
|
30 |
+
"num_return_sequences": 10,
|
31 |
+
"temperature": 0.5,
|
32 |
+
"output_scores": True,
|
33 |
+
"output_logits": True,
|
34 |
+
"return_dict_in_generate": True
|
35 |
+
}
|
36 |
+
|
37 |
+
def make_move(input:str, node=game, board = board):
|
38 |
+
# check for reset
|
39 |
+
if input.lower() == 'reset':
|
40 |
+
board.reset()
|
41 |
+
node.root().variations.clear()
|
42 |
+
return chess.svg.board(board=board), "New game!"
|
43 |
+
|
44 |
+
# check for pgn
|
45 |
+
if input[0] == '[' or input[:3] == '1. ':
|
46 |
+
pgn = io.StringIO(input)
|
47 |
+
game = chess.pgn.read_game(pgn)
|
48 |
+
board.reset()
|
49 |
+
node.root().variations.clear()
|
50 |
+
|
51 |
+
for move in game.mainline_moves():
|
52 |
+
board.push(move)
|
53 |
+
node.add_variation(move)
|
54 |
+
|
55 |
+
return chess.svg.board(board=board,lastmove=move), ""#str(node.root()).split(']')[-1].strip()
|
56 |
+
|
57 |
+
|
58 |
+
try:
|
59 |
+
move = chess.Move.from_uci(input)
|
60 |
+
if move in board.legal_moves:
|
61 |
+
board.push(move)
|
62 |
+
|
63 |
+
while node.next() is not None:
|
64 |
+
node = node.next()
|
65 |
+
node = node.add_variation(move)
|
66 |
+
|
67 |
+
# get computer's move
|
68 |
+
|
69 |
+
prefix = ' '.join([x.uci() for x in board.move_stack])
|
70 |
+
encoding = tokenizer(text=prefix,
|
71 |
+
return_tensors='pt',
|
72 |
+
)['input_ids']
|
73 |
+
|
74 |
+
output = model.generate(encoding, **generate_kwargs) # [b,p,v]
|
75 |
+
new_tokens = tokenizer.batch_decode(output.sequences[:,-3:])
|
76 |
+
unique_moves, unique_indices = np.unique([x[:4] if ' ' in x else x for x in new_tokens], return_index=True)
|
77 |
+
unique_indices = torch.Tensor(list(unique_indices)).to(dtype=torch.int)
|
78 |
+
logits = torch.stack(output.logits) # [token, batch, vocab]
|
79 |
+
logits = logits[:,unique_indices] # [token, batch, vocab]
|
80 |
+
|
81 |
+
# select moves based on mean logit value for tokens 1 and 2
|
82 |
+
logit_priority_order = logits.max(dim=-1).values.T[:,:2].mean(-1).topk(len(unique_indices)).indices
|
83 |
+
priority_ordered_moves = unique_moves[logit_priority_order]
|
84 |
+
|
85 |
+
# if there's only 1 option, we have to pack it back into a list
|
86 |
+
if isinstance(priority_ordered_moves, str):
|
87 |
+
priority_ordered_moves = [priority_ordered_moves]
|
88 |
+
|
89 |
+
# test if any moves are valid
|
90 |
+
for uci in priority_ordered_moves:
|
91 |
+
move = chess.Move.from_uci(uci)
|
92 |
+
if move in board.legal_moves:
|
93 |
+
board.push(move)
|
94 |
+
while node.next() is not None:
|
95 |
+
node = node.next()
|
96 |
+
node = node.add_variation(move)
|
97 |
+
return chess.svg.board(board=board,lastmove=move), ""#str(node.root()).split(']')[-1].strip()
|
98 |
+
|
99 |
+
# no moves are valid
|
100 |
+
bad_from_tiles = [chess.parse_square(x) for x in [x[:2] for x in unique_moves]]
|
101 |
+
bad_to_tiles = [chess.parse_square(x) for x in [x[2:] for x in unique_moves]]
|
102 |
+
arrows = [chess.svg.Arrow(tail, head, color="red") for (tail, head) in zip(bad_from_tiles, bad_to_tiles)]
|
103 |
+
return chess.svg.board(board=board,arrows=arrows), '|'.join(unique_moves)
|
104 |
+
else:
|
105 |
+
return chess.svg.board(board=board,lastmove=move), f"Illegal move: {input}"
|
106 |
+
|
107 |
+
except ValueError:
|
108 |
+
return chess.svg.board(board=board), f"Invalid UCI format: {uci} {list(unique_moves)}"
|
109 |
+
except Exception:
|
110 |
+
return chess.svg.board(board=board), traceback.format_exc()
|
111 |
+
|
112 |
+
# Define the Gradio interface
|
113 |
+
iface = gr.Interface(
|
114 |
+
fn=make_move,
|
115 |
+
inputs="text",
|
116 |
+
outputs=["html", "text"],
|
117 |
+
examples=[['e2e4'], ['d2d4'], ['Reset']],
|
118 |
+
title="Play Versus ChessGPT",
|
119 |
+
description="Enter moves in UCI notation (e.g., e2e4 for pawn from e2 to e4). Enter 'reset' to restart the game.",
|
120 |
+
allow_flagging='never',
|
121 |
+
)
|
122 |
+
|
123 |
+
# Launch the Gradio app
|
124 |
+
iface.launch()
|
125 |
+
|
126 |
+
# %%
|