Spaces:
Sleeping
Sleeping
""" | |
Gradio interface for plotting attention. | |
""" | |
import chess | |
import gradio as gr | |
import torch | |
import uuid | |
import re | |
from . import constants, state, visualisation | |
def compute_cache( | |
game_pgn, | |
board_fen, | |
attention_layer, | |
attention_head, | |
comp_index, | |
state_cache, | |
state_board_index, | |
): | |
if game_pgn == "" and board_fen != "": | |
board = chess.Board(board_fen) | |
fen_list = [board.fen()] | |
else: | |
board = chess.Board() | |
fen_list = [board.fen()] | |
for move in game_pgn.split(): | |
if move.endswith("."): | |
continue | |
try: | |
board.push_san(move) | |
fen_list.append(board.fen()) | |
except ValueError: | |
gr.Warning(f"Invalid move {move}, stopping before it.") | |
break | |
state_cache = [(fen, state.model_cache(fen)) for fen in fen_list] | |
return ( | |
*make_plot( | |
attention_layer, attention_head, comp_index, state_cache, state_board_index | |
), | |
state_cache, | |
) | |
def make_plot( | |
attention_layer, | |
attention_head, | |
comp_index, | |
state_cache, | |
state_board_index, | |
): | |
if state_cache is None: | |
gr.Warning("Cache not computed!") | |
return None, None, None, None, None | |
fen, (out, cache) = state_cache[state_board_index] | |
attn_list = [a[0, attention_head - 1] for a in cache[attention_layer - 1]] | |
prompt_attn, *comp_attn = attn_list | |
comp_attn.insert(0, prompt_attn[-1:]) | |
comp_attn = [a.squeeze(0) for a in comp_attn] | |
if len(comp_attn) != 5: | |
raise NotImplementedError("This is not implemented yet.") | |
config_total = meta_total = dump_total = 0 | |
config_done = False | |
heatmap = torch.zeros(64) | |
h_index = 0 | |
for i, t_o in enumerate(out[0]): | |
try: | |
t_attn = comp_attn[comp_index - 1][i] | |
if (i < 3) or (i > len(out[0]) - 10): | |
dump_total += t_attn | |
continue | |
t_str = state.model.tokenizer.decode(t_o) | |
if t_str.startswith(" ") and h_index > 0: | |
config_done = True | |
if not config_done: | |
if t_str == "/": | |
dump_total += t_attn | |
continue | |
t_str = re.sub(r"\d", lambda m: "0" * int(m.group(0)), t_str) | |
config_total += t_attn | |
t_str_len = len(t_str.strip()) | |
pre_t_attn = t_attn / t_str_len | |
for j in range(t_str_len): | |
heatmap[h_index + j] = pre_t_attn | |
h_index += t_str_len | |
else: | |
meta_total += t_attn | |
except IndexError: | |
break | |
raw_attention = comp_attn[comp_index - 1] | |
highlited_tokens = [ | |
(state.model.tokenizer.decode(out[0][i]), raw_attention[i]) | |
for i in range(len(raw_attention)) | |
] | |
uci_move = state.model.tokenizer.decode(out[0][-5:-1]).strip() | |
board = chess.Board(fen) | |
heatmap = heatmap.view(8, 8).flip(0).view(64) | |
move = chess.Move.from_uci(uci_move) | |
svg_board, fig = visualisation.render_heatmap( | |
board, heatmap, arrows=[(move.from_square, move.to_square)] | |
) | |
info = ( | |
f"[Completion] Complete: '{state.model.tokenizer.decode(out[0][-5:])}'" | |
f" Chosen: '{state.model.tokenizer.decode(out[0][-5:][comp_index-1])}'" | |
f"\n[Distribution] Config: {config_total:.2f} Meta: {meta_total:.2f} Dump: {dump_total:.2f}" | |
) | |
id = str(uuid.uuid4()) | |
with open(f"{constants.FIGURE_DIRECTORY}/board_{id}.svg", "w") as f: | |
f.write(svg_board) | |
return ( | |
board.fen(), | |
info, | |
fig, | |
f"{constants.FIGURE_DIRECTORY}/board_{id}.svg", | |
highlited_tokens, | |
) | |
def previous_board( | |
attention_layer, | |
attention_head, | |
comp_index, | |
state_cache, | |
state_board_index, | |
): | |
state_board_index -= 1 | |
if state_board_index < 0: | |
gr.Warning("Already at first board.") | |
state_board_index = 0 | |
return ( | |
*make_plot( | |
attention_layer, attention_head, comp_index, state_cache, state_board_index | |
), | |
state_board_index, | |
) | |
def next_board( | |
attention_layer, | |
attention_head, | |
comp_index, | |
state_cache, | |
state_board_index, | |
): | |
state_board_index += 1 | |
if state_board_index >= len(state_cache): | |
gr.Warning("Already at last board.") | |
state_board_index = len(state_cache) - 1 | |
return ( | |
*make_plot( | |
attention_layer, attention_head, comp_index, state_cache, state_board_index | |
), | |
state_board_index, | |
) | |
with gr.Blocks() as interface: | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Group(): | |
gr.Markdown( | |
"Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)." | |
) | |
game_pgn = gr.Textbox( | |
label="Game PGN", | |
lines=1, | |
) | |
board_fen = gr.Textbox( | |
label="Board FEN", | |
lines=1, | |
max_lines=1, | |
) | |
compute_cache_button = gr.Button("Compute cache") | |
with gr.Group(): | |
with gr.Row(): | |
attention_layer = gr.Slider( | |
label="Attention layer", | |
minimum=1, | |
maximum=12, | |
step=1, | |
value=1, | |
) | |
attention_head = gr.Slider( | |
label="Attention head", | |
minimum=1, | |
maximum=12, | |
step=1, | |
value=1, | |
) | |
comp_index = gr.Slider( | |
label="Completion index", | |
minimum=1, | |
maximum=6, | |
step=1, | |
value=1, | |
) | |
with gr.Row(): | |
previous_board_button = gr.Button("Previous board") | |
next_board_button = gr.Button("Next board") | |
current_board_fen = gr.Textbox( | |
label="Board FEN", | |
lines=1, | |
max_lines=1, | |
) | |
info = gr.Textbox( | |
label="Info", | |
lines=1, | |
info=( | |
"'Config' refers to the board configuration tokens." | |
"\n'Meta' to the additional board tokens (like color or castling)." | |
"\n'Dump' to the rest of the tokens (including '/')." | |
), | |
) | |
gr.Markdown( | |
"Note that only the 'Config' attention is plotted.\n\nSee below for the raw attention." | |
) | |
raw_attention_html = gr.HighlightedText( | |
label="Raw attention", | |
) | |
with gr.Column(): | |
image_board = gr.Image(label="Board") | |
colorbar = gr.Plot(label="Colorbar") | |
static_inputs = [ | |
attention_layer, | |
attention_head, | |
comp_index, | |
] | |
static_outputs = [ | |
current_board_fen, | |
info, | |
colorbar, | |
image_board, | |
raw_attention_html, | |
] | |
state_cache = gr.State(value=None) | |
state_board_index = gr.State(value=0) | |
compute_cache_button.click( | |
compute_cache, | |
inputs=[game_pgn, board_fen, *static_inputs, state_cache, state_board_index], | |
outputs=[*static_outputs, state_cache], | |
) | |
previous_board_button.click( | |
previous_board, | |
inputs=[*static_inputs, state_cache, state_board_index], | |
outputs=[*static_outputs, state_board_index], | |
) | |
next_board_button.click( | |
next_board, | |
inputs=[*static_inputs, state_cache, state_board_index], | |
outputs=[*static_outputs, state_board_index], | |
) | |
attention_layer.change( | |
make_plot, | |
inputs=[*static_inputs, state_cache, state_board_index], | |
outputs=[*static_outputs], | |
) | |
attention_head.change( | |
make_plot, | |
inputs=[*static_inputs, state_cache, state_board_index], | |
outputs=[*static_outputs], | |
) | |
comp_index.change( | |
make_plot, | |
inputs=[*static_inputs, state_cache, state_board_index], | |
outputs=[*static_outputs], | |
) | |