File size: 2,154 Bytes
5df657d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
using Random
using StatsBase

"""
	simluate_rollout(b::Board, policy, side [rng=MersenneTwister(420))

Simulate one rollout of a simulation based on the given `Chess.board` state.

Policy is a function, given the board and `MoveList`, returns an `AbstractArray` of probability
weights for each `Move` in `Move`List` based on index.
"""
function simulate_rollout(b::Board, policy, side; rng = MersenneTwister(420))::Tuple{Board, Int64}
	#pprint(b) # Debugging
	movelist = MoveList(200)
	num_sim_moves = 0
	while !isterminal(b) # TODO Use `matein1` possibly to trim leaf nodes in sims?
		moves(b, movelist)
		policy_weights = ProbabilityWeights(policy(b, movelist))
		#pprint(b)
		#println(movelist, policy_weights)
		domove!(b, sample(movelist, policy_weights))
		recycle!(movelist)
		num_sim_moves += 1
	end
	return b, num_sim_moves
end

"""
	CESPF(b::Board, movelist::MoveList)
Utilizes `Chess.jl`'s `see()` function to simulate (C)apture / (E)scape (S)tronger (P)iece
(F)irst heuristic in simulation/rollout policy. We use Chess weights set in `see` function to get weight for which
move we prefer to take.
"""
function CESPF(b::Board, movelist::MoveList)
	unnorm_policy_weights = map(x -> see(b, x), movelist)
	# Center raw centipawn values to 1 to then normalize
	centered_policy_weights = (1 + abs(min(unnorm_policy_weights...))) .+
					unnorm_policy_weights
	return centered_policy_weights / sum(centered_policy_weights)
end

"""
	CESPF_greedy(b::Board, movelist::MoveList)
Utilizes `Chess.jl`'s `see()` function to simulate (C)apture / (E)scape (S)tronger (P)iece
(F)irst heuristic in simulation/rollout policy. We use Chess weights set in `see` function to get weight for which
move we prefer to take. This is greedy, and will set only the maximal valued policies to a non-zero 
probability
"""
function CESPF_greedy(b::Board, movelist::MoveList)
	unnorm_policy_weights = map(x -> see(b, x), movelist)
	policy_weights = zeros(length(unnorm_policy_weights))
	max_idxs = findall(unnorm_policy_weights .== maximum(unnorm_policy_weights))
	for max_idx in max_idxs
		policy_weights[max_idx] = 1.0 / length(max_idxs)
	end
	return policy_weights
end