Spaces:
Runtime error
Runtime error
import os | |
import chess | |
import chess.pgn | |
# credit code https://github.com/notnil/chess-gpt | |
def get_legal_moves(board): | |
"""Returns a list of legal moves in UCI notation.""" | |
return list(map(board.san, board.legal_moves)) | |
def init_game() -> tuple[chess.pgn.Game, chess.Board]: | |
"""Initializes a new game.""" | |
board = chess.Board() | |
game = chess.pgn.Game() | |
game.headers["White"] = "User" | |
game.headers["Black"] = "Chess-engine" | |
del game.headers["Event"] | |
del game.headers["Date"] | |
del game.headers["Site"] | |
del game.headers["Round"] | |
del game.headers["Result"] | |
game.setup(board) | |
return game, board | |
def generate_prompt(game: chess.pgn.Game, board: chess.Board) -> str: | |
moves = get_legal_moves(board) | |
moves_str = ",".join(moves) | |
return f""" | |
The task is play a Chess game: | |
You are the Chess-engine playing a chess match against the user as black and trying to win. | |
The current FEN notation is: | |
{board.fen()} | |
The next valid moves are: | |
{moves_str} | |
Continue the game. | |
{str(game)[:-2]}""" | |
def get_move(content, moves): | |
lines = content.splitlines() | |
for line in lines: | |
for lm in moves: | |
if lm in line: | |
return lm | |
class ChessGame: | |
def __init__(self, docschatllm): | |
self.docschatllm = docschatllm | |
def start_game(self): | |
self.game, self.board = init_game() | |
self.game_cp, _ = init_game() | |
self.node = self.game | |
self.node_copy = self.game_cp | |
svg_board = chess.svg.board(self.board, size=350) | |
return svg_board, "Valid moves: "+",".join(get_legal_moves(self.board)) # display(self.board) | |
def user_move(self, move_input): | |
try: | |
self.board.push_san(move_input) | |
except ValueError: | |
print("Invalid move") | |
svg_board = chess.svg.board(self.board, size=350) | |
return svg_board, "Valid moves: "+",".join(get_legal_moves(self.board)), 'Invalid move' | |
self.node = self.node.add_variation(self.board.move_stack[-1]) | |
self.node_copy = self.node_copy.add_variation(self.board.move_stack[-1]) | |
if self.board.is_game_over(): | |
svg_board = chess.svg.board(self.board, size=350) | |
return svg_board, ",".join(get_legal_moves(self.board)), 'GAME OVER' | |
prompt = generate_prompt(self.game, self.board) | |
print("Prompt: \n"+prompt) | |
print("#############") | |
for i in range(10): #tries | |
if i == 9: | |
svg_board = chess.svg.board(self.board, size=350) | |
return svg_board, ",".join(get_legal_moves(self.board)), "The model can't do a valid move. Restart the game" | |
try: | |
"""Returns the move from the prompt.""" | |
content = self.docschatllm.llm.predict(prompt) ### from selected model ### | |
#print(moves) | |
print("Response: \n"+content) | |
print("#############") | |
moves = get_legal_moves(self.board) | |
move = get_move(content, moves) | |
print(move) | |
print("#############") | |
self.board.push_san(move) | |
break | |
except: | |
prompt = prompt[1:] | |
print("attempt a move.") | |
self.node = self.node.add_variation(self.board.move_stack[-1]) | |
self.node_copy = self.node_copy.add_variation(self.board.move_stack[-1]) | |
svg_board = chess.svg.board(self.board, size=350) | |
return svg_board, "Valid moves: "+",".join(get_legal_moves(self.board)), '' | |