austindavis's picture
Update app.py
44c8b7b verified
raw
history blame
No virus
10.4 kB
import io
import traceback
from typing import List
import chess.pgn
import chess.svg
import gradio as gr
import numpy as np
import tokenizers
import torch
from tokenizers import models, pre_tokenizers, processors
from torch import Tensor as TT
from transformers import AutoModelForCausalLM, GPT2LMHeadModel, PreTrainedTokenizerFast
import chess
class UciTokenizer(PreTrainedTokenizerFast):
_PAD_TOKEN: str
_UNK_TOKEN: str
_EOS_TOKEN: str
_BOS_TOKEN: str
stoi: dict[str, int]
"""Integer to String mapping"""
itos: dict[int, str]
"""String to Integer Mapping. This is the vocab"""
def __init__(
self,
stoi,
itos,
pad_token,
unk_token,
bos_token,
eos_token,
name_or_path,
):
self.stoi = stoi
self.itos = itos
self._PAD_TOKEN = pad_token
self._UNK_TOKEN = unk_token
self._EOS_TOKEN = eos_token
self._BOS_TOKEN = bos_token
# Define the model
tok_model = models.WordLevel(vocab=self.stoi, unk_token=self._UNK_TOKEN)
slow_tokenizer = tokenizers.Tokenizer(tok_model)
slow_tokenizer.pre_tokenizer = self._init_pretokenizer()
# post processing adds special tokens unless explicitly ignored
post_proc = processors.TemplateProcessing(
single=f"{bos_token} $0",
pair=None,
special_tokens=[(bos_token, 1)],
)
slow_tokenizer.post_processor=post_proc
super().__init__(
tokenizer_object=slow_tokenizer,
unk_token=self._UNK_TOKEN,
bos_token=self._BOS_TOKEN,
eos_token=self._EOS_TOKEN,
pad_token=self._PAD_TOKEN,
name_or_path=name_or_path,
)
# Override the decode behavior to ensure spaces are correctly handled
def _decode(
token_ids: int | List[int],
skip_special_tokens=False,
clean_up_tokenization_spaces=False,
) -> int | List[int]:
if isinstance(token_ids, int):
return self.itos.get(token_ids, self._UNK_TOKEN)
if isinstance(token_ids, dict):
token_ids = token_ids["input_ids"]
if isinstance(token_ids, TT):
token_ids = token_ids.tolist()
if isinstance(token_ids, list):
tokens_str = [self.itos.get(xi, self._UNK_TOKEN) for xi in token_ids]
moves = self._process_str_tokens(tokens_str)
return " ".join(moves)
self._decode = _decode
def _init_pretokenizer(self) -> pre_tokenizers.PreTokenizer:
raise NotImplementedError
def _process_str_tokens(self, tokens_str: list[str]) -> list[str]:
raise NotImplementedError
def get_id2square_list() -> list[int]:
raise NotImplementedError
class UciTileTokenizer(UciTokenizer):
""" Uci tokenizer converting start/end tiles and promotion types each into individual tokens"""
stoi = {
tok: idx
for tok, idx in list(
zip(["<pad>", "<s>", "</s>", "<unk>"] + chess.SQUARE_NAMES + list("qrbn"), range(72))
)
}
itos = {
idx: tok
for tok, idx in list(
zip(["<pad>", "<s>", "</s>", "<unk>"] + chess.SQUARE_NAMES + list("qrbn"), range(72))
)
}
id2square:List[int] = [None]*4 + list(range(64))+[None]*4
"""
List mapping token IDs to squares on the chess board. Order is file then row, i.e.:
`A1, B1, C1, ..., F8, G8, H8`
"""
def get_id2square_list(self) -> List[int]:
return self.id2square
def __init__(self):
super().__init__(
self.stoi,
self.itos,
pad_token="<pad>",
unk_token="<unk>",
bos_token="<s>",
eos_token="</s>",
name_or_path="austindavis/uci_tile_tokenizer",
)
def _init_pretokenizer(self):
# Pre-tokenizer to split input into UCI moves
pattern = tokenizers.Regex(r"\d")
pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.Whitespace(),
pre_tokenizers.Split(pattern=pattern, behavior="merged_with_previous"),
]
)
return pre_tokenizer
def _process_str_tokens(self, token_str):
moves = []
next_move = ""
for token in token_str:
# skip special tokens
if token in self.all_special_tokens:
continue
# handle promotions
if len(token) == 1:
moves.append(next_move + token)
continue
# handle regular tokens
if len(next_move) == 4:
moves.append(next_move)
next_move = token
else:
next_move += token
moves.append(next_move)
return moves
def setup_app(model: GPT2LMHeadModel):
"""
Configures a Gradio App to use the GPT model for move generation.
The model must be compatible with a UciTileTokenizer.
"""
tokenizer = UciTileTokenizer()
# Initialize the chess board
board = chess.Board()
game:chess.pgn.GameNode = chess.pgn.Game()
game.headers["Event"] = "Example"
generate_kwargs = {
"max_new_tokens": 3,
"num_return_sequences": 10,
"temperature": 0.5,
"output_scores": True,
"output_logits": True,
"return_dict_in_generate": True
}
def make_move(input:str, node=game, board = board):
# check for reset
if input.lower() == 'reset':
board.reset()
node.root().variations.clear()
return chess.svg.board(board=board), "New game!"
# check for pgn
if input[0] == '[' or input[:3] == '1. ':
pgn = io.StringIO(input)
game = chess.pgn.read_game(pgn)
board.reset()
node.root().variations.clear()
for move in game.mainline_moves():
board.push(move)
node.add_variation(move)
return chess.svg.board(board=board,lastmove=move), ""#str(node.root()).split(']')[-1].strip()
try:
move = chess.Move.from_uci(input)
if move in board.legal_moves:
board.push(move)
while node.next() is not None:
node = node.next()
node = node.add_variation(move)
# get computer's move
prefix = ' '.join([x.uci() for x in board.move_stack])
encoding = tokenizer(text=prefix,
return_tensors='pt',
)['input_ids']
output = model.generate(encoding, **generate_kwargs) # [b,p,v]
new_tokens = tokenizer.batch_decode(output.sequences[:,-3:])
unique_moves, unique_indices = np.unique([x[:4] if ' ' in x else x for x in new_tokens], return_index=True)
unique_indices = torch.Tensor(list(unique_indices)).to(dtype=torch.int)
logits = torch.stack(output.logits) # [token, batch, vocab]
logits = logits[:,unique_indices] # [token, batch, vocab]
# select moves based on mean logit value for tokens 1 and 2
logit_priority_order = logits.max(dim=-1).values.T[:,:2].mean(-1).topk(len(unique_indices)).indices
priority_ordered_moves = unique_moves[logit_priority_order]
# if there's only 1 option, we have to pack it back into a list
if isinstance(priority_ordered_moves, str):
priority_ordered_moves = [priority_ordered_moves]
# test if any moves are valid
for uci in priority_ordered_moves:
move = chess.Move.from_uci(uci)
if move in board.legal_moves:
board.push(move)
while node.next() is not None:
node = node.next()
node = node.add_variation(move)
return chess.svg.board(board=board,lastmove=move), "".join(str(node.root()).split("]")[-1]).strip()
# no moves are valid
bad_from_tiles = [chess.parse_square(x) for x in [x[:2] for x in unique_moves]]
bad_to_tiles = [chess.parse_square(x) for x in [x[2:] for x in unique_moves]]
arrows = [chess.svg.Arrow(tail, head, color="red") for (tail, head) in zip(bad_from_tiles, bad_to_tiles)]
checks = None
if board.is_check():
checks = board.pieces(chess.PIECE_TYPES[-1],board.turn).pop()
return chess.svg.board(board=board,arrows=arrows, check=checks), '|'.join(unique_moves)
else:
return chess.svg.board(board=board,lastmove=move), f"Illegal move: {input}"
except chess.InvalidMoveError:
return chess.svg.board(board=board), f"Invalid UCI format: {input}"
except Exception:
return chess.svg.board(board=board), traceback.format_exc()
input_box = gr.Textbox(None,placeholder="Enter your move in UCI format")
# Define the Gradio interface
iface = gr.Interface(
fn=make_move,
inputs=input_box,
outputs=["html", "text"],
examples=[['e2e4'], ['d2d4'], ['Reset']],
title="Play Versus ChessGPT",
description="Enter moves in UCI notation (e.g., e2e4 for pawn from e2 to e4). Enter 'reset' to restart the game.",
allow_flagging='never',
submit_btn = "Move",
stop_btn = "Stop",
clear_btn = "Clear w/o reset",
)
iface.output_components[0].label = "Board"
iface.output_components[0].show_label = True
iface.output_components[1].label = "Move Sequence"
return iface
checkpoint_name = "austindavis/gpt2-lichess-uci-202306"
model: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(checkpoint_name)
model.requires_grad_(False)
iface = setup_app(model)
iface.launch()