Xmaster6y commited on
Commit
045d76c
1 Parent(s): f4a47cd

play against AI

Browse files
app.py CHANGED
@@ -2,22 +2,30 @@
2
  Main Gradio module.
3
  """
4
 
 
 
5
  import gradio as gr
6
 
7
  from src import (
8
  call_interface,
 
 
9
  )
10
 
11
 
12
  demo = gr.TabbedInterface(
13
  [
 
14
  call_interface.interface,
15
  ],
16
  [
 
17
  "Call",
18
  ],
19
  title="GPT-2 Stockfish Debug",
20
  analytics_enabled=False,
21
  )
22
 
 
 
23
  demo.launch()
 
2
  Main Gradio module.
3
  """
4
 
5
+ import wandb
6
+
7
  import gradio as gr
8
 
9
  from src import (
10
  call_interface,
11
+ play_interface,
12
+ constants,
13
  )
14
 
15
 
16
  demo = gr.TabbedInterface(
17
  [
18
+ play_interface.interface,
19
  call_interface.interface,
20
  ],
21
  [
22
+ "Play",
23
  "Call",
24
  ],
25
  title="GPT-2 Stockfish Debug",
26
  analytics_enabled=False,
27
  )
28
 
29
+ wandb.login(key=constants.WANDB_API_KEY)
30
+
31
  demo.launch()
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python-chess
src/call_interface.py CHANGED
@@ -6,7 +6,6 @@ import huggingface_hub
6
  import gradio as gr
7
 
8
 
9
-
10
  model_name = "yp-edu/gpt2-stockfish-debug"
11
 
12
  headers = {"X-Wait-For-Model": "true"}
 
6
  import gradio as gr
7
 
8
 
 
9
  model_name = "yp-edu/gpt2-stockfish-debug"
10
 
11
  headers = {"X-Wait-For-Model": "true"}
src/constants.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """Constants for the app.
2
+ """
3
+
4
+ import os
5
+
6
+ FIGURE_DIRECTORY = os.environ.get("FIGURE_DIRECTORY", "./")
7
+ WANDB_API_KEY = os.environ.get("WANDB_API_KEY", "")
src/play_interface.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Interface to play against the model.
2
+ """
3
+
4
+ import huggingface_hub
5
+ import chess
6
+ import uuid
7
+ import random
8
+ import wandb
9
+
10
+ import gradio as gr
11
+
12
+ from . import constants
13
+
14
+ model_name = "yp-edu/gpt2-stockfish-debug"
15
+ headers = {"X-Wait-For-Model": "true"}
16
+ client = huggingface_hub.InferenceClient(
17
+ model=model_name, headers=headers
18
+ )
19
+ inference_fn = client.text_generation
20
+
21
+
22
+ def plot_board(
23
+ board: chess.Board,
24
+ ):
25
+ try:
26
+ last_move = board.peek()
27
+ arrows = [(last_move.from_square, last_move.to_square)]
28
+ except IndexError:
29
+ arrows = []
30
+ if board.is_check():
31
+ check = board.king(board.turn)
32
+ else:
33
+ check = None
34
+ svg_board = chess.svg.board(
35
+ board,
36
+ check=check,
37
+ size=350,
38
+ arrows=arrows,
39
+ )
40
+ id = str(uuid.uuid4())
41
+ with open(f"{constants.FIGURE_DIRECTORY}/board_{id}.svg", "w") as f:
42
+ f.write(svg_board)
43
+ return f"{constants.FIGURE_DIRECTORY}/board_{id}.svg"
44
+
45
+ def render_board(
46
+ current_board: chess.Board,
47
+ ):
48
+ fen = current_board.fen()
49
+ pgn = current_board.variation_san(current_board.move_stack)
50
+ image_board = plot_board(current_board)
51
+ return fen, pgn, image_board
52
+
53
+ def play_user_move(
54
+ uci_move: str,
55
+ current_board: chess.Board,
56
+ ):
57
+ current_board.push_uci(uci_move)
58
+ return current_board
59
+
60
+ def play_ai_move(
61
+ current_board: chess.Board,
62
+ temperature: float = 0.1,
63
+ top_k: int = 3,
64
+ ):
65
+ uci_move = inference_fn(
66
+ inputs=f"FEN: {current_board.fen()}\nMOVE:",
67
+ temperature=temperature,
68
+ top_k=top_k,
69
+ )
70
+ current_board.push_uci(uci_move)
71
+ return current_board
72
+
73
+ def try_play_move(
74
+ username: str,
75
+ move_to_play: str,
76
+ current_board: chess.Board,
77
+ ):
78
+ if current_board.is_game_over():
79
+ gr.Warning("The game is already over")
80
+ return render_board(current_board)
81
+ try:
82
+ current_board = play_user_move(move_to_play, current_board)
83
+ if current_board.is_game_over():
84
+ gr.Info(f"Congratulations, {username}!")
85
+ with wandb.init(project="gpt2-stockfish-debug", entity="yp-edu") as run:
86
+ run.log(
87
+ {
88
+ "username": username,
89
+ "winin": current_board.fullmove_number,
90
+ "pgn": current_board.variation_san(current_board.move_stack),
91
+ }
92
+ )
93
+ run.finish()
94
+ return render_board(current_board)
95
+ except:
96
+ gr.Warning("Invalid move")
97
+ return render_board(current_board)
98
+ temperature_retries = [
99
+ (i+1)/10 for i in range(10)
100
+ ]
101
+ for temperature in temperature_retries:
102
+ try:
103
+ current_board = play_ai_move(current_board, temperature=temperature)
104
+ break
105
+ except:
106
+ gr.Warning(f"AI move failed with temperature {temperature}")
107
+ else:
108
+ gr.Warning("AI move failed with all temperatures")
109
+ current_board.pop()
110
+ return render_board(current_board)
111
+
112
+
113
+ with gr.Blocks() as interface:
114
+ with gr.Row():
115
+ current_fen = gr.Textbox(
116
+ label="Board FEN",
117
+ lines=1,
118
+ max_lines=1,
119
+ value=chess.STARTING_FEN,
120
+ )
121
+ current_pgn = gr.Textbox(
122
+ label="Action sequence",
123
+ lines=1,
124
+ value="",
125
+ )
126
+ with gr.Row():
127
+ username = gr.Textbox(
128
+ label="Username to record on leaderboard (should you win)",
129
+ lines=1,
130
+ max_lines=1,
131
+ value="",
132
+ )
133
+ with gr.Row():
134
+ with gr.Column():
135
+ with gr.Row():
136
+ move_to_play = gr.Textbox(
137
+ label="Move to play (UCI)",
138
+ lines=1,
139
+ max_lines=1,
140
+ value="",
141
+ )
142
+ play_button = gr.Button("Play")
143
+
144
+ with gr.Column():
145
+ image_board = gr.Image(label="Board")
146
+
147
+ static_inputs = [
148
+ username,
149
+ move_to_play,
150
+ ]
151
+ static_outputs = [
152
+ current_fen,
153
+ current_pgn,
154
+ image_board,
155
+ ]
156
+ is_ai_white = random.choice([True, False])
157
+ init_board = chess.Board()
158
+ if is_ai_white:
159
+ init_board = play_ai_move(init_board)
160
+ state_board = gr.State(value=init_board)
161
+ play_button.click(
162
+ try_play_move,
163
+ inputs=[*static_inputs, state_board],
164
+ outputs=[*static_outputs, gr.State()],
165
+ )
166
+ interface.load(render_board, inputs=[state_board], outputs=[*static_outputs, gr.State()])