import io import traceback from typing import List import chess.pgn import chess.svg import gradio as gr import numpy as np import tokenizers import torch from tokenizers import models, pre_tokenizers, processors from torch import Tensor as TT from transformers import AutoModelForCausalLM, GPT2LMHeadModel, PreTrainedTokenizerFast import chess checkpoint_name = "austindavis/chess-gpt2-uci-8x8x512" class UciTokenizer(PreTrainedTokenizerFast): _PAD_TOKEN: str _UNK_TOKEN: str _EOS_TOKEN: str _BOS_TOKEN: str stoi: dict[str, int] """Integer to String mapping""" itos: dict[int, str] """String to Integer Mapping. This is the vocab""" def __init__( self, stoi, itos, pad_token, unk_token, bos_token, eos_token, name_or_path, ): self.stoi = stoi self.itos = itos self._PAD_TOKEN = pad_token self._UNK_TOKEN = unk_token self._EOS_TOKEN = eos_token self._BOS_TOKEN = bos_token # Define the model tok_model = models.WordLevel(vocab=self.stoi, unk_token=self._UNK_TOKEN) slow_tokenizer = tokenizers.Tokenizer(tok_model) slow_tokenizer.pre_tokenizer = self._init_pretokenizer() # post processing adds special tokens unless explicitly ignored post_proc = processors.TemplateProcessing( single=f"{bos_token} $0", pair=None, special_tokens=[(bos_token, 1)], ) slow_tokenizer.post_processor=post_proc super().__init__( tokenizer_object=slow_tokenizer, unk_token=self._UNK_TOKEN, bos_token=self._BOS_TOKEN, eos_token=self._EOS_TOKEN, pad_token=self._PAD_TOKEN, name_or_path=name_or_path, ) # Override the decode behavior to ensure spaces are correctly handled def _decode( token_ids: int | List[int], skip_special_tokens=False, clean_up_tokenization_spaces=False, ) -> int | List[int]: if isinstance(token_ids, int): return self.itos.get(token_ids, self._UNK_TOKEN) if isinstance(token_ids, dict): token_ids = token_ids["input_ids"] if isinstance(token_ids, TT): token_ids = token_ids.tolist() if isinstance(token_ids, list): tokens_str = [self.itos.get(xi, self._UNK_TOKEN) for xi in token_ids] moves = self._process_str_tokens(tokens_str) return " ".join(moves) self._decode = _decode def _init_pretokenizer(self) -> pre_tokenizers.PreTokenizer: raise NotImplementedError def _process_str_tokens(self, tokens_str: list[str]) -> list[str]: raise NotImplementedError def get_id2square_list() -> list[int]: raise NotImplementedError class UciTileTokenizer(UciTokenizer): """ Uci tokenizer converting start/end tiles and promotion types each into individual tokens""" stoi = { tok: idx for tok, idx in list( zip(["", "", "", ""] + chess.SQUARE_NAMES + list("qrbn"), range(72)) ) } itos = { idx: tok for tok, idx in list( zip(["", "", "", ""] + chess.SQUARE_NAMES + list("qrbn"), range(72)) ) } id2square:List[int] = [None]*4 + list(range(64))+[None]*4 """ List mapping token IDs to squares on the chess board. Order is file then row, i.e.: `A1, B1, C1, ..., F8, G8, H8` """ def get_id2square_list(self) -> List[int]: return self.id2square def __init__(self): super().__init__( self.stoi, self.itos, pad_token="", unk_token="", bos_token="", eos_token="", name_or_path="austindavis/uci_tile_tokenizer", ) def _init_pretokenizer(self): # Pre-tokenizer to split input into UCI moves pattern = tokenizers.Regex(r"\d") pre_tokenizer = pre_tokenizers.Sequence( [ pre_tokenizers.Whitespace(), pre_tokenizers.Split(pattern=pattern, behavior="merged_with_previous"), ] ) return pre_tokenizer def _process_str_tokens(self, token_str): moves = [] next_move = "" for token in token_str: # skip special tokens if token in self.all_special_tokens: continue # handle promotions if len(token) == 1: moves.append(next_move + token) continue # handle regular tokens if len(next_move) == 4: moves.append(next_move) next_move = token else: next_move += token moves.append(next_move) return moves def setup_app(model: GPT2LMHeadModel): """ Configures a Gradio App to use the GPT model for move generation. The model must be compatible with a UciTileTokenizer. """ tokenizer = UciTileTokenizer() # Initialize the chess board board = chess.Board() game:chess.pgn.GameNode = chess.pgn.Game() game.headers["Event"] = "Example" generate_kwargs = { "max_new_tokens": 3, "num_return_sequences": 10, "temperature": 0.5, "output_scores": True, "output_logits": True, "return_dict_in_generate": True } def make_move(input:str, node=game, board = board): # check for reset if input.lower() == 'reset': board.reset() node.root().variations.clear() return chess.svg.board(board=board), "New game!" # check for pgn if input[0] == '[' or input[:3] == '1. ': pgn = io.StringIO(input) game = chess.pgn.read_game(pgn) board.reset() node.root().variations.clear() for move in game.mainline_moves(): board.push(move) node.add_variation(move) return chess.svg.board(board=board,lastmove=move), ""#str(node.root()).split(']')[-1].strip() try: move = chess.Move.from_uci(input) if move in board.legal_moves: board.push(move) while node.next() is not None: node = node.next() node = node.add_variation(move) # get computer's move prefix = ' '.join([x.uci() for x in board.move_stack]) encoding = tokenizer(text=prefix, return_tensors='pt', )['input_ids'] output = model.generate(encoding, **generate_kwargs) # [b,p,v] new_tokens = tokenizer.batch_decode(output.sequences[:,-3:]) unique_moves, unique_indices = np.unique([x[:4] if ' ' in x else x for x in new_tokens], return_index=True) unique_indices = torch.Tensor(list(unique_indices)).to(dtype=torch.int) logits = torch.stack(output.logits) # [token, batch, vocab] logits = logits[:,unique_indices] # [token, batch, vocab] # select moves based on mean logit value for tokens 1 and 2 logit_priority_order = logits.max(dim=-1).values.T[:,:2].mean(-1).topk(len(unique_indices)).indices priority_ordered_moves = unique_moves[logit_priority_order] # if there's only 1 option, we have to pack it back into a list if isinstance(priority_ordered_moves, str): priority_ordered_moves = [priority_ordered_moves] # test if any moves are valid for uci in priority_ordered_moves: move = chess.Move.from_uci(uci) if move in board.legal_moves: board.push(move) while node.next() is not None: node = node.next() node = node.add_variation(move) return chess.svg.board(board=board,lastmove=move), "".join(str(node.root()).split("]")[-1]).strip() # no moves are valid bad_from_tiles = [chess.parse_square(x) for x in [x[:2] for x in unique_moves]] bad_to_tiles = [chess.parse_square(x) for x in [x[2:] for x in unique_moves]] arrows = [chess.svg.Arrow(tail, head, color="red") for (tail, head) in zip(bad_from_tiles, bad_to_tiles)] checks = None if board.is_check(): checks = board.pieces(chess.PIECE_TYPES[-1],board.turn).pop() return chess.svg.board(board=board,arrows=arrows, check=checks), '|'.join(unique_moves) else: return chess.svg.board(board=board,lastmove=move), f"Illegal move: {input}" except chess.InvalidMoveError: return chess.svg.board(board=board), f"Invalid UCI format: {input}" except Exception: return chess.svg.board(board=board), traceback.format_exc() input_box = gr.Textbox(None,placeholder="Enter your move in UCI format") # Define the Gradio interface iface = gr.Interface( fn=make_move, inputs=input_box, outputs=["html", "text"], examples=[['e2e4'], ['d2d4'], ['Reset']], title="Play Versus ChessGPT", description="Enter moves in UCI notation (e.g., e2e4 for pawn from e2 to e4). Enter 'reset' to restart the game.", allow_flagging='never', submit_btn = "Move", stop_btn = "Stop", clear_btn = "Clear w/o reset", ) iface.output_components[0].label = "Board" iface.output_components[0].show_label = True iface.output_components[1].label = "Move Sequence" return iface model: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(checkpoint_name) model.requires_grad_(False) iface = setup_app(model) iface.launch()