File size: 4,651 Bytes
6c227b9
 
1c8df7b
 
 
 
 
 
 
 
 
 
 
 
 
44e120c
1c8df7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
!pip install transformers chess torch 

# %%
from transformers import AutoModelForCausalLM, GPT2LMHeadModel
import torch
import chess
import chess.pgn
import chess.svg
import gradio as gr
from modeling.utils import uci_to_board
from modeling.uci_tokenizers import UciTileTokenizer
import numpy as np
import io
import traceback
# %%
checkpoint_name = "austindavis/chess-gpt2-uci-8x8x512"
tokenizer = UciTileTokenizer()
model: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(checkpoint_name)
model.requires_grad_(False)

# %%
# Initialize the chess board
board = chess.Board()
game:chess.pgn.GameNode = chess.pgn.Game()



game.headers["Event"] = "Example"

generate_kwargs = {
                "max_new_tokens": 3,
                "num_return_sequences": 10,
                "temperature": 0.5, 
                "output_scores": True,
                "output_logits": True,
                "return_dict_in_generate": True
                }

def make_move(input:str, node=game, board = board):
    # check for reset
    if input.lower() == 'reset':
        board.reset()
        node.root().variations.clear()
        return chess.svg.board(board=board), "New game!"
    
    # check for pgn
    if input[0] == '[' or input[:3] == '1. ':
        pgn = io.StringIO(input)
        game = chess.pgn.read_game(pgn)
        board.reset()
        node.root().variations.clear()

        for move in game.mainline_moves():
            board.push(move)
            node.add_variation(move)

        return chess.svg.board(board=board,lastmove=move), ""#str(node.root()).split(']')[-1].strip()


    try:
        move = chess.Move.from_uci(input)
        if move in board.legal_moves:
            board.push(move)

            while node.next() is not None:
                node = node.next()
            node = node.add_variation(move)

            # get computer's move

            prefix = ' '.join([x.uci() for x in board.move_stack])
            encoding = tokenizer(text=prefix,
                return_tensors='pt', 
                )['input_ids']

            output = model.generate(encoding, **generate_kwargs) # [b,p,v]
            new_tokens = tokenizer.batch_decode(output.sequences[:,-3:])
            unique_moves, unique_indices = np.unique([x[:4] if ' ' in x else x for x in new_tokens], return_index=True)
            unique_indices = torch.Tensor(list(unique_indices)).to(dtype=torch.int)
            logits = torch.stack(output.logits) # [token, batch, vocab]
            logits = logits[:,unique_indices]  # [token, batch, vocab]
            
            # select moves based on mean logit value for tokens 1 and 2
            logit_priority_order = logits.max(dim=-1).values.T[:,:2].mean(-1).topk(len(unique_indices)).indices
            priority_ordered_moves = unique_moves[logit_priority_order]
            
            # if there's only 1 option, we have to pack it back into a list
            if isinstance(priority_ordered_moves, str):
                priority_ordered_moves = [priority_ordered_moves]

            # test if any moves are valid
            for uci in priority_ordered_moves:
                move = chess.Move.from_uci(uci)
                if move in board.legal_moves:
                    board.push(move)
                    while node.next() is not None:
                        node = node.next()
                    node = node.add_variation(move)
                    return chess.svg.board(board=board,lastmove=move), ""#str(node.root()).split(']')[-1].strip()
            
            # no moves are valid
            bad_from_tiles = [chess.parse_square(x) for x in [x[:2] for x in unique_moves]]
            bad_to_tiles = [chess.parse_square(x) for x in [x[2:] for x in unique_moves]]
            arrows = [chess.svg.Arrow(tail, head, color="red") for (tail, head) in zip(bad_from_tiles, bad_to_tiles)]
            return chess.svg.board(board=board,arrows=arrows), '|'.join(unique_moves)
        else:
            return chess.svg.board(board=board,lastmove=move), f"Illegal move:  {input}"
    
    except ValueError:
        return chess.svg.board(board=board), f"Invalid UCI format:  {uci} {list(unique_moves)}"
    except Exception:
        return chess.svg.board(board=board), traceback.format_exc()

# Define the Gradio interface
iface = gr.Interface(
    fn=make_move,
    inputs="text",
    outputs=["html", "text"],
    examples=[['e2e4'], ['d2d4'], ['Reset']],
    title="Play Versus ChessGPT",
    description="Enter moves in UCI notation (e.g., e2e4 for pawn from e2 to e4). Enter 'reset' to restart the game.",
    allow_flagging='never',
)

# Launch the Gradio app
iface.launch()

# %%