# %% from transformers import AutoModelForCausalLM, GPT2LMHeadModel import torch import chess import chess.pgn import chess.svg import gradio as gr from modeling.utils import uci_to_board from modeling.uci_tokenizers import UciTileTokenizer import numpy as np import io import traceback # %% checkpoint_name = "austindavis/gpt2-lichess-uci-201601" tokenizer = UciTileTokenizer() model: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(checkpoint_name) model.requires_grad_(False) # %% # Initialize the chess board board = chess.Board() game:chess.pgn.GameNode = chess.pgn.Game() game.headers["Event"] = "Example" generate_kwargs = { "max_new_tokens": 3, "num_return_sequences": 10, "temperature": 0.5, "output_scores": True, "output_logits": True, "return_dict_in_generate": True } def make_move(input:str, node=game, board = board): # check for reset if input.lower() == 'reset': board.reset() node.root().variations.clear() return chess.svg.board(board=board), "New game!" # check for pgn if input[0] == '[' or input[:3] == '1. ': pgn = io.StringIO(input) game = chess.pgn.read_game(pgn) board.reset() node.root().variations.clear() for move in game.mainline_moves(): board.push(move) node.add_variation(move) return chess.svg.board(board=board,lastmove=move), ""#str(node.root()).split(']')[-1].strip() try: move = chess.Move.from_uci(input) if move in board.legal_moves: board.push(move) while node.next() is not None: node = node.next() node = node.add_variation(move) # get computer's move prefix = ' '.join([x.uci() for x in board.move_stack]) encoding = tokenizer(text=prefix, return_tensors='pt', )['input_ids'] output = model.generate(encoding, **generate_kwargs) # [b,p,v] new_tokens = tokenizer.batch_decode(output.sequences[:,-3:]) unique_moves, unique_indices = np.unique([x[:4] if ' ' in x else x for x in new_tokens], return_index=True) unique_indices = torch.Tensor(list(unique_indices)).to(dtype=torch.int) logits = torch.stack(output.logits) # [token, batch, vocab] logits = logits[:,unique_indices] # [token, batch, vocab] # select moves based on mean logit value for tokens 1 and 2 logit_priority_order = logits.max(dim=-1).values.T[:,:2].mean(-1).topk(len(unique_indices)).indices priority_ordered_moves = unique_moves[logit_priority_order] # if there's only 1 option, we have to pack it back into a list if isinstance(priority_ordered_moves, str): priority_ordered_moves = [priority_ordered_moves] # test if any moves are valid for uci in priority_ordered_moves: move = chess.Move.from_uci(uci) if move in board.legal_moves: board.push(move) while node.next() is not None: node = node.next() node = node.add_variation(move) return chess.svg.board(board=board,lastmove=move), ""#str(node.root()).split(']')[-1].strip() # no moves are valid bad_from_tiles = [chess.parse_square(x) for x in [x[:2] for x in unique_moves]] bad_to_tiles = [chess.parse_square(x) for x in [x[2:] for x in unique_moves]] arrows = [chess.svg.Arrow(tail, head, color="red") for (tail, head) in zip(bad_from_tiles, bad_to_tiles)] return chess.svg.board(board=board,arrows=arrows), '|'.join(unique_moves) else: return chess.svg.board(board=board,lastmove=move), f"Illegal move: {input}" except ValueError: return chess.svg.board(board=board), f"Invalid UCI format: {uci} {list(unique_moves)}" except Exception: return chess.svg.board(board=board), traceback.format_exc() # Define the Gradio interface iface = gr.Interface( fn=make_move, inputs="text", outputs=["html", "text"], examples=[['e2e4'], ['d2d4'], ['Reset']], title="Play Versus ChessGPT", description="Enter moves in UCI notation (e.g., e2e4 for pawn from e2 to e4). Enter 'reset' to restart the game.", allow_flagging='never', ) # Launch the Gradio app iface.launch() # %%