Xmaster6y commited on
Commit
55ecc31
1 Parent(s): c50cbfd

attention interface

Browse files
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Gpt2 Stockfish Debug Viz
3
  emoji: 🔥
4
  colorFrom: blue
5
  colorTo: red
 
1
  ---
2
+ title: GPT-2 Stockfish Debug
3
  emoji: 🔥
4
  colorFrom: blue
5
  colorTo: red
app.py CHANGED
@@ -7,6 +7,7 @@ import wandb
7
  import gradio as gr
8
 
9
  from src import (
 
10
  call_interface,
11
  play_interface,
12
  constants,
@@ -16,10 +17,12 @@ from src import (
16
  demo = gr.TabbedInterface(
17
  [
18
  play_interface.interface,
 
19
  call_interface.interface,
20
  ],
21
  [
22
  "Play",
 
23
  "Call",
24
  ],
25
  title="GPT-2 Stockfish Debug",
 
7
  import gradio as gr
8
 
9
  from src import (
10
+ attention_interface,
11
  call_interface,
12
  play_interface,
13
  constants,
 
17
  demo = gr.TabbedInterface(
18
  [
19
  play_interface.interface,
20
+ attention_interface.interface,
21
  call_interface.interface,
22
  ],
23
  [
24
  "Play",
25
+ "Attention Viz",
26
  "Call",
27
  ],
28
  title="GPT-2 Stockfish Debug",
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  python-chess
2
  wandb
 
 
1
  python-chess
2
  wandb
3
+ nnsight
src/attention_interface.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio interface for plotting attention.
3
+ """
4
+
5
+ import chess
6
+ import gradio as gr
7
+ import torch
8
+ import uuid
9
+ import re
10
+
11
+ from . import constants, state, visualisation
12
+
13
+
14
+ def compute_cache(
15
+ game_pgn,
16
+ attention_layer,
17
+ attention_head,
18
+ comp_index,
19
+ state_cache,
20
+ state_board_index,
21
+ ):
22
+ board = chess.Board()
23
+ fen_list = [board.fen()]
24
+ for move in game_pgn.split():
25
+ if move.endswith("."):
26
+ continue
27
+ try:
28
+ board.push_san(move)
29
+ fen_list.append(board.fen())
30
+ except ValueError:
31
+ gr.Warning(f"Invalid move {move}, stopping before it.")
32
+ break
33
+ state_cache = [(fen, state.model_cache(fen)) for fen in fen_list]
34
+ return (
35
+ *make_plot(
36
+ attention_layer, attention_head, comp_index, state_cache, state_board_index
37
+ ),
38
+ state_cache,
39
+ )
40
+
41
+
42
+ def make_plot(
43
+ attention_layer,
44
+ attention_head,
45
+ comp_index,
46
+ state_cache,
47
+ state_board_index,
48
+ ):
49
+ if state_cache is None:
50
+ gr.Warning("Cache not computed!")
51
+ return None, None, None, None, None
52
+
53
+ fen, (out, cache) = state_cache[state_board_index]
54
+ attn_list = [a[0, attention_head - 1] for a in cache[attention_layer - 1]]
55
+ prompt_attn, *comp_attn = attn_list
56
+ comp_attn.insert(0, prompt_attn[-1:])
57
+ comp_attn = [a.squeeze(0) for a in comp_attn]
58
+ if len(comp_attn) != 5:
59
+ raise NotImplementedError("This is not implemented yet.")
60
+
61
+ config_total = meta_total = dump_total = 0
62
+ config_done = False
63
+ heatmap = torch.zeros(64)
64
+ h_index = 0
65
+ for i, t_o in enumerate(out[0]):
66
+ try:
67
+ t_attn = comp_attn[comp_index - 1][i]
68
+ if (i < 3) or (i > len(out[0]) - 10):
69
+ dump_total += t_attn
70
+ continue
71
+ t_str = state.model.tokenizer.decode(t_o)
72
+ if t_str.startswith(" ") and h_index > 0:
73
+ config_done = True
74
+ if not config_done:
75
+ if t_str == "/":
76
+ dump_total += t_attn
77
+ continue
78
+ t_str = re.sub(r"\d", lambda m: "0" * int(m.group(0)), t_str)
79
+ config_total += t_attn
80
+ t_str_len = len(t_str.strip())
81
+ pre_t_attn = t_attn / t_str_len
82
+ for j in range(t_str_len):
83
+ heatmap[h_index + j] = pre_t_attn
84
+ h_index += t_str_len
85
+ else:
86
+ meta_total += t_attn
87
+ except IndexError:
88
+ break
89
+ raw_attention = comp_attn[comp_index - 1]
90
+ highlited_tokens = [
91
+ (state.model.tokenizer.decode(out[0][i]), raw_attention[i])
92
+ for i in range(len(raw_attention))
93
+ ]
94
+ uci_move = state.model.tokenizer.decode(out[0][-5:-1]).strip()
95
+ board = chess.Board(fen)
96
+ heatmap = heatmap.view(8, 8).flip(0).view(64)
97
+ move = chess.Move.from_uci(uci_move)
98
+ svg_board, fig = visualisation.render_heatmap(
99
+ board, heatmap, arrows=[(move.from_square, move.to_square)]
100
+ )
101
+ info = (
102
+ f"[Completion] Complete: '{state.model.tokenizer.decode(out[0][-5:])}'"
103
+ f" Chosen: '{state.model.tokenizer.decode(out[0][-5:][comp_index-1])}'"
104
+ f"\n[Distribution] Config: {config_total:.2f} Meta: {meta_total:.2f} Dump: {dump_total:.2f}"
105
+ )
106
+ id = str(uuid.uuid4())
107
+ with open(f"{constants.FIGURE_DIRECTORY}/board_{id}.svg", "w") as f:
108
+ f.write(svg_board)
109
+ return (
110
+ board.fen(),
111
+ info,
112
+ fig,
113
+ f"{constants.FIGURE_DIRECTORY}/board_{id}.svg",
114
+ highlited_tokens,
115
+ )
116
+
117
+
118
+ def previous_board(
119
+ attention_layer,
120
+ attention_head,
121
+ comp_index,
122
+ state_cache,
123
+ state_board_index,
124
+ ):
125
+ state_board_index -= 1
126
+ if state_board_index < 0:
127
+ gr.Warning("Already at first board.")
128
+ state_board_index = 0
129
+ return (
130
+ *make_plot(
131
+ attention_layer, attention_head, comp_index, state_cache, state_board_index
132
+ ),
133
+ state_board_index,
134
+ )
135
+
136
+
137
+ def next_board(
138
+ attention_layer,
139
+ attention_head,
140
+ comp_index,
141
+ state_cache,
142
+ state_board_index,
143
+ ):
144
+ state_board_index += 1
145
+ if state_board_index >= len(state_cache):
146
+ gr.Warning("Already at last board.")
147
+ state_board_index = len(state_cache) - 1
148
+ return (
149
+ *make_plot(
150
+ attention_layer, attention_head, comp_index, state_cache, state_board_index
151
+ ),
152
+ state_board_index,
153
+ )
154
+
155
+
156
+ with gr.Blocks() as interface:
157
+ with gr.Row():
158
+ with gr.Column():
159
+ game_pgn = gr.Textbox(
160
+ label="Game PGN",
161
+ lines=1,
162
+ )
163
+ compute_cache_button = gr.Button("Compute cache")
164
+ with gr.Group():
165
+ with gr.Row():
166
+ attention_layer = gr.Slider(
167
+ label="Attention layer",
168
+ minimum=1,
169
+ maximum=12,
170
+ step=1,
171
+ value=1,
172
+ )
173
+ attention_head = gr.Slider(
174
+ label="Attention head",
175
+ minimum=1,
176
+ maximum=12,
177
+ step=1,
178
+ value=1,
179
+ )
180
+ comp_index = gr.Slider(
181
+ label="Completion index",
182
+ minimum=1,
183
+ maximum=6,
184
+ step=1,
185
+ value=1,
186
+ )
187
+ with gr.Row():
188
+ previous_board_button = gr.Button("Previous board")
189
+ next_board_button = gr.Button("Next board")
190
+ current_board_fen = gr.Textbox(
191
+ label="Board FEN",
192
+ lines=1,
193
+ max_lines=1,
194
+ )
195
+ info = gr.Textbox(
196
+ label="Info",
197
+ lines=1,
198
+ info=(
199
+ "'Config' refers to the board configuration tokens."
200
+ "\n'Meta' to the additional board tokens (like color or castling)."
201
+ "\n'Dump' to the rest of the tokens (including '/')."
202
+ ),
203
+ )
204
+ gr.Markdown(
205
+ "Note that only the 'Config' attention is plotted.\n\nSee below for the raw attention."
206
+ )
207
+ raw_attention_html = gr.HighlightedText(
208
+ label="Raw attention",
209
+ )
210
+ with gr.Column():
211
+ image_board = gr.Image(label="Board")
212
+ colorbar = gr.Plot(label="Colorbar")
213
+
214
+ static_inputs = [
215
+ attention_layer,
216
+ attention_head,
217
+ comp_index,
218
+ ]
219
+ static_outputs = [
220
+ current_board_fen,
221
+ info,
222
+ colorbar,
223
+ image_board,
224
+ raw_attention_html,
225
+ ]
226
+
227
+ state_cache = gr.State(value=None)
228
+ state_board_index = gr.State(value=0)
229
+ compute_cache_button.click(
230
+ compute_cache,
231
+ inputs=[game_pgn, *static_inputs, state_cache, state_board_index],
232
+ outputs=[*static_outputs, state_cache],
233
+ )
234
+
235
+ previous_board_button.click(
236
+ previous_board,
237
+ inputs=[*static_inputs, state_cache, state_board_index],
238
+ outputs=[*static_outputs, state_board_index],
239
+ )
240
+ next_board_button.click(
241
+ next_board,
242
+ inputs=[*static_inputs, state_cache, state_board_index],
243
+ outputs=[*static_outputs, state_board_index],
244
+ )
245
+ attention_layer.change(
246
+ make_plot,
247
+ inputs=[*static_inputs, state_cache, state_board_index],
248
+ outputs=[*static_outputs],
249
+ )
250
+ attention_head.change(
251
+ make_plot,
252
+ inputs=[*static_inputs, state_cache, state_board_index],
253
+ outputs=[*static_outputs],
254
+ )
255
+ comp_index.change(
256
+ make_plot,
257
+ inputs=[*static_inputs, state_cache, state_board_index],
258
+ outputs=[*static_outputs],
259
+ )
src/call_interface.py CHANGED
@@ -9,9 +9,7 @@ import gradio as gr
9
  model_name = "yp-edu/gpt2-stockfish-debug"
10
 
11
  headers = {"X-Wait-For-Model": "true"}
12
- client = huggingface_hub.InferenceClient(
13
- model=model_name, headers=headers
14
- )
15
 
16
  inputs = gr.Textbox(label="Prompt")
17
  outputs = gr.Textbox(label="Completion")
 
9
  model_name = "yp-edu/gpt2-stockfish-debug"
10
 
11
  headers = {"X-Wait-For-Model": "true"}
12
+ client = huggingface_hub.InferenceClient(model=model_name, headers=headers)
 
 
13
 
14
  inputs = gr.Textbox(label="Prompt")
15
  outputs = gr.Textbox(label="Completion")
src/play_interface.py CHANGED
@@ -15,17 +15,20 @@ import gradio as gr
15
  from . import constants
16
 
17
  model_name = "yp-edu/gpt2-stockfish-debug"
18
- headers = {"X-Wait-For-Model": "true"}
19
- client = huggingface_hub.InferenceClient(
20
- model=model_name, headers=headers
21
- )
 
22
  inference_fn = client.text_generation
23
 
24
 
25
  def plot_board(
26
  board: chess.Board,
27
- orientation: bool = chess.WHITE,
28
  ):
 
 
29
  try:
30
  last_move = board.peek()
31
  arrows = [(last_move.from_square, last_move.to_square)]
@@ -47,17 +50,17 @@ def plot_board(
47
  f.write(svg_board)
48
  return f"{constants.FIGURE_DIRECTORY}/board_{id}.svg"
49
 
 
50
  def render_board(
51
  current_board: chess.Board,
52
- orientation: Optional[bool] = chess.WHITE,
53
  ):
54
  fen = current_board.fen()
55
  pgn = current_board.root().variation_san(current_board.move_stack)
56
- if orientation is None:
57
- orientation = current_board.turn
58
  image_board = plot_board(current_board, orientation=orientation)
59
  return fen, pgn, "", image_board
60
 
 
61
  def play_user_move(
62
  uci_move: str,
63
  current_board: chess.Board,
@@ -65,6 +68,7 @@ def play_user_move(
65
  current_board.push_uci(uci_move)
66
  return current_board
67
 
 
68
  def play_ai_move(
69
  current_board: chess.Board,
70
  temperature: float = 0.1,
@@ -76,6 +80,7 @@ def play_ai_move(
76
  current_board.push_uci(uci_move.strip())
77
  return current_board
78
 
 
79
  def try_play_move(
80
  username: str,
81
  move_to_play: str,
@@ -83,7 +88,10 @@ def try_play_move(
83
  ):
84
  if current_board.is_game_over():
85
  gr.Warning("The game is already over")
86
- return *render_board(current_board), current_board
 
 
 
87
  try:
88
  current_board = play_user_move(move_to_play.strip(), current_board)
89
  if current_board.is_game_over():
@@ -93,17 +101,20 @@ def try_play_move(
93
  {
94
  "username": username,
95
  "winin": current_board.fullmove_number,
96
- "pgn": current_board.root().variation_san(current_board.move_stack),
 
 
97
  }
98
  )
99
  run.finish()
100
- return *render_board(current_board, orientation=not current_board.turn), current_board
 
 
 
101
  except:
102
  gr.Warning("Invalid move")
103
  return *render_board(current_board), current_board
104
- temperature_retries = [
105
- (i+1)/10 for i in range(10)
106
- ]
107
  for temperature in temperature_retries:
108
  try:
109
  current_board = play_ai_move(current_board, temperature=temperature)
@@ -187,6 +198,7 @@ with gr.Blocks() as interface:
187
  if is_ai_white:
188
  board = play_ai_move(board)
189
  return *render_board(board), board
 
190
  reset_button.click(
191
  reset_board,
192
  outputs=[*static_outputs, state_board],
 
15
  from . import constants
16
 
17
  model_name = "yp-edu/gpt2-stockfish-debug"
18
+ headers = {
19
+ "X-Wait-For-Model": "true",
20
+ "X-Use-Cache": "false",
21
+ }
22
+ client = huggingface_hub.InferenceClient(model=model_name, headers=headers)
23
  inference_fn = client.text_generation
24
 
25
 
26
  def plot_board(
27
  board: chess.Board,
28
+ orientation: Optional[bool] = None,
29
  ):
30
+ if orientation is None:
31
+ orientation = board.turn
32
  try:
33
  last_move = board.peek()
34
  arrows = [(last_move.from_square, last_move.to_square)]
 
50
  f.write(svg_board)
51
  return f"{constants.FIGURE_DIRECTORY}/board_{id}.svg"
52
 
53
+
54
  def render_board(
55
  current_board: chess.Board,
56
+ orientation: Optional[bool] = None,
57
  ):
58
  fen = current_board.fen()
59
  pgn = current_board.root().variation_san(current_board.move_stack)
 
 
60
  image_board = plot_board(current_board, orientation=orientation)
61
  return fen, pgn, "", image_board
62
 
63
+
64
  def play_user_move(
65
  uci_move: str,
66
  current_board: chess.Board,
 
68
  current_board.push_uci(uci_move)
69
  return current_board
70
 
71
+
72
  def play_ai_move(
73
  current_board: chess.Board,
74
  temperature: float = 0.1,
 
80
  current_board.push_uci(uci_move.strip())
81
  return current_board
82
 
83
+
84
  def try_play_move(
85
  username: str,
86
  move_to_play: str,
 
88
  ):
89
  if current_board.is_game_over():
90
  gr.Warning("The game is already over")
91
+ return (
92
+ *render_board(current_board, orientation=not current_board.turn),
93
+ current_board,
94
+ )
95
  try:
96
  current_board = play_user_move(move_to_play.strip(), current_board)
97
  if current_board.is_game_over():
 
101
  {
102
  "username": username,
103
  "winin": current_board.fullmove_number,
104
+ "pgn": current_board.root().variation_san(
105
+ current_board.move_stack
106
+ ),
107
  }
108
  )
109
  run.finish()
110
+ return (
111
+ *render_board(current_board, orientation=not current_board.turn),
112
+ current_board,
113
+ )
114
  except:
115
  gr.Warning("Invalid move")
116
  return *render_board(current_board), current_board
117
+ temperature_retries = [(i + 1) / 10 for i in range(10)]
 
 
118
  for temperature in temperature_retries:
119
  try:
120
  current_board = play_ai_move(current_board, temperature=temperature)
 
198
  if is_ai_white:
199
  board = play_ai_move(board)
200
  return *render_board(board), board
201
+
202
  reset_button.click(
203
  reset_board,
204
  outputs=[*static_outputs, state_board],
src/state.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Global state of the app.
2
+ """
3
+
4
+ import re
5
+
6
+ from transformers import AutoConfig
7
+ import torch
8
+ from nnsight import LanguageModel
9
+
10
+ conf = AutoConfig.from_pretrained("yp-edu/gpt2-stockfish-debug")
11
+ model = LanguageModel("yp-edu/gpt2-stockfish-debug")
12
+ model.eval()
13
+
14
+
15
+ def make_prompt(fen):
16
+ board, player, castling, *fen_remaining = fen.split()
17
+ board = re.sub(r"(\d)", lambda m: "0" * int(m.group(1)), board)
18
+ spaced_board = " ".join(board)
19
+ spaced_castling = " ".join(castling)
20
+ full_fen = f"{spaced_board} {player} {spaced_castling} {' '.join(fen_remaining)}"
21
+ return f"FEN: {full_fen} \nMOVE:"
22
+
23
+
24
+ def model_cache(fen):
25
+ global model
26
+ prompt = f"FEN: {fen}\nMOVE:"
27
+ attentions = {i: [] for i in range(12)}
28
+ with model.generate(prompt, max_new_tokens=10, output_attentions=True) as tracer:
29
+ out = model.generator.output.save()
30
+ for i in range(10):
31
+ for i in range(12):
32
+ attentions[i].append(model.transformer.h[i].attn.output[2].save())
33
+ tracer.next()
34
+ real_attentions = {}
35
+ for i in range(12):
36
+ real_attentions[i] = []
37
+ for a in attentions[i]:
38
+ try:
39
+ _ = a.shape
40
+ real_attentions[i].append(a)
41
+ except ValueError:
42
+ break
43
+ return out, real_attentions
44
+
45
+
46
+ def attribute_seqence(fen, out, attn_tensor):
47
+ global model
48
+
49
+ out_str = model.tokenizer.batch_decode(out)[0]
src/visualisation.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visualisation utils.
3
+ """
4
+
5
+ import chess
6
+ import chess.svg
7
+ import matplotlib
8
+ import matplotlib.pyplot as plt
9
+
10
+
11
+ COLOR_MAP = matplotlib.colormaps["RdYlBu_r"].resampled(1000)
12
+ ALPHA = 1.0
13
+
14
+
15
+ def render_heatmap(
16
+ board,
17
+ heatmap,
18
+ square=None,
19
+ vmin=None,
20
+ vmax=None,
21
+ arrows=None,
22
+ normalise="none",
23
+ ):
24
+ """
25
+ Render a heatmap on the board.
26
+ """
27
+ if normalise == "abs":
28
+ a_max = heatmap.abs().max()
29
+ if a_max != 0:
30
+ heatmap = heatmap / a_max
31
+ vmin = -1
32
+ vmax = 1
33
+ if vmin is None:
34
+ vmin = heatmap.min()
35
+ if vmax is None:
36
+ vmax = heatmap.max()
37
+ norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax, clip=False)
38
+
39
+ color_dict = {}
40
+ for square_index in range(64):
41
+ color = COLOR_MAP(norm(heatmap[square_index]))
42
+ color = (*color[:3], ALPHA)
43
+ color_dict[square_index] = matplotlib.colors.to_hex(color, keep_alpha=True)
44
+ fig = plt.figure(figsize=(6, 0.6))
45
+ ax = plt.gca()
46
+ ax.axis("off")
47
+ fig.colorbar(
48
+ matplotlib.cm.ScalarMappable(norm=norm, cmap=COLOR_MAP),
49
+ ax=ax,
50
+ orientation="horizontal",
51
+ fraction=1.0,
52
+ )
53
+ if square is not None:
54
+ try:
55
+ check = chess.parse_square(square)
56
+ except ValueError:
57
+ check = None
58
+ else:
59
+ check = None
60
+ if arrows is None:
61
+ arrows = []
62
+ plt.close()
63
+ return (
64
+ chess.svg.board(
65
+ board,
66
+ check=check,
67
+ fill=color_dict,
68
+ size=350,
69
+ arrows=arrows,
70
+ ),
71
+ fig,
72
+ )