From 50aa0ce1cd1e638c62c5ddd6ee602db3b1c23d47 Mon Sep 17 00:00:00 2001 From: Max Pumperla Date: Mon, 27 Aug 2018 10:58:31 +0200 Subject: [PATCH] update code for chapter 13 --- code/dlgo/agent/__init__.py | 2 + code/dlgo/agent/alphago.py | 168 +++++++++++++++++++++++++++ code/dlgo/agent/alphago_test.py | 86 ++++++++++++++ code/dlgo/agent/base.py | 6 +- code/dlgo/agent/naive_fast.py | 1 + code/dlgo/agent/pg.py | 54 +++++---- code/dlgo/agent/predict.py | 28 +++-- code/dlgo/agent/termination.py | 7 +- code/dlgo/encoders/__init__.py | 8 +- code/dlgo/encoders/alphago.py | 143 +++++++++++++++++++++++ code/dlgo/encoders/alphago_test.py | 26 +++++ code/dlgo/encoders/base.py | 4 +- code/dlgo/encoders/betago.py | 2 +- code/dlgo/encoders/oneplane.py | 2 +- code/dlgo/encoders/sevenplane.py | 2 +- code/dlgo/encoders/simple.py | 5 +- code/dlgo/encoders/utils.py | 107 +++++++++++++++++ code/dlgo/networks/alphago.py | 48 ++++++++ code/dlgo/networks/alphago_zero.py | 140 ++++++++++++++++++++++ code/dlgo/networks/fullyconnected.py | 1 - 20 files changed, 792 insertions(+), 48 deletions(-) create mode 100644 code/dlgo/agent/alphago.py create mode 100644 code/dlgo/agent/alphago_test.py create mode 100644 code/dlgo/encoders/alphago.py create mode 100644 code/dlgo/encoders/alphago_test.py create mode 100644 code/dlgo/encoders/utils.py create mode 100644 code/dlgo/networks/alphago.py create mode 100644 code/dlgo/networks/alphago_zero.py diff --git a/code/dlgo/agent/__init__.py b/code/dlgo/agent/__init__.py index 5f9534aa..ca5151f8 100644 --- a/code/dlgo/agent/__init__.py +++ b/code/dlgo/agent/__init__.py @@ -1,5 +1,7 @@ +from .alphago import * from .base import * from .pg import * from .predict import * from .naive import * from .naive_fast import * +from .termination import * diff --git a/code/dlgo/agent/alphago.py b/code/dlgo/agent/alphago.py new file mode 100644 index 00000000..2ed41d2d --- /dev/null +++ b/code/dlgo/agent/alphago.py @@ -0,0 +1,168 @@ +# tag::alphago_imports[] +import numpy as np +from dlgo.agent.base import Agent +from dlgo.goboard_fast import Move +from dlgo import kerasutil +import operator +# end::alphago_imports[] + + +__all__ = [ + 'AlphaGoNode', + 'AlphaGoMCTS' +] + + +# tag::init_alphago_node[] +class AlphaGoNode: + def __init__(self, parent=None, probability=1.0): + self.parent = parent # <1> + self.children = {} # <1> + + self.visit_count = 0 + self.q_value = 0 + self.prior_value = probability # <2> + self.u_value = probability # <3> +# <1> Tree nodes have one parent and potentially many children. +# <2> A node is initialized with a prior probability. +# <3> The utility function will be updated during search. +# end::init_alphago_node[] + +# tag::select_node[] + def select_child(self): + return max(self.children.items(), + key=lambda child: child[1].q_value + \ + child[1].u_value) +# end::select_node[] + +# tag::expand_children[] + def expand_children(self, moves, probabilities): + for move, prob in zip(moves, probabilities): + if move not in self.children: + self.children[move] = AlphaGoNode(probability=prob) +# end::expand_children[] + +# tag::update_values[] + def update_values(self, leaf_value): + if self.parent is not None: + self.parent.update_values(leaf_value) # <1> + + self.visit_count += 1 # <2> + + self.q_value += leaf_value / self.visit_count # <3> + + if self.parent is not None: + c_u = 5 + self.u_value = c_u * np.sqrt(self.parent.visit_count) \ + * self.prior_value / (1 + self.visit_count) # <4> + +# <1> We update parents first to ensure we traverse the tree top to bottom. +# <2> Increment the visit count for this node. +# <3> Add the specified leaf value to the Q-value, normalized by visit count. +# <4> Update utility with current visit counts. +# end::update_values[] + + +# tag::alphago_mcts_init[] +class AlphaGoMCTS(Agent): + def __init__(self, policy_agent, fast_policy_agent, value_agent, + lambda_value=0.5, num_simulations=1000, + depth=50, rollout_limit=100): + self.policy = policy_agent + self.rollout_policy = fast_policy_agent + self.value = value_agent + + self.lambda_value = lambda_value + self.num_simulations = num_simulations + self.depth = depth + self.rollout_limit = rollout_limit + self.root = AlphaGoNode() +# end::alphago_mcts_init[] + +# tag::alphago_mcts_rollout[] + def select_move(self, game_state): + for simulation in range(self.num_simulations): # <1> + current_state = game_state + node = self.root + for depth in range(self.depth): # <2> + if not node.children: # <3> + if current_state.is_over(): + break + moves, probabilities = self.policy_probabilities(current_state) # <4> + node.expand_children(moves, probabilities) # <4> + + move, node = node.select_child() # <5> + current_state = current_state.apply_move(move) # <5> + + value = self.value.predict(current_state) # <6> + rollout = self.policy_rollout(current_state) # <6> + + weighted_value = (1 - self.lambda_value) * value + \ + self.lambda_value * rollout # <7> + + node.update_values(weighted_value) # <8> +# <1> From current state play out a number of simulations +# <2> Play moves until the specified depth is reached. +# <3> If the current node doesn't have any children... +# <4> ... expand them with probabilities from the strong policy. +# <5> If there are children, we can select one and play the corresponding move. +# <6> Compute output of value network and a rollout by the fast policy. +# <7> Determine the combined value function. +# <8> Update values for this node in the backup phase +# end::alphago_mcts_rollout[] + +# tag::alphago_mcts_selection[] + move = max(self.root.children, key=lambda move: # <1> + self.root.children.get(move).visit_count) # <1> + + self.root = AlphaGoNode() + if move in self.root.children: # <2> + self.root = self.root.children[move] + self.root.parent = None + + return move +# <1> Pick most visited child of the root as next move. +# <2> If the picked move is a child, set new root to this child node. +# end::alphago_mcts_selection[] + +# tag::alphago_policy_probs[] + def policy_probabilities(self, game_state): + encoder = self.policy._encoder + outputs = self.policy.predict(game_state) + legal_moves = game_state.legal_moves() + if not legal_moves: + return [], [] + encoded_points = [encoder.encode_point(move.point) for move in legal_moves if move.point] + legal_outputs = outputs[encoded_points] + normalized_outputs = legal_outputs / np.sum(legal_outputs) + return legal_moves, normalized_outputs +# end::alphago_policy_probs[] + +# tag::alphago_policy_rollout[] + def policy_rollout(self, game_state): + for step in range(self.rollout_limit): + if game_state.is_over(): + break + move_probabilities = self.rollout_policy.predict(game_state) + encoder = self.rollout_policy.encoder + valid_moves = [m for idx, m in enumerate(move_probabilities) + if Move(encoder.decode_point_index(idx)) in game_state.legal_moves()] + max_index, max_value = max(enumerate(valid_moves), key=operator.itemgetter(1)) + max_point = encoder.decode_point_index(max_index) + greedy_move = Move(max_point) + if greedy_move in game_state.legal_moves(): + game_state = game_state.apply_move(greedy_move) + + next_player = game_state.next_player + winner = game_state.winner() + if winner is not None: + return 1 if winner == next_player else -1 + else: + return 0 +# end::alphago_policy_rollout[] + + + def serialize(self, h5file): + raise IOError("AlphaGoMCTS agent can\'t be serialized" + + "consider serializing the three underlying" + + "neural networks instad.") diff --git a/code/dlgo/agent/alphago_test.py b/code/dlgo/agent/alphago_test.py new file mode 100644 index 00000000..bc9314e7 --- /dev/null +++ b/code/dlgo/agent/alphago_test.py @@ -0,0 +1,86 @@ +import unittest + +from dlgo.data.processor import GoDataProcessor +from dlgo.agent.predict import DeepLearningAgent +from dlgo.networks.alphago import alphago_model +from dlgo.agent.pg import PolicyAgent +from dlgo.agent.predict import load_prediction_agent +from dlgo.encoders.alphago import AlphaGoEncoder +from dlgo.rl.simulate import experience_simulation +from dlgo.networks.alphago import alphago_model +from dlgo.rl import ValueAgent, load_experience +from dlgo.agent import load_prediction_agent, load_policy_agent, AlphaGoMCTS +from dlgo.rl import load_value_agent +from dlgo.goboard_fast import GameState + +from keras.callbacks import ModelCheckpoint +import h5py +import numpy as np + +class AlphaGoAgentTest(unittest.TestCase): + def test_1_supervised_learning(self): + rows, cols = 19, 19 + encoder = AlphaGoEncoder() + + input_shape = (encoder.num_planes, rows, cols) + alphago_sl_policy = alphago_model(input_shape, is_policy_net=True) + + alphago_sl_policy.compile('sgd', 'categorical_crossentropy', metrics=['accuracy']) + + alphago_sl_agent = DeepLearningAgent(alphago_sl_policy, encoder) + + inputs = np.ones((10,) + input_shape) + outputs = alphago_sl_policy.predict(inputs) + assert(outputs.shape == (10, 361)) + + with h5py.File('test_alphago_sl_policy.h5', 'w') as sl_agent_out: + alphago_sl_agent.serialize(sl_agent_out) + + def test_2_reinforcement_learning(self): + encoder = AlphaGoEncoder() + + sl_agent = load_prediction_agent(h5py.File('test_alphago_sl_policy.h5')) + sl_opponent = load_prediction_agent(h5py.File('test_alphago_sl_policy.h5')) + + alphago_rl_agent = PolicyAgent(sl_agent.model, encoder) + opponent = PolicyAgent(sl_opponent.model, encoder) + + num_games = 1 + experience = experience_simulation(num_games, alphago_rl_agent, opponent) + + alphago_rl_agent.train(experience) + + with h5py.File('test_alphago_rl_policy.h5', 'w') as rl_agent_out: + alphago_rl_agent.serialize(rl_agent_out) + + with h5py.File('test_alphago_rl_experience.h5', 'w') as exp_out: + experience.serialize(exp_out) + + def test_3_alphago_value(self): + rows, cols = 19, 19 + encoder = AlphaGoEncoder() + input_shape = (encoder.num_planes, rows, cols) + alphago_value_network = alphago_model(input_shape) + + alphago_value = ValueAgent(alphago_value_network, encoder) + + experience = load_experience(h5py.File('test_alphago_rl_experience.h5', 'r')) + + alphago_value.train(experience) + + with h5py.File('test_alphago_value.h5', 'w') as value_agent_out: + alphago_value.serialize(value_agent_out) + + def test_4_alphago_mcts(self): + fast_policy = load_prediction_agent(h5py.File('test_alphago_sl_policy.h5', 'r')) + strong_policy = load_policy_agent(h5py.File('test_alphago_rl_policy.h5', 'r')) + value = load_value_agent(h5py.File('test_alphago_value.h5', 'r')) + + alphago = AlphaGoMCTS(strong_policy, fast_policy, value, + num_simulations=20, depth=5, rollout_limit=10) + start = GameState.new_game(19) + alphago.select_move(start) + + +if __name__ == '__main__': + unittest.main() diff --git a/code/dlgo/agent/base.py b/code/dlgo/agent/base.py index 3d8df940..161dedd2 100644 --- a/code/dlgo/agent/base.py +++ b/code/dlgo/agent/base.py @@ -4,8 +4,10 @@ # tag::agent[] -class Agent(): - """Interface for a go-playing bot.""" +class Agent: + def __init__(self): + pass + def select_move(self, game_state): raise NotImplementedError() # end::agent[] diff --git a/code/dlgo/agent/naive_fast.py b/code/dlgo/agent/naive_fast.py index d3a49b09..f0daaa73 100644 --- a/code/dlgo/agent/naive_fast.py +++ b/code/dlgo/agent/naive_fast.py @@ -11,6 +11,7 @@ class FastRandomBot(Agent): def __init__(self): + Agent.__init__(self) self.dim = None self.point_cache = [] diff --git a/code/dlgo/agent/pg.py b/code/dlgo/agent/pg.py index ad6732de..b8f78289 100644 --- a/code/dlgo/agent/pg.py +++ b/code/dlgo/agent/pg.py @@ -12,32 +12,37 @@ __all__ = [ 'PolicyAgent', 'load_policy_agent', + 'policy_gradient_loss', ] +# Keeping this around so we can read existing agents. But from now on +# we'll use the built-in crossentropy loss. +def policy_gradient_loss(y_true, y_pred): + clip_pred = K.clip(y_pred, K.epsilon(), 1 - K.epsilon()) + loss = -1 * y_true * K.log(clip_pred) + return K.mean(K.sum(loss, axis=1)) + + def normalize(x): total = np.sum(x) return x / total -def prepare_experience_data(experience, board_width, board_height): - experience_size = experience.actions.shape[0] - target_vectors = np.zeros((experience_size, board_width * board_height)) - for i in range(experience_size): - action = experience.actions[i] - reward = experience.rewards[i] - target_vectors[i][action] = reward - return target_vectors - - class PolicyAgent(Agent): """An agent that uses a deep policy network to select moves.""" def __init__(self, model, encoder): + Agent.__init__(self) self._model = model self._encoder = encoder self._collector = None self._temperature = 0.0 + def predict(self, game_state): + encoded_state = self._encoder.encode(game_state) + input_tensor = np.array([encoded_state]) + return self._model.predict(input_tensor)[0] + def set_temperature(self, temperature): self._temperature = temperature @@ -48,14 +53,14 @@ def select_move(self, game_state): num_moves = self._encoder.board_width * self._encoder.board_height board_tensor = self._encoder.encode(game_state) - X = np.array([board_tensor]) + x = np.array([board_tensor]) if np.random.random() < self._temperature: # Explore random moves. move_probs = np.ones(num_moves) / num_moves else: # Follow our current policy. - move_probs = self._model.predict(X)[0] + move_probs = self._model.predict(x)[0] # Prevent move probs from getting stuck at 0 or 1. eps = 1e-5 @@ -90,24 +95,29 @@ def serialize(self, h5file): h5file.create_group('model') kerasutil.save_model_to_hdf5_group(self._model, h5file['model']) - def train(self, experience, lr, clipnorm, batch_size): - self._model.compile( - loss='categorical_crossentropy', - optimizer=SGD(lr=lr, clipnorm=clipnorm)) + def train(self, experience, lr=0.0000001, clipnorm=1.0, batch_size=512): + opt = SGD(lr=lr, clipnorm=clipnorm) + self._model.compile(loss='categorical_crossentropy', optimizer=opt) - target_vectors = prepare_experience_data( - experience, - self._encoder.board_width, - self._encoder.board_height) + n = experience.states.shape[0] + # Translate the actions/rewards. + num_moves = self._encoder.board_width * self._encoder.board_height + y = np.zeros((n, num_moves)) + for i in range(n): + action = experience.actions[i] + reward = experience.rewards[i] + y[i][action] = reward self._model.fit( - experience.states, target_vectors, + experience.states, y, batch_size=batch_size, epochs=1) def load_policy_agent(h5file): - model = kerasutil.load_model_from_hdf5_group(h5file['model']) + model = kerasutil.load_model_from_hdf5_group( + h5file['model'], + custom_objects={'policy_gradient_loss': policy_gradient_loss}) encoder_name = h5file['encoder'].attrs['name'] if not isinstance(encoder_name, str): encoder_name = encoder_name.decode('ascii') diff --git a/code/dlgo/agent/predict.py b/code/dlgo/agent/predict.py index 7df5bfb9..99e6a237 100644 --- a/code/dlgo/agent/predict.py +++ b/code/dlgo/agent/predict.py @@ -16,16 +16,20 @@ # tag::dl_agent_init[] class DeepLearningAgent(Agent): def __init__(self, model, encoder): - self._model = model - self._encoder = encoder + Agent.__init__(self) + self.model = model + self.encoder = encoder # end::dl_agent_init[] # tag::dl_agent_predict[] + def predict(self, game_state): + encoded_state = self.encoder.encode(game_state) + input_tensor = np.array([encoded_state]) + return self.model.predict(input_tensor)[0] + def select_move(self, game_state): - num_moves = self._encoder.board_width * self._encoder.board_height - board_tensor = self._encoder.encode(game_state) - X = np.array([board_tensor]) - move_probs = self._model.predict(X)[0] + num_moves = self.encoder.board_width * self.encoder.board_height + move_probs = self.predict(game_state) # end::dl_agent_predict[] # tag::dl_agent_probabilities[] @@ -43,11 +47,11 @@ def select_move(self, game_state): ranked_moves = np.random.choice( candidates, num_moves, replace=False, p=move_probs) # <2> for point_idx in ranked_moves: - point = self._encoder.decode_point_index(point_idx) + point = self.encoder.decode_point_index(point_idx) if game_state.is_valid_move(goboard.Move.play(point)) and \ not is_point_an_eye(game_state.board, point, game_state.next_player): # <3> return goboard.Move.play(point) - return goboard.Move.pass_turn() # <4> No legal, non-self-destructive moves less. + return goboard.Move.pass_turn() # <4> # <1> Turn the probabilities into a ranked list of moves. # <2> Sample potential candidates # <3> Starting from the top, find a valid move that doesn't reduce eye-space. @@ -57,11 +61,11 @@ def select_move(self, game_state): # tag::dl_agent_serialize[] def serialize(self, h5file): h5file.create_group('encoder') - h5file['encoder'].attrs['name'] = self._encoder.name() - h5file['encoder'].attrs['board_width'] = self._encoder.board_width - h5file['encoder'].attrs['board_height'] = self._encoder.board_height + h5file['encoder'].attrs['name'] = self.encoder.name() + h5file['encoder'].attrs['board_width'] = self.encoder.board_width + h5file['encoder'].attrs['board_height'] = self.encoder.board_height h5file.create_group('model') - kerasutil.save_model_to_hdf5_group(self._model, h5file['model']) + kerasutil.save_model_to_hdf5_group(self.model, h5file['model']) # end::dl_agent_serialize[] diff --git a/code/dlgo/agent/termination.py b/code/dlgo/agent/termination.py index 6a2c339e..6328c682 100644 --- a/code/dlgo/agent/termination.py +++ b/code/dlgo/agent/termination.py @@ -6,7 +6,7 @@ # tag::termination_strategy[] -class TerminationStrategy(): +class TerminationStrategy: def __init__(self): pass @@ -32,6 +32,7 @@ def should_pass(self, game_state): class ResignLargeMargin(TerminationStrategy): def __init__(self, own_color, cut_off_move, margin): + TerminationStrategy.__init__(self) self.own_color = own_color self.cut_off_move = cut_off_move self.margin = margin @@ -55,8 +56,10 @@ def should_resign(self, game_state): class TerminationAgent(Agent): def __init__(self, agent, strategy=None): + Agent.__init__(self) self.agent = agent - self.strategy = strategy if strategy is not None else TerminationStrategy() + self.strategy = strategy if strategy is not None \ + else TerminationStrategy() def select_move(self, game_state): if self.strategy.should_pass(game_state): diff --git a/code/dlgo/encoders/__init__.py b/code/dlgo/encoders/__init__.py index eb9a0327..8594146e 100644 --- a/code/dlgo/encoders/__init__.py +++ b/code/dlgo/encoders/__init__.py @@ -1,2 +1,6 @@ -from .base import * -from . import simple +from dlgo.encoders.base import * +#from dlgo.encoders.alphago import * +from dlgo.encoders.betago import * +from dlgo.encoders.oneplane import * +from dlgo.encoders.sevenplane import * +from dlgo.encoders.simple import * \ No newline at end of file diff --git a/code/dlgo/encoders/alphago.py b/code/dlgo/encoders/alphago.py new file mode 100644 index 00000000..a646d358 --- /dev/null +++ b/code/dlgo/encoders/alphago.py @@ -0,0 +1,143 @@ +from dlgo.encoders.base import Encoder +from dlgo.encoders.utils import is_ladder_escape, is_ladder_capture +from dlgo.gotypes import Point, Player +from dlgo.goboard_fast import Move +from dlgo.agent.helpers_fast import is_point_an_eye +import numpy as np + +""" +Feature name num of planes Description +Stone colour 3 Player stone / opponent stone / empty +Ones 1 A constant plane filled with 1 +Zeros 1 A constant plane filled with 0 +Sensibleness 1 Whether a move is legal and does not fill its own eyes +Turns since 8 How many turns since a move was played +Liberties 8 Number of liberties (empty adjacent points) +Liberties after move 8 Number of liberties after this move is played +Capture size 8 How many opponent stones would be captured +Self-atari size 8 How many of own stones would be captured +Ladder capture 1 Whether a move at this point is a successful ladder capture +Ladder escape 1 Whether a move at this point is a successful ladder escape +""" + +FEATURE_OFFSETS = { + "stone_color": 0, + "ones": 3, + "zeros": 4, + "sensibleness": 5, + "turns_since": 6, + "liberties": 14, + "liberties_after": 22, + "capture_size": 30, + "self_atari_size": 38, + "ladder_capture": 46, + "ladder_escape": 47, + "current_player_color": 48 +} + + +def offset(feature): + return FEATURE_OFFSETS[feature] + + +class AlphaGoEncoder(Encoder): + def __init__(self, board_size=(19, 19), use_player_plane=True): + self.board_width, self.board_height = board_size + self.use_player_plane = use_player_plane + self.num_planes = 48 + use_player_plane + + def name(self): + return 'alphago' + + def encode(self, game_state): + board_tensor = np.zeros((self.num_planes, self.board_height, self.board_width)) + for r in range(self.board_height): + for c in range(self.board_width): + point = Point(row=r + 1, col=c + 1) + + go_string = game_state.board.get_go_string(point) + if go_string and go_string.color == game_state.next_player: + board_tensor[offset("stone_color")][r][c] = 1 + elif go_string and go_string.color == game_state.next_player.other: + board_tensor[offset("stone_color") + 1][r][c] = 1 + else: + board_tensor[offset("stone_color") + 2][r][c] = 1 + + board_tensor[offset("ones")] = self.ones() + board_tensor[offset("zeros")] = self.zeros() + + if not is_point_an_eye(game_state.board, point, game_state.next_player): + board_tensor[offset("sensibleness")][r][c] = 1 + + ages = min(game_state.board.move_ages.get(r, c), 8) + if ages > 0: + print(ages) + board_tensor[offset("turns_since") + ages][r][c] = 1 + + if game_state.board.get_go_string(point): + liberties = min(game_state.board.get_go_string(point).num_liberties, 8) + board_tensor[offset("liberties") + liberties][r][c] = 1 + + move = Move(point) + if game_state.is_valid_move(move): + new_state = game_state.apply_move(move) + liberties = min(new_state.board.get_go_string(point).num_liberties, 8) + board_tensor[offset("liberties_after") + liberties][r][c] = 1 + + adjacent_strings = [game_state.board.get_go_string(nb) + for nb in point.neighbors()] + capture_count = 0 + for go_string in adjacent_strings: + other_player = game_state.next_player.other + if go_string and go_string.num_liberties == 1 and go_string.color == other_player: + capture_count += len(go_string.stones) + capture_count = min(capture_count, 8) + board_tensor[offset("capture_size") + capture_count][r][c] = 1 + + if go_string and go_string.num_liberties == 1: + go_string = game_state.board.get_go_string(point) + if go_string: + num_atari_stones = min(len(go_string.stones), 8) + board_tensor[offset("self_atari_size") + num_atari_stones][r][c] = 1 + + if is_ladder_capture(game_state, point): + board_tensor[offset("ladder_capture")][r][c] = 1 + + if is_ladder_escape(game_state, point): + board_tensor[offset("ladder_escape")][r][c] = 1 + + if self.use_player_plane: + if game_state.next_player == Player.black: + board_tensor[offset("ones")] = self.ones() + else: + board_tensor[offset("zeros")] = self.zeros() + + return board_tensor + + def ones(self): + return np.ones((1, self.board_height, self.board_width)) + + + def zeros(self): + return np.zeros((1, self.board_height, self.board_width)) + + def capture_size(self, game_state, num_planes=8): + pass + + def encode_point(self, point): + return self.board_width * (point.row - 1) + (point.col - 1) + + def decode_point_index(self, index): + row = index // self.board_width + col = index % self.board_width + return Point(row=row + 1, col=col + 1) + + def num_points(self): + return self.board_width * self.board_height + + def shape(self): + return self.num_planes, self.board_height, self.board_width + + +def create(board_size): + return AlphaGoEncoder(board_size) diff --git a/code/dlgo/encoders/alphago_test.py b/code/dlgo/encoders/alphago_test.py new file mode 100644 index 00000000..866a0209 --- /dev/null +++ b/code/dlgo/encoders/alphago_test.py @@ -0,0 +1,26 @@ +import unittest + +from dlgo.agent.helpers import is_point_an_eye +from dlgo.goboard_fast import Board, GameState, Move +from dlgo.gotypes import Player, Point +from dlgo.encoders.alphago import AlphaGoEncoder + + +class AlphaGoEncoderTest(unittest.TestCase): + def test_encoder(self): + alphago = AlphaGoEncoder() + + start = GameState.new_game(19) + next_state = start.apply_move(Move.play(Point(16, 16))) + alphago.encode(next_state) + + self.assertEquals(alphago.name(), 'alphago') + self.assertEquals(alphago.board_height, 19) + self.assertEquals(alphago.board_width, 19) + self.assertEquals(alphago.num_planes, 49) + self.assertEquals(alphago.shape(), (49, 19, 19)) + + + +if __name__ == '__main__': + unittest.main() diff --git a/code/dlgo/encoders/base.py b/code/dlgo/encoders/base.py index b0bdb60e..582f8ec1 100644 --- a/code/dlgo/encoders/base.py +++ b/code/dlgo/encoders/base.py @@ -9,7 +9,7 @@ # tag::base_encoder[] -class Encoder(): +class Encoder: def name(self): # <1> raise NotImplementedError() @@ -28,7 +28,7 @@ def num_points(self): # <5> def shape(self): # <6> raise NotImplementedError() -# <1> Loading an encoder by name. +# <1> Lets us support logging or saving the name of the encoder our model is using. # <2> Turn a Go board into a numeric data. # <3> Turn a Go board point into an integer index. # <4> Turn an integer index back into a Go board point. diff --git a/code/dlgo/encoders/betago.py b/code/dlgo/encoders/betago.py index 2724ab69..2f9107d5 100644 --- a/code/dlgo/encoders/betago.py +++ b/code/dlgo/encoders/betago.py @@ -57,7 +57,7 @@ def num_points(self): return self.board_width * self.board_height def shape(self): - return (self.num_planes, self.board_height, self.board_width) + return self.num_planes, self.board_height, self.board_width def create(board_size): diff --git a/code/dlgo/encoders/oneplane.py b/code/dlgo/encoders/oneplane.py index 25281d3a..d36148f6 100644 --- a/code/dlgo/encoders/oneplane.py +++ b/code/dlgo/encoders/oneplane.py @@ -47,7 +47,7 @@ def num_points(self): return self.board_width * self.board_height def shape(self): - return (self.num_planes, self.board_height, self.board_width) + return self.num_planes, self.board_height, self.board_width # <1> Turn a board point into an integer index. # <2> Turn an integer index into a board point. diff --git a/code/dlgo/encoders/sevenplane.py b/code/dlgo/encoders/sevenplane.py index 34819680..1a8b149d 100644 --- a/code/dlgo/encoders/sevenplane.py +++ b/code/dlgo/encoders/sevenplane.py @@ -49,7 +49,7 @@ def num_points(self): return self.board_width * self.board_height def shape(self): - return (self.num_planes, self.board_height, self.board_width) + return self.num_planes, self.board_height, self.board_width def create(board_size): diff --git a/code/dlgo/encoders/simple.py b/code/dlgo/encoders/simple.py index 15790031..7b36e8e8 100644 --- a/code/dlgo/encoders/simple.py +++ b/code/dlgo/encoders/simple.py @@ -1,7 +1,8 @@ import numpy as np from dlgo.encoders.base import Encoder -from dlgo.goboard import Move, Player, Point +from dlgo.goboard import Move +from dlgo.gotypes import Player, Point class SimpleEncoder(Encoder): @@ -59,7 +60,7 @@ def num_points(self): return self.board_width * self.board_height def shape(self): - return (self.num_planes, self.board_height, self.board_width) + return self.num_planes, self.board_height, self.board_width def create(board_size): diff --git a/code/dlgo/encoders/utils.py b/code/dlgo/encoders/utils.py new file mode 100644 index 00000000..4d08c6db --- /dev/null +++ b/code/dlgo/encoders/utils.py @@ -0,0 +1,107 @@ +from dlgo.goboard import Move + + +def is_ladder_capture(game_state, candidate, recursion_depth=50): + return is_ladder(True, game_state, candidate, None, recursion_depth) + + +def is_ladder_escape(game_state, candidate, recursion_depth=50): + return is_ladder(False, game_state, candidate, None, recursion_depth) + + +def is_ladder(try_capture, game_state, candidate, + ladder_stones=None, recursion_depth=50): + """Ladders are played out in reversed roles, one player tries to capture, + the other to escape. We determine the ladder status by recursively calling + is_ladder in opposite roles, providing suitable capture or escape candidates. + + Arguments: + try_capture: boolean flag to indicate if you want to capture or escape the ladder + game_state: current game state, instance of GameState + candidate: a move that potentially leads to escaping the ladder or capturing it, instance of Move + ladder_stones: the stones to escape or capture, list of Point. Will be inferred if not provided. + recursion_depth: when to stop recursively calling this function, integer valued. + + Returns True if game state is a ladder and try_capture is true (the ladder captures) + or if game state is not a ladder and try_capture is false (you can successfully escape) + and False otherwise. + """ + + if not game_state.is_valid_move(Move(candidate)) or not recursion_depth: + return False + + next_player = game_state.next_player + capture_player = next_player if try_capture else next_player.other + escape_player = capture_player.other + + if ladder_stones is None: + ladder_stones = guess_ladder_stones(game_state, candidate, escape_player) + + for ladder_stone in ladder_stones: + current_state = game_state.apply_move(candidate) + + if try_capture: + candidates = determine_escape_candidates( + game_state, ladder_stone, capture_player) + attempted_escapes = [ # now try to escape + is_ladder(False, current_state, escape_candidate, + ladder_stone, recursion_depth - 1) + for escape_candidate in candidates] + + if not any(attempted_escapes): + return True # if at least one escape fails, we capture + else: + if count_liberties(current_state, ladder_stone) >= 3: + return True # successful escape + if count_liberties(current_state, ladder_stone) == 1: + continue # failed escape, others might still do + candidates = liberties(current_state, ladder_stone) + attempted_captures = [ # now try to capture + is_ladder(True, current_state, capture_candidate, + ladder_stone, recursion_depth - 1) + for capture_candidate in candidates] + if any(attempted_captures): + continue # failed escape, try others + return True # candidate can't be caught in a ladder, escape. + return False # no captures / no escapes + + +def is_candidate(game_state, move, player): + return game_state.next_player == player and \ + count_liberties(game_state, move) == 2 + + +def guess_ladder_stones(game_state, move, escape_player): + adjacent_strings = [game_state.board.get_go_string(nb) for nb in move.neighbors() if game_state.board.get_go_string(nb)] + if adjacent_strings: + string = adjacent_strings[0] + neighbors = [] + for string in adjacent_strings: + stones = string.stones + for stone in stones: + neighbors.append(stone) + return [Move(nb) for nb in neighbors if is_candidate(game_state, Move(nb), escape_player)] + else: + return [] + + +def determine_escape_candidates(game_state, move, capture_player): + escape_candidates = move.neighbors() + for other_ladder_stone in game_state.board.get_go_string(move).stones: + for neighbor in other_ladder_stone.neighbors(): + right_color = game_state.color(neighbor) == capture_player + one_liberty = count_liberties(game_state, neighbor) == 1 + if right_color and one_liberty: + escape_candidates.append(liberties(game_state, neighbor)) + return escape_candidates + + +def count_liberties(game_state, move): + if game_state.board.get_go_string(move): + return game_state.board.get_go_string(move).num_liberties + else: + return 0 + + +def liberties(game_state, move): + return list(game_state.board.get_go_string(move).liberties) diff --git a/code/dlgo/networks/alphago.py b/code/dlgo/networks/alphago.py new file mode 100644 index 00000000..7e3971c6 --- /dev/null +++ b/code/dlgo/networks/alphago.py @@ -0,0 +1,48 @@ +# tag::alphago_base[] +from keras.models import Sequential +from keras.layers.core import Dense, Flatten +from keras.layers.convolutional import Conv2D + + +def alphago_model(input_shape, is_policy_net=False, # <1> + num_filters=192, # <2> + first_kernel_size=5, + other_kernel_size=3): # <3> + + model = Sequential() + model.add( + Conv2D(num_filters, first_kernel_size, input_shape=input_shape, padding='same', + data_format='channels_first', activation='relu')) + + for i in range(2, 12): # <4> + model.add( + Conv2D(num_filters, other_kernel_size, padding='same', + data_format='channels_first', activation='relu')) +# <1> With this boolean flag you specify if you want a policy or value network +# <2> All but the last convolutional layers have the same number of filters +# <3> The first layer has kernel size 5, all others only 3. +# <4> The first 12 layers of AlphaGo's policy and value network are identical. +# end::alphago_base[] + +# tag::alphago_policy[] + if is_policy_net: + model.add( + Conv2D(filters=1, kernel_size=1, padding='same', + data_format='channels_first', activation='softmax')) + model.add(Flatten()) + return model +# end::alphago_policy[] + +# tag::alphago_value[] + else: + model.add( + Conv2D(num_filters, other_kernel_size, padding='same', + data_format='channels_first', activation='relu')) + model.add( + Conv2D(filters=1, kernel_size=1, padding='same', + data_format='channels_first', activation='relu')) + model.add(Flatten()) + model.add(Dense(256, activation='relu')) + model.add(Dense(1, activation='tanh')) + return model +# end::alphago_value[] diff --git a/code/dlgo/networks/alphago_zero.py b/code/dlgo/networks/alphago_zero.py new file mode 100644 index 00000000..73a4c56a --- /dev/null +++ b/code/dlgo/networks/alphago_zero.py @@ -0,0 +1,140 @@ +from keras.layers import * +from keras.models import Model + + +'''The dual residual architecture is the strongest +of the architectures tested by DeepMind for AlphaGo +Zero. It consists of an initial convolutional block, +followed by a number (40 for the strongest, 20 as +baseline) of residual blocks. The network is topped +off by two "heads", one to predict policies and one +for value functions. +''' +def dual_residual_network(input_shape, blocks=20): + inputs = Input(shape=input_shape) + first_conv = conv_bn_relu_block(name="init")(inputs) + res_tower = residual_tower(blocks=blocks)(first_conv) + policy = policy_head()(res_tower) + value = value_head()(res_tower) + return Model(inputs=inputs, outputs=[policy, value]) + + +'''The dual convolutional architecture replaces residual +blocks from the dual residual architecture with batch-normalized +convolution layers. The default block size is 12. +''' +def dual_conv_network(input_shape, blocks=12): + inputs = Input(shape=input_shape) + first_conv = conv_bn_relu_block(name="init")(inputs) + conv_tower = convolutional_tower(blocks=blocks)(first_conv) + policy = policy_head()(conv_tower) + value = value_head()(conv_tower) + return Model(inputs=inputs, outputs=[policy, value]) + + +''' In the separate residual architecture policy and value +head don't share a common "tail", i.e. there's two sets of +residual blocks for policy and value networks, respectively. +''' +def separate_residual_network(input_shape, blocks=20): + inputs_pol = Input(shape=input_shape) + first_conv_pol = conv_bn_relu_block(name="init")(inputs_pol) + res_tower_pol = residual_tower(blocks=blocks)(first_conv_pol) + policy = policy_head()(res_tower_pol) + policy_model = Model(inputs=inputs_pol, outputs=policy) + + inputs_val = Input(shape=input_shape) + first_conv_val = conv_bn_relu_block(name="init")(inputs_val) + res_tower_val = residual_tower(blocks=blocks)(first_conv_val) + value = value_head()(res_tower_val) + value_model = Model(inputs=inputs_val, outputs=value) + + return policy_model, value_model + + +'''The separate convolutional network is structurally identical +to the separate residual network, except that residual blocks +are replaced by convolutional blocks. +''' +def separate_conv_network(input_shape, blocks=20): + inputs_pol = Input(shape=input_shape) + first_conv_pol = conv_bn_relu_block(name="init")(inputs_pol) + conv_tower_pol = convolutional_tower(blocks=blocks)(first_conv_pol) + policy = policy_head()(conv_tower_pol) + policy_model = Model(inputs=inputs_pol, outputs=policy) + + inputs_val = Input(shape=input_shape) + first_conv_val = conv_bn_relu_block(name="init")(inputs_val) + conv_tower_val = convolutional_tower(blocks=blocks)(first_conv_val) + value = value_head()(conv_tower_val) + value_model = Model(inputs=inputs_val, outputs=value) + + return policy_model, value_model + + +def conv_bn_relu_block(name, activation=True, filters=256, kernel_size=(3,3), + strides=(1,1), padding="same", init="he_normal"): + def f(inputs): + conv = Conv2D(filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + kernel_initializer=init, + data_format='channels_first', + name="{}_conv_block".format(name))(inputs) + batch_norm = BatchNormalization(axis=1, name="{}_batch_norm".format(name))(conv) + return Activation("relu", name="{}_relu".format(name))(batch_norm) if activation else batch_norm + return f + + +def residual_block(block_num, **args): + def f(inputs): + res = conv_bn_relu_block(name="residual_1_{}".format(block_num), activation=True, **args)(inputs) + res = conv_bn_relu_block(name="residual_2_{}".format(block_num) , activation=False, **args)(res) + res = add([inputs, res], name="add_{}".format(block_num)) + return Activation("relu", name="{}_relu".format(block_num))(res) + return f + + +def residual_tower(blocks, **args): + def f(inputs): + x = inputs + for i in range(blocks): + x = residual_block(block_num=i)(x) + return x + return f + +def convolutional_tower(blocks, **args): + def f(inputs): + x = inputs + for i in range(blocks): + x = conv_bn_relu_block(name=i)(x) + return x + return f + + +def policy_head(): + def f(inputs): + conv = Conv2D(filters=2, + kernel_size=(3, 3), + strides=(1, 1), + padding="same", + name="policy_head_conv_block")(inputs) + batch_norm = BatchNormalization(axis=1, name="policy_head_batch_norm")(conv) + activation = Activation("relu", name="policy_head_relu")(batch_norm) + return Dense(units= 19*19 +1, name="policy_head_dense")(activation) + return f + + +def value_head(): + def f(inputs): + conv = Conv2D(filters=1, + kernel_size=(1, 1), + strides=(1, 1), + padding="same", + name="value_head_conv_block")(inputs) + batch_norm = BatchNormalization(axis=1, name="value_head_batch_norm")(conv) + activation = Activation("relu", name="value_head_relu")(batch_norm) + dense = Dense(units= 256, name="value_head_dense", activation="relu")(activation) + return Dense(units= 1, name="value_head_output", activation="tanh")(dense) + return f \ No newline at end of file diff --git a/code/dlgo/networks/fullyconnected.py b/code/dlgo/networks/fullyconnected.py index 596fdf29..8a711878 100644 --- a/code/dlgo/networks/fullyconnected.py +++ b/code/dlgo/networks/fullyconnected.py @@ -1,6 +1,5 @@ from __future__ import absolute_import from keras.layers.core import Dense, Activation, Flatten -from keras.layers.convolutional import Conv2D, ZeroPadding2D def layers(input_shape):