rwkv-mcts-cot / main.py
tenet's picture
Update main.py
c0828ba verified
raw
history blame
4.76 kB
import os
import math
import random
import numpy as np
import gradio as gr
from transformers import AutoTokenizer
from rwkv import RWKV
# Define the Node class for MCTS
class Node:
def __init__(self, state, parent=None):
self.state = state
self.parent = parent
self.children = []
self.visits = 0
self.wins = 0
def is_fully_expanded(self):
return len(self.children) > 0
def best_child(self, c_param=1.4):
choices_weights = [
(child.wins / child.visits) + c_param * (2 * math.log(self.visits) / child.visits) ** 0.5 for child in self.children
]
return self.children[np.argmax(choices_weights)]
def expand(self, state):
new_node = Node(state, self)
self.children.append(new_node)
return new_node
# Define the MCTS class
class MCTS:
def __init__(self, simulation_limit=1000):
self.root = None
self.simulation_limit = simulation_limit
def search(self, initial_state):
self.root = Node(initial_state)
for _ in range(self.simulation_limit):
node = self.tree_policy(self.root)
reward = self.default_policy(node.state)
self.backpropagate(node, reward)
return self.root.best_child(c_param=0).state
def tree_policy(self, node):
while not node.state.is_terminal():
if not node.is_fully_expanded():
return self.expand(node)
else:
node = node.best_child()
return node
def expand(self, node):
tried_states = [child.state for child in node.children]
new_state = node.state.get_random_child_state()
while new_state in tried_states:
new_state = node.state.get_random_child_state()
return node.expand(new_state)
def default_policy(self, state):
while not state.is_terminal():
state = state.get_random_child_state()
return state.get_reward()
def backpropagate(self, node, reward):
while node is not None:
node.visits += 1
node.wins += reward
node = node.parent
# Define the Game State and Rules
class GameState:
def __init__(self, board, player):
self.board = board
self.player = player
def is_terminal(self):
return self.check_win() or self.check_draw()
def check_win(self):
for row in self.board:
if row.count(row[0]) == len(row) and row[0] != 0:
return True
for col in range(len(self.board)):
if self.board[0][col] == self.board[1][col] == self.board[2][col] and self.board[0][col] != 0:
return True
if self.board[0][0] == self.board[1][1] == self.board[2][2] and self.board[0][0] != 0:
return True
if self.board[0][2] == self.board[1][1] == self.board[2][0] and self.board[0][2] != 0:
return True
return False
def check_draw(self):
return all(self.board[row][col] != 0 for row in range(len(self.board)) for col in range(len(self.board)))
def get_random_child_state(self):
available_moves = [(row, col) for row in range(len(self.board)) for col in range(len(self.board)) if self.board[row][col] == 0]
if not available_moves:
return self
row, col = random.choice(available_moves)
new_board = [row.copy() for row in self.board]
new_board[row][col] = self.player
return GameState(new_board, 3 - self.player)
def get_reward(self):
if self.check_win():
return 1 if self.player == 1 else -1
return 0
def __str__(self):
return "\n".join(" ".join(str(cell) for cell in row) for row in self.board)
# Initialize the RWKV model and tokenizer
model_name = "BlinkDL/rwkv-4-raven"
tokenizer = AutoTokenizer.from_pretrained("gpt2") # Use a tokenizer from a supported model
# Load the RWKV model
model = RWKV(model=model_name, strategy="cuda fp16")
# Generate Chain-of-Thought
def generate_cot(state):
input_text = f"Current state: {state}\nWhat is the best move?"
inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.generate(inputs.input_ids, max_length=100, num_return_sequences=1)
cot = tokenizer.decode(outputs[0], skip_special_tokens=True)
return cot
# Use CoT in MCTS
def mcts_with_cot(initial_state):
mcts = MCTS(simulation_limit=1000)
best_state = mcts.search(initial_state)
cot = generate_cot(best_state)
return best_state, cot
# Function to be called by Gradio
def run_mcts_cot(initial_board):
initial_state = GameState(initial_board, 1)
best_state, cot = mcts_with_cot(initial_state)
return str(best_state), cot