austindavis's picture
Update app.py
44e120c verified
raw
history blame
4.61 kB
# %%
from transformers import AutoModelForCausalLM, GPT2LMHeadModel
import torch
import chess
import chess.pgn
import chess.svg
import gradio as gr
from modeling.utils import uci_to_board
from modeling.uci_tokenizers import UciTileTokenizer
import numpy as np
import io
import traceback
# %%
checkpoint_name = "austindavis/chess-gpt2-uci-8x8x512"
tokenizer = UciTileTokenizer()
model: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(checkpoint_name)
model.requires_grad_(False)
# %%
# Initialize the chess board
board = chess.Board()
game:chess.pgn.GameNode = chess.pgn.Game()
game.headers["Event"] = "Example"
generate_kwargs = {
"max_new_tokens": 3,
"num_return_sequences": 10,
"temperature": 0.5,
"output_scores": True,
"output_logits": True,
"return_dict_in_generate": True
}
def make_move(input:str, node=game, board = board):
# check for reset
if input.lower() == 'reset':
board.reset()
node.root().variations.clear()
return chess.svg.board(board=board), "New game!"
# check for pgn
if input[0] == '[' or input[:3] == '1. ':
pgn = io.StringIO(input)
game = chess.pgn.read_game(pgn)
board.reset()
node.root().variations.clear()
for move in game.mainline_moves():
board.push(move)
node.add_variation(move)
return chess.svg.board(board=board,lastmove=move), ""#str(node.root()).split(']')[-1].strip()
try:
move = chess.Move.from_uci(input)
if move in board.legal_moves:
board.push(move)
while node.next() is not None:
node = node.next()
node = node.add_variation(move)
# get computer's move
prefix = ' '.join([x.uci() for x in board.move_stack])
encoding = tokenizer(text=prefix,
return_tensors='pt',
)['input_ids']
output = model.generate(encoding, **generate_kwargs) # [b,p,v]
new_tokens = tokenizer.batch_decode(output.sequences[:,-3:])
unique_moves, unique_indices = np.unique([x[:4] if ' ' in x else x for x in new_tokens], return_index=True)
unique_indices = torch.Tensor(list(unique_indices)).to(dtype=torch.int)
logits = torch.stack(output.logits) # [token, batch, vocab]
logits = logits[:,unique_indices] # [token, batch, vocab]
# select moves based on mean logit value for tokens 1 and 2
logit_priority_order = logits.max(dim=-1).values.T[:,:2].mean(-1).topk(len(unique_indices)).indices
priority_ordered_moves = unique_moves[logit_priority_order]
# if there's only 1 option, we have to pack it back into a list
if isinstance(priority_ordered_moves, str):
priority_ordered_moves = [priority_ordered_moves]
# test if any moves are valid
for uci in priority_ordered_moves:
move = chess.Move.from_uci(uci)
if move in board.legal_moves:
board.push(move)
while node.next() is not None:
node = node.next()
node = node.add_variation(move)
return chess.svg.board(board=board,lastmove=move), ""#str(node.root()).split(']')[-1].strip()
# no moves are valid
bad_from_tiles = [chess.parse_square(x) for x in [x[:2] for x in unique_moves]]
bad_to_tiles = [chess.parse_square(x) for x in [x[2:] for x in unique_moves]]
arrows = [chess.svg.Arrow(tail, head, color="red") for (tail, head) in zip(bad_from_tiles, bad_to_tiles)]
return chess.svg.board(board=board,arrows=arrows), '|'.join(unique_moves)
else:
return chess.svg.board(board=board,lastmove=move), f"Illegal move: {input}"
except ValueError:
return chess.svg.board(board=board), f"Invalid UCI format: {uci} {list(unique_moves)}"
except Exception:
return chess.svg.board(board=board), traceback.format_exc()
# Define the Gradio interface
iface = gr.Interface(
fn=make_move,
inputs="text",
outputs=["html", "text"],
examples=[['e2e4'], ['d2d4'], ['Reset']],
title="Play Versus ChessGPT",
description="Enter moves in UCI notation (e.g., e2e4 for pawn from e2 to e4). Enter 'reset' to restart the game.",
allow_flagging='never',
)
# Launch the Gradio app
iface.launch()
# %%