-
-
Notifications
You must be signed in to change notification settings - Fork 392
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e008cc9
commit 50aa0ce
Showing
20 changed files
with
792 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
from .alphago import * | ||
from .base import * | ||
from .pg import * | ||
from .predict import * | ||
from .naive import * | ||
from .naive_fast import * | ||
from .termination import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
# tag::alphago_imports[] | ||
import numpy as np | ||
from dlgo.agent.base import Agent | ||
from dlgo.goboard_fast import Move | ||
from dlgo import kerasutil | ||
import operator | ||
# end::alphago_imports[] | ||
|
||
|
||
__all__ = [ | ||
'AlphaGoNode', | ||
'AlphaGoMCTS' | ||
] | ||
|
||
|
||
# tag::init_alphago_node[] | ||
class AlphaGoNode: | ||
def __init__(self, parent=None, probability=1.0): | ||
self.parent = parent # <1> | ||
self.children = {} # <1> | ||
|
||
self.visit_count = 0 | ||
self.q_value = 0 | ||
self.prior_value = probability # <2> | ||
self.u_value = probability # <3> | ||
# <1> Tree nodes have one parent and potentially many children. | ||
# <2> A node is initialized with a prior probability. | ||
# <3> The utility function will be updated during search. | ||
# end::init_alphago_node[] | ||
|
||
# tag::select_node[] | ||
def select_child(self): | ||
return max(self.children.items(), | ||
key=lambda child: child[1].q_value + \ | ||
child[1].u_value) | ||
# end::select_node[] | ||
|
||
# tag::expand_children[] | ||
def expand_children(self, moves, probabilities): | ||
for move, prob in zip(moves, probabilities): | ||
if move not in self.children: | ||
self.children[move] = AlphaGoNode(probability=prob) | ||
# end::expand_children[] | ||
|
||
# tag::update_values[] | ||
def update_values(self, leaf_value): | ||
if self.parent is not None: | ||
self.parent.update_values(leaf_value) # <1> | ||
|
||
self.visit_count += 1 # <2> | ||
|
||
self.q_value += leaf_value / self.visit_count # <3> | ||
|
||
if self.parent is not None: | ||
c_u = 5 | ||
self.u_value = c_u * np.sqrt(self.parent.visit_count) \ | ||
* self.prior_value / (1 + self.visit_count) # <4> | ||
|
||
# <1> We update parents first to ensure we traverse the tree top to bottom. | ||
# <2> Increment the visit count for this node. | ||
# <3> Add the specified leaf value to the Q-value, normalized by visit count. | ||
# <4> Update utility with current visit counts. | ||
# end::update_values[] | ||
|
||
|
||
# tag::alphago_mcts_init[] | ||
class AlphaGoMCTS(Agent): | ||
def __init__(self, policy_agent, fast_policy_agent, value_agent, | ||
lambda_value=0.5, num_simulations=1000, | ||
depth=50, rollout_limit=100): | ||
self.policy = policy_agent | ||
self.rollout_policy = fast_policy_agent | ||
self.value = value_agent | ||
|
||
self.lambda_value = lambda_value | ||
self.num_simulations = num_simulations | ||
self.depth = depth | ||
self.rollout_limit = rollout_limit | ||
self.root = AlphaGoNode() | ||
# end::alphago_mcts_init[] | ||
|
||
# tag::alphago_mcts_rollout[] | ||
def select_move(self, game_state): | ||
for simulation in range(self.num_simulations): # <1> | ||
current_state = game_state | ||
node = self.root | ||
for depth in range(self.depth): # <2> | ||
if not node.children: # <3> | ||
if current_state.is_over(): | ||
break | ||
moves, probabilities = self.policy_probabilities(current_state) # <4> | ||
node.expand_children(moves, probabilities) # <4> | ||
|
||
move, node = node.select_child() # <5> | ||
current_state = current_state.apply_move(move) # <5> | ||
|
||
value = self.value.predict(current_state) # <6> | ||
rollout = self.policy_rollout(current_state) # <6> | ||
|
||
weighted_value = (1 - self.lambda_value) * value + \ | ||
self.lambda_value * rollout # <7> | ||
|
||
node.update_values(weighted_value) # <8> | ||
# <1> From current state play out a number of simulations | ||
# <2> Play moves until the specified depth is reached. | ||
# <3> If the current node doesn't have any children... | ||
# <4> ... expand them with probabilities from the strong policy. | ||
# <5> If there are children, we can select one and play the corresponding move. | ||
# <6> Compute output of value network and a rollout by the fast policy. | ||
# <7> Determine the combined value function. | ||
# <8> Update values for this node in the backup phase | ||
# end::alphago_mcts_rollout[] | ||
|
||
# tag::alphago_mcts_selection[] | ||
move = max(self.root.children, key=lambda move: # <1> | ||
self.root.children.get(move).visit_count) # <1> | ||
|
||
self.root = AlphaGoNode() | ||
if move in self.root.children: # <2> | ||
self.root = self.root.children[move] | ||
self.root.parent = None | ||
|
||
return move | ||
# <1> Pick most visited child of the root as next move. | ||
# <2> If the picked move is a child, set new root to this child node. | ||
# end::alphago_mcts_selection[] | ||
|
||
# tag::alphago_policy_probs[] | ||
def policy_probabilities(self, game_state): | ||
encoder = self.policy._encoder | ||
outputs = self.policy.predict(game_state) | ||
legal_moves = game_state.legal_moves() | ||
if not legal_moves: | ||
return [], [] | ||
encoded_points = [encoder.encode_point(move.point) for move in legal_moves if move.point] | ||
legal_outputs = outputs[encoded_points] | ||
normalized_outputs = legal_outputs / np.sum(legal_outputs) | ||
return legal_moves, normalized_outputs | ||
# end::alphago_policy_probs[] | ||
|
||
# tag::alphago_policy_rollout[] | ||
def policy_rollout(self, game_state): | ||
for step in range(self.rollout_limit): | ||
if game_state.is_over(): | ||
break | ||
move_probabilities = self.rollout_policy.predict(game_state) | ||
encoder = self.rollout_policy.encoder | ||
valid_moves = [m for idx, m in enumerate(move_probabilities) | ||
if Move(encoder.decode_point_index(idx)) in game_state.legal_moves()] | ||
max_index, max_value = max(enumerate(valid_moves), key=operator.itemgetter(1)) | ||
max_point = encoder.decode_point_index(max_index) | ||
greedy_move = Move(max_point) | ||
if greedy_move in game_state.legal_moves(): | ||
game_state = game_state.apply_move(greedy_move) | ||
|
||
next_player = game_state.next_player | ||
winner = game_state.winner() | ||
if winner is not None: | ||
return 1 if winner == next_player else -1 | ||
else: | ||
return 0 | ||
# end::alphago_policy_rollout[] | ||
|
||
|
||
def serialize(self, h5file): | ||
raise IOError("AlphaGoMCTS agent can\'t be serialized" + | ||
"consider serializing the three underlying" + | ||
"neural networks instad.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import unittest | ||
|
||
from dlgo.data.processor import GoDataProcessor | ||
from dlgo.agent.predict import DeepLearningAgent | ||
from dlgo.networks.alphago import alphago_model | ||
from dlgo.agent.pg import PolicyAgent | ||
from dlgo.agent.predict import load_prediction_agent | ||
from dlgo.encoders.alphago import AlphaGoEncoder | ||
from dlgo.rl.simulate import experience_simulation | ||
from dlgo.networks.alphago import alphago_model | ||
from dlgo.rl import ValueAgent, load_experience | ||
from dlgo.agent import load_prediction_agent, load_policy_agent, AlphaGoMCTS | ||
from dlgo.rl import load_value_agent | ||
from dlgo.goboard_fast import GameState | ||
|
||
from keras.callbacks import ModelCheckpoint | ||
import h5py | ||
import numpy as np | ||
|
||
class AlphaGoAgentTest(unittest.TestCase): | ||
def test_1_supervised_learning(self): | ||
rows, cols = 19, 19 | ||
encoder = AlphaGoEncoder() | ||
|
||
input_shape = (encoder.num_planes, rows, cols) | ||
alphago_sl_policy = alphago_model(input_shape, is_policy_net=True) | ||
|
||
alphago_sl_policy.compile('sgd', 'categorical_crossentropy', metrics=['accuracy']) | ||
|
||
alphago_sl_agent = DeepLearningAgent(alphago_sl_policy, encoder) | ||
|
||
inputs = np.ones((10,) + input_shape) | ||
outputs = alphago_sl_policy.predict(inputs) | ||
assert(outputs.shape == (10, 361)) | ||
|
||
with h5py.File('test_alphago_sl_policy.h5', 'w') as sl_agent_out: | ||
alphago_sl_agent.serialize(sl_agent_out) | ||
|
||
def test_2_reinforcement_learning(self): | ||
encoder = AlphaGoEncoder() | ||
|
||
sl_agent = load_prediction_agent(h5py.File('test_alphago_sl_policy.h5')) | ||
sl_opponent = load_prediction_agent(h5py.File('test_alphago_sl_policy.h5')) | ||
|
||
alphago_rl_agent = PolicyAgent(sl_agent.model, encoder) | ||
opponent = PolicyAgent(sl_opponent.model, encoder) | ||
|
||
num_games = 1 | ||
experience = experience_simulation(num_games, alphago_rl_agent, opponent) | ||
|
||
alphago_rl_agent.train(experience) | ||
|
||
with h5py.File('test_alphago_rl_policy.h5', 'w') as rl_agent_out: | ||
alphago_rl_agent.serialize(rl_agent_out) | ||
|
||
with h5py.File('test_alphago_rl_experience.h5', 'w') as exp_out: | ||
experience.serialize(exp_out) | ||
|
||
def test_3_alphago_value(self): | ||
rows, cols = 19, 19 | ||
encoder = AlphaGoEncoder() | ||
input_shape = (encoder.num_planes, rows, cols) | ||
alphago_value_network = alphago_model(input_shape) | ||
|
||
alphago_value = ValueAgent(alphago_value_network, encoder) | ||
|
||
experience = load_experience(h5py.File('test_alphago_rl_experience.h5', 'r')) | ||
|
||
alphago_value.train(experience) | ||
|
||
with h5py.File('test_alphago_value.h5', 'w') as value_agent_out: | ||
alphago_value.serialize(value_agent_out) | ||
|
||
def test_4_alphago_mcts(self): | ||
fast_policy = load_prediction_agent(h5py.File('test_alphago_sl_policy.h5', 'r')) | ||
strong_policy = load_policy_agent(h5py.File('test_alphago_rl_policy.h5', 'r')) | ||
value = load_value_agent(h5py.File('test_alphago_value.h5', 'r')) | ||
|
||
alphago = AlphaGoMCTS(strong_policy, fast_policy, value, | ||
num_simulations=20, depth=5, rollout_limit=10) | ||
start = GameState.new_game(19) | ||
alphago.select_move(start) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.