Spaces:
Runtime error
Runtime error
import os | |
import math | |
import random | |
import numpy as np | |
import gradio as gr | |
from transformers import AutoTokenizer | |
from rwkv import RWKV | |
# Define the Node class for MCTS | |
class Node: | |
def __init__(self, state, parent=None): | |
self.state = state | |
self.parent = parent | |
self.children = [] | |
self.visits = 0 | |
self.wins = 0 | |
def is_fully_expanded(self): | |
return len(self.children) > 0 | |
def best_child(self, c_param=1.4): | |
choices_weights = [ | |
(child.wins / child.visits) + c_param * (2 * math.log(self.visits) / child.visits) ** 0.5 for child in self.children | |
] | |
return self.children[np.argmax(choices_weights)] | |
def expand(self, state): | |
new_node = Node(state, self) | |
self.children.append(new_node) | |
return new_node | |
# Define the MCTS class | |
class MCTS: | |
def __init__(self, simulation_limit=1000): | |
self.root = None | |
self.simulation_limit = simulation_limit | |
def search(self, initial_state): | |
self.root = Node(initial_state) | |
for _ in range(self.simulation_limit): | |
node = self.tree_policy(self.root) | |
reward = self.default_policy(node.state) | |
self.backpropagate(node, reward) | |
return self.root.best_child(c_param=0).state | |
def tree_policy(self, node): | |
while not node.state.is_terminal(): | |
if not node.is_fully_expanded(): | |
return self.expand(node) | |
else: | |
node = node.best_child() | |
return node | |
def expand(self, node): | |
tried_states = [child.state for child in node.children] | |
new_state = node.state.get_random_child_state() | |
while new_state in tried_states: | |
new_state = node.state.get_random_child_state() | |
return node.expand(new_state) | |
def default_policy(self, state): | |
while not state.is_terminal(): | |
state = state.get_random_child_state() | |
return state.get_reward() | |
def backpropagate(self, node, reward): | |
while node is not None: | |
node.visits += 1 | |
node.wins += reward | |
node = node.parent | |
# Define the Game State and Rules | |
class GameState: | |
def __init__(self, board, player): | |
self.board = board | |
self.player = player | |
def is_terminal(self): | |
return self.check_win() or self.check_draw() | |
def check_win(self): | |
for row in self.board: | |
if row.count(row[0]) == len(row) and row[0] != 0: | |
return True | |
for col in range(len(self.board)): | |
if self.board[0][col] == self.board[1][col] == self.board[2][col] and self.board[0][col] != 0: | |
return True | |
if self.board[0][0] == self.board[1][1] == self.board[2][2] and self.board[0][0] != 0: | |
return True | |
if self.board[0][2] == self.board[1][1] == self.board[2][0] and self.board[0][2] != 0: | |
return True | |
return False | |
def check_draw(self): | |
return all(self.board[row][col] != 0 for row in range(len(self.board)) for col in range(len(self.board))) | |
def get_random_child_state(self): | |
available_moves = [(row, col) for row in range(len(self.board)) for col in range(len(self.board)) if self.board[row][col] == 0] | |
if not available_moves: | |
return self | |
row, col = random.choice(available_moves) | |
new_board = [row.copy() for row in self.board] | |
new_board[row][col] = self.player | |
return GameState(new_board, 3 - self.player) | |
def get_reward(self): | |
if self.check_win(): | |
return 1 if self.player == 1 else -1 | |
return 0 | |
def __str__(self): | |
return "\n".join(" ".join(str(cell) for cell in row) for row in self.board) | |
# Initialize the RWKV model and tokenizer | |
model_name = "BlinkDL/rwkv-4-raven" | |
tokenizer = AutoTokenizer.from_pretrained("gpt2") # Use a tokenizer from a supported model | |
# Load the RWKV model | |
model = RWKV(model=model_name, strategy="cuda fp16") | |
# Generate Chain-of-Thought | |
def generate_cot(state): | |
input_text = f"Current state: {state}\nWhat is the best move?" | |
inputs = tokenizer(input_text, return_tensors="pt") | |
outputs = model.generate(inputs.input_ids, max_length=100, num_return_sequences=1) | |
cot = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return cot | |
# Use CoT in MCTS | |
def mcts_with_cot(initial_state): | |
mcts = MCTS(simulation_limit=1000) | |
best_state = mcts.search(initial_state) | |
cot = generate_cot(best_state) | |
return best_state, cot | |
# Function to be called by Gradio | |
def run_mcts_cot(initial_board): | |
initial_state = GameState(initial_board, 1) | |
best_state, cot = mcts_with_cot(initial_state) | |
return str(best_state), cot |