viz-gpt2-stockfish-debug / src /attention_interface.py
Xmaster6y's picture
typo
2f281df unverified
"""
Gradio interface for plotting attention.
"""
import chess
import gradio as gr
import torch
import uuid
import re
from . import constants, state, visualisation
def compute_cache(
game_pgn,
board_fen,
attention_layer,
attention_head,
comp_index,
state_cache,
state_board_index,
):
if game_pgn == "" and board_fen != "":
board = chess.Board(board_fen)
fen_list = [board.fen()]
else:
board = chess.Board()
fen_list = [board.fen()]
for move in game_pgn.split():
if move.endswith("."):
continue
try:
board.push_san(move)
fen_list.append(board.fen())
except ValueError:
gr.Warning(f"Invalid move {move}, stopping before it.")
break
state_cache = [(fen, state.model_cache(fen)) for fen in fen_list]
return (
*make_plot(
attention_layer, attention_head, comp_index, state_cache, state_board_index
),
state_cache,
)
def make_plot(
attention_layer,
attention_head,
comp_index,
state_cache,
state_board_index,
):
if state_cache is None:
gr.Warning("Cache not computed!")
return None, None, None, None, None
fen, (out, cache) = state_cache[state_board_index]
attn_list = [a[0, attention_head - 1] for a in cache[attention_layer - 1]]
prompt_attn, *comp_attn = attn_list
comp_attn.insert(0, prompt_attn[-1:])
comp_attn = [a.squeeze(0) for a in comp_attn]
if len(comp_attn) != 5:
raise NotImplementedError("This is not implemented yet.")
config_total = meta_total = dump_total = 0
config_done = False
heatmap = torch.zeros(64)
h_index = 0
for i, t_o in enumerate(out[0]):
try:
t_attn = comp_attn[comp_index - 1][i]
if (i < 3) or (i > len(out[0]) - 10):
dump_total += t_attn
continue
t_str = state.model.tokenizer.decode(t_o)
if t_str.startswith(" ") and h_index > 0:
config_done = True
if not config_done:
if t_str == "/":
dump_total += t_attn
continue
t_str = re.sub(r"\d", lambda m: "0" * int(m.group(0)), t_str)
config_total += t_attn
t_str_len = len(t_str.strip())
pre_t_attn = t_attn / t_str_len
for j in range(t_str_len):
heatmap[h_index + j] = pre_t_attn
h_index += t_str_len
else:
meta_total += t_attn
except IndexError:
break
raw_attention = comp_attn[comp_index - 1]
highlited_tokens = [
(state.model.tokenizer.decode(out[0][i]), raw_attention[i])
for i in range(len(raw_attention))
]
uci_move = state.model.tokenizer.decode(out[0][-5:-1]).strip()
board = chess.Board(fen)
heatmap = heatmap.view(8, 8).flip(0).view(64)
move = chess.Move.from_uci(uci_move)
svg_board, fig = visualisation.render_heatmap(
board, heatmap, arrows=[(move.from_square, move.to_square)]
)
info = (
f"[Completion] Complete: '{state.model.tokenizer.decode(out[0][-5:])}'"
f" Chosen: '{state.model.tokenizer.decode(out[0][-5:][comp_index-1])}'"
f"\n[Distribution] Config: {config_total:.2f} Meta: {meta_total:.2f} Dump: {dump_total:.2f}"
)
id = str(uuid.uuid4())
with open(f"{constants.FIGURE_DIRECTORY}/board_{id}.svg", "w") as f:
f.write(svg_board)
return (
board.fen(),
info,
fig,
f"{constants.FIGURE_DIRECTORY}/board_{id}.svg",
highlited_tokens,
)
def previous_board(
attention_layer,
attention_head,
comp_index,
state_cache,
state_board_index,
):
state_board_index -= 1
if state_board_index < 0:
gr.Warning("Already at first board.")
state_board_index = 0
return (
*make_plot(
attention_layer, attention_head, comp_index, state_cache, state_board_index
),
state_board_index,
)
def next_board(
attention_layer,
attention_head,
comp_index,
state_cache,
state_board_index,
):
state_board_index += 1
if state_board_index >= len(state_cache):
gr.Warning("Already at last board.")
state_board_index = len(state_cache) - 1
return (
*make_plot(
attention_layer, attention_head, comp_index, state_cache, state_board_index
),
state_board_index,
)
with gr.Blocks() as interface:
with gr.Row():
with gr.Column():
with gr.Group():
gr.Markdown(
"Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)."
)
game_pgn = gr.Textbox(
label="Game PGN",
lines=1,
)
board_fen = gr.Textbox(
label="Board FEN",
lines=1,
max_lines=1,
)
compute_cache_button = gr.Button("Compute cache")
with gr.Group():
with gr.Row():
attention_layer = gr.Slider(
label="Attention layer",
minimum=1,
maximum=12,
step=1,
value=1,
)
attention_head = gr.Slider(
label="Attention head",
minimum=1,
maximum=12,
step=1,
value=1,
)
comp_index = gr.Slider(
label="Completion index",
minimum=1,
maximum=6,
step=1,
value=1,
)
with gr.Row():
previous_board_button = gr.Button("Previous board")
next_board_button = gr.Button("Next board")
current_board_fen = gr.Textbox(
label="Board FEN",
lines=1,
max_lines=1,
)
info = gr.Textbox(
label="Info",
lines=1,
info=(
"'Config' refers to the board configuration tokens."
"\n'Meta' to the additional board tokens (like color or castling)."
"\n'Dump' to the rest of the tokens (including '/')."
),
)
gr.Markdown(
"Note that only the 'Config' attention is plotted.\n\nSee below for the raw attention."
)
raw_attention_html = gr.HighlightedText(
label="Raw attention",
)
with gr.Column():
image_board = gr.Image(label="Board")
colorbar = gr.Plot(label="Colorbar")
static_inputs = [
attention_layer,
attention_head,
comp_index,
]
static_outputs = [
current_board_fen,
info,
colorbar,
image_board,
raw_attention_html,
]
state_cache = gr.State(value=None)
state_board_index = gr.State(value=0)
compute_cache_button.click(
compute_cache,
inputs=[game_pgn, board_fen, *static_inputs, state_cache, state_board_index],
outputs=[*static_outputs, state_cache],
)
previous_board_button.click(
previous_board,
inputs=[*static_inputs, state_cache, state_board_index],
outputs=[*static_outputs, state_board_index],
)
next_board_button.click(
next_board,
inputs=[*static_inputs, state_cache, state_board_index],
outputs=[*static_outputs, state_board_index],
)
attention_layer.change(
make_plot,
inputs=[*static_inputs, state_cache, state_board_index],
outputs=[*static_outputs],
)
attention_head.change(
make_plot,
inputs=[*static_inputs, state_cache, state_board_index],
outputs=[*static_outputs],
)
comp_index.change(
make_plot,
inputs=[*static_inputs, state_cache, state_board_index],
outputs=[*static_outputs],
)