|
|
|
from transformers import AutoModelForCausalLM, GPT2LMHeadModel |
|
import torch |
|
import chess |
|
import chess.pgn |
|
import chess.svg |
|
import gradio as gr |
|
from modeling.utils import uci_to_board |
|
from modeling.uci_tokenizers import UciTileTokenizer |
|
import numpy as np |
|
import io |
|
import traceback |
|
|
|
checkpoint_name = "austindavis/chess-gpt2-uci-8x8x512" |
|
tokenizer = UciTileTokenizer() |
|
model: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(checkpoint_name) |
|
model.requires_grad_(False) |
|
|
|
|
|
|
|
board = chess.Board() |
|
game:chess.pgn.GameNode = chess.pgn.Game() |
|
|
|
|
|
|
|
game.headers["Event"] = "Example" |
|
|
|
generate_kwargs = { |
|
"max_new_tokens": 3, |
|
"num_return_sequences": 10, |
|
"temperature": 0.5, |
|
"output_scores": True, |
|
"output_logits": True, |
|
"return_dict_in_generate": True |
|
} |
|
|
|
def make_move(input:str, node=game, board = board): |
|
|
|
if input.lower() == 'reset': |
|
board.reset() |
|
node.root().variations.clear() |
|
return chess.svg.board(board=board), "New game!" |
|
|
|
|
|
if input[0] == '[' or input[:3] == '1. ': |
|
pgn = io.StringIO(input) |
|
game = chess.pgn.read_game(pgn) |
|
board.reset() |
|
node.root().variations.clear() |
|
|
|
for move in game.mainline_moves(): |
|
board.push(move) |
|
node.add_variation(move) |
|
|
|
return chess.svg.board(board=board,lastmove=move), "" |
|
|
|
|
|
try: |
|
move = chess.Move.from_uci(input) |
|
if move in board.legal_moves: |
|
board.push(move) |
|
|
|
while node.next() is not None: |
|
node = node.next() |
|
node = node.add_variation(move) |
|
|
|
|
|
|
|
prefix = ' '.join([x.uci() for x in board.move_stack]) |
|
encoding = tokenizer(text=prefix, |
|
return_tensors='pt', |
|
)['input_ids'] |
|
|
|
output = model.generate(encoding, **generate_kwargs) |
|
new_tokens = tokenizer.batch_decode(output.sequences[:,-3:]) |
|
unique_moves, unique_indices = np.unique([x[:4] if ' ' in x else x for x in new_tokens], return_index=True) |
|
unique_indices = torch.Tensor(list(unique_indices)).to(dtype=torch.int) |
|
logits = torch.stack(output.logits) |
|
logits = logits[:,unique_indices] |
|
|
|
|
|
logit_priority_order = logits.max(dim=-1).values.T[:,:2].mean(-1).topk(len(unique_indices)).indices |
|
priority_ordered_moves = unique_moves[logit_priority_order] |
|
|
|
|
|
if isinstance(priority_ordered_moves, str): |
|
priority_ordered_moves = [priority_ordered_moves] |
|
|
|
|
|
for uci in priority_ordered_moves: |
|
move = chess.Move.from_uci(uci) |
|
if move in board.legal_moves: |
|
board.push(move) |
|
while node.next() is not None: |
|
node = node.next() |
|
node = node.add_variation(move) |
|
return chess.svg.board(board=board,lastmove=move), "" |
|
|
|
|
|
bad_from_tiles = [chess.parse_square(x) for x in [x[:2] for x in unique_moves]] |
|
bad_to_tiles = [chess.parse_square(x) for x in [x[2:] for x in unique_moves]] |
|
arrows = [chess.svg.Arrow(tail, head, color="red") for (tail, head) in zip(bad_from_tiles, bad_to_tiles)] |
|
return chess.svg.board(board=board,arrows=arrows), '|'.join(unique_moves) |
|
else: |
|
return chess.svg.board(board=board,lastmove=move), f"Illegal move: {input}" |
|
|
|
except ValueError: |
|
return chess.svg.board(board=board), f"Invalid UCI format: {uci} {list(unique_moves)}" |
|
except Exception: |
|
return chess.svg.board(board=board), traceback.format_exc() |
|
|
|
|
|
iface = gr.Interface( |
|
fn=make_move, |
|
inputs="text", |
|
outputs=["html", "text"], |
|
examples=[['e2e4'], ['d2d4'], ['Reset']], |
|
title="Play Versus ChessGPT", |
|
description="Enter moves in UCI notation (e.g., e2e4 for pawn from e2 to e4). Enter 'reset' to restart the game.", |
|
allow_flagging='never', |
|
) |
|
|
|
|
|
iface.launch() |
|
|
|
|
|
|