austindavis commited on
Commit
1c8df7b
1 Parent(s): 224a696

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ from transformers import AutoModelForCausalLM, GPT2LMHeadModel
3
+ import torch
4
+ import chess
5
+ import chess.pgn
6
+ import chess.svg
7
+ import gradio as gr
8
+ from modeling.utils import uci_to_board
9
+ from modeling.uci_tokenizers import UciTileTokenizer
10
+ import numpy as np
11
+ import io
12
+ import traceback
13
+ # %%
14
+ checkpoint_name = "austindavis/gpt2-lichess-uci-201601"
15
+ tokenizer = UciTileTokenizer()
16
+ model: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(checkpoint_name)
17
+ model.requires_grad_(False)
18
+
19
+ # %%
20
+ # Initialize the chess board
21
+ board = chess.Board()
22
+ game:chess.pgn.GameNode = chess.pgn.Game()
23
+
24
+
25
+
26
+ game.headers["Event"] = "Example"
27
+
28
+ generate_kwargs = {
29
+ "max_new_tokens": 3,
30
+ "num_return_sequences": 10,
31
+ "temperature": 0.5,
32
+ "output_scores": True,
33
+ "output_logits": True,
34
+ "return_dict_in_generate": True
35
+ }
36
+
37
+ def make_move(input:str, node=game, board = board):
38
+ # check for reset
39
+ if input.lower() == 'reset':
40
+ board.reset()
41
+ node.root().variations.clear()
42
+ return chess.svg.board(board=board), "New game!"
43
+
44
+ # check for pgn
45
+ if input[0] == '[' or input[:3] == '1. ':
46
+ pgn = io.StringIO(input)
47
+ game = chess.pgn.read_game(pgn)
48
+ board.reset()
49
+ node.root().variations.clear()
50
+
51
+ for move in game.mainline_moves():
52
+ board.push(move)
53
+ node.add_variation(move)
54
+
55
+ return chess.svg.board(board=board,lastmove=move), ""#str(node.root()).split(']')[-1].strip()
56
+
57
+
58
+ try:
59
+ move = chess.Move.from_uci(input)
60
+ if move in board.legal_moves:
61
+ board.push(move)
62
+
63
+ while node.next() is not None:
64
+ node = node.next()
65
+ node = node.add_variation(move)
66
+
67
+ # get computer's move
68
+
69
+ prefix = ' '.join([x.uci() for x in board.move_stack])
70
+ encoding = tokenizer(text=prefix,
71
+ return_tensors='pt',
72
+ )['input_ids']
73
+
74
+ output = model.generate(encoding, **generate_kwargs) # [b,p,v]
75
+ new_tokens = tokenizer.batch_decode(output.sequences[:,-3:])
76
+ unique_moves, unique_indices = np.unique([x[:4] if ' ' in x else x for x in new_tokens], return_index=True)
77
+ unique_indices = torch.Tensor(list(unique_indices)).to(dtype=torch.int)
78
+ logits = torch.stack(output.logits) # [token, batch, vocab]
79
+ logits = logits[:,unique_indices] # [token, batch, vocab]
80
+
81
+ # select moves based on mean logit value for tokens 1 and 2
82
+ logit_priority_order = logits.max(dim=-1).values.T[:,:2].mean(-1).topk(len(unique_indices)).indices
83
+ priority_ordered_moves = unique_moves[logit_priority_order]
84
+
85
+ # if there's only 1 option, we have to pack it back into a list
86
+ if isinstance(priority_ordered_moves, str):
87
+ priority_ordered_moves = [priority_ordered_moves]
88
+
89
+ # test if any moves are valid
90
+ for uci in priority_ordered_moves:
91
+ move = chess.Move.from_uci(uci)
92
+ if move in board.legal_moves:
93
+ board.push(move)
94
+ while node.next() is not None:
95
+ node = node.next()
96
+ node = node.add_variation(move)
97
+ return chess.svg.board(board=board,lastmove=move), ""#str(node.root()).split(']')[-1].strip()
98
+
99
+ # no moves are valid
100
+ bad_from_tiles = [chess.parse_square(x) for x in [x[:2] for x in unique_moves]]
101
+ bad_to_tiles = [chess.parse_square(x) for x in [x[2:] for x in unique_moves]]
102
+ arrows = [chess.svg.Arrow(tail, head, color="red") for (tail, head) in zip(bad_from_tiles, bad_to_tiles)]
103
+ return chess.svg.board(board=board,arrows=arrows), '|'.join(unique_moves)
104
+ else:
105
+ return chess.svg.board(board=board,lastmove=move), f"Illegal move: {input}"
106
+
107
+ except ValueError:
108
+ return chess.svg.board(board=board), f"Invalid UCI format: {uci} {list(unique_moves)}"
109
+ except Exception:
110
+ return chess.svg.board(board=board), traceback.format_exc()
111
+
112
+ # Define the Gradio interface
113
+ iface = gr.Interface(
114
+ fn=make_move,
115
+ inputs="text",
116
+ outputs=["html", "text"],
117
+ examples=[['e2e4'], ['d2d4'], ['Reset']],
118
+ title="Play Versus ChessGPT",
119
+ description="Enter moves in UCI notation (e.g., e2e4 for pawn from e2 to e4). Enter 'reset' to restart the game.",
120
+ allow_flagging='never',
121
+ )
122
+
123
+ # Launch the Gradio app
124
+ iface.launch()
125
+
126
+ # %%